Skip to content

Add SDPO (Self-Distillation Policy Optimization) trainer#4935

Open
MengAiDev wants to merge 64 commits intohuggingface:mainfrom
MengAiDev:4929
Open

Add SDPO (Self-Distillation Policy Optimization) trainer#4935
MengAiDev wants to merge 64 commits intohuggingface:mainfrom
MengAiDev:4929

Conversation

@MengAiDev
Copy link
Contributor

@MengAiDev MengAiDev commented Jan 30, 2026

Implements SDPO algorithm from arxiv.org/abs/2601.20802. SDPO augments on-policy optimization with self-distillation from the model's own high-reward trajectories, converting tokenized feedback into a dense learning signal.

  • Add SDPOConfig with distillation parameters (alpha, topk, ema_update_rate, etc.)
  • Add SDPOTrainer extending GRPOTrainer with self-distillation loss
  • Add comprehensive tests for SDPOConfig and SDPOTrainer
  • Add example script demonstrating SDPO usage

Fixes #4929


Note

Medium Risk
Adds new experimental training algorithms that change how rollouts, rewards, and loss are computed, including new EMA teacher synchronization paths; failures could affect training stability/metrics but are largely isolated to experimental APIs.

Overview
Introduces a new experimental self-distillation training stack (SelfDistillationConfig, SelfDistillationMixin, OnlineRolloutMixin, BaseSelfDistillationTrainer) that buffers on-policy generations, scores them with reward functions/models, and computes an optional self-distillation loss (token-level or full-logit, with top-k and importance-sampling clipping).

Adds two new experimental trainers: SDPOTrainer/SDPOConfig for Self-Distillation Policy Optimization (reprompting a teacher context from successful rollouts and optional privileged_context, with EMA teacher regularization and diagnostics/warnings), and SDFTTrainer/SDFTConfig for on-policy self-distilled fine-tuning using an explicit teacher prompt built from prompt + privileged_context (optionally generating from the teacher-conditioned prompt and skipping initial loss tokens).

Extends the public API with PEFTAdapterEMACallback (EMA “teacher” adapter for PEFT runs) and adds example scripts, documentation pages, paper index entries, and comprehensive tests covering training flows, callback hooks, PEFT/EMA behavior, and diagnostic edge cases.

Written by Cursor Bugbot for commit 3ffcb16. This will update automatically on new commits. Configure here.

MengAiDev and others added 9 commits January 30, 2026 10:04
Implements SDPO algorithm from arxiv.org/abs/2601.20802.
SDPO augments on-policy optimization with self-distillation from
the model's own high-reward trajectories, converting tokenized
feedback into a dense learning signal.

- Add SDPOConfig with distillation parameters (alpha, topk, ema_update_rate, etc.)
- Add SDPOTrainer extending GRPOTrainer with self-distillation loss
- Add comprehensive tests for SDPOConfig and SDPOTrainer
- Add example script demonstrating SDPO usage
@kashif
Copy link
Collaborator

kashif commented Feb 2, 2026

@MengAiDev I have cleaned up the structure and docs and tests. Next we need to address the main TODOs regarding the teacher logits.

@kashif
Copy link
Collaborator

kashif commented Feb 2, 2026

cc @jonhue here is a port of SDPO for TRL

@jonhue
Copy link

jonhue commented Feb 2, 2026

@MengAiDev @kashif Thanks so much for implementing this!! Let's coordinate with @Shekswess and #4941. It might be cleanest to have one implementation for SDFT & SDPO ("self-distillation") since both are algorithmically the same and they differ only in whether data is offline or online.

@kashif
Copy link
Collaborator

kashif commented Feb 2, 2026

agree! lets try that if its ok for you @MengAiDev

@Shekswess
Copy link

Wohoo !
This is really awesome, bravo legends @kashif @jonhue @MengAiDev. Maybe we should also then have the offline version of the trainer, knowing that some folks (like me that are GPU poor hahahahaha) can experiment with the approaches

@LeonEricsson
Copy link
Collaborator

Regarding the discussion on how to combine SDFT/SDPO PRs:

This PR inherits from GRPOTrainer, while the SDFT PR modifies it in place. Both approaches carry baggage from GRPOTrainer that isn’t necessarily applicable to SDPO/SDFT — but this also provides a nice playground for experimentation.

The tradeoff with inheritance is less control, but I like how it nicely isolates SDPO’s key contributions and exposes relevant hparams clearly. If future research demands more flexibility, we can revisit and consider breaking out SDPO into its own trainer.

