Skip to content

Support max_length in DPO VLM training#5284

Open
albertvillanova wants to merge 9 commits intohuggingface:mainfrom
albertvillanova:fix-5283
Open

Support max_length in DPO VLM training#5284
albertvillanova wants to merge 9 commits intohuggingface:mainfrom
albertvillanova:fix-5283

Conversation

@albertvillanova
Copy link
Member

@albertvillanova albertvillanova commented Mar 13, 2026

Support max_length in DPO VLM training.

  • Truncate sequence-aligned side-inputs (token_type_ids, mm_token_type_ids) with input_ids in DPO VLM training

Fix #5283.

This PR addresses a regression affecting vision-language model (VLM) training when using sequence truncation. The main fix ensures that auxiliary token fields (mm_token_type_ids and token_type_ids) are truncated in sync with input_ids, preventing shape mismatches and crashes during the model's forward pass. Additionally, a regression test is added to verify this behavior.

Changes

Bug fix for sequence truncation in VLMs:

  • Ensured that token_type_ids and mm_token_type_ids are truncated to match the length of input_ids in both compute_ref_log_probs and _compute_loss methods of DPOTrainer, preventing shape mismatch errors during training.
    • Note that pixel_values, image_grid_thw, image_sizes, and pixel_attention_mask are patch-level or image-level tensors and should not be truncated.

Testing improvements:

  • Added a regression test test_train_vlm_with_max_length in tests/test_dpo_trainer.py to verify that truncation with max_length does not crash the model and that image tokens are handled correctly.

Follow-up

If this approach is approved, I will implement a similar fix for other trainers.

Related

See related discussion in: #5279 (comment)


Note

Medium Risk
Touches core DPO training/inference paths by changing how batches are truncated and how token_type_ids/mm_token_type_ids are passed to the model; risk is mostly around unintended shape/flush behavior across truncation modes.

Overview
Fixes a VLM regression where enabling max_length could crash DPOTrainer due to sequence-aligned side tensors not being truncated with input_ids.

DPOTrainer._truncate_inputs now supports truncating additional sequence-aligned tensors (and includes them in flush_right/flush_left for keep_end), and both compute_ref_log_probs and _compute_loss now ensure token_type_ids/mm_token_type_ids match the truncated sequence length.

Adds a test_train_vlm_with_max_length regression test to confirm VLM DPO training runs successfully with truncation enabled.

Written by Cursor Bugbot for commit 79d28c4. This will update automatically on new commits. Configure here.

@albertvillanova albertvillanova changed the title Truncate token_type_ids and mm_token_type_ids with input_ids in DPO VLM training Support max_length in DPO VLM training Mar 13, 2026
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

input_ids, attention_mask, completion_mask = self._truncate_inputs(input_ids, attention_mask, completion_mask)
# token_type_ids is sequence-length-aligned: truncate to match input_ids
# in keep_end mode, token_type_ids participates in flush_right/flush_left
extra = (inputs["token_type_ids"],) if "token_type_ids" in inputs else ()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not having mm_token_type_ids in extra?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tokens in extra can be truncated both with "keep_start" and "keep_end", and I think it is semantically wrong to use "keep_end" in VLM mm_token_type_ids, but I'm addressing that in a following PR:

So, let's treat mm_token_type_ids and token_type_ids symmetrically to be internally consistent, and leave the semantical correction to the other PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I think you should align compute_ref_log_probs with compute_loss, ie having mm_token_type_ids in extra in both cases

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

DPOTrainer crashes when max_length is set with VLMs: IndexError

3 participants