Skip to content

Adopt OMEinsum.jl's pre-permutation dimension fusion in einsum2 #116

@shinaoka

Description

@shinaoka

Summary

OMEinsum.jl's key performance trick: before GEMM, fuse all dimensions within each group (lo/sum/ro/batch) into a single dimension, then permute the fused tensor (rank 3-4) instead of the original (rank 24). This eliminates most permutation overhead.

The trick (from OMEinsum.jl analyze_binary + tensorpermute!)

OMEinsum's approach for a binary contraction:

  1. Classify indices into 4 groups: lo (left-output), sum (contracted), ro (right-output), batch
  2. Determine which groups are contiguous in the current memory layout
  3. Group contiguous dims in the permutation: tensorpermute! detects consecutive dims in the permutation order and fuses them via reshape before calling permutedims!
  4. Fuse groups after permutation: reshape to (m, k, nb) for A and (k, n, nb) for B — always rank 3
  5. GEMM on the fused rank-3 tensors

Step 3 is critical. From OMEinsum's tensorpermute!:

for i = 2:N
    if perm[i] == perm[i-1] + 1  # consecutive dims in permutation
        newshape[perm[i-1]] *= size(A, perm[i])  # fuse into previous group
    else
        push!(newperm, perm[i])  # new group
    end
end
A_ = reshape(A, newshape...)       # zero-cost reshape
permutedims!(C_, A_, newperm)      # permute at reduced rank

Impact: step 408 example

For tensor B with 24 binary dims (all size 2, col-major) and natural labels:

  • sum dims at positions 1-8: strides [1,2,4,...,128] → contiguous → fuse to size 256
  • ro dims at positions 9-21: strides [256,512,...] → contiguous → fuse to size 8192
  • batch dims at positions 22-24: strides [...] → contiguous → fuse to size 8
  • After fusion: rank 3, shape [256, 8192, 8], already in [sum, ro, batch] order → NO permutation needed

Result: OMEinsum achieves 19 ms (natural labels) vs Rust's 300 ms (scrambled labels) for the same contraction.

Even with scrambled labels, OMEinsum takes 108 ms because tensorpermute! still reduces the effective rank before calling permutedims!.

Current behavior in strided-rs

einsum2_into_owned does:

  1. a.permuted(&left_perm) — metadata-only reorder (scrambles strides)
  2. prepare_input_ownedtry_fuse_group on the permuted dims → fails (strides scattered) → full copy

The issue: permutation is applied BEFORE fusion. After permuting, contiguous-in-memory dims are no longer adjacent, so fusion fails.

Proposed fix: fuse BEFORE permute (OMEinsum approach)

Adopt OMEinsum's tensorpermute! strategy in einsum2:

  1. Compute the target permutation as today (left_perm, right_perm)
  2. Before applying the permutation, scan for consecutive indices in the permutation that correspond to contiguous dims in memory
  3. Fuse those groups via metadata-only reshape (zero cost)
  4. Compute reduced-rank permutation on the fused tensor
  5. Apply permutation on the fused tensor (lower rank → much cheaper copy_into)
  6. Reshape to final GEMM shape

This is the highest-impact optimization because it addresses the root cause: rank-24 permutations with scattered strides are inherently expensive, but most of those 24 dims can be fused.

Reproduction

# Shows OMEinsum natural vs scrambled labels
julia --project=. micro_bench/step408_fair.jl

Context

This is the most impactful of the three related issues. With effective dim fusion, the permutation cost drops from 236ms to near-zero, and GEMM operates on rank-3 tensors with simple layouts.

Related: #114 (copy_into base performance), #115 (GEMM dispatch overhead).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions