Support max_length in DPO VLM training#5284
Support max_length in DPO VLM training#5284albertvillanova wants to merge 9 commits intohuggingface:mainfrom
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
trl/trainer/dpo_trainer.py
Outdated
| input_ids, attention_mask, completion_mask = self._truncate_inputs(input_ids, attention_mask, completion_mask) | ||
| # token_type_ids is sequence-length-aligned: truncate to match input_ids | ||
| # in keep_end mode, token_type_ids participates in flush_right/flush_left | ||
| extra = (inputs["token_type_ids"],) if "token_type_ids" in inputs else () |
There was a problem hiding this comment.
why not having mm_token_type_ids in extra?
There was a problem hiding this comment.
The tokens in extra can be truncated both with "keep_start" and "keep_end", and I think it is semantically wrong to use "keep_end" in VLM mm_token_type_ids, but I'm addressing that in a following PR:
So, let's treat mm_token_type_ids and token_type_ids symmetrically to be internally consistent, and leave the semantical correction to the other PR.
There was a problem hiding this comment.
ok, I think you should align compute_ref_log_probs with compute_loss, ie having mm_token_type_ids in extra in both cases
Support max_length in DPO VLM training.
Fix #5283.
This PR addresses a regression affecting vision-language model (VLM) training when using sequence truncation. The main fix ensures that auxiliary token fields (
mm_token_type_idsandtoken_type_ids) are truncated in sync withinput_ids, preventing shape mismatches and crashes during the model's forward pass. Additionally, a regression test is added to verify this behavior.Changes
Bug fix for sequence truncation in VLMs:
token_type_idsandmm_token_type_idsare truncated to match the length ofinput_idsin bothcompute_ref_log_probsand_compute_lossmethods ofDPOTrainer, preventing shape mismatch errors during training.Testing improvements:
test_train_vlm_with_max_lengthintests/test_dpo_trainer.pyto verify that truncation withmax_lengthdoes not crash the model and that image tokens are handled correctly.Follow-up
If this approach is approved, I will implement a similar fix for other trainers.
Related
See related discussion in: #5279 (comment)
Note
Medium Risk
Touches core DPO training/inference paths by changing how batches are truncated and how
token_type_ids/mm_token_type_idsare passed to the model; risk is mostly around unintended shape/flush behavior across truncation modes.Overview
Fixes a VLM regression where enabling
max_lengthcould crashDPOTrainerdue to sequence-aligned side tensors not being truncated withinput_ids.DPOTrainer._truncate_inputsnow supports truncating additional sequence-aligned tensors (and includes them inflush_right/flush_leftforkeep_end), and bothcompute_ref_log_probsand_compute_lossnow ensuretoken_type_ids/mm_token_type_idsmatch the truncated sequence length.Adds a
test_train_vlm_with_max_lengthregression test to confirm VLM DPO training runs successfully with truncation enabled.Written by Cursor Bugbot for commit 79d28c4. This will update automatically on new commits. Configure here.