diff --git a/docs/superpowers/plans/2026-05-22-primitive-layer-deepening.md b/docs/superpowers/plans/2026-05-22-primitive-layer-deepening.md new file mode 100644 index 0000000..1a3336c --- /dev/null +++ b/docs/superpowers/plans/2026-05-22-primitive-layer-deepening.md @@ -0,0 +1,272 @@ +# Primitive Layer Deepening Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** 收敛 `online_softmax`、`matmul`、`tile_io` 的 shallow standalone primitive layer,把 host-side validation、kernel launch、type conversion 统一到更深的内部 module。 + +**Architecture:** 保持 public header 暂时稳定,不在这一轮同时引入 spec/API breakage。先把重复的 host-side validation、dynamic shared-memory launch、kernel launch error handling 折叠进一个内部 helper seam,再把 FP32/FP16 转换收敛到单一 type adapter policy;这样 wrapper 继续兼容,但 implementation 更深、locality 更好。 + +**Tech Stack:** CUDA C++17, CMake presets, GoogleTest + +--- + +## File Map + +- Create: `src/kernels/primitive_api_utils.cuh` +- Create: `src/kernels/impl/type_adapter.cuh` +- Modify: `src/kernels/online_softmax.cu` +- Modify: `src/kernels/matmul.cu` +- Modify: `src/kernels/tile_io.cu` +- Modify: `src/kernels/impl/tile_io.cuh` +- Modify: `src/forward/flash_attention_forward_typed.cu` +- Modify: `src/backward/flash_attention_backward_typed.cu` +- Test: `tests/unit/test_online_softmax.cu` +- Test: `tests/unit/test_matmul.cu` +- Test: `tests/unit/test_dtype.cu` + +### Task 1: Centralize primitive wrapper validation and launch helpers + +**Files:** +- Create: `src/kernels/primitive_api_utils.cuh` +- Modify: `src/kernels/online_softmax.cu` +- Modify: `src/kernels/matmul.cu` +- Modify: `src/kernels/tile_io.cu` + +- [ ] **Step 1: Write the failing test** + +Use existing characterization tests as the guard surface: + +```cpp +EXPECT_EQ(kernels::online_softmax_forward(nullptr, d_valid, d_valid, 4, 4, 2, stream), + FlashAttentionError::NULL_POINTER); +EXPECT_EQ(kernels::matmul_ABt<64, 64, 32>(nullptr, d_valid, d_valid, 1.0f, stream), + FlashAttentionError::NULL_POINTER); +EXPECT_EQ(kernels::load_tile<64, 64>(d_src, d_dst, -1, 0, 128, 128, 128, stream), + FlashAttentionError::INVALID_DIMENSION); +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: + +```bash +ctest --preset release --output-on-failure -R "OnlineSoftmaxTest|MatmulTest|TileIOTest" +``` + +Expected: 当前机器无 `nvcc`,会卡在 configure/build;在有 CUDA 环境时,这一步用于证明 refactor 前后行为回归可观测。 + +- [ ] **Step 3: Write minimal implementation** + +Create one internal helper seam: + +```cpp +inline FlashAttentionError validate_non_null(std::initializer_list ptrs); +inline FlashAttentionError validate_positive_dimensions(std::initializer_list dims); +inline FlashAttentionError validate_tile_window(int row_start, int col_start, + int max_rows, int max_cols, int stride); +template +inline FlashAttentionError prepare_kernel_launch(KernelFunc kernel, size_t smem_size); +inline FlashAttentionError finish_kernel_launch(); +``` + +Then replace file-local helpers like: + +```cpp +FlashAttentionError err = detail::validate_non_null({A, B, C}); +if (err != FlashAttentionError::SUCCESS) return err; +``` + +- [ ] **Step 4: Run test to verify it passes** + +Run: + +```bash +ctest --preset release --output-on-failure -R "OnlineSoftmaxTest|MatmulTest|TileIOTest" +``` + +Expected: 在有 CUDA 环境时,既有 null/dimension tests 继续 PASS。 + +- [ ] **Step 5: Commit** + +```bash +git add src/kernels/primitive_api_utils.cuh src/kernels/online_softmax.cu src/kernels/matmul.cu src/kernels/tile_io.cu +git commit -m "refactor(kernels): centralize primitive wrapper helpers" +``` + +### Task 2: Unify FP32/FP16 conversion policy + +**Files:** +- Create: `src/kernels/impl/type_adapter.cuh` +- Modify: `src/kernels/impl/tile_io.cuh` +- Modify: `src/forward/flash_attention_forward_typed.cu` +- Modify: `src/backward/flash_attention_backward_typed.cu` +- Test: `tests/unit/test_dtype.cu` + +- [ ] **Step 1: Write the failing test** + +Use the current FP16 characterization tests: + +```cpp +TEST(DTypeTest, FP16ForwardMatchesFP32); +TEST(DTypeTest, FP16BackwardMatchesFP32); +``` + +Keep the finite-gradient checks as explicit acceptance criteria: + +```cpp +EXPECT_TRUE(std::isfinite(__half2float(h_dQ[i]))); +EXPECT_TRUE(std::isfinite(__half2float(h_dK[i]))); +EXPECT_TRUE(std::isfinite(__half2float(h_dV[i]))); +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: + +```bash +ctest --preset release --output-on-failure -R DTypeTest +``` + +Expected: 当前机器依旧被 `nvcc` 阻塞;在有 CUDA 环境时,这一步用于确认 refactor 没把 FP16/FP32 路径改坏。 + +- [ ] **Step 3: Write minimal implementation** + +Create one trait: + +```cpp +template +struct TypeAdapter; + +template <> +struct TypeAdapter { + __device__ static float to_compute(float value) { return value; } + __device__ static float from_compute(float value) { return value; } +}; + +template <> +struct TypeAdapter { + __device__ static float to_compute(half value) { return __half2float(value); } + __device__ static half from_compute(float value) { return __float2half(value); } +}; +``` + +Then replace scattered conversions like: + +```cpp +sum += impl::TypeAdapter::to_compute(dO_row[d]) * + impl::TypeAdapter::to_compute(O_row[d]); +L_ptr[global_row] = + impl::TypeAdapter::from_compute(m_tile[row] + logf(l_tile[row])); +``` + +- [ ] **Step 4: Run test to verify it passes** + +Run: + +```bash +ctest --preset release --output-on-failure -R DTypeTest +``` + +Expected: 在有 CUDA 环境时,FP16/FP32 tests PASS。 + +- [ ] **Step 5: Commit** + +```bash +git add src/kernels/impl/type_adapter.cuh src/kernels/impl/tile_io.cuh src/forward/flash_attention_forward_typed.cu src/backward/flash_attention_backward_typed.cu +git commit -m "refactor(kernels): unify primitive type conversion" +``` + +### Task 3: Preserve and extend the regression surface + +**Files:** +- Modify: `tests/unit/test_online_softmax.cu` +- Modify: `tests/unit/test_matmul.cu` + +- [ ] **Step 1: Write the failing test** + +Keep the two key regressions: + +```cpp +TEST_F(OnlineSoftmaxTest, Forward_MultiBlockCrossWarpMatchesReference); +TEST_F(OnlineSoftmaxTest, FinalizeNullNormalizerReturnsError); +TEST_F(MatmulTest, ABt_HeadDim128_LargeTile); +``` + +- [ ] **Step 2: Run test to verify it fails** + +Run: + +```bash +ctest --preset release --output-on-failure -R "OnlineSoftmaxTest|MatmulTest" +``` + +Expected: 在有 CUDA 环境时,旧实现会在 cross-warp/multi-block correctness 和 large-smem launch 上失败。 + +- [ ] **Step 3: Write minimal implementation** + +No new production API. Just keep tests aligned with the bug fixes: + +```cpp +EXPECT_NEAR(h_output[i], h_output_expected[i], 1e-4f); +ASSERT_EQ(err, FlashAttentionError::SUCCESS); +``` + +- [ ] **Step 4: Run test to verify it passes** + +Run: + +```bash +ctest --preset release --output-on-failure -R "OnlineSoftmaxTest|MatmulTest" +``` + +Expected: 在有 CUDA 环境时,new regressions PASS。 + +- [ ] **Step 5: Commit** + +```bash +git add tests/unit/test_online_softmax.cu tests/unit/test_matmul.cu +git commit -m "test(kernels): lock down primitive wrapper regressions" +``` + +### Task 4: Final hygiene and verification + +**Files:** +- Modify: all touched files + +- [ ] **Step 1: Format modified files** + +Run: + +```bash +clang-format -i src/kernels/primitive_api_utils.cuh src/kernels/impl/type_adapter.cuh src/kernels/impl/tile_io.cuh src/kernels/online_softmax.cu src/kernels/matmul.cu src/kernels/tile_io.cu src/forward/flash_attention_forward_typed.cu src/backward/flash_attention_backward_typed.cu tests/unit/test_online_softmax.cu tests/unit/test_matmul.cu +``` + +- [ ] **Step 2: Run diff sanity checks** + +Run: + +```bash +git --no-pager diff --check +git --no-pager diff --stat +``` + +Expected: no whitespace errors; diff concentrated in primitive-layer files. + +- [ ] **Step 3: Run repo verification** + +Run: + +```bash +cmake --preset release +cmake --build --preset release +ctest --preset release --output-on-failure +``` + +Expected: 当前机器因缺 `nvcc` 仍会阻塞在 configure;有 CUDA 环境后,这三步必须全绿。 + +- [ ] **Step 4: Commit** + +```bash +git add src tests docs/superpowers/plans/2026-05-22-primitive-layer-deepening.md +git commit -m "refactor(kernels): deepen standalone primitive layer" +``` diff --git a/src/backward/flash_attention_backward_typed.cu b/src/backward/flash_attention_backward_typed.cu index 3102458..fa71c70 100644 --- a/src/backward/flash_attention_backward_typed.cu +++ b/src/backward/flash_attention_backward_typed.cu @@ -28,7 +28,8 @@ __global__ void __launch_bounds__(128) float sum = 0.0f; #pragma unroll for (int d = 0; d < HEAD_DIM; d++) { - sum += impl::to_float(dO_row[d]) * impl::to_float(O_row[d]); + sum += impl::TypeAdapter::to_compute(dO_row[d]) * + impl::TypeAdapter::to_compute(O_row[d]); } D[batch_head_idx * seq_len + row_idx] = sum; @@ -81,7 +82,9 @@ __global__ void __launch_bounds__(128) } for (int i = tid; i < BLOCK_M; i += num_threads) { int global_idx = q_start + i; - L_tile[i] = (global_idx < seq_len) ? impl::to_float(L_ptr[global_idx]) : 0.0f; + L_tile[i] = (global_idx < seq_len) + ? impl::TypeAdapter::to_compute(L_ptr[global_idx]) + : 0.0f; D_tile[i] = (global_idx < seq_len) ? D_ptr[global_idx] : 0.0f; } __syncthreads(); @@ -205,7 +208,9 @@ __global__ void __launch_bounds__(128) for (int i = tid; i < BLOCK_M; i += num_threads) { int global_idx = q_start + i; - L_tile[i] = (global_idx < seq_len) ? impl::to_float(L_ptr[global_idx]) : 0.0f; + L_tile[i] = (global_idx < seq_len) + ? impl::TypeAdapter::to_compute(L_ptr[global_idx]) + : 0.0f; D_tile[i] = (global_idx < seq_len) ? D_ptr[global_idx] : 0.0f; } __syncthreads(); diff --git a/src/forward/flash_attention_forward_typed.cu b/src/forward/flash_attention_forward_typed.cu index ac6c564..0761824 100644 --- a/src/forward/flash_attention_forward_typed.cu +++ b/src/forward/flash_attention_forward_typed.cu @@ -163,11 +163,12 @@ __global__ void __launch_bounds__(128) float l_inv = 1.0f / l_tile[row]; for (int d = 0; d < HEAD_DIM; d++) { O_ptr[global_row * HEAD_DIM + d] = - InputT(O_tile[row * HEAD_DIM + d] * l_inv); // Implicit float->half if needed + impl::TypeAdapter::from_compute(O_tile[row * HEAD_DIM + d] * l_inv); } // Store logsumexp for backward pass - L_ptr[global_row] = InputT(m_tile[row] + logf(l_tile[row])); + L_ptr[global_row] = + impl::TypeAdapter::from_compute(m_tile[row] + logf(l_tile[row])); } } diff --git a/src/kernels/impl/online_softmax.cuh b/src/kernels/impl/online_softmax.cuh index cb4067f..2a0994e 100644 --- a/src/kernels/impl/online_softmax.cuh +++ b/src/kernels/impl/online_softmax.cuh @@ -120,9 +120,12 @@ __device__ __forceinline__ float block_reduce_max(float val, float* shared) { for (int offset = 16; offset > 0; offset /= 2) { val = fmaxf(val, __shfl_xor_sync(0xffffffff, val, offset)); } + if (lane == 0) { + shared[0] = val; + } } - - return val; + __syncthreads(); + return shared[0]; } /// Block-level sum reduction using shared memory. @@ -147,9 +150,12 @@ __device__ __forceinline__ float block_reduce_sum(float val, float* shared) { for (int offset = 16; offset > 0; offset /= 2) { val += __shfl_xor_sync(0xffffffff, val, offset); } + if (lane == 0) { + shared[0] = val; + } } - - return val; + __syncthreads(); + return shared[0]; } } // namespace impl diff --git a/src/kernels/impl/tile_io.cuh b/src/kernels/impl/tile_io.cuh index 71ff0c0..d9ab4ad 100644 --- a/src/kernels/impl/tile_io.cuh +++ b/src/kernels/impl/tile_io.cuh @@ -4,24 +4,11 @@ #include #include +#include "type_adapter.cuh" + namespace cuflash { namespace impl { -// Type conversion utilities for unified FP32/FP16 kernels -__device__ __forceinline__ float to_float(float val) { - return val; -} -__device__ __forceinline__ float to_float(half val) { - return __half2float(val); -} - -__device__ __forceinline__ void store_float(float* ptr, float val) { - *ptr = val; -} -__device__ __forceinline__ void store_float(half* ptr, float val) { - *ptr = __float2half(val); -} - // Check alignment for float4 vectorized loads (16-byte alignment) __device__ __forceinline__ bool is_aligned_16(const void* ptr) { return (reinterpret_cast(ptr) & 0xF) == 0; @@ -137,16 +124,18 @@ __device__ __forceinline__ void load_tile_to_shared(const half* __restrict__ src if (global_row < max_rows && global_col + 1 < max_cols) { half2 val = *reinterpret_cast(&src[global_row * src_stride + global_col]); - dst[local_row * BLOCK_COLS + local_col] = __half2float(val.x); - dst[local_row * BLOCK_COLS + local_col + 1] = __half2float(val.y); + dst[local_row * BLOCK_COLS + local_col] = TypeAdapter::to_compute(val.x); + dst[local_row * BLOCK_COLS + local_col + 1] = + TypeAdapter::to_compute(val.y); } else if (global_row < max_rows) { dst[local_row * BLOCK_COLS + local_col] = - (global_col < max_cols) - ? __half2float(src[global_row * src_stride + global_col]) - : 0.0f; + (global_col < max_cols) ? TypeAdapter::to_compute( + src[global_row * src_stride + global_col]) + : 0.0f; dst[local_row * BLOCK_COLS + local_col + 1] = (global_col + 1 < max_cols) - ? __half2float(src[global_row * src_stride + global_col + 1]) + ? TypeAdapter::to_compute( + src[global_row * src_stride + global_col + 1]) : 0.0f; } else { dst[local_row * BLOCK_COLS + local_col] = 0.0f; @@ -162,7 +151,7 @@ __device__ __forceinline__ void load_tile_to_shared(const half* __restrict__ src if (global_row < max_rows && global_col < max_cols) { dst[local_row * BLOCK_COLS + local_col] = - __half2float(src[global_row * src_stride + global_col]); + TypeAdapter::to_compute(src[global_row * src_stride + global_col]); } else { dst[local_row * BLOCK_COLS + local_col] = 0.0f; } @@ -177,7 +166,7 @@ __device__ __forceinline__ void load_tile_to_shared(const half* __restrict__ src if (global_row < max_rows && global_col < max_cols) { dst[local_row * BLOCK_COLS + local_col] = - __half2float(src[global_row * src_stride + global_col]); + TypeAdapter::to_compute(src[global_row * src_stride + global_col]); } else { dst[local_row * BLOCK_COLS + local_col] = 0.0f; } @@ -285,16 +274,19 @@ __device__ __forceinline__ void store_tile_from_shared(const float* __restrict__ if (global_row < max_rows && global_col + 1 < max_cols) { half2 val; - val.x = __float2half(src[local_row * BLOCK_COLS + local_col]); - val.y = __float2half(src[local_row * BLOCK_COLS + local_col + 1]); + val.x = + TypeAdapter::from_compute(src[local_row * BLOCK_COLS + local_col]); + val.y = TypeAdapter::from_compute( + src[local_row * BLOCK_COLS + local_col + 1]); *reinterpret_cast(&dst[global_row * dst_stride + global_col]) = val; } else if (global_row < max_rows) { if (global_col < max_cols) - dst[global_row * dst_stride + global_col] = - __float2half(src[local_row * BLOCK_COLS + local_col]); + dst[global_row * dst_stride + global_col] = TypeAdapter::from_compute( + src[local_row * BLOCK_COLS + local_col]); if (global_col + 1 < max_cols) dst[global_row * dst_stride + global_col + 1] = - __float2half(src[local_row * BLOCK_COLS + local_col + 1]); + TypeAdapter::from_compute( + src[local_row * BLOCK_COLS + local_col + 1]); } } } else { @@ -306,7 +298,7 @@ __device__ __forceinline__ void store_tile_from_shared(const float* __restrict__ if (global_row < max_rows && global_col < max_cols) { dst[global_row * dst_stride + global_col] = - __float2half(src[local_row * BLOCK_COLS + local_col]); + TypeAdapter::from_compute(src[local_row * BLOCK_COLS + local_col]); } } } @@ -319,7 +311,7 @@ __device__ __forceinline__ void store_tile_from_shared(const float* __restrict__ if (global_row < max_rows && global_col < max_cols) { dst[global_row * dst_stride + global_col] = - __float2half(src[local_row * BLOCK_COLS + local_col]); + TypeAdapter::from_compute(src[local_row * BLOCK_COLS + local_col]); } } } diff --git a/src/kernels/impl/type_adapter.cuh b/src/kernels/impl/type_adapter.cuh new file mode 100644 index 0000000..2d10beb --- /dev/null +++ b/src/kernels/impl/type_adapter.cuh @@ -0,0 +1,24 @@ +#pragma once + +#include + +namespace cuflash { +namespace impl { + +template +struct TypeAdapter; + +template<> +struct TypeAdapter { + __device__ __forceinline__ static float to_compute(float value) { return value; } + __device__ __forceinline__ static float from_compute(float value) { return value; } +}; + +template<> +struct TypeAdapter { + __device__ __forceinline__ static float to_compute(half value) { return __half2float(value); } + __device__ __forceinline__ static half from_compute(float value) { return __float2half(value); } +}; + +} // namespace impl +} // namespace cuflash diff --git a/src/kernels/matmul.cu b/src/kernels/matmul.cu index 29224ef..6e95bfe 100644 --- a/src/kernels/matmul.cu +++ b/src/kernels/matmul.cu @@ -3,6 +3,7 @@ #include "cuflash/kernels/matmul.cuh" #include "impl/tile_io.cuh" +#include "primitive_api_utils.cuh" namespace cuflash { namespace kernels { @@ -149,88 +150,84 @@ __global__ void matmul_AtB_kernel(const float* __restrict__ A, const float* __re // Host Entry Points // ============================================================================= -// Validation helper -static FlashAttentionError validate_matmul_params(const float* A, const float* B, const float* C) { - if (!A || !B || !C) { - return FlashAttentionError::NULL_POINTER; - } - return FlashAttentionError::SUCCESS; -} - // C = A @ B^T template FlashAttentionError matmul_ABt(const float* A, const float* B, float* C, float scale, cudaStream_t stream) { - FlashAttentionError err = validate_matmul_params(A, B, C); + FlashAttentionError err = detail::validate_non_null({A, B, C}); if (err != FlashAttentionError::SUCCESS) { return err; } size_t smem_size = (M * K + N * K + M * N) * sizeof(float); + FlashAttentionError status = + detail::prepare_kernel_launch(matmul_ABt_kernel, smem_size); + if (status != FlashAttentionError::SUCCESS) { + return status; + } matmul_ABt_kernel<<<1, MATMUL_THREADS, smem_size, stream>>>(A, B, C, scale); - cudaError_t cuda_err = cudaGetLastError(); - if (cuda_err != cudaSuccess) { - return FlashAttentionError::CUDA_ERROR; - } - return FlashAttentionError::SUCCESS; + return detail::finish_kernel_launch(); } // C = A @ B template FlashAttentionError matmul_AB(const float* A, const float* B, float* C, float scale, cudaStream_t stream) { - FlashAttentionError err = validate_matmul_params(A, B, C); + FlashAttentionError err = detail::validate_non_null({A, B, C}); if (err != FlashAttentionError::SUCCESS) { return err; } size_t smem_size = (M * K + K * N + M * N) * sizeof(float); + FlashAttentionError status = + detail::prepare_kernel_launch(matmul_AB_kernel, smem_size); + if (status != FlashAttentionError::SUCCESS) { + return status; + } matmul_AB_kernel<<<1, MATMUL_THREADS, smem_size, stream>>>(A, B, C, scale); - cudaError_t cuda_err = cudaGetLastError(); - if (cuda_err != cudaSuccess) { - return FlashAttentionError::CUDA_ERROR; - } - return FlashAttentionError::SUCCESS; + return detail::finish_kernel_launch(); } // C += A @ B template FlashAttentionError matmul_AB_acc(const float* A, const float* B, float* C, float scale, cudaStream_t stream) { - FlashAttentionError err = validate_matmul_params(A, B, C); + FlashAttentionError err = detail::validate_non_null({A, B, C}); if (err != FlashAttentionError::SUCCESS) { return err; } size_t smem_size = (M * K + K * N + M * N) * sizeof(float); + FlashAttentionError status = + detail::prepare_kernel_launch(matmul_AB_acc_kernel, smem_size); + if (status != FlashAttentionError::SUCCESS) { + return status; + } matmul_AB_acc_kernel<<<1, MATMUL_THREADS, smem_size, stream>>>(A, B, C, scale); - cudaError_t cuda_err = cudaGetLastError(); - if (cuda_err != cudaSuccess) { - return FlashAttentionError::CUDA_ERROR; - } - return FlashAttentionError::SUCCESS; + return detail::finish_kernel_launch(); } // C = A^T @ B template FlashAttentionError matmul_AtB(const float* A, const float* B, float* C, float scale, cudaStream_t stream) { - FlashAttentionError err = validate_matmul_params(A, B, C); + FlashAttentionError err = detail::validate_non_null({A, B, C}); if (err != FlashAttentionError::SUCCESS) { return err; } size_t smem_size = (K * M + K * N + M * N) * sizeof(float); + FlashAttentionError status = + detail::prepare_kernel_launch(matmul_AtB_kernel, smem_size); + if (status != FlashAttentionError::SUCCESS) { + return status; + } matmul_AtB_kernel<<<1, MATMUL_THREADS, smem_size, stream>>>(A, B, C, scale); - cudaError_t cuda_err = cudaGetLastError(); - if (cuda_err != cudaSuccess) { - return FlashAttentionError::CUDA_ERROR; - } - return FlashAttentionError::SUCCESS; + return detail::finish_kernel_launch(); } // ============================================================================= diff --git a/src/kernels/online_softmax.cu b/src/kernels/online_softmax.cu index 957567a..d7f366e 100644 --- a/src/kernels/online_softmax.cu +++ b/src/kernels/online_softmax.cu @@ -3,6 +3,7 @@ #include "cuflash/kernels/online_softmax.cuh" #include "impl/online_softmax.cuh" +#include "primitive_api_utils.cuh" namespace cuflash { namespace kernels { @@ -79,10 +80,14 @@ __global__ void online_softmax_forward_kernel(const float* __restrict__ input, // Use shared memory for reductions float* reduce_smem = smem; + float* state_smem = reduce_smem + SOFTMAX_THREADS / 32; // Process blocks - impl::OnlineSoftmaxState state; - state.init(); + if (threadIdx.x == 0) { + state_smem[0] = -INFINITY; + state_smem[1] = 0.0f; + } + __syncthreads(); const float* row_input = input + row * cols; float* row_output = output + row * cols; @@ -104,8 +109,6 @@ __global__ void online_softmax_forward_kernel(const float* __restrict__ input, block_max = fmaxf(block_max, val); } - block_max = impl::block_reduce_sum(block_max, reduce_smem); - // Actually need max reduction - reuse shared memory block_max = impl::block_reduce_max(block_max, reduce_smem); // Compute exp sum @@ -117,13 +120,18 @@ __global__ void online_softmax_forward_kernel(const float* __restrict__ input, // Update state if (threadIdx.x == 0) { + impl::OnlineSoftmaxState state; + state.m = state_smem[0]; + state.l = state_smem[1]; state.update(block_max, block_sum); + state_smem[0] = state.m; + state_smem[1] = state.l; } __syncthreads(); } // Final normalization - float l_inv = state.get_normalizer(); + float l_inv = 1.0f / state_smem[1]; for (int b = 0; b < num_blocks; b++) { int start = b * BLOCK_SIZE; @@ -139,7 +147,7 @@ __global__ void online_softmax_forward_kernel(const float* __restrict__ input, block_max = impl::block_reduce_max(block_max, reduce_smem); // Compute and store output - float rescale = expf(block_max - state.m); + float rescale = expf(block_max - state_smem[0]); for (int i = threadIdx.x; i < block_len; i += blockDim.x) { float val = row_input[start + i]; row_output[start + i] = expf(val - block_max) * rescale * l_inv; @@ -148,7 +156,7 @@ __global__ void online_softmax_forward_kernel(const float* __restrict__ input, // Store logsumexp if (threadIdx.x == 0) { - logsumexp[row] = state.logsumexp(); + logsumexp[row] = state_smem[0] + logf(state_smem[1]); } } @@ -156,84 +164,70 @@ __global__ void online_softmax_forward_kernel(const float* __restrict__ input, // Host Entry Points // ============================================================================= -// Validation helper -static FlashAttentionError validate_online_softmax_params(const float* ptr, int rows) { - if (!ptr) { - return FlashAttentionError::NULL_POINTER; - } - if (rows <= 0) { - return FlashAttentionError::INVALID_DIMENSION; - } - return FlashAttentionError::SUCCESS; -} - // Init FlashAttentionError online_softmax_init(float* state_m, float* state_l, int rows, cudaStream_t stream) { - FlashAttentionError err = validate_online_softmax_params(state_m, rows); + FlashAttentionError err = detail::validate_non_null({state_m, state_l}); + if (err != FlashAttentionError::SUCCESS) + return err; + err = detail::validate_positive_dimensions({rows}); if (err != FlashAttentionError::SUCCESS) return err; - if (!state_l) - return FlashAttentionError::NULL_POINTER; int blocks = (rows + SOFTMAX_THREADS - 1) / SOFTMAX_THREADS; online_softmax_init_kernel<<>>(state_m, state_l, rows); - cudaError_t cuda_err = cudaGetLastError(); - return (cuda_err == cudaSuccess) ? FlashAttentionError::SUCCESS - : FlashAttentionError::CUDA_ERROR; + return detail::finish_kernel_launch(); } // Update FlashAttentionError online_softmax_update(const float* block_max, const float* block_sum, float* state_m, float* state_l, int rows, cudaStream_t stream) { - FlashAttentionError err = validate_online_softmax_params(block_max, rows); + FlashAttentionError err = detail::validate_non_null({block_max, block_sum, state_m, state_l}); + if (err != FlashAttentionError::SUCCESS) + return err; + err = detail::validate_positive_dimensions({rows}); if (err != FlashAttentionError::SUCCESS) return err; - if (!block_sum || !state_m || !state_l) - return FlashAttentionError::NULL_POINTER; int blocks = (rows + SOFTMAX_THREADS - 1) / SOFTMAX_THREADS; online_softmax_update_kernel<<>>(block_max, block_sum, state_m, state_l, rows); - cudaError_t cuda_err = cudaGetLastError(); - return (cuda_err == cudaSuccess) ? FlashAttentionError::SUCCESS - : FlashAttentionError::CUDA_ERROR; + return detail::finish_kernel_launch(); } // Finalize FlashAttentionError online_softmax_finalize(const float* state_m, const float* state_l, float* logsumexp, float* normalizer, int rows, cudaStream_t stream) { - FlashAttentionError err = validate_online_softmax_params(state_m, rows); + FlashAttentionError err = detail::validate_non_null({state_m, state_l, logsumexp, normalizer}); + if (err != FlashAttentionError::SUCCESS) + return err; + err = detail::validate_positive_dimensions({rows}); if (err != FlashAttentionError::SUCCESS) return err; - if (!state_l || !logsumexp) - return FlashAttentionError::NULL_POINTER; int blocks = (rows + SOFTMAX_THREADS - 1) / SOFTMAX_THREADS; online_softmax_finalize_kernel<<>>( state_m, state_l, logsumexp, normalizer, rows); - cudaError_t cuda_err = cudaGetLastError(); - return (cuda_err == cudaSuccess) ? FlashAttentionError::SUCCESS - : FlashAttentionError::CUDA_ERROR; + return detail::finish_kernel_launch(); } // Forward (convenience) FlashAttentionError online_softmax_forward(const float* input, float* output, float* logsumexp, int rows, int cols, int block_size, cudaStream_t stream) { - if (!input || !output || !logsumexp) { - return FlashAttentionError::NULL_POINTER; - } - if (rows <= 0 || cols <= 0 || block_size <= 0) { - return FlashAttentionError::INVALID_DIMENSION; - } + FlashAttentionError err = detail::validate_non_null({input, output, logsumexp}); + if (err != FlashAttentionError::SUCCESS) + return err; + err = detail::validate_positive_dimensions({rows, cols, block_size}); + if (err != FlashAttentionError::SUCCESS) + return err; - size_t smem_size = SOFTMAX_THREADS / 32 * sizeof(float); + size_t smem_size = (SOFTMAX_THREADS / 32 + 2) * sizeof(float); // Dispatch based on block size if (block_size <= 32) { @@ -247,9 +241,7 @@ FlashAttentionError online_softmax_forward(const float* input, float* output, fl <<>>(input, output, logsumexp, rows, cols); } - cudaError_t cuda_err = cudaGetLastError(); - return (cuda_err == cudaSuccess) ? FlashAttentionError::SUCCESS - : FlashAttentionError::CUDA_ERROR; + return detail::finish_kernel_launch(); } // Explicit template instantiations diff --git a/src/kernels/primitive_api_utils.cuh b/src/kernels/primitive_api_utils.cuh new file mode 100644 index 0000000..0b0daa6 --- /dev/null +++ b/src/kernels/primitive_api_utils.cuh @@ -0,0 +1,56 @@ +#pragma once + +#include + +#include "cuflash/flash_attention.h" +#include "kernel_launch_utils.cuh" + +namespace cuflash { +namespace kernels { +namespace detail { + +inline FlashAttentionError validate_non_null(std::initializer_list pointers) { + for (const void* pointer : pointers) { + if (pointer == nullptr) { + return FlashAttentionError::NULL_POINTER; + } + } + return FlashAttentionError::SUCCESS; +} + +inline FlashAttentionError validate_positive_dimensions(std::initializer_list dimensions) { + for (int dimension : dimensions) { + if (dimension <= 0) { + return FlashAttentionError::INVALID_DIMENSION; + } + } + return FlashAttentionError::SUCCESS; +} + +inline FlashAttentionError validate_tile_window(int row_start, int col_start, int max_rows, + int max_cols, int stride) { + FlashAttentionError status = validate_positive_dimensions({max_rows, max_cols, stride}); + if (status != FlashAttentionError::SUCCESS) { + return status; + } + + if (row_start < 0 || col_start < 0 || row_start >= max_rows || col_start >= max_cols) { + return FlashAttentionError::INVALID_DIMENSION; + } + + return FlashAttentionError::SUCCESS; +} + +template +inline FlashAttentionError prepare_kernel_launch(KernelFunc kernel, size_t smem_size) { + return prepare_dynamic_smem_launch(reinterpret_cast(kernel), smem_size); +} + +inline FlashAttentionError finish_kernel_launch() { + return (cudaGetLastError() == cudaSuccess) ? FlashAttentionError::SUCCESS + : FlashAttentionError::CUDA_ERROR; +} + +} // namespace detail +} // namespace kernels +} // namespace cuflash diff --git a/src/kernels/tile_io.cu b/src/kernels/tile_io.cu index f4abe9a..6cf7939 100644 --- a/src/kernels/tile_io.cu +++ b/src/kernels/tile_io.cu @@ -3,6 +3,7 @@ #include "cuflash/kernels/tile_io.cuh" #include "impl/tile_io.cuh" +#include "primitive_api_utils.cuh" namespace cuflash { namespace kernels { @@ -107,31 +108,16 @@ __global__ void load_store_roundtrip_kernel(const float* __restrict__ src, float // Host Entry Points // ============================================================================= -// Common validation -static FlashAttentionError validate_tile_params(int row_start, int col_start, int max_rows, - int max_cols, int stride) { - if (max_rows <= 0 || max_cols <= 0 || stride <= 0) { - return FlashAttentionError::INVALID_DIMENSION; - } - if (row_start < 0 || col_start < 0) { - return FlashAttentionError::INVALID_DIMENSION; - } - if (row_start >= max_rows || col_start >= max_cols) { - return FlashAttentionError::INVALID_DIMENSION; - } - return FlashAttentionError::SUCCESS; -} - // FP32 Load template FlashAttentionError load_tile(const float* src, float* dst, int row_start, int col_start, int max_rows, int max_cols, int src_stride, cudaStream_t stream) { - if (!src || !dst) { - return FlashAttentionError::NULL_POINTER; + FlashAttentionError err = detail::validate_non_null({src, dst}); + if (err != FlashAttentionError::SUCCESS) { + return err; } - FlashAttentionError err = - validate_tile_params(row_start, col_start, max_rows, max_cols, src_stride); + err = detail::validate_tile_window(row_start, col_start, max_rows, max_cols, src_stride); if (err != FlashAttentionError::SUCCESS) { return err; } @@ -140,23 +126,19 @@ FlashAttentionError load_tile(const float* src, float* dst, int row_start, int c load_tile_kernel_fp32<<<1, 128, smem_size, stream>>>( src, dst, row_start, col_start, max_rows, max_cols, src_stride); - cudaError_t cuda_err = cudaGetLastError(); - if (cuda_err != cudaSuccess) { - return FlashAttentionError::CUDA_ERROR; - } - return FlashAttentionError::SUCCESS; + return detail::finish_kernel_launch(); } // FP16 Load template FlashAttentionError load_tile(const half* src, float* dst, int row_start, int col_start, int max_rows, int max_cols, int src_stride, cudaStream_t stream) { - if (!src || !dst) { - return FlashAttentionError::NULL_POINTER; + FlashAttentionError err = detail::validate_non_null({src, dst}); + if (err != FlashAttentionError::SUCCESS) { + return err; } - FlashAttentionError err = - validate_tile_params(row_start, col_start, max_rows, max_cols, src_stride); + err = detail::validate_tile_window(row_start, col_start, max_rows, max_cols, src_stride); if (err != FlashAttentionError::SUCCESS) { return err; } @@ -165,23 +147,19 @@ FlashAttentionError load_tile(const half* src, float* dst, int row_start, int co load_tile_kernel_fp16<<<1, 128, smem_size, stream>>>( src, dst, row_start, col_start, max_rows, max_cols, src_stride); - cudaError_t cuda_err = cudaGetLastError(); - if (cuda_err != cudaSuccess) { - return FlashAttentionError::CUDA_ERROR; - } - return FlashAttentionError::SUCCESS; + return detail::finish_kernel_launch(); } // FP32 Store template FlashAttentionError store_tile(const float* src, float* dst, int row_start, int col_start, int max_rows, int max_cols, int dst_stride, cudaStream_t stream) { - if (!src || !dst) { - return FlashAttentionError::NULL_POINTER; + FlashAttentionError err = detail::validate_non_null({src, dst}); + if (err != FlashAttentionError::SUCCESS) { + return err; } - FlashAttentionError err = - validate_tile_params(row_start, col_start, max_rows, max_cols, dst_stride); + err = detail::validate_tile_window(row_start, col_start, max_rows, max_cols, dst_stride); if (err != FlashAttentionError::SUCCESS) { return err; } @@ -190,23 +168,19 @@ FlashAttentionError store_tile(const float* src, float* dst, int row_start, int store_tile_kernel_fp32<<<1, 128, smem_size, stream>>>( src, dst, row_start, col_start, max_rows, max_cols, dst_stride); - cudaError_t cuda_err = cudaGetLastError(); - if (cuda_err != cudaSuccess) { - return FlashAttentionError::CUDA_ERROR; - } - return FlashAttentionError::SUCCESS; + return detail::finish_kernel_launch(); } // FP16 Store template FlashAttentionError store_tile(const float* src, half* dst, int row_start, int col_start, int max_rows, int max_cols, int dst_stride, cudaStream_t stream) { - if (!src || !dst) { - return FlashAttentionError::NULL_POINTER; + FlashAttentionError err = detail::validate_non_null({src, dst}); + if (err != FlashAttentionError::SUCCESS) { + return err; } - FlashAttentionError err = - validate_tile_params(row_start, col_start, max_rows, max_cols, dst_stride); + err = detail::validate_tile_window(row_start, col_start, max_rows, max_cols, dst_stride); if (err != FlashAttentionError::SUCCESS) { return err; } @@ -215,11 +189,7 @@ FlashAttentionError store_tile(const float* src, half* dst, int row_start, int c store_tile_kernel_fp16<<<1, 128, smem_size, stream>>>( src, dst, row_start, col_start, max_rows, max_cols, dst_stride); - cudaError_t cuda_err = cudaGetLastError(); - if (cuda_err != cudaSuccess) { - return FlashAttentionError::CUDA_ERROR; - } - return FlashAttentionError::SUCCESS; + return detail::finish_kernel_launch(); } // Round-trip @@ -227,12 +197,12 @@ template FlashAttentionError load_store_tile_roundtrip(const float* src, float* dst, int row_start, int col_start, int max_rows, int max_cols, int stride, cudaStream_t stream) { - if (!src || !dst) { - return FlashAttentionError::NULL_POINTER; + FlashAttentionError err = detail::validate_non_null({src, dst}); + if (err != FlashAttentionError::SUCCESS) { + return err; } - FlashAttentionError err = - validate_tile_params(row_start, col_start, max_rows, max_cols, stride); + err = detail::validate_tile_window(row_start, col_start, max_rows, max_cols, stride); if (err != FlashAttentionError::SUCCESS) { return err; } @@ -241,11 +211,7 @@ FlashAttentionError load_store_tile_roundtrip(const float* src, float* dst, int load_store_roundtrip_kernel <<<1, 128, smem_size, stream>>>(src, dst, row_start, col_start, max_rows, max_cols, stride); - cudaError_t cuda_err = cudaGetLastError(); - if (cuda_err != cudaSuccess) { - return FlashAttentionError::CUDA_ERROR; - } - return FlashAttentionError::SUCCESS; + return detail::finish_kernel_launch(); } // ============================================================================= diff --git a/tests/unit/test_matmul.cu b/tests/unit/test_matmul.cu index 56fade0..58fb217 100644 --- a/tests/unit/test_matmul.cu +++ b/tests/unit/test_matmul.cu @@ -197,6 +197,42 @@ TEST_F(MatmulTest, ABt_HeadDim128_SmallTile) { cudaFree(d_C); } +TEST_F(MatmulTest, ABt_HeadDim128_LargeTile) { + constexpr int M = 64, N = 64, K = 128; + + std::vector h_A(M * K), h_B(N * K); + std::mt19937 gen(9); + std::uniform_real_distribution dist(-0.25f, 0.25f); + for (auto& v : h_A) + v = dist(gen); + for (auto& v : h_B) + v = dist(gen); + + float scale = 0.08839f; // ~1/sqrt(128) + auto h_C_expected = matmul_ABt_cpu(h_A, h_B, M, N, K, scale); + + float *d_A, *d_B, *d_C; + cudaMalloc(&d_A, M * K * sizeof(float)); + cudaMalloc(&d_B, N * K * sizeof(float)); + cudaMalloc(&d_C, M * N * sizeof(float)); + cudaMemcpy(d_A, h_A.data(), M * K * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(d_B, h_B.data(), N * K * sizeof(float), cudaMemcpyHostToDevice); + + FlashAttentionError err = kernels::matmul_ABt(d_A, d_B, d_C, scale, stream); + ASSERT_EQ(err, FlashAttentionError::SUCCESS); + + std::vector h_C(M * N); + cudaMemcpy(h_C.data(), d_C, M * N * sizeof(float), cudaMemcpyDeviceToHost); + + for (size_t i = 0; i < h_C.size(); i++) { + EXPECT_NEAR(h_C[i], h_C_expected[i], 1e-3f) << "Mismatch at index " << i; + } + + cudaFree(d_A); + cudaFree(d_B); + cudaFree(d_C); +} + // ============================================================================= // matmul_AB Tests (Attention Output Computation) // ============================================================================= diff --git a/tests/unit/test_online_softmax.cu b/tests/unit/test_online_softmax.cu index f05475a..b3e9ccd 100644 --- a/tests/unit/test_online_softmax.cu +++ b/tests/unit/test_online_softmax.cu @@ -150,6 +150,48 @@ TEST_F(OnlineSoftmaxTest, Forward_LargeInput) { cudaFree(d_logsumexp); } +TEST_F(OnlineSoftmaxTest, Forward_MultiBlockCrossWarpMatchesReference) { + constexpr int ROWS = 3; + constexpr int COLS = 96; + + std::vector h_input(ROWS * COLS); + std::mt19937 gen(7); + std::uniform_real_distribution dist(-8.0f, 8.0f); + for (auto& v : h_input) { + v = dist(gen); + } + + auto h_output_expected = reference_softmax(h_input, ROWS, COLS); + auto h_logsumexp_expected = reference_logsumexp(h_input, ROWS, COLS); + + float *d_input, *d_output, *d_logsumexp; + cudaMalloc(&d_input, ROWS * COLS * sizeof(float)); + cudaMalloc(&d_output, ROWS * COLS * sizeof(float)); + cudaMalloc(&d_logsumexp, ROWS * sizeof(float)); + cudaMemcpy(d_input, h_input.data(), ROWS * COLS * sizeof(float), cudaMemcpyHostToDevice); + + FlashAttentionError err = + kernels::online_softmax_forward(d_input, d_output, d_logsumexp, ROWS, COLS, 64, stream); + ASSERT_EQ(err, FlashAttentionError::SUCCESS); + + std::vector h_output(ROWS * COLS), h_logsumexp(ROWS); + cudaMemcpy(h_output.data(), d_output, ROWS * COLS * sizeof(float), cudaMemcpyDeviceToHost); + cudaMemcpy(h_logsumexp.data(), d_logsumexp, ROWS * sizeof(float), cudaMemcpyDeviceToHost); + + for (size_t i = 0; i < h_output.size(); i++) { + EXPECT_NEAR(h_output[i], h_output_expected[i], 1e-4f) << "Softmax mismatch at index " << i; + } + + for (int r = 0; r < ROWS; r++) { + EXPECT_NEAR(h_logsumexp[r], h_logsumexp_expected[r], 1e-4f) + << "Logsumexp mismatch at row " << r; + } + + cudaFree(d_input); + cudaFree(d_output); + cudaFree(d_logsumexp); +} + TEST_F(OnlineSoftmaxTest, Forward_NumericalStability) { // Test with large values that would cause overflow in naive implementation constexpr int ROWS = 2; @@ -301,6 +343,21 @@ TEST_F(OnlineSoftmaxTest, InvalidDimensionReturnsError) { cudaFree(d_valid); } +TEST_F(OnlineSoftmaxTest, FinalizeNullNormalizerReturnsError) { + float *d_state_m, *d_state_l, *d_logsumexp; + cudaMalloc(&d_state_m, 4 * sizeof(float)); + cudaMalloc(&d_state_l, 4 * sizeof(float)); + cudaMalloc(&d_logsumexp, 4 * sizeof(float)); + + FlashAttentionError err = + kernels::online_softmax_finalize(d_state_m, d_state_l, d_logsumexp, nullptr, 4, stream); + EXPECT_EQ(err, FlashAttentionError::NULL_POINTER); + + cudaFree(d_state_m); + cudaFree(d_state_l); + cudaFree(d_logsumexp); +} + // ============================================================================= // Sum-to-One Property Test // =============================================================================