Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces Direct Preference Optimization (DPO) support to the framework, including the DPOTrainer, DPO-specific engine implementations for FSDP, Megatron, and Archon backends, and dataset processing for the HH-RLHF dataset. The review feedback identifies critical issues: potential crashes in DPOEngine due to incorrect batch type handling in training and evaluation paths, an incorrect dataset size calculation in DPOTrainer that would lead to improper learning rate scheduling, and an inefficient Python loop in the DPO loss computation that should be vectorized to improve performance and numerical stability.
|
|
||
| def _train_dpo(self, data: dict[str, Any]) -> None: | ||
| """Train on a batch (DPO).""" | ||
| if _dpo_loss_weight(data) == 0: |
There was a problem hiding this comment.
The _dpo_loss_weight function expects a dictionary containing cu_seqlens, but data here is a list[dict[str, Any]] (the raw batch from the dataloader). This will cause a TypeError when trying to access data["cu_seqlens"]. Since the goal is to skip empty batches and log placeholder stats, you should check the list length instead.
| if _dpo_loss_weight(data) == 0: | |
| if not data: | |
| _log_empty_dpo_stats(current_platform.current_device()) | |
| return |
| batched_call(self._evaluate_dpo, data, unpack=False) | ||
|
|
||
| def _evaluate_dpo(self, data: dict[str, Any]) -> None: | ||
| if _dpo_loss_weight(data) == 0: |
There was a problem hiding this comment.
|
|
||
| ft_spec = FinetuneSpec( | ||
| total_train_epochs=config.total_train_epochs, | ||
| dataset_size=len(self.train_dataloader) * config.train_dataset.batch_size, |
There was a problem hiding this comment.
The dataset_size is calculated using the length of the sharded dataloader, which represents the number of samples per rank. This will cause FinetuneSpec to compute an incorrect total_train_steps (underestimated by a factor of world_size), leading to incorrect learning rate scheduling and premature training termination. You should use the total dataset size instead.
| dataset_size=len(self.train_dataloader) * config.train_dataset.batch_size, | |
| dataset_size=len(train_dataset), |
| seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).cpu() | ||
| n_seqs = seqlens.shape[0] | ||
|
|
||
| policy_logps = torch.zeros(n_seqs, dtype=torch.float64, device=device) | ||
| ref_logps = torch.zeros(n_seqs, dtype=torch.float64, device=device) | ||
|
|
||
| for i in range(n_seqs): | ||
| start = cu_seqlens[i] | ||
| end = cu_seqlens[i + 1] | ||
| m = loss_mask[start:end] | ||
| policy_logps[i] = torch.where(m, logprobs[start:end], 0.0).sum() | ||
| ref_logps[i] = torch.where(m, ref_logprobs[start:end], 0.0).sum() |
There was a problem hiding this comment.
This Python loop over sequences in the packed batch is inefficient and causes multiple GPU-CPU synchronizations. Furthermore, summing log-probabilities in low precision (e.g., bfloat16) can lead to numerical instability for long sequences. It is highly recommended to vectorize this operation and perform the summation in float64 to maintain precision.
n_seqs = cu_seqlens.numel() - 1
seq_ids = torch.zeros(logprobs.shape[0], dtype=torch.long, device=device)
seq_ids.scatter_(0, cu_seqlens[1:-1].long(), 1)
seq_ids = seq_ids.cumsum(dim=0)
policy_logps = torch.zeros(n_seqs, dtype=torch.float64, device=device)
policy_logps.index_add_(0, seq_ids, torch.where(loss_mask, logprobs, 0.0).to(torch.float64))
ref_logps = torch.zeros(n_seqs, dtype=torch.float64, device=device)
ref_logps.index_add_(0, seq_ids, torch.where(loss_mask, ref_logprobs, 0.0).to(torch.float64))1c85889 to
7c8f3af
Compare
Description
Related Issue
Fixes #1137
Type of Change
Checklist
pre-commit run --all-files)./docs/build_all.sh)main/review-prcommand/create-prBreaking Change Details (if applicable):
Additional Context
Need help? Check the Contributing Guide or ask in
GitHub Discussions!