From 81e9bb96c17596f15d07d63acd432c180a36de8d Mon Sep 17 00:00:00 2001 From: Prashant Rawat Date: Thu, 14 May 2026 15:31:09 -0700 Subject: [PATCH] step==1 contiguous fast path in portable compute_slice Summary: When step == 1 (the common case: tensor.narrow, x[a:b], KV cache reads, etc.), the per-row slice is a single contiguous block of length*length_per_step bytes. Replace the inner loop of length separate memcpy(length_per_step) calls with a single bulk memcpy. For length=1 slices: equivalent (1 memcpy either way). For length>1: ~2-10x speedup of the slice itself (fewer function calls, better cache prefetch, SIMD-friendly bulk copy). Llama4 speech encoder hot-path: mask.narrow (12x/chunk), freqs_cos/sin.narrow (2x/chunk), KV reads (~5x/layer/chunk). 62 slice_copy/chunk * 72 chunks = ~4500 slices per audio prefill. Differential Revision: D105241644 --- kernels/portable/cpu/util/slice_util.cpp | 64 +++++++++++++++++------- 1 file changed, 46 insertions(+), 18 deletions(-) diff --git a/kernels/portable/cpu/util/slice_util.cpp b/kernels/portable/cpu/util/slice_util.cpp index 9afd7b7f6f0..d298297b393 100644 --- a/kernels/portable/cpu/util/slice_util.cpp +++ b/kernels/portable/cpu/util/slice_util.cpp @@ -214,32 +214,60 @@ void compute_slice( const bool use_multithreading = leading_dims >= MIN_LEADING_DIMS_FOR_MT && total_elements >= MIN_ELEMENTS_FOR_MT; + // Contiguous fast path for step == 1 (the common case: + // tensor.narrow, x[a:b], KV cache reads, etc.). When step == 1, the per-row + // slice is one contiguous block of `length * length_per_step` bytes — replace + // `length` calls of memcpy(length_per_step) with a single bulk memcpy. + const bool step_is_one = (step == 1); + const size_t row_bytes = static_cast(length) * length_per_step; if (use_multithreading) { // Use parallel_for to distribute work across leading dimensions // Calculate grain size based on number of elements per leading dimension const int64_t grain_size = MIN_LEADING_DIMS_FOR_MT; - executorch::extension::parallel_for( - 0, leading_dims, grain_size, [&](const auto begin, const auto end) { - for (const auto i : c10::irange(begin, end)) { - const char* src = - input_data + (i * dim_length + start) * length_per_step; - char* local_dest = dest + i * length * length_per_step; - for ([[maybe_unused]] const auto j : c10::irange(length)) { - memcpy(local_dest, src, length_per_step); - src += step * length_per_step; - local_dest += length_per_step; + if (step_is_one) { + executorch::extension::parallel_for( + 0, leading_dims, grain_size, [&](const auto begin, const auto end) { + for (const auto i : c10::irange(begin, end)) { + const char* src = + input_data + (i * dim_length + start) * length_per_step; + char* local_dest = dest + i * row_bytes; + memcpy(local_dest, src, row_bytes); } - } - }); + }); + } else { + executorch::extension::parallel_for( + 0, leading_dims, grain_size, [&](const auto begin, const auto end) { + for (const auto i : c10::irange(begin, end)) { + const char* src = + input_data + (i * dim_length + start) * length_per_step; + char* local_dest = dest + i * row_bytes; + for ([[maybe_unused]] const auto j : c10::irange(length)) { + memcpy(local_dest, src, length_per_step); + src += step * length_per_step; + local_dest += length_per_step; + } + } + }); + } } else { // Single-threaded path for small workloads - for (const auto i : c10::irange(leading_dims)) { - const char* src = input_data + (i * dim_length + start) * length_per_step; - for ([[maybe_unused]] const auto j : c10::irange(length)) { - memcpy(dest, src, length_per_step); - src += step * length_per_step; - dest += length_per_step; + if (step_is_one) { + for (const auto i : c10::irange(leading_dims)) { + const char* src = + input_data + (i * dim_length + start) * length_per_step; + memcpy(dest, src, row_bytes); + dest += row_bytes; + } + } else { + for (const auto i : c10::irange(leading_dims)) { + const char* src = + input_data + (i * dim_length + start) * length_per_step; + for ([[maybe_unused]] const auto j : c10::irange(length)) { + memcpy(dest, src, length_per_step); + src += step * length_per_step; + dest += length_per_step; + } } } }