Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 52 additions & 58 deletions docs/permutation-optimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,82 +67,76 @@ directly instead of going through `strided_kernel::copy_into` (which fell back
to the non-HPTT `map_into` path). This ensures the HPTT-optimized permutation
is used when copying GEMM results back to non-contiguous destinations.

## Current Strategy: Still Lazy, but with Fast Permutation
## Current Strategy: Lazy Permutation + Source-Stride-Order Copy

strided-rs currently **does NOT** always materialize like OMEinsum.jl. Instead,
it keeps the lazy (metadata-only) permutation but ensures that when a subsequent
step needs to copy the scattered tensor, it uses the HPTT-optimized path.
strided-rs keeps the lazy (metadata-only) permutation but uses two optimized
copy strategies when a subsequent step needs contiguous data:

This is a pragmatic middle ground:
### 1. Source-stride-order copy in `prepare_input_owned`

- **No extra copy when not needed** — if the next step's canonical order aligns,
`try_fuse_group` succeeds and no copy occurs (truly zero cost)
- **Fast copy when needed** — HPTT permutation runs at ~25 GB/s instead of
4 GB/s for the scattered case

### When lazy permutation wins
Since einsum2 always produces col-major output, the source of
`prepare_input_owned` is physically contiguous in memory — only the
dims/strides metadata is permuted. The function `copy_strided_src_order`
iterates in **source-stride order** (smallest source stride innermost), giving
sequential reads that exploit the hardware prefetcher on cold-cache data.
Scattered writes are absorbed by hardware write-combining buffers.

- Few dimensions with large sizes (e.g., `[1000, 1000, 1000]`) — even scattered
reads have good cache line utilization
- Next step's access pattern aligns with current strides — no copy at all
- Final output — no subsequent step pays the deferred cost
This replaces HPTT for this specific path because HPTT iterates in
*destination*-stride order — sequential writes, scattered reads. For large
cold-cache data with many small dimensions (e.g. 24 binary dims), the scattered
reads dominate performance. Additionally, HPTT's bilateral fusion can only merge
consecutive dimensions; for 24 binary dims with scattered strides this leaves
~17 fused dims with a 2×2 inner tile and 15 recursion levels — high per-element
overhead.

### When eager materialization would win

- Many small dimensions (e.g., 24 dims of size 2) where the scattered copy,
even with HPTT, is slower than two contiguous-to-contiguous copies
- Long chains of steps where the deferred cost propagates

## Benchmark Results
With `--features parallel`, a rayon-parallelized variant
(`copy_strided_src_order_par`) splits the outer source-stride dimensions across
threads, with automatic fallback to single-threaded when `RAYON_NUM_THREADS=1`
or the tensor is small (< 1M elements).

`tensornetwork_permutation_light_415` (415 tensors, 24 binary dims, Apple M2):
### 2. HPTT for other copy paths

| Threads | strided-rs faer | OMEinsum.jl | Ratio |
|---------|----------------:|------------:|------:|
| 1T | 208 ms | 166 ms (IQR 83) | 1.25x |
| 4T | 142 ms | 172 ms (IQR 40) | **0.83x** |
The rest of the pipeline (`finalize_into`, `bgemm_faer` pack, `single_tensor`,
`operand`) still uses HPTT via `strided_kernel::copy_into` /
`strided_perm::copy_into`. These paths typically operate on warm-cache data or
have different stride patterns where HPTT's blocked approach remains effective.

strided-rs is now competitive (faster at 4T), with dramatically lower variance
(IQR < 4 ms vs 40-84 ms for OMEinsum.jl).
### Why not always-materialize?

## Open Questions
- **No extra copy when not needed** — if the next step's canonical order aligns,
`try_fuse_group` succeeds and no copy occurs (truly zero cost)
- **Source-order copy is fast enough** — sequential reads on contiguous source
achieve near-memcpy bandwidth

### Always-materialize as a future option
## Benchmark Results

Issue #109 proposed always materializing output permutations (matching
OMEinsum.jl). With HPTT-style permutation, the cost of eager materialization is
low (~7 ms for 16M elements). The remaining question is whether the benefit
outweighs the cost across all workload types:
`tensornetwork_permutation_light_415` (415 tensors, 24 binary dims, AMD EPYC
7713P):

- For tensor networks with many small dimensions: likely beneficial
- For workloads with large contiguous dimensions: likely wasteful
- A heuristic based on dimension count/sizes could be added
| Configuration | opt_flops (ms) | vs OMEinsum.jl (388 ms) |
|---------------|---------------:|------------------------:|
| Original (HPTT) 1T | 455 | 1.17x slower |
| Source-order copy 1T | 298 | **1.30x faster** |
| Source-order copy + parallel 4T | 228 | **1.70x faster** |

### Two-stage permutation
The source-order copy alone yields a 34% improvement over HPTT. Adding
parallel copy with 4 threads provides a further 24% speedup.

When a lazy permutation is followed by another permutation (e.g., from the next
step's input preparation), and the combined result is still non-contiguous,
strided-rs currently performs a single scattered-to-contiguous copy. An
alternative would be to split this into two stages:
## Open Questions

1. First: permute the lazy tensor to contiguous (HPTT, fast)
2. Second: permute the contiguous result to the target layout (HPTT, fast)
### Extending source-order copy to other paths

Two contiguous-to-contiguous permutations can be faster than one
scattered-to-contiguous permutation because sequential reads have full cache
line utilization and the hardware prefetcher works effectively. This is exactly
the pattern that makes OMEinsum.jl's eager strategy work well.
Currently only `prepare_input_owned` uses source-stride-order copy. Other copy
paths (`finalize_into`, etc.) could also benefit when source data is contiguous
but strides are scattered. This would require detecting contiguous-source
patterns at each call site.

This two-stage approach could be implemented as:
- Detect when the source has non-contiguous strides before calling
`copy_into_col_major` or `strided_perm::copy_into`
- If so, first materialize to a temporary contiguous buffer, then permute from
the temporary to the destination
- Only apply when the total element count exceeds a threshold (to avoid overhead
for small tensors)
### Thread scaling

This would give the benefits of eager materialization without changing the
overall lazy architecture.
With 4 threads the parallel copy helps significantly, but the improvement may
plateau with more threads as memory bandwidth saturates. Benchmarking with
higher thread counts on different architectures would clarify the scaling
characteristics.

## Related Issues and PRs

Expand Down
2 changes: 2 additions & 0 deletions strided-einsum2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ faer-traits = { version = "0.24", optional = true }
cblas-sys = { version = "0.2", optional = true }
cblas-inject = { version = "0.1", optional = true }
num-complex = { version = "0.4", optional = true }
rayon = { version = "1.10", optional = true }

[features]
default = ["faer", "faer-traits"]
parallel = ["rayon"]
blas = ["dep:cblas-sys", "dep:num-complex"]
blas-inject = ["dep:cblas-inject", "dep:num-complex"]

Expand Down
202 changes: 201 additions & 1 deletion strided-einsum2/src/contiguous.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ use std::collections::HashMap;
use strided_perm::try_fuse_group;
use strided_view::{StridedArray, StridedView, StridedViewMut};

#[cfg(feature = "parallel")]
use rayon::iter::{IntoParallelIterator, ParallelIterator};

/// GEMM-ready input operand with contiguous data.
pub struct ContiguousOperand<T: Copy + 'static> {
ptr: *const T,
Expand Down Expand Up @@ -370,6 +373,173 @@ pub fn prepare_input_view<T: Scalar + 'static>(
}
}

/// Copy elements from `src` to `dst`, iterating in source-stride order.
///
/// Dimensions are traversed from smallest to largest source stride, giving
/// sequential (or near-sequential) reads. Writes to the destination may be
/// scattered, but hardware write-combining buffers absorb much of the cost.
///
/// # Why not HPTT (`strided_kernel::copy_into_col_major`)?
///
/// HPTT iterates in *destination*-stride order (optimized for sequential
/// writes). This is ideal when the source data is warm in cache, but in
/// `prepare_input_owned` the source is usually a large intermediate whose
/// L3 cache lines have been evicted by subsequent contraction steps.
/// With cold-cache source and many small non-contiguous dimensions (e.g.
/// 24 binary dims of size 2 after a metadata-only permutation), HPTT's
/// bilateral fusion leaves ~17 fused dims with a 2×2 inner tile and 15
/// levels of recursion per 4 elements — both cache-unfriendly reads AND
/// high per-element overhead.
///
/// Source-stride-order iteration gives sequential reads that exploit the
/// hardware prefetcher, which dominates performance on cold-cache,
/// memory-bandwidth-bound copies.
unsafe fn copy_strided_src_order<T: Copy>(
src_ptr: *const T,
dst_ptr: *mut T,
dims: &[usize],
src_strides: &[isize],
dst_strides: &[isize],
) {
let ndim = dims.len();
let n: usize = dims.iter().product();
if n == 0 {
return;
}

// Sort dimensions by source stride (ascending) → innermost = smallest src stride
let mut dim_order: Vec<usize> = (0..ndim).filter(|&i| dims[i] > 1).collect();
dim_order.sort_by_key(|&i| src_strides[i].unsigned_abs());

let sorted_dims: Vec<usize> = dim_order.iter().map(|&i| dims[i]).collect();
let sorted_src: Vec<isize> = dim_order.iter().map(|&i| src_strides[i]).collect();
let sorted_dst: Vec<isize> = dim_order.iter().map(|&i| dst_strides[i]).collect();
let nd = sorted_dims.len();

let mut idx = vec![0usize; nd];
let mut so: isize = 0;
let mut do_: isize = 0;

for _ in 0..n {
*dst_ptr.offset(do_) = *src_ptr.offset(so);

for d in 0..nd {
idx[d] += 1;
if idx[d] < sorted_dims[d] {
so += sorted_src[d];
do_ += sorted_dst[d];
break;
} else {
so -= sorted_src[d] * (sorted_dims[d] as isize - 1);
do_ -= sorted_dst[d] * (sorted_dims[d] as isize - 1);
idx[d] = 0;
}
}
}
}

