diff --git a/src/liger_kernel/chunked_loss/__init__.py b/src/liger_kernel/chunked_loss/__init__.py index d3624adbb..85653cd48 100644 --- a/src/liger_kernel/chunked_loss/__init__.py +++ b/src/liger_kernel/chunked_loss/__init__.py @@ -2,6 +2,7 @@ from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401 from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401 from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401 +from liger_kernel.chunked_loss.grpo_loss import DapoConfig # noqa: F401 from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDLoss # noqa: F401 from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401 from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401 diff --git a/src/liger_kernel/chunked_loss/fused_linear_ppo.py b/src/liger_kernel/chunked_loss/fused_linear_ppo.py index 503af93a0..0a0f1333a 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_ppo.py +++ b/src/liger_kernel/chunked_loss/fused_linear_ppo.py @@ -29,6 +29,7 @@ def forward( ref_input=None, ref_weight=None, ref_bias=None, + sampling_ratio=None, epsilon_low=0.2, epsilon_high=0.2, beta=0.04, @@ -39,6 +40,7 @@ def forward( compiled=True, use_ref_model=False, chunk_size=1, + dapo_config=None, ): # TODO: check torch compile matmul """Chunked forward pass for PPO loss computation. @@ -97,6 +99,7 @@ def forward( temperature=temperature, use_ref_model=use_ref_model, ppo_loss_fn=cls.ppo_loss_fn, + dapo_config=dapo_config, ) def fused_fwd_bwd( @@ -107,6 +110,7 @@ def fused_fwd_bwd( ref_per_token_logps_chunk, old_per_token_logps_chunk, ref_input_chunk, + sampling_ratio_chunk, ): """Fused forward and backward for a chunk.""" argnums = (0, 1, 5) if bias is not None else (0, 1) @@ -120,6 +124,7 @@ def fused_fwd_bwd( ref_per_token_logps_chunk=ref_per_token_logps_chunk, # arg 6 old_per_token_logps_chunk=old_per_token_logps_chunk, # arg 7 ref_input_chunk=ref_input_chunk, # arg 8 + sampling_ratio_chunk=sampling_ratio_chunk, # arg 9 ) def accumulate_chunk( @@ -130,6 +135,7 @@ def accumulate_chunk( ref_per_token_logps_chunk=None, old_per_token_logps_chunk=None, ref_input_chunk=None, + sampling_ratio_chunk=None, ): (chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd( input_chunk, @@ -139,6 +145,7 @@ def accumulate_chunk( ref_per_token_logps_chunk, old_per_token_logps_chunk, ref_input_chunk, + sampling_ratio_chunk, ) if bias is not None: grad_bias.add_(chunk_grad_bias[0]) @@ -189,6 +196,11 @@ def accumulate_chunk( if use_ref_model and ref_per_token_logps is None else [None] * chunks ) + _sampling_ratio_chunks = ( + torch.chunk(sampling_ratio, chunks=chunks, dim=0) + if sampling_ratio is not None + else [None] * chunks + ) for ( input_chunk, @@ -198,6 +210,7 @@ def accumulate_chunk( ref_per_token_logps_chunk, old_per_token_logps_chunk, ref_input_chunk, + sampling_ratio_chunk, ) in zip( _input_chunks, _selected_token_ids_chunks, @@ -206,6 +219,7 @@ def accumulate_chunk( _ref_per_token_logps_chunks, _old_per_token_logps_chunks, _ref_input_chunks, + _sampling_ratio_chunks, ): # Mark dynamic dimensions torch._dynamo.mark_dynamic(input_chunk, 1) @@ -217,6 +231,8 @@ def accumulate_chunk( torch._dynamo.mark_dynamic(ref_input_chunk, 1) if old_per_token_logps_chunk is not None: torch._dynamo.mark_dynamic(old_per_token_logps_chunk, 1) + if sampling_ratio_chunk is not None: + torch._dynamo.mark_dynamic(sampling_ratio_chunk, 1) accumulate_chunk( input_chunk, @@ -226,6 +242,7 @@ def accumulate_chunk( ref_per_token_logps_chunk, old_per_token_logps_chunk, ref_input_chunk, + sampling_ratio_chunk, ) # Combine gradients @@ -257,6 +274,7 @@ def _compute_chunk_loss( ref_input_chunk=None, ref_weight=None, ref_bias=None, + sampling_ratio_chunk=None, full_attention_mask=None, epsilon_low=0.2, epsilon_high=0.2, @@ -267,6 +285,7 @@ def _compute_chunk_loss( temperature=1.0, use_ref_model=False, ppo_loss_fn=None, + dapo_config=None, ): """Compute loss for a single chunk.""" # Get policy log probabilities using chunk_forward @@ -290,12 +309,14 @@ def _compute_chunk_loss( ref_per_token_logps=ref_per_token_logps_chunk.float() if ref_per_token_logps_chunk is not None else None, old_per_token_logps=old_per_token_logps_chunk.float() if old_per_token_logps_chunk is not None else None, ref_log_probs=ref_log_probs, # used when ref_per_token_logps is None + sampling_ratio=sampling_ratio_chunk.float() if sampling_ratio_chunk is not None else None, epsilon_low=epsilon_low, epsilon_high=epsilon_high, beta=beta, loss_type=loss_type, max_completion_length=max_completion_length, importance_sampling_level=importance_sampling_level, + dapo_config=dapo_config ) return chunk_loss, chunk_metrics @@ -338,6 +359,7 @@ def backward(ctx, grad_output, *grad_metrics): None, # grad_ref_input None, # grad_ref_weight None, # grad_ref_bias + None, # grad_sampling_ratio None, # grad_epsilon_low None, # grad_epsilon_high None, # grad_beta @@ -347,4 +369,5 @@ def backward(ctx, grad_output, *grad_metrics): None, # grad_chunk_size None, # grad_loss_type None, # grad_max_completion_length + None, # grad_dapo_config ) diff --git a/src/liger_kernel/chunked_loss/grpo_loss.py b/src/liger_kernel/chunked_loss/grpo_loss.py index 1dbee9abf..5337761fd 100644 --- a/src/liger_kernel/chunked_loss/grpo_loss.py +++ b/src/liger_kernel/chunked_loss/grpo_loss.py @@ -3,7 +3,17 @@ import torch from liger_kernel.chunked_loss.fused_linear_ppo import LigerFusedLinearPPOBase +from dataclasses import dataclass +# For dapo compute_loss semantics check below url +# https://github.com/MotifTechnologies/trl/blob/5e512d71e0f642ea5ac0d901cec364d3a3d55c08/trl/trainer/dapo_trainer.py#L1813 + +@dataclass +class DapoConfig: + normalizer: float = None + entropy_adv_alpha: float = None + entropy_adv_kappa: float = None + entropy_coeff: float = None def k3_loss_fn(log_p, log_q): # computes k3 estimate of KL[q, p] @@ -15,6 +25,51 @@ def clip_coef_fn(coef, epsilon_low, epsilon_high): return torch.clamp(coef, 1 - epsilon_low, 1 + epsilon_high) +def get_grpo_loss( + per_token_loss, + attention_mask, + full_attention_mask, + loss_type="bnpo", + max_completion_length=None, + dapo_config=None, +): + """ + Normalize per-token loss based on the loss type. + + Args: + per_token_loss: Per-token loss tensor. Shape: (batch_size, seq_len) + attention_mask: Attention mask tensor. Shape: (batch_size, seq_len) + full_attention_mask: Full attention mask tensor. Shape: (batch_size, seq_len) + loss_type: Type of loss calculation ("grpo", "bnpo", "dr_grpo", "dapo"). Defaults to "bnpo". + max_completion_length: Maximum completion length, required for "dr_grpo". Defaults to None. + dapo_config: DapoConfig instance, required for "dapo". Defaults to None. + + Returns: + torch.Tensor: Normalized loss scalar. + """ + if loss_type == "grpo": + # Average per-sequence loss + loss = ( + (per_token_loss * attention_mask).sum(-1) / torch.clamp(attention_mask.sum(-1), min=1.0) + ).sum() / full_attention_mask.shape[0] + elif loss_type == "bnpo": + # Batch Normalized Per-token loss (original implementation) + loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0) + elif loss_type == "dr_grpo": + # Dimension-Reduced GRPO (normalize by batch_size * max_completion_length) + if max_completion_length is None: + raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'") + loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length) + elif loss_type == "dapo": + norm = getattr(dapo_config, "normalizer", None) + if norm is None: + raise ValueError("DapoConfig and normalizer must be provided for loss_type 'dapo'") + loss = (per_token_loss * attention_mask).sum() / norm + else: + raise ValueError(f"Unknown loss type: {loss_type}") + return loss + + class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase): @staticmethod def ppo_loss_fn( @@ -26,15 +81,18 @@ def ppo_loss_fn( ref_per_token_logps=None, # shape: [chunk_size, seq_len] old_per_token_logps=None, ref_log_probs=None, # used when ref_per_token_logps is None (shape: [chunk_size, seq_len, vocab_size]) + sampling_ratio=None, epsilon_low=0.2, epsilon_high=0.2, beta=0.04, - loss_type="bnpo", # ["grpo", "bnpo", "dr_grpo"] + loss_type="bnpo", # ["grpo", "bnpo", "dr_grpo", "dapo"] max_completion_length=None, # Required for dr_grpo importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO - **kwargs, + dapo_config=None, ): """GRPO Loss Function matching GRPOTrainer implementation.""" + with torch.no_grad(): + entropies = -(log_probs.exp() * log_probs).sum(dim=-1) per_token_logps = log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze( -1 ) # (batch_size, seq_len) @@ -49,6 +107,16 @@ def ppo_loss_fn( else: ref_per_token_logps = per_token_logps.detach() + alpha = getattr(dapo_config, "entropy_adv_alpha", None) + kappa = getattr(dapo_config, "entropy_adv_kappa", None) + if alpha is not None and kappa is not None: + entropy_adv_alpha = alpha + entropy_adv_kappa = kappa + entropy_term = entropy_adv_alpha * entropies.detach() + adv_kappa_term = advantages.abs() / entropy_adv_kappa + entropy_term = torch.min(entropy_term, adv_kappa_term) + advantages = advantages + entropy_term + # Compute policy gradient loss with importance sampling ratio old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach() log_ratio = per_token_logps - old_per_token_logps @@ -77,25 +145,41 @@ def ppo_loss_fn( # Combine losses per_token_loss = per_token_loss + beta * kl_div + if sampling_ratio is not None: + per_token_loss = per_token_loss * sampling_ratio + # Note: We normalize by the number of tokens in the batch (using full_attention_mask), # which is consistent with the DAPO loss implementation (https://arxiv.org/html/2503.14476v1) # and TRL GRPO implementation # (https://github.com/huggingface/trl/blob/e751a16df56e70190fb94bed4a2035eec3303777/trl/trainer/grpo_trainer.py#L966) - if loss_type == "grpo": - # Average per-sequence loss - loss = ( - (per_token_loss * attention_mask).sum(-1) / torch.clamp(attention_mask.sum(-1), min=1.0) - ).sum() / full_attention_mask.shape[0] - elif loss_type == "bnpo": - # Batch Normalized Per-token loss (original implementation) - loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0) - elif loss_type == "dr_grpo": - # Dimension-Reduced GRPO (normalize by batch_size * max_completion_length) - if max_completion_length is None: - raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'") - loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length) - else: - raise ValueError(f"Unknown loss type: {loss_type}") + loss = get_grpo_loss( + per_token_loss=per_token_loss, + attention_mask=attention_mask, + full_attention_mask=full_attention_mask, + loss_type=loss_type, + max_completion_length=max_completion_length, + dapo_config=dapo_config, + ) + completion_token_count = full_attention_mask.sum().clamp(min=1.0) + + def masked_batch_mean(x): + if x.shape[1] == 1: # when importance_sampling_level == "sequence" + return x.mean() + else: + return (x * attention_mask).sum() / completion_token_count + mean_entropy = masked_batch_mean(entropies) + + entropy_coeff = getattr(dapo_config, "entropy_coeff", None) + if entropy_coeff is not None: + entropy_loss = get_grpo_loss( + per_token_loss=entropies, + attention_mask=attention_mask, + full_attention_mask=full_attention_mask, + loss_type=loss_type, + max_completion_length=max_completion_length, + dapo_config=dapo_config, + ) + loss = loss - (entropy_coeff * entropy_loss) # Calculate metrics metrics = [] @@ -114,7 +198,11 @@ def ppo_loss_fn( ) is_clipped = is_clipped.unsqueeze(1).expand_as(attention_mask) + if entropy_coeff is not None: + metrics.append(entropy_loss) + metrics.append((is_clipped * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)) + metrics.append(mean_entropy) return loss, metrics @classmethod @@ -132,6 +220,7 @@ def forward( ref_input=None, ref_weight=None, ref_bias=None, + sampling_ratio=None, beta=0.04, epsilon_low=0.2, epsilon_high=0.2, @@ -142,6 +231,7 @@ def forward( compiled=True, use_ref_model=True, chunk_size=1, + dapo_config=None, ): """ Fused linear layer with GRPO loss. @@ -181,6 +271,7 @@ def forward( ref_input=ref_input, ref_weight=ref_weight, ref_bias=ref_bias, + sampling_ratio=sampling_ratio, beta=beta, epsilon_low=epsilon_low, epsilon_high=epsilon_high, @@ -191,6 +282,7 @@ def forward( use_ref_model=use_ref_model, chunk_size=chunk_size, importance_sampling_level=importance_sampling_level, + dapo_config=dapo_config, ) @staticmethod @@ -211,6 +303,7 @@ def backward(ctx, grad_output, *grad_metrics): None, # grad_ref_input None, # grad_ref_weight None, # grad_ref_bias + None, # grad_sampling_ratio None, # grad_beta None, # grad_epsilon_low None, # grad_epsilon_high @@ -221,6 +314,7 @@ def backward(ctx, grad_output, *grad_metrics): None, # grad_compiled None, # grad_use_ref_model None, # grad_chunk_size + None, # grad_dapo_config ) @@ -278,6 +372,8 @@ def forward( ref_input=None, ref_weight=None, ref_bias=None, + sampling_ratio=None, + dapo_config=None, ): return LigerFusedLinearGRPOFunction.apply( _input, @@ -291,6 +387,7 @@ def forward( ref_input, ref_weight, ref_bias, + sampling_ratio, self.beta, self.epsilon_low, self.epsilon_high, @@ -301,4 +398,5 @@ def forward( self.compiled, self.use_ref_model, self.chunk_size, + dapo_config, )