If we proceed with this PR’s approach, extending it to cover the offline case should, at first glance, just require modifying the _build_teacher_inputs function.​​​​​​​​​​​​​​​​

@qgallouedec
Copy link
Member

That a good point Leon, I need to review the PR carefully, but in general, I’d rather isolate first and abstract later, if needed. (abstractions are easy to do, hard to undo)

@Shekswess
Copy link

@qgallouedec @LeonEricsson if you see my implementation #4941 (comment), of the offline SDFT I think it can be really really improved, tried to follow the official code from the authors with small modifications, feel free to ping us on how we can make these stuff better. Cannot wait to start to experiment hehehehe

@niksdagr8
Copy link

Any progress on this is much appreciated

)
kl_teacher = F.kl_div(mixture, teacher_log_probs, reduction="none", log_target=True)
kl_student = F.kl_div(mixture, student_log_probs, reduction="none", log_target=True)
kl = torch.lerp(kl_student, kl_teacher, alpha)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JSD divergence uses wrong interpolation direction via lerp

Medium Severity

In _compute_divergence, the JSD branch computes kl_teacher = F.kl_div(mixture, teacher_log_probs, ...) which gives KL(teacher || mixture), and kl_student = F.kl_div(mixture, student_log_probs, ...) which gives KL(student || mixture). Then torch.lerp(kl_student, kl_teacher, alpha) computes (1-alpha)*kl_student + alpha*kl_teacher. For alpha=0.5 (symmetric JSD) this is correct, but for asymmetric alpha values the weighting is inverted relative to the standard skew-JSD definition where alpha weights the teacher distribution in the mixture. The mixture is built with (1-alpha)*student + alpha*teacher, so the divergence terms are weighted backwards — the teacher KL should be weighted by (1-alpha) and the student KL by alpha.

Fix in Cursor Fix in Web

return base_policy_loss + self.args.distillation_weight * sdpo_loss

if self.args.distillation_weight <= 0.0:
return super()._compute_loss(model, inputs)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Distillation-only mode falls back to policy loss unexpectedly

Medium Severity

When sdpo_policy_loss_mode is "distillation_only" and distillation_weight <= 0.0, the code falls through to super()._compute_loss(model, inputs), which returns the base online policy loss. This contradicts the "distillation_only" mode semantics — the trainer silently switches to a pure policy-gradient objective when distillation weight is zero, instead of returning a zero loss or raising an error.

Fix in Cursor Fix in Web

if has_solution or use_feedback:
self_distillation_mask[i] = 1.0
if has_solution:
num_with_solution += 1
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Successful rollout count double-computes for redundant global loop

Medium Severity

The build method in SuccessfulRolloutTeacherContextBuilder iterates over ALL total_samples (global) in the first loop to build self_distillation_mask and metrics, then iterates over only the local process slice in a second loop to build teacher messages. In multi-GPU settings, the first loop sets self_distillation_mask for all global indices, but only local_self_distillation_mask = self_distillation_mask[process_slice] is returned. The success_group_count metric counts groups across all processes, but uses incomplete local data for all_prompts gathered via gather_object, which may not align with global reward indices on each process.

Additional Locations (1)
Fix in Cursor Fix in Web


