add multimodal executorch support#39832
add multimodal executorch support#39832mergennachin wants to merge 3 commits intohuggingface:mainfrom
Conversation
|
[For maintainers] Suggested jobs to run (before merge) run-slow: moshi |
14ae06d to
70e366e
Compare
This commit enhances the ExecuTorch integration to support multimodal models like Gemma-3, LLaVA, and other vision-language models. Key changes: - Enhanced TorchExportableModuleWithHybridCache to support inputs_embeds parameter and multimodal configs - Added TorchExportableModuleForImageTextLM for image-text language models - Added ImageEncoderExportableModule for vision encoders - Added a test for multimodal functionality This enables ExecuTorch export for vision-language models while maintaining backward compatibility with text-only models.
162df79 to
ff1ac47
Compare
zucchini-nlp
left a comment
There was a problem hiding this comment.
Hey, thanks a lot for the PR! I agree that we need to export the LM and vision backbones separately, and handle input merging manually. Left a few comments, imo we should make sure different types of multimodal arch can be exportable (i.e. expected inputs, config attr names)
| if not hasattr(model.config, "text_config") or not hasattr(model.config.text_config, "use_cache") or model.config.text_config.use_cache is False: | ||
| raise ValueError("The model must have caching enabled to be performant.") |
There was a problem hiding this comment.
model.get_text_config() is more reliable because it is not always called text_config. And since it's accessed a lot below, we can just save it in self.text_config = model.get_text_config()
| # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable | ||
| ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) | ||
| ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) | ||
| self.model.model.config._attn_implementation = "sdpa_without_vmap" |
There was a problem hiding this comment.
Let's use public API - model.set_attn_implementation("sdpa_without_vmap")
| if hasattr(self.model, "base_model_prefix"): | ||
| base = getattr(self.model, self.model.base_model_prefix, self.model) | ||
| model_device = base.device | ||
| elif hasattr(self.model, "model"): | ||
| model_device = self.model.model.device | ||
| else: | ||
| model_device = "cpu" | ||
| logging.warning( | ||
| "TorchExportableModuleForImageTextLM.export Can't infer device from the model. Set to CPU by default." | ||
| ) |
There was a problem hiding this comment.
hmm, I think model.device would be fine. The model here is the language backbone
| super().__init__() | ||
| self.model = model | ||
|
|
||
| def forward(self, pixel_values): |
There was a problem hiding this comment.
most models currently require extra inputs such as num_patches, image_attn_mask etc.
| vision_outputs = self.model.vision_tower(pixel_values=pixel_values).last_hidden_state | ||
| image_features = self.model.multi_modal_projector(vision_outputs) | ||
| return image_features |
There was a problem hiding this comment.
Ig self.model is the multimodal model. We should use model.get_image_features() which handles the pipeline correctly for the given model, because some models might need extra ops on top of this
| return causal_mask | ||
|
|
||
|
|
||
| class TorchExportableModuleForImageTextLM(torch.nn.Module): |
There was a problem hiding this comment.
I feel like this is same as TorchExportableModuleForDecoderOnlyLM with the only diff that the input model in multimodal. We could re-use TorchExportableModuleForDecoderOnlyLM and ask users to export the language backbone explicitly like TorchExportableModuleForDecoderOnlyLM(model.language_model)
|
Hey @zucchini-nlp Thanks a lot for the thoughtful reviews. @jackzhxng will take this over the finish line in #39836 I'm gonna close this PR for the time being but hope @jackzhxng can incorporate some of your suggestions and recommendations in the PR. |
New Class: TorchExportableModuleForImageTextLM
Dedicated wrapper for image-text language models:
New Class: ImageEncoderExportableModule
Wrapper for vision encoder components: