Skip to content

Add Tensor.unfold op support for PyTorch frontend#2672

Merged
TobyRoseman merged 2 commits intoapple:mainfrom
jathinsn27:add-tensor-unfold-op
Apr 14, 2026
Merged

Add Tensor.unfold op support for PyTorch frontend#2672
TobyRoseman merged 2 commits intoapple:mainfrom
jathinsn27:add-tensor-unfold-op

Conversation

@jathinsn27
Copy link
Copy Markdown
Contributor

Summary

  • Add converter for aten::unfold (Tensor.unfold(dimension, size, step)), which was previously unsupported
  • Maps to MIL sliding_windows op with a transpose to match PyTorch's output layout

Fixes #2599

Test

  • Added TestTensorUnfold with 5 parametrized cases covering different axes, step sizes, negative indexing, and rank-4 vision input
  • Added test_chained_unfold for back-to-back unfold calls
  • All 24 test cases pass across all (compute_unit, backend, frontend) combinations
  • Verified tests fail without the fix (op 'unfold' not implemented) and pass with it

@TobyRoseman
Copy link
Copy Markdown
Collaborator

@jathinsn27
Copy link
Copy Markdown
Contributor Author

CI Run: https://gitlab.com/coremltools1/coremltools/-/pipelines/2450037346

Thanks for triggering the CI. The failures is from ExecuTorch frontend. Since aten.unfold is not in the Core ATen opset, so ExecuTorch's verifier rejects it before coremltools runs. I've pushed a fix to solve this issue. Could you retrigger CI when you get a chance?

@TobyRoseman
Copy link
Copy Markdown
Collaborator

Updated CI: https://gitlab.com/coremltools1/coremltools/-/pipelines/2452550699

@TobyRoseman TobyRoseman merged commit 8147ec1 into apple:main Apr 14, 2026
@TobyRoseman
Copy link
Copy Markdown
Collaborator

@jathinsn27 - thanks for your contribution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[PyTorch Conversion] SmolVLM model fails due to unsupported 'unfold' op in Core ML

2 participants