Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/liger_kernel/chunked_loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss # noqa: F401
from liger_kernel.chunked_loss.grpo_loss import DapoConfig # noqa: F401
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDLoss # noqa: F401
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
Expand Down
23 changes: 23 additions & 0 deletions src/liger_kernel/chunked_loss/fused_linear_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def forward(
ref_input=None,
ref_weight=None,
ref_bias=None,
sampling_ratio=None,
epsilon_low=0.2,
epsilon_high=0.2,
beta=0.04,
Expand All @@ -39,6 +40,7 @@ def forward(
compiled=True,
use_ref_model=False,
chunk_size=1,
dapo_config=None,
):
# TODO: check torch compile matmul
"""Chunked forward pass for PPO loss computation.
Expand Down Expand Up @@ -97,6 +99,7 @@ def forward(
temperature=temperature,
use_ref_model=use_ref_model,
ppo_loss_fn=cls.ppo_loss_fn,
dapo_config=dapo_config,
)

def fused_fwd_bwd(
Expand All @@ -107,6 +110,7 @@ def fused_fwd_bwd(
ref_per_token_logps_chunk,
old_per_token_logps_chunk,
ref_input_chunk,
sampling_ratio_chunk,
):
"""Fused forward and backward for a chunk."""
argnums = (0, 1, 5) if bias is not None else (0, 1)
Expand All @@ -120,6 +124,7 @@ def fused_fwd_bwd(
ref_per_token_logps_chunk=ref_per_token_logps_chunk, # arg 6
old_per_token_logps_chunk=old_per_token_logps_chunk, # arg 7
ref_input_chunk=ref_input_chunk, # arg 8
sampling_ratio_chunk=sampling_ratio_chunk, # arg 9
)

def accumulate_chunk(
Expand All @@ -130,6 +135,7 @@ def accumulate_chunk(
ref_per_token_logps_chunk=None,
old_per_token_logps_chunk=None,
ref_input_chunk=None,
sampling_ratio_chunk=None,
):
(chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
input_chunk,
Expand All @@ -139,6 +145,7 @@ def accumulate_chunk(
ref_per_token_logps_chunk,
old_per_token_logps_chunk,
ref_input_chunk,
sampling_ratio_chunk,
)
if bias is not None:
grad_bias.add_(chunk_grad_bias[0])
Expand Down Expand Up @@ -189,6 +196,11 @@ def accumulate_chunk(
if use_ref_model and ref_per_token_logps is None
else [None] * chunks
)
_sampling_ratio_chunks = (
torch.chunk(sampling_ratio, chunks=chunks, dim=0)
if sampling_ratio is not None
else [None] * chunks
)

for (
input_chunk,
Expand All @@ -198,6 +210,7 @@ def accumulate_chunk(
ref_per_token_logps_chunk,
old_per_token_logps_chunk,
ref_input_chunk,
sampling_ratio_chunk,
) in zip(
_input_chunks,
_selected_token_ids_chunks,
Expand All @@ -206,6 +219,7 @@ def accumulate_chunk(
_ref_per_token_logps_chunks,
_old_per_token_logps_chunks,
_ref_input_chunks,
_sampling_ratio_chunks,
):
# Mark dynamic dimensions
torch._dynamo.mark_dynamic(input_chunk, 1)
Expand All @@ -217,6 +231,8 @@ def accumulate_chunk(
torch._dynamo.mark_dynamic(ref_input_chunk, 1)
if old_per_token_logps_chunk is not None:
torch._dynamo.mark_dynamic(old_per_token_logps_chunk, 1)
if sampling_ratio_chunk is not None:
torch._dynamo.mark_dynamic(sampling_ratio_chunk, 1)

accumulate_chunk(
input_chunk,
Expand All @@ -226,6 +242,7 @@ def accumulate_chunk(
ref_per_token_logps_chunk,
old_per_token_logps_chunk,
ref_input_chunk,
sampling_ratio_chunk,
)

# Combine gradients
Expand Down Expand Up @@ -257,6 +274,7 @@ def _compute_chunk_loss(
ref_input_chunk=None,
ref_weight=None,
ref_bias=None,
sampling_ratio_chunk=None,
full_attention_mask=None,
epsilon_low=0.2,
epsilon_high=0.2,
Expand All @@ -267,6 +285,7 @@ def _compute_chunk_loss(
temperature=1.0,
use_ref_model=False,
ppo_loss_fn=None,
dapo_config=None,
):
"""Compute loss for a single chunk."""
# Get policy log probabilities using chunk_forward
Expand All @@ -290,12 +309,14 @@ def _compute_chunk_loss(
ref_per_token_logps=ref_per_token_logps_chunk.float() if ref_per_token_logps_chunk is not None else None,
old_per_token_logps=old_per_token_logps_chunk.float() if old_per_token_logps_chunk is not None else None,
ref_log_probs=ref_log_probs, # used when ref_per_token_logps is None
sampling_ratio=sampling_ratio_chunk.float() if sampling_ratio_chunk is not None else None,
epsilon_low=epsilon_low,
epsilon_high=epsilon_high,
beta=beta,
loss_type=loss_type,
max_completion_length=max_completion_length,
importance_sampling_level=importance_sampling_level,
dapo_config=dapo_config
)

return chunk_loss, chunk_metrics
Expand Down Expand Up @@ -338,6 +359,7 @@ def backward(ctx, grad_output, *grad_metrics):
None, # grad_ref_input
None, # grad_ref_weight
None, # grad_ref_bias
None, # grad_sampling_ratio
None, # grad_epsilon_low
None, # grad_epsilon_high
None, # grad_beta
Expand All @@ -347,4 +369,5 @@ def backward(ctx, grad_output, *grad_metrics):
None, # grad_chunk_size
None, # grad_loss_type
None, # grad_max_completion_length
None, # grad_dapo_config
)
Loading