Add SDPO (Self-Distillation Policy Optimization) trainer#4935
Add SDPO (Self-Distillation Policy Optimization) trainer#4935MengAiDev wants to merge 64 commits intohuggingface:mainfrom
Conversation
Implements SDPO algorithm from arxiv.org/abs/2601.20802. SDPO augments on-policy optimization with self-distillation from the model's own high-reward trajectories, converting tokenized feedback into a dense learning signal. - Add SDPOConfig with distillation parameters (alpha, topk, ema_update_rate, etc.) - Add SDPOTrainer extending GRPOTrainer with self-distillation loss - Add comprehensive tests for SDPOConfig and SDPOTrainer - Add example script demonstrating SDPO usage
|
@MengAiDev I have cleaned up the structure and docs and tests. Next we need to address the main TODOs regarding the teacher logits. |
|
cc @jonhue here is a port of SDPO for TRL |
|
@MengAiDev @kashif Thanks so much for implementing this!! Let's coordinate with @Shekswess and #4941. It might be cleanest to have one implementation for SDFT & SDPO ("self-distillation") since both are algorithmically the same and they differ only in whether data is offline or online. |
|
agree! lets try that if its ok for you @MengAiDev |
|
Wohoo ! |
|
Regarding the discussion on how to combine SDFT/SDPO PRs: This PR inherits from GRPOTrainer, while the SDFT PR modifies it in place. Both approaches carry baggage from GRPOTrainer that isn’t necessarily applicable to SDPO/SDFT — but this also provides a nice playground for experimentation. The tradeoff with inheritance is less control, but I like how it nicely isolates SDPO’s key contributions and exposes relevant hparams clearly. If future research demands more flexibility, we can revisit and consider breaking out SDPO into its own trainer. If we proceed with this PR’s approach, extending it to cover the offline case should, at first glance, just require modifying the _build_teacher_inputs function. |
|
That a good point Leon, I need to review the PR carefully, but in general, I’d rather isolate first and abstract later, if needed. (abstractions are easy to do, hard to undo) |
|
@qgallouedec @LeonEricsson if you see my implementation #4941 (comment), of the offline SDFT I think it can be really really improved, tried to follow the official code from the authors with small modifications, feel free to ping us on how we can make these stuff better. Cannot wait to start to experiment hehehehe |
|
Any progress on this is much appreciated |
trl/experimental/self_distillation/base_self_distillation_trainer.py
Outdated
Show resolved
Hide resolved
| ) | ||
| kl_teacher = F.kl_div(mixture, teacher_log_probs, reduction="none", log_target=True) | ||
| kl_student = F.kl_div(mixture, student_log_probs, reduction="none", log_target=True) | ||
| kl = torch.lerp(kl_student, kl_teacher, alpha) |
There was a problem hiding this comment.
JSD divergence uses wrong interpolation direction via lerp
Medium Severity
In _compute_divergence, the JSD branch computes kl_teacher = F.kl_div(mixture, teacher_log_probs, ...) which gives KL(teacher || mixture), and kl_student = F.kl_div(mixture, student_log_probs, ...) which gives KL(student || mixture). Then torch.lerp(kl_student, kl_teacher, alpha) computes (1-alpha)*kl_student + alpha*kl_teacher. For alpha=0.5 (symmetric JSD) this is correct, but for asymmetric alpha values the weighting is inverted relative to the standard skew-JSD definition where alpha weights the teacher distribution in the mixture. The mixture is built with (1-alpha)*student + alpha*teacher, so the divergence terms are weighted backwards — the teacher KL should be weighted by (1-alpha) and the student KL by alpha.
| return base_policy_loss + self.args.distillation_weight * sdpo_loss | ||
|
|
||
| if self.args.distillation_weight <= 0.0: | ||
| return super()._compute_loss(model, inputs) |
There was a problem hiding this comment.
Distillation-only mode falls back to policy loss unexpectedly
Medium Severity
When sdpo_policy_loss_mode is "distillation_only" and distillation_weight <= 0.0, the code falls through to super()._compute_loss(model, inputs), which returns the base online policy loss. This contradicts the "distillation_only" mode semantics — the trainer silently switches to a pure policy-gradient objective when distillation weight is zero, instead of returning a zero loss or raising an error.
| if has_solution or use_feedback: | ||
| self_distillation_mask[i] = 1.0 | ||
| if has_solution: | ||
| num_with_solution += 1 |
There was a problem hiding this comment.
Successful rollout count double-computes for redundant global loop
Medium Severity
The build method in SuccessfulRolloutTeacherContextBuilder iterates over ALL total_samples (global) in the first loop to build self_distillation_mask and metrics, then iterates over only the local process slice in a second loop to build teacher messages. In multi-GPU settings, the first loop sets self_distillation_mask for all global indices, but only local_self_distillation_mask = self_distillation_mask[process_slice] is returned. The success_group_count metric counts groups across all processes, but uses incomplete local data for all_prompts gathered via gather_object, which may not align with global reward indices on each process.
Additional Locations (1)
|
|
||
| per_token_loss = self._compute_divergence( | ||
| topk_student_log_probs, topk_teacher_log_probs, self.args.distillation_alpha | ||
| ) |
There was a problem hiding this comment.
Duplicated top-k distillation logic in two branches
Medium Severity
The top-k logit distillation code block is nearly identically duplicated in _compute_self_distillation_loss. The block under full_logit_distillation=True with distillation_topk is not None (lines ~209–227) is functionally identical to the block under not full_logit_distillation with distillation_topk is not None and _allow_topk_without_full_logit_distillation() (lines ~234–252). Both compute logsumexp, gather top-k indices, handle distillation_add_tail vs renorm, and call _compute_divergence. This duplication increases the risk of a future fix being applied to one branch but not the other.
| self._buffered_inputs = None | ||
| self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} | ||
| self.prompt_tokenizer = PromptTokenizer(self) | ||
| self.teacher_context_builder = DemonstrationTeacherContextBuilder(self) |
There was a problem hiding this comment.
SDFTTrainer silently ignores chat_template_kwargs from config
Medium Severity
SDFTTrainer.__init__ never assigns self.chat_template_kwargs, but PromptTokenizer.apply_prompt_template reads it via getattr(self.trainer, "chat_template_kwargs", {}). This means any chat_template_kwargs specified in the SDFTConfig are silently ignored during prompt template application, teacher prompt construction, and generation. BaseSelfDistillationTrainer correctly sets self.chat_template_kwargs = args.chat_template_kwargs or {}, but SDFTTrainer (which doesn't inherit from it) omits this.
Additional Locations (1)
| content = last_message.get("content", "") | ||
| if isinstance(content, list): | ||
| return " ".join(part.get("text", "") for part in content if part.get("type") == "text") | ||
| return content |
There was a problem hiding this comment.
Duplicated _extract_last_user_text in two builder classes
Low Severity
_extract_last_user_text is identically implemented in both DemonstrationTeacherContextBuilder and SuccessfulRolloutTeacherContextBuilder. This duplication increases maintenance burden — a future bug fix to one copy could easily miss the other.
Additional Locations (1)
| ) | ||
|
|
||
| per_token_loss = per_token_loss * response_mask | ||
| loss = self._aggregate_self_distillation_loss(per_token_loss, response_mask) |
There was a problem hiding this comment.
Distillation loss double-applies response mask before aggregation
Low Severity
per_token_loss is multiplied by response_mask on line 282, then _aggregate_self_distillation_loss multiplies by response_mask again internally (e.g., (per_token_loss * response_mask).sum()). The double masking is functionally harmless since masked positions are already zero, but it's confusing and suggests the aggregation helper's interface contract is unclear — callers must decide whether to pre-mask or let the aggregator handle it, not both.
Additional Locations (1)
|
|
||
| per_token_loss = self._compute_divergence( | ||
| topk_student_log_probs, topk_teacher_log_probs, self.args.distillation_alpha | ||
| ) |
There was a problem hiding this comment.
Top-k distillation code block duplicated across branches
Low Severity
The top-k logit distillation logic (computing student_logsumexp, topk_indices, gathering teacher logits, tail/renorm, and calling _compute_divergence) is copy-pasted almost identically between the full_logit_distillation=True branch (lines 209-227) and the elif distillation_topk is not None branch (lines 234-252). This redundancy increases maintenance burden and risks inconsistent fixes if either branch is modified.
| model = kwargs["model"] | ||
| if self.accelerator is not None: | ||
| model = self.accelerator.unwrap_model(model) | ||
| self.sync_target_model(model, self.ref_model, self.update_rate) |
There was a problem hiding this comment.
EMA teacher sync uses inverted update rate
High Severity
EMATeacherSyncCallback passes self.update_rate (default 0.05) directly as alpha to sync_target_model, where alpha controls the retention of the target model (target = alpha * target + (1-alpha) * source). With update_rate=0.05, the teacher retains only 5% of its weights each step, becoming a near-copy of the student. The intended EMA behavior (teacher slowly tracks the student) requires passing 1.0 - self.update_rate as alpha, so the teacher retains 95% and incorporates 5% of the student.
Additional Locations (1)
| } | ||
| if old_per_token_logps is not None: | ||
| output["old_per_token_logps"] = old_per_token_logps | ||
|
|
There was a problem hiding this comment.
Globally gathered rewards tensor split with local tensors
Low Severity
The rewards tensor in the output dict of _generate_and_score_completions has global size (gathered across all processes), while all other tensors (prompt_ids, completion_ids, advantages, etc.) are local. When _prepare_inputs passes this dict to split_tensor_dict, the rewards tensor is split into chunks with dimensions inconsistent with the local tensors. Currently harmless since rewards isn't accessed after splitting, but this is a latent bug that would surface if any downstream code tries to use rewards from the split inputs in multi-GPU settings.
Additional Locations (1)
| "teacher_input_ids": teacher_input_ids, | ||
| "teacher_attention_mask": teacher_attention_mask, | ||
| "self_distillation_mask": local_self_distillation_mask, | ||
| } |
There was a problem hiding this comment.
Multi-GPU index mismatch causes crash in teacher context builder
High Severity
In SuccessfulRolloutTeacherContextBuilder.build(), the rewards tensor passed in is already LOCAL (sliced per-process in OnlineRolloutMixin._generate_and_score_completions), so total_samples = rewards.shape[0] is the local batch size. Helper arrays like has_solution_flags, successful_demo_indices, use_feedback_flags, and self_distillation_mask are sized by this local count. However, the second loop iterates using global_idx in range(process_start, process_start + num_local), and on any GPU with process_index > 0, global_idx exceeds the local array bounds, causing an IndexError. Similarly, self_distillation_mask[process_slice] slices out of bounds on non-zero ranks.
Additional Locations (1)
| next-token predictions back into the policy. | ||
| """ | ||
|
|
||
| config_cls = SDPOConfig |
There was a problem hiding this comment.
SDPOTrainer missing _name causes wrong metric prefix
Low Severity
SDPOTrainer does not override _name from BaseSelfDistillationTrainer, which defaults to "SelfDistillation". The inherited _log_self_distillation_metric uses _name.lower() to build a metric prefix, so SDPO distillation metrics get logged under selfdistillation/distillation_loss instead of the expected sdpo/distillation_loss. Compare with SDFTTrainer which correctly sets _name = "SDFT" and overrides the logging method. The _tag_names attribute is also not overridden, so hub model cards would show "self-distillation" instead of "sdpo".
Additional Locations (1)
| feedback_template: str = field( | ||
| default="\nThe following is feedback from your unsuccessful earlier attempt:\n\n{feedback_raw}\n\n", | ||
| metadata={"help": "Template for formatting environment feedback for reprompting."}, | ||
| ) |
There was a problem hiding this comment.
Template formatting breaks with curly braces in content
Medium Severity
The reprompt_template, solution_template, and feedback_template in SDPOConfig all use Python's str.format() for interpolation. If model completions, environment feedback, or privileged context contain literal curly braces (common in code generation tasks like the paper's LiveCodeBench benchmark), .format() will raise a KeyError or ValueError. The SDFT teacher_prompt_template has the same issue. Using a safer interpolation method would prevent runtime crashes on code-containing content.
Additional Locations (1)
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 4 potential issues.
There are 8 total unresolved issues (including 4 from previous reviews).
Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
| return super()._compute_loss(model, inputs) | ||
|
|
||
| sdpo_loss = self._compute_self_distillation_loss(model, inputs) / accumulation_scale | ||
| return self.args.distillation_weight * sdpo_loss |
There was a problem hiding this comment.
Double gradient accumulation scaling in hybrid SDPO loss
High Severity
In hybrid mode, SDPOTrainer._compute_loss calls super()._compute_loss(model, inputs) which is OnlineRolloutMixin._compute_loss. That method already divides the policy loss by accumulation_scale (line 343 of online_rollout_mixin.py). The base_policy_loss is therefore already scaled. When the two terms are summed (base_policy_loss + distillation_weight * sdpo_loss), the policy loss has been divided by accumulation_scale once, while the distillation loss has also been divided once — but the overall loss returned from _compute_loss is then not divided again by compute_loss. However, examining the distillation_only default mode: the parent compute_loss calls _compute_loss, which returns distillation_weight * sdpo_loss (already divided by accumulation_scale). This is fine. The real issue is that OnlineRolloutMixin._compute_loss logs the already-scaled policy loss as self_distillation/policy_loss, making the logged metric dependent on gradient_accumulation_steps rather than representing the true loss magnitude.
Additional Locations (1)
| ) | ||
| ) | ||
|
|
||
| self.model_accepts_loss_kwargs = False |
There was a problem hiding this comment.
SDFTTrainer missing _set_signature_columns_if_needed override
Medium Severity
BaseSelfDistillationTrainer overrides _set_signature_columns_if_needed to preserve the prompt and privileged_context columns in the dataset. SDFTTrainer does not inherit from BaseSelfDistillationTrainer — it inherits from SelfDistillationMixin and _BaseTrainer directly. Since SelfDistillationMixin does not define _set_signature_columns_if_needed, the default Trainer implementation will be used, which may strip the privileged_context column from the dataset before training.
Additional Locations (1)
| callbacks=callbacks, | ||
| optimizers=optimizers, | ||
| compute_loss_func="non-None value to disable scaling", | ||
| ) |
There was a problem hiding this comment.
SDFTTrainer missing _diagnostic_counters attribute initialization
Low Severity
SDFTTrainer inherits SelfDistillationMixin which uses self._diagnostic_counters in _warn_on_degenerate_diagnostics (via OnlineRolloutMixin), but SDFTTrainer.__init__ never initializes _diagnostic_counters. The BaseSelfDistillationTrainer initializes it at line 139-142, but SDFTTrainer doesn't inherit from that class. If any diagnostic warning path is triggered, an AttributeError would occur.
|
|
||
| per_token_loss = self._compute_divergence( | ||
| topk_student_log_probs, topk_teacher_log_probs, self.args.distillation_alpha | ||
| ) |
There was a problem hiding this comment.
Top-k distillation selects indices from student not teacher
Medium Severity
In top-k distillation, the top-k indices are selected from the student logits (torch.topk(student_logits, ...)), and the same indices are gathered from teacher logits. This means the distillation focuses on tokens the student already assigns high probability to, rather than tokens the teacher considers important. For self-distillation where the teacher is conditioned on privileged context, the teacher's top tokens may differ significantly from the student's, causing the loss to miss the most informative signal from the teacher distribution.
|
|
||
| In the current TRL implementation: | ||
|
|
||
| - SDFT uses an explicit `ref_model` teacher |
There was a problem hiding this comment.
is ref_model the teacher? or are these two separate models?
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>


Implements SDPO algorithm from arxiv.org/abs/2601.20802. SDPO augments on-policy optimization with self-distillation from the model's own high-reward trajectories, converting tokenized feedback into a dense learning signal.
Fixes #4929
Note
Medium Risk
Adds new experimental training algorithms that change how rollouts, rewards, and loss are computed, including new EMA teacher synchronization paths; failures could affect training stability/metrics but are largely isolated to experimental APIs.
Overview
Introduces a new experimental self-distillation training stack (
SelfDistillationConfig,SelfDistillationMixin,OnlineRolloutMixin,BaseSelfDistillationTrainer) that buffers on-policy generations, scores them with reward functions/models, and computes an optional self-distillation loss (token-level or full-logit, with top-k and importance-sampling clipping).Adds two new experimental trainers:
SDPOTrainer/SDPOConfigfor Self-Distillation Policy Optimization (reprompting a teacher context from successful rollouts and optionalprivileged_context, with EMA teacher regularization and diagnostics/warnings), andSDFTTrainer/SDFTConfigfor on-policy self-distilled fine-tuning using an explicit teacher prompt built fromprompt+privileged_context(optionally generating from the teacher-conditioned prompt and skipping initial loss tokens).Extends the public API with
PEFTAdapterEMACallback(EMA “teacher” adapter for PEFT runs) and adds example scripts, documentation pages, paper index entries, and comprehensive tests covering training flows, callback hooks, PEFT/EMA behavior, and diagnostic edge cases.Written by Cursor Bugbot for commit 3ffcb16. This will update automatically on new commits. Configure here.