per_token_loss = self._compute_divergence(
topk_student_log_probs, topk_teacher_log_probs, self.args.distillation_alpha
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicated top-k distillation logic in two branches

Medium Severity

The top-k logit distillation code block is nearly identically duplicated in _compute_self_distillation_loss. The block under full_logit_distillation=True with distillation_topk is not None (lines ~209–227) is functionally identical to the block under not full_logit_distillation with distillation_topk is not None and _allow_topk_without_full_logit_distillation() (lines ~234–252). Both compute logsumexp, gather top-k indices, handle distillation_add_tail vs renorm, and call _compute_divergence. This duplication increases the risk of a future fix being applied to one branch but not the other.

Fix in Cursor Fix in Web

self._buffered_inputs = None
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
self.prompt_tokenizer = PromptTokenizer(self)
self.teacher_context_builder = DemonstrationTeacherContextBuilder(self)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SDFTTrainer silently ignores chat_template_kwargs from config

Medium Severity

SDFTTrainer.__init__ never assigns self.chat_template_kwargs, but PromptTokenizer.apply_prompt_template reads it via getattr(self.trainer, "chat_template_kwargs", {}). This means any chat_template_kwargs specified in the SDFTConfig are silently ignored during prompt template application, teacher prompt construction, and generation. BaseSelfDistillationTrainer correctly sets self.chat_template_kwargs = args.chat_template_kwargs or {}, but SDFTTrainer (which doesn't inherit from it) omits this.

Additional Locations (1)
Fix in Cursor Fix in Web

content = last_message.get("content", "")
if isinstance(content, list):
return " ".join(part.get("text", "") for part in content if part.get("type") == "text")
return content
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicated _extract_last_user_text in two builder classes

Low Severity

_extract_last_user_text is identically implemented in both DemonstrationTeacherContextBuilder and SuccessfulRolloutTeacherContextBuilder. This duplication increases maintenance burden — a future bug fix to one copy could easily miss the other.

Additional Locations (1)
Fix in Cursor Fix in Web

)

per_token_loss = per_token_loss * response_mask
loss = self._aggregate_self_distillation_loss(per_token_loss, response_mask)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Distillation loss double-applies response mask before aggregation

Low Severity

per_token_loss is multiplied by response_mask on line 282, then _aggregate_self_distillation_loss multiplies by response_mask again internally (e.g., (per_token_loss * response_mask).sum()). The double masking is functionally harmless since masked positions are already zero, but it's confusing and suggests the aggregation helper's interface contract is unclear — callers must decide whether to pre-mask or let the aggregator handle it, not both.

Additional Locations (1)
Fix in Cursor Fix in Web


per_token_loss = self._compute_divergence(
topk_student_log_probs, topk_teacher_log_probs, self.args.distillation_alpha
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Top-k distillation code block duplicated across branches

Low Severity

The top-k logit distillation logic (computing student_logsumexp, topk_indices, gathering teacher logits, tail/renorm, and calling _compute_divergence) is copy-pasted almost identically between the full_logit_distillation=True branch (lines 209-227) and the elif distillation_topk is not None branch (lines 234-252). This redundancy increases maintenance burden and risks inconsistent fixes if either branch is modified.

Fix in Cursor Fix in Web

model = kwargs["model"]
if self.accelerator is not None:
model = self.accelerator.unwrap_model(model)
self.sync_target_model(model, self.ref_model, self.update_rate)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EMA teacher sync uses inverted update rate

High Severity

EMATeacherSyncCallback passes self.update_rate (default 0.05) directly as alpha to sync_target_model, where alpha controls the retention of the target model (target = alpha * target + (1-alpha) * source). With update_rate=0.05, the teacher retains only 5% of its weights each step, becoming a near-copy of the student. The intended EMA behavior (teacher slowly tracks the student) requires passing 1.0 - self.update_rate as alpha, so the teacher retains 95% and incorporates 5% of the student.

Additional Locations (1)
Fix in Cursor Fix in Web

}
if old_per_token_logps is not None:
output["old_per_token_logps"] = old_per_token_logps

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Globally gathered rewards tensor split with local tensors

Low Severity

The rewards tensor in the output dict of _generate_and_score_completions has global size (gathered across all processes), while all other tensors (prompt_ids, completion_ids, advantages, etc.) are local. When _prepare_inputs passes this dict to split_tensor_dict, the rewards tensor is split into chunks with dimensions inconsistent with the local tensors. Currently harmless since rewards isn't accessed after splitting, but this is a latent bug that would surface if any downstream code tries to use rewards from the split inputs in multi-GPU settings.

Additional Locations (1)
Fix in Cursor Fix in Web

"teacher_input_ids": teacher_input_ids,
"teacher_attention_mask": teacher_attention_mask,
"self_distillation_mask": local_self_distillation_mask,
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Multi-GPU index mismatch causes crash in teacher context builder

High Severity

In SuccessfulRolloutTeacherContextBuilder.build(), the rewards tensor passed in is already LOCAL (sliced per-process in OnlineRolloutMixin._generate_and_score_completions), so total_samples = rewards.shape[0] is the local batch size. Helper arrays like has_solution_flags, successful_demo_indices, use_feedback_flags, and self_distillation_mask are sized by this local count. However, the second loop iterates using global_idx in range(process_start, process_start + num_local), and on any GPU with process_index > 0, global_idx exceeds the local array bounds, causing an IndexError. Similarly, self_distillation_mask[process_slice] slices out of bounds on non-zero ranks.

Additional Locations (1)
Fix in Cursor Fix in Web

next-token predictions back into the policy.
"""

config_cls = SDPOConfig
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SDPOTrainer missing _name causes wrong metric prefix

Low Severity

SDPOTrainer does not override _name from BaseSelfDistillationTrainer, which defaults to "SelfDistillation". The inherited _log_self_distillation_metric uses _name.lower() to build a metric prefix, so SDPO distillation metrics get logged under selfdistillation/distillation_loss instead of the expected sdpo/distillation_loss. Compare with SDFTTrainer which correctly sets _name = "SDFT" and overrides the logging method. The _tag_names attribute is also not overridden, so hub model cards would show "self-distillation" instead of "sdpo".

Additional Locations (1)
Fix in Cursor Fix in Web

feedback_template: str = field(
default="\nThe following is feedback from your unsuccessful earlier attempt:\n\n{feedback_raw}\n\n",
metadata={"help": "Template for formatting environment feedback for reprompting."},
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Template formatting breaks with curly braces in content

Medium Severity

The reprompt_template, solution_template, and feedback_template in SDPOConfig all use Python's str.format() for interpolation. If model completions, environment feedback, or privileged context contain literal curly braces (common in code generation tasks like the paper's LiveCodeBench benchmark), .format() will raise a KeyError or ValueError. The SDFT teacher_prompt_template has the same issue. Using a safer interpolation method would prevent runtime crashes on code-containing content.

Additional Locations (1)
Fix in Cursor Fix in Web

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 4 potential issues.

There are 8 total unresolved issues (including 4 from previous reviews).

Fix All in Cursor

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

return super()._compute_loss(model, inputs)

sdpo_loss = self._compute_self_distillation_loss(model, inputs) / accumulation_scale
return self.args.distillation_weight * sdpo_loss
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double gradient accumulation scaling in hybrid SDPO loss

High Severity

In hybrid mode, SDPOTrainer._compute_loss calls super()._compute_loss(model, inputs) which is OnlineRolloutMixin._compute_loss. That method already divides the policy loss by accumulation_scale (line 343 of online_rollout_mixin.py). The base_policy_loss is therefore already scaled. When the two terms are summed (base_policy_loss + distillation_weight * sdpo_loss), the policy loss has been divided by accumulation_scale once, while the distillation loss has also been divided once — but the overall loss returned from _compute_loss is then not divided again by compute_loss. However, examining the distillation_only default mode: the parent compute_loss calls _compute_loss, which returns distillation_weight * sdpo_loss (already divided by accumulation_scale). This is fine. The real issue is that OnlineRolloutMixin._compute_loss logs the already-scaled policy loss as self_distillation/policy_loss, making the logged metric dependent on gradient_accumulation_steps rather than representing the true loss magnitude.

Additional Locations (1)
Fix in Cursor Fix in Web

)
)

self.model_accepts_loss_kwargs = False
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SDFTTrainer missing _set_signature_columns_if_needed override

Medium Severity

BaseSelfDistillationTrainer overrides _set_signature_columns_if_needed to preserve the prompt and privileged_context columns in the dataset. SDFTTrainer does not inherit from BaseSelfDistillationTrainer — it inherits from SelfDistillationMixin and _BaseTrainer directly. Since SelfDistillationMixin does not define _set_signature_columns_if_needed, the default Trainer implementation will be used, which may strip the privileged_context column from the dataset before training.

Additional Locations (1)
Fix in Cursor Fix in Web

callbacks=callbacks,
optimizers=optimizers,
compute_loss_func="non-None value to disable scaling",
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SDFTTrainer missing _diagnostic_counters attribute initialization

Low Severity

SDFTTrainer inherits SelfDistillationMixin which uses self._diagnostic_counters in _warn_on_degenerate_diagnostics (via OnlineRolloutMixin), but SDFTTrainer.__init__ never initializes _diagnostic_counters. The BaseSelfDistillationTrainer initializes it at line 139-142, but SDFTTrainer doesn't inherit from that class. If any diagnostic warning path is triggered, an AttributeError would occur.

Fix in Cursor Fix in Web


per_token_loss = self._compute_divergence(
topk_student_log_probs, topk_teacher_log_probs, self.args.distillation_alpha
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Top-k distillation selects indices from student not teacher

Medium Severity

In top-k distillation, the top-k indices are selected from the student logits (torch.topk(student_logits, ...)), and the same indices are gathered from teacher logits. This means the distillation focuses on tokens the student already assigns high probability to, rather than tokens the teacher considers important. For self-distillation where the teacher is conditioned on privileged context, the teacher's top tokens may differ significantly from the student's, causing the loss to miss the most informative signal from the teacher distribution.

Fix in Cursor Fix in Web


In the current TRL implementation:

- SDFT uses an explicit `ref_model` teacher
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is ref_model the teacher? or are these two separate models?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes its the teacher

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

SDPO: Reinforcement Learning via Self-Distillation

8 participants