/// Parallel version of [`copy_strided_src_order`].
///
/// Outer dimensions (by source stride) are split across rayon threads; each
/// thread runs a sequential odometer over the inner dimensions.
/// Falls back to the single-threaded version for small tensors.
#[cfg(feature = "parallel")]
unsafe fn copy_strided_src_order_par<T: Copy + Send + Sync>(
src_ptr: *const T,
dst_ptr: *mut T,
dims: &[usize],
src_strides: &[isize],
dst_strides: &[isize],
) {
let ndim = dims.len();
let n: usize = dims.iter().product();
if n == 0 {
return;
}

// Fall back to sequential when parallelism would add overhead without gain.
const PAR_THRESHOLD: usize = 1 << 20; // 1M elements
if n < PAR_THRESHOLD || rayon::current_num_threads() <= 1 {
copy_strided_src_order(src_ptr, dst_ptr, dims, src_strides, dst_strides);
return;
}

// Sort dimensions by source stride (ascending) → innermost = smallest src stride
let mut dim_order: Vec<usize> = (0..ndim).filter(|&i| dims[i] > 1).collect();
dim_order.sort_by_key(|&i| src_strides[i].unsigned_abs());

let sorted_dims: Vec<usize> = dim_order.iter().map(|&i| dims[i]).collect();
let sorted_src: Vec<isize> = dim_order.iter().map(|&i| src_strides[i]).collect();
let sorted_dst: Vec<isize> = dim_order.iter().map(|&i| dst_strides[i]).collect();
let nd = sorted_dims.len();

// Peel outer dims until we have enough parallel tasks (>= 4× threads).
let min_tasks = rayon::current_num_threads() * 4;
let mut split_at = nd; // index into sorted arrays: [0..split_at) inner, [split_at..nd) outer
let mut par_count: usize = 1;
while split_at > 0 && par_count < min_tasks {
split_at -= 1;
par_count *= sorted_dims[split_at];
}

let inner_n: usize = sorted_dims[..split_at].iter().product::<usize>().max(1);

// Convert pointers to usize for Send (same pattern as strided-perm).
let src_addr = src_ptr as usize;
let dst_addr = dst_ptr as usize;

let outer_dims = sorted_dims[split_at..].to_vec();
let outer_src = sorted_src[split_at..].to_vec();
let outer_dst = sorted_dst[split_at..].to_vec();
let inner_dims = sorted_dims[..split_at].to_vec();
let inner_src = sorted_src[..split_at].to_vec();
let inner_dst = sorted_dst[..split_at].to_vec();

(0..par_count).into_par_iter().for_each(|outer_idx| {
// Compute base offsets from outer multi-index.
let mut src_off: isize = 0;
let mut dst_off: isize = 0;
let mut rem = outer_idx;
for d in 0..outer_dims.len() {
let i = rem % outer_dims[d];
rem /= outer_dims[d];
src_off += i as isize * outer_src[d];
dst_off += i as isize * outer_dst[d];
}

let sp = (src_addr as isize + src_off * std::mem::size_of::<T>() as isize) as *const T;
let dp = (dst_addr as isize + dst_off * std::mem::size_of::<T>() as isize) as *mut T;

if split_at == 0 {
// No inner dims — single element per task.
unsafe { *dp = *sp };
return;
}

// Sequential odometer over inner dims.
let mut idx = vec![0usize; split_at];
let mut so: isize = 0;
let mut do_: isize = 0;

for _ in 0..inner_n {
unsafe { *dp.offset(do_) = *sp.offset(so) };

for d in 0..split_at {
idx[d] += 1;
if idx[d] < inner_dims[d] {
so += inner_src[d];
do_ += inner_dst[d];
break;
} else {
so -= inner_src[d] * (inner_dims[d] as isize - 1);
do_ -= inner_dst[d] * (inner_dims[d] as isize - 1);
idx[d] = 0;
}
}
}
});
}

/// Prepare an owned input array for GEMM.
///
/// Expects batch-last canonical order: `[group1..., group2..., batch...]`.
Expand Down Expand Up @@ -433,7 +603,37 @@ pub fn prepare_input_owned<T: Scalar + 'static>(
if needs_copy {
let m: usize = group1_dims.iter().product::<usize>().max(1);
let (mut buf, buf_is_pooled) = alloc_col_major_uninit_with_pool(&dims);
strided_kernel::copy_into_col_major(&mut buf.view_mut(), &arr.view())?;
// Use source-stride-order copy instead of HPTT (strided_kernel::copy_into_col_major).
// einsum2 always produces col-major output and only metadata permutations
// are applied between steps, so the source is physically contiguous but has
// scattered strides. HPTT iterates in destination order → scattered reads
// from cold L3 cache. Source-order iteration gives sequential reads that
// exploit the hardware prefetcher. See doc comment on copy_strided_src_order.
{
let dst_strides = buf.strides().to_vec();
unsafe {
#[cfg(feature = "parallel")]
{
copy_strided_src_order_par(
arr.view().ptr(),
buf.view_mut().as_mut_ptr(),
&dims,
&strides,
&dst_strides,
);
}
#[cfg(not(feature = "parallel"))]
{
copy_strided_src_order(
arr.view().ptr(),
buf.view_mut().as_mut_ptr(),
&dims,
&strides,
&dst_strides,
);
}
}
}
let ptr = buf.view().ptr();
let batch_strides = buf.strides()[n_inner..].to_vec();
let row_stride = if m == 0 { 0 } else { 1isize };
Expand Down
1 change: 1 addition & 0 deletions strided-opteinsum/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ thiserror = "1.0"
[features]
default = ["faer"]
faer = ["strided-einsum2/faer", "strided-einsum2/faer-traits"]
parallel = ["strided-einsum2/parallel"]
blas = ["strided-einsum2/blas"]
blas-inject = ["strided-einsum2/blas-inject"]

Expand Down