diff --git a/docs/permutation-optimization.md b/docs/permutation-optimization.md index 2991349..40e8f13 100644 --- a/docs/permutation-optimization.md +++ b/docs/permutation-optimization.md @@ -67,82 +67,76 @@ directly instead of going through `strided_kernel::copy_into` (which fell back to the non-HPTT `map_into` path). This ensures the HPTT-optimized permutation is used when copying GEMM results back to non-contiguous destinations. -## Current Strategy: Still Lazy, but with Fast Permutation +## Current Strategy: Lazy Permutation + Source-Stride-Order Copy -strided-rs currently **does NOT** always materialize like OMEinsum.jl. Instead, -it keeps the lazy (metadata-only) permutation but ensures that when a subsequent -step needs to copy the scattered tensor, it uses the HPTT-optimized path. +strided-rs keeps the lazy (metadata-only) permutation but uses two optimized +copy strategies when a subsequent step needs contiguous data: -This is a pragmatic middle ground: +### 1. Source-stride-order copy in `prepare_input_owned` -- **No extra copy when not needed** — if the next step's canonical order aligns, - `try_fuse_group` succeeds and no copy occurs (truly zero cost) -- **Fast copy when needed** — HPTT permutation runs at ~25 GB/s instead of - 4 GB/s for the scattered case - -### When lazy permutation wins +Since einsum2 always produces col-major output, the source of +`prepare_input_owned` is physically contiguous in memory — only the +dims/strides metadata is permuted. The function `copy_strided_src_order` +iterates in **source-stride order** (smallest source stride innermost), giving +sequential reads that exploit the hardware prefetcher on cold-cache data. +Scattered writes are absorbed by hardware write-combining buffers. -- Few dimensions with large sizes (e.g., `[1000, 1000, 1000]`) — even scattered - reads have good cache line utilization -- Next step's access pattern aligns with current strides — no copy at all -- Final output — no subsequent step pays the deferred cost +This replaces HPTT for this specific path because HPTT iterates in +*destination*-stride order — sequential writes, scattered reads. For large +cold-cache data with many small dimensions (e.g. 24 binary dims), the scattered +reads dominate performance. Additionally, HPTT's bilateral fusion can only merge +consecutive dimensions; for 24 binary dims with scattered strides this leaves +~17 fused dims with a 2×2 inner tile and 15 recursion levels — high per-element +overhead. -### When eager materialization would win - -- Many small dimensions (e.g., 24 dims of size 2) where the scattered copy, - even with HPTT, is slower than two contiguous-to-contiguous copies -- Long chains of steps where the deferred cost propagates - -## Benchmark Results +With `--features parallel`, a rayon-parallelized variant +(`copy_strided_src_order_par`) splits the outer source-stride dimensions across +threads, with automatic fallback to single-threaded when `RAYON_NUM_THREADS=1` +or the tensor is small (< 1M elements). -`tensornetwork_permutation_light_415` (415 tensors, 24 binary dims, Apple M2): +### 2. HPTT for other copy paths -| Threads | strided-rs faer | OMEinsum.jl | Ratio | -|---------|----------------:|------------:|------:| -| 1T | 208 ms | 166 ms (IQR 83) | 1.25x | -| 4T | 142 ms | 172 ms (IQR 40) | **0.83x** | +The rest of the pipeline (`finalize_into`, `bgemm_faer` pack, `single_tensor`, +`operand`) still uses HPTT via `strided_kernel::copy_into` / +`strided_perm::copy_into`. These paths typically operate on warm-cache data or +have different stride patterns where HPTT's blocked approach remains effective. -strided-rs is now competitive (faster at 4T), with dramatically lower variance -(IQR < 4 ms vs 40-84 ms for OMEinsum.jl). +### Why not always-materialize? -## Open Questions +- **No extra copy when not needed** — if the next step's canonical order aligns, + `try_fuse_group` succeeds and no copy occurs (truly zero cost) +- **Source-order copy is fast enough** — sequential reads on contiguous source + achieve near-memcpy bandwidth -### Always-materialize as a future option +## Benchmark Results -Issue #109 proposed always materializing output permutations (matching -OMEinsum.jl). With HPTT-style permutation, the cost of eager materialization is -low (~7 ms for 16M elements). The remaining question is whether the benefit -outweighs the cost across all workload types: +`tensornetwork_permutation_light_415` (415 tensors, 24 binary dims, AMD EPYC +7713P): -- For tensor networks with many small dimensions: likely beneficial -- For workloads with large contiguous dimensions: likely wasteful -- A heuristic based on dimension count/sizes could be added +| Configuration | opt_flops (ms) | vs OMEinsum.jl (388 ms) | +|---------------|---------------:|------------------------:| +| Original (HPTT) 1T | 455 | 1.17x slower | +| Source-order copy 1T | 298 | **1.30x faster** | +| Source-order copy + parallel 4T | 228 | **1.70x faster** | -### Two-stage permutation +The source-order copy alone yields a 34% improvement over HPTT. Adding +parallel copy with 4 threads provides a further 24% speedup. -When a lazy permutation is followed by another permutation (e.g., from the next -step's input preparation), and the combined result is still non-contiguous, -strided-rs currently performs a single scattered-to-contiguous copy. An -alternative would be to split this into two stages: +## Open Questions -1. First: permute the lazy tensor to contiguous (HPTT, fast) -2. Second: permute the contiguous result to the target layout (HPTT, fast) +### Extending source-order copy to other paths -Two contiguous-to-contiguous permutations can be faster than one -scattered-to-contiguous permutation because sequential reads have full cache -line utilization and the hardware prefetcher works effectively. This is exactly -the pattern that makes OMEinsum.jl's eager strategy work well. +Currently only `prepare_input_owned` uses source-stride-order copy. Other copy +paths (`finalize_into`, etc.) could also benefit when source data is contiguous +but strides are scattered. This would require detecting contiguous-source +patterns at each call site. -This two-stage approach could be implemented as: -- Detect when the source has non-contiguous strides before calling - `copy_into_col_major` or `strided_perm::copy_into` -- If so, first materialize to a temporary contiguous buffer, then permute from - the temporary to the destination -- Only apply when the total element count exceeds a threshold (to avoid overhead - for small tensors) +### Thread scaling -This would give the benefits of eager materialization without changing the -overall lazy architecture. +With 4 threads the parallel copy helps significantly, but the improvement may +plateau with more threads as memory bandwidth saturates. Benchmarking with +higher thread counts on different architectures would clarify the scaling +characteristics. ## Related Issues and PRs diff --git a/strided-einsum2/Cargo.toml b/strided-einsum2/Cargo.toml index ab1cc92..059defb 100644 --- a/strided-einsum2/Cargo.toml +++ b/strided-einsum2/Cargo.toml @@ -20,9 +20,11 @@ faer-traits = { version = "0.24", optional = true } cblas-sys = { version = "0.2", optional = true } cblas-inject = { version = "0.1", optional = true } num-complex = { version = "0.4", optional = true } +rayon = { version = "1.10", optional = true } [features] default = ["faer", "faer-traits"] +parallel = ["rayon"] blas = ["dep:cblas-sys", "dep:num-complex"] blas-inject = ["dep:cblas-inject", "dep:num-complex"] diff --git a/strided-einsum2/src/contiguous.rs b/strided-einsum2/src/contiguous.rs index e8a3c3e..99cb694 100644 --- a/strided-einsum2/src/contiguous.rs +++ b/strided-einsum2/src/contiguous.rs @@ -12,6 +12,9 @@ use std::collections::HashMap; use strided_perm::try_fuse_group; use strided_view::{StridedArray, StridedView, StridedViewMut}; +#[cfg(feature = "parallel")] +use rayon::iter::{IntoParallelIterator, ParallelIterator}; + /// GEMM-ready input operand with contiguous data. pub struct ContiguousOperand { ptr: *const T, @@ -370,6 +373,173 @@ pub fn prepare_input_view( } } +/// Copy elements from `src` to `dst`, iterating in source-stride order. +/// +/// Dimensions are traversed from smallest to largest source stride, giving +/// sequential (or near-sequential) reads. Writes to the destination may be +/// scattered, but hardware write-combining buffers absorb much of the cost. +/// +/// # Why not HPTT (`strided_kernel::copy_into_col_major`)? +/// +/// HPTT iterates in *destination*-stride order (optimized for sequential +/// writes). This is ideal when the source data is warm in cache, but in +/// `prepare_input_owned` the source is usually a large intermediate whose +/// L3 cache lines have been evicted by subsequent contraction steps. +/// With cold-cache source and many small non-contiguous dimensions (e.g. +/// 24 binary dims of size 2 after a metadata-only permutation), HPTT's +/// bilateral fusion leaves ~17 fused dims with a 2×2 inner tile and 15 +/// levels of recursion per 4 elements — both cache-unfriendly reads AND +/// high per-element overhead. +/// +/// Source-stride-order iteration gives sequential reads that exploit the +/// hardware prefetcher, which dominates performance on cold-cache, +/// memory-bandwidth-bound copies. +unsafe fn copy_strided_src_order( + src_ptr: *const T, + dst_ptr: *mut T, + dims: &[usize], + src_strides: &[isize], + dst_strides: &[isize], +) { + let ndim = dims.len(); + let n: usize = dims.iter().product(); + if n == 0 { + return; + } + + // Sort dimensions by source stride (ascending) → innermost = smallest src stride + let mut dim_order: Vec = (0..ndim).filter(|&i| dims[i] > 1).collect(); + dim_order.sort_by_key(|&i| src_strides[i].unsigned_abs()); + + let sorted_dims: Vec = dim_order.iter().map(|&i| dims[i]).collect(); + let sorted_src: Vec = dim_order.iter().map(|&i| src_strides[i]).collect(); + let sorted_dst: Vec = dim_order.iter().map(|&i| dst_strides[i]).collect(); + let nd = sorted_dims.len(); + + let mut idx = vec![0usize; nd]; + let mut so: isize = 0; + let mut do_: isize = 0; + + for _ in 0..n { + *dst_ptr.offset(do_) = *src_ptr.offset(so); + + for d in 0..nd { + idx[d] += 1; + if idx[d] < sorted_dims[d] { + so += sorted_src[d]; + do_ += sorted_dst[d]; + break; + } else { + so -= sorted_src[d] * (sorted_dims[d] as isize - 1); + do_ -= sorted_dst[d] * (sorted_dims[d] as isize - 1); + idx[d] = 0; + } + } + } +} + +/// Parallel version of [`copy_strided_src_order`]. +/// +/// Outer dimensions (by source stride) are split across rayon threads; each +/// thread runs a sequential odometer over the inner dimensions. +/// Falls back to the single-threaded version for small tensors. +#[cfg(feature = "parallel")] +unsafe fn copy_strided_src_order_par( + src_ptr: *const T, + dst_ptr: *mut T, + dims: &[usize], + src_strides: &[isize], + dst_strides: &[isize], +) { + let ndim = dims.len(); + let n: usize = dims.iter().product(); + if n == 0 { + return; + } + + // Fall back to sequential when parallelism would add overhead without gain. + const PAR_THRESHOLD: usize = 1 << 20; // 1M elements + if n < PAR_THRESHOLD || rayon::current_num_threads() <= 1 { + copy_strided_src_order(src_ptr, dst_ptr, dims, src_strides, dst_strides); + return; + } + + // Sort dimensions by source stride (ascending) → innermost = smallest src stride + let mut dim_order: Vec = (0..ndim).filter(|&i| dims[i] > 1).collect(); + dim_order.sort_by_key(|&i| src_strides[i].unsigned_abs()); + + let sorted_dims: Vec = dim_order.iter().map(|&i| dims[i]).collect(); + let sorted_src: Vec = dim_order.iter().map(|&i| src_strides[i]).collect(); + let sorted_dst: Vec = dim_order.iter().map(|&i| dst_strides[i]).collect(); + let nd = sorted_dims.len(); + + // Peel outer dims until we have enough parallel tasks (>= 4× threads). + let min_tasks = rayon::current_num_threads() * 4; + let mut split_at = nd; // index into sorted arrays: [0..split_at) inner, [split_at..nd) outer + let mut par_count: usize = 1; + while split_at > 0 && par_count < min_tasks { + split_at -= 1; + par_count *= sorted_dims[split_at]; + } + + let inner_n: usize = sorted_dims[..split_at].iter().product::().max(1); + + // Convert pointers to usize for Send (same pattern as strided-perm). + let src_addr = src_ptr as usize; + let dst_addr = dst_ptr as usize; + + let outer_dims = sorted_dims[split_at..].to_vec(); + let outer_src = sorted_src[split_at..].to_vec(); + let outer_dst = sorted_dst[split_at..].to_vec(); + let inner_dims = sorted_dims[..split_at].to_vec(); + let inner_src = sorted_src[..split_at].to_vec(); + let inner_dst = sorted_dst[..split_at].to_vec(); + + (0..par_count).into_par_iter().for_each(|outer_idx| { + // Compute base offsets from outer multi-index. + let mut src_off: isize = 0; + let mut dst_off: isize = 0; + let mut rem = outer_idx; + for d in 0..outer_dims.len() { + let i = rem % outer_dims[d]; + rem /= outer_dims[d]; + src_off += i as isize * outer_src[d]; + dst_off += i as isize * outer_dst[d]; + } + + let sp = (src_addr as isize + src_off * std::mem::size_of::() as isize) as *const T; + let dp = (dst_addr as isize + dst_off * std::mem::size_of::() as isize) as *mut T; + + if split_at == 0 { + // No inner dims — single element per task. + unsafe { *dp = *sp }; + return; + } + + // Sequential odometer over inner dims. + let mut idx = vec![0usize; split_at]; + let mut so: isize = 0; + let mut do_: isize = 0; + + for _ in 0..inner_n { + unsafe { *dp.offset(do_) = *sp.offset(so) }; + + for d in 0..split_at { + idx[d] += 1; + if idx[d] < inner_dims[d] { + so += inner_src[d]; + do_ += inner_dst[d]; + break; + } else { + so -= inner_src[d] * (inner_dims[d] as isize - 1); + do_ -= inner_dst[d] * (inner_dims[d] as isize - 1); + idx[d] = 0; + } + } + } + }); +} + /// Prepare an owned input array for GEMM. /// /// Expects batch-last canonical order: `[group1..., group2..., batch...]`. @@ -433,7 +603,37 @@ pub fn prepare_input_owned( if needs_copy { let m: usize = group1_dims.iter().product::().max(1); let (mut buf, buf_is_pooled) = alloc_col_major_uninit_with_pool(&dims); - strided_kernel::copy_into_col_major(&mut buf.view_mut(), &arr.view())?; + // Use source-stride-order copy instead of HPTT (strided_kernel::copy_into_col_major). + // einsum2 always produces col-major output and only metadata permutations + // are applied between steps, so the source is physically contiguous but has + // scattered strides. HPTT iterates in destination order → scattered reads + // from cold L3 cache. Source-order iteration gives sequential reads that + // exploit the hardware prefetcher. See doc comment on copy_strided_src_order. + { + let dst_strides = buf.strides().to_vec(); + unsafe { + #[cfg(feature = "parallel")] + { + copy_strided_src_order_par( + arr.view().ptr(), + buf.view_mut().as_mut_ptr(), + &dims, + &strides, + &dst_strides, + ); + } + #[cfg(not(feature = "parallel"))] + { + copy_strided_src_order( + arr.view().ptr(), + buf.view_mut().as_mut_ptr(), + &dims, + &strides, + &dst_strides, + ); + } + } + } let ptr = buf.view().ptr(); let batch_strides = buf.strides()[n_inner..].to_vec(); let row_stride = if m == 0 { 0 } else { 1isize }; diff --git a/strided-opteinsum/Cargo.toml b/strided-opteinsum/Cargo.toml index d6a84f5..70cf2c5 100644 --- a/strided-opteinsum/Cargo.toml +++ b/strided-opteinsum/Cargo.toml @@ -15,6 +15,7 @@ thiserror = "1.0" [features] default = ["faer"] faer = ["strided-einsum2/faer", "strided-einsum2/faer-traits"] +parallel = ["strided-einsum2/parallel"] blas = ["strided-einsum2/blas"] blas-inject = ["strided-einsum2/blas-inject"]