Skip to content

[WIP]Feature/dpo trainer#1190

Draft
HT-Yuan wants to merge 2 commits intoinclusionAI:mainfrom
HT-Yuan:feature/dpo-trainer
Draft

[WIP]Feature/dpo trainer#1190
HT-Yuan wants to merge 2 commits intoinclusionAI:mainfrom
HT-Yuan:feature/dpo-trainer

Conversation

@HT-Yuan
Copy link
Copy Markdown
Contributor

@HT-Yuan HT-Yuan commented Apr 16, 2026

Description

Related Issue

Fixes #1137

Type of Change

  • 🐛 Bug fix
  • [ x] ✨ New feature
  • 💥 Breaking change
  • 📝 Documentation update
  • ♻️ Refactoring
  • ⚡ Performance improvement
  • ✅ Test coverage improvement

Checklist

  • I have read the Contributing Guide
  • [x ] Pre-commit hooks pass (pre-commit run --all-files)
  • Relevant tests pass; new tests added for new functionality
  • Documentation updated (if applicable; built with ./docs/build_all.sh)
  • Branch is up to date with main
  • Self-reviewed via /review-pr command
  • This PR was created by a coding agent via /create-pr
  • This PR is a breaking change

Breaking Change Details (if applicable):

Additional Context


Need help? Check the Contributing Guide or ask in
GitHub Discussions!

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Direct Preference Optimization (DPO) support to the framework, including the DPOTrainer, DPO-specific engine implementations for FSDP, Megatron, and Archon backends, and dataset processing for the HH-RLHF dataset. The review feedback identifies critical issues: potential crashes in DPOEngine due to incorrect batch type handling in training and evaluation paths, an incorrect dataset size calculation in DPOTrainer that would lead to improper learning rate scheduling, and an inefficient Python loop in the DPO loss computation that should be vectorized to improve performance and numerical stability.


def _train_dpo(self, data: dict[str, Any]) -> None:
"""Train on a batch (DPO)."""
if _dpo_loss_weight(data) == 0:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _dpo_loss_weight function expects a dictionary containing cu_seqlens, but data here is a list[dict[str, Any]] (the raw batch from the dataloader). This will cause a TypeError when trying to access data["cu_seqlens"]. Since the goal is to skip empty batches and log placeholder stats, you should check the list length instead.

Suggested change
if _dpo_loss_weight(data) == 0:
if not data:
_log_empty_dpo_stats(current_platform.current_device())
return

batched_call(self._evaluate_dpo, data, unpack=False)

def _evaluate_dpo(self, data: dict[str, Any]) -> None:
if _dpo_loss_weight(data) == 0:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the training path, _dpo_loss_weight will crash here because data is a list of dictionaries, not a packed dictionary with cu_seqlens.

Suggested change
if _dpo_loss_weight(data) == 0:
if not data:
_log_empty_dpo_stats(current_platform.current_device())
return


ft_spec = FinetuneSpec(
total_train_epochs=config.total_train_epochs,
dataset_size=len(self.train_dataloader) * config.train_dataset.batch_size,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The dataset_size is calculated using the length of the sharded dataloader, which represents the number of samples per rank. This will cause FinetuneSpec to compute an incorrect total_train_steps (underestimated by a factor of world_size), leading to incorrect learning rate scheduling and premature training termination. You should use the total dataset size instead.

Suggested change
dataset_size=len(self.train_dataloader) * config.train_dataset.batch_size,
dataset_size=len(train_dataset),

Comment on lines +174 to +185
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu()
n_seqs = seqlens.shape[0]

policy_logps = torch.zeros(n_seqs, dtype=torch.float64, device=device)
ref_logps = torch.zeros(n_seqs, dtype=torch.float64, device=device)

for i in range(n_seqs):
start = cu_seqlens[i]
end = cu_seqlens[i + 1]
m = loss_mask[start:end]
policy_logps[i] = torch.where(m, logprobs[start:end], 0.0).sum()
ref_logps[i] = torch.where(m, ref_logprobs[start:end], 0.0).sum()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This Python loop over sequences in the packed batch is inefficient and causes multiple GPU-CPU synchronizations. Furthermore, summing log-probabilities in low precision (e.g., bfloat16) can lead to numerical instability for long sequences. It is highly recommended to vectorize this operation and perform the summation in float64 to maintain precision.

    n_seqs = cu_seqlens.numel() - 1
    seq_ids = torch.zeros(logprobs.shape[0], dtype=torch.long, device=device)
    seq_ids.scatter_(0, cu_seqlens[1:-1].long(), 1)
    seq_ids = seq_ids.cumsum(dim=0)

    policy_logps = torch.zeros(n_seqs, dtype=torch.float64, device=device)
    policy_logps.index_add_(0, seq_ids, torch.where(loss_mask, logprobs, 0.0).to(torch.float64))

    ref_logps = torch.zeros(n_seqs, dtype=torch.float64, device=device)
    ref_logps.index_add_(0, seq_ids, torch.where(loss_mask, ref_logprobs, 0.0).to(torch.float64))

@HT-Yuan HT-Yuan force-pushed the feature/dpo-trainer branch from 1c85889 to 7c8f3af Compare April 16, 2026 06:58
@HT-Yuan HT-Yuan marked this pull request as draft April 16, 2026 07:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

DPO algo implementation

1 participant