-
Notifications
You must be signed in to change notification settings - Fork 0
Description
Summary
OMEinsum.jl's key performance trick: before GEMM, fuse all dimensions within each group (lo/sum/ro/batch) into a single dimension, then permute the fused tensor (rank 3-4) instead of the original (rank 24). This eliminates most permutation overhead.
The trick (from OMEinsum.jl analyze_binary + tensorpermute!)
OMEinsum's approach for a binary contraction:
- Classify indices into 4 groups:
lo(left-output),sum(contracted),ro(right-output),batch - Determine which groups are contiguous in the current memory layout
- Group contiguous dims in the permutation:
tensorpermute!detects consecutive dims in the permutation order and fuses them viareshapebefore callingpermutedims! - Fuse groups after permutation: reshape to
(m, k, nb)for A and(k, n, nb)for B — always rank 3 - GEMM on the fused rank-3 tensors
Step 3 is critical. From OMEinsum's tensorpermute!:
for i = 2:N
if perm[i] == perm[i-1] + 1 # consecutive dims in permutation
newshape[perm[i-1]] *= size(A, perm[i]) # fuse into previous group
else
push!(newperm, perm[i]) # new group
end
end
A_ = reshape(A, newshape...) # zero-cost reshape
permutedims!(C_, A_, newperm) # permute at reduced rankImpact: step 408 example
For tensor B with 24 binary dims (all size 2, col-major) and natural labels:
- sum dims at positions 1-8: strides [1,2,4,...,128] → contiguous → fuse to size 256
- ro dims at positions 9-21: strides [256,512,...] → contiguous → fuse to size 8192
- batch dims at positions 22-24: strides [...] → contiguous → fuse to size 8
- After fusion: rank 3, shape [256, 8192, 8], already in [sum, ro, batch] order → NO permutation needed
Result: OMEinsum achieves 19 ms (natural labels) vs Rust's 300 ms (scrambled labels) for the same contraction.
Even with scrambled labels, OMEinsum takes 108 ms because tensorpermute! still reduces the effective rank before calling permutedims!.
Current behavior in strided-rs
einsum2_into_owned does:
a.permuted(&left_perm)— metadata-only reorder (scrambles strides)prepare_input_owned→try_fuse_groupon the permuted dims → fails (strides scattered) → full copy
The issue: permutation is applied BEFORE fusion. After permuting, contiguous-in-memory dims are no longer adjacent, so fusion fails.
Proposed fix: fuse BEFORE permute (OMEinsum approach)
Adopt OMEinsum's tensorpermute! strategy in einsum2:
- Compute the target permutation as today (
left_perm,right_perm) - Before applying the permutation, scan for consecutive indices in the permutation that correspond to contiguous dims in memory
- Fuse those groups via metadata-only reshape (zero cost)
- Compute reduced-rank permutation on the fused tensor
- Apply permutation on the fused tensor (lower rank → much cheaper
copy_into) - Reshape to final GEMM shape
This is the highest-impact optimization because it addresses the root cause: rank-24 permutations with scattered strides are inherently expensive, but most of those 24 dims can be fused.
Reproduction
# Shows OMEinsum natural vs scrambled labels
julia --project=. micro_bench/step408_fair.jlContext
This is the most impactful of the three related issues. With effective dim fusion, the permutation cost drops from 236ms to near-zero, and GEMM operates on rank-3 tensors with simple layouts.
Related: #114 (copy_into base performance), #115 (GEMM dispatch overhead).