Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,3 @@ venv/

# Benchmark outputs
benchmark_results/

# Local AI preferences
CLAUDE.local.md
.claude/settings.local.json
.claude/skills/
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Simplified contributor guidance, Copilot instructions, and pull request metadata to match the real repository workflow.
- Reduced GitHub Pages scope to product documentation; changelog history now lives only in the root `CHANGELOG.md`.
- Removed docs-site spec mirrors, release-note mirrors, and AI planning artifacts that duplicated repository content.
- Removed stale README references to the deleted `AGENTS.md` workflow document and replaced leftover SDD branding with lean repository wording.
- Removed `.claude`/`CLAUDE.local.md` ignore rules that only served deleted AI tooling overlays.
- Simplified the GitHub Pages landing page links so the docs site no longer surfaces changelog navigation.
- Fixed CUDA preset validation on fresh Ubuntu environments by documenting and working through the required host-compiler/toolkit alignment.
- Fixed backward kernel dispatch for `head_dim=64` by using a smaller shared-memory tiling path that fits current CUDA limits.
- Fixed package-smoke consumption by exposing CUDA headers through the exported target and simplifying the downstream smoke project to a pure C++ consumer.
- Fixed the FP16 tile store test to validate half-rounding semantics instead of comparing against the original float with an unrealistically tight tolerance.
- Fixed the optional PyTorch comparison script to skip cleanly when `torch` is not installed.

---

Expand Down
6 changes: 5 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ add_library(cuflash_attn ${SOURCES})
target_link_libraries(cuflash_attn PUBLIC CUDA::cudart)
target_include_directories(cuflash_attn PUBLIC
$<BUILD_INTERFACE:${CMAKE_SOURCE_DIR}/include>
$<BUILD_INTERFACE:${CUDAToolkit_INCLUDE_DIRS}>
$<INSTALL_INTERFACE:${CUDAToolkit_INCLUDE_DIRS}>
$<INSTALL_INTERFACE:include>
)
# Private include directories for internal kernels
Expand Down Expand Up @@ -229,6 +231,8 @@ if(BUILD_TESTING AND BUILD_SHARED_LIBS)
-Dinstall_dir=${CUFLASH_PACKAGE_SMOKE_INSTALL_DIR}
-Dgenerator=${CMAKE_GENERATOR}
-Dbuild_type=${CMAKE_BUILD_TYPE}
-Dc_compiler=${CMAKE_C_COMPILER}
-Dcxx_compiler=${CMAKE_CXX_COMPILER}
-Dcuda_architectures=${CMAKE_CUDA_ARCHITECTURES}
-Dctest_command=${CMAKE_CTEST_COMMAND}
-P ${CMAKE_SOURCE_DIR}/cmake/run_package_smoke.cmake
Expand All @@ -244,7 +248,7 @@ if(BUILD_TESTING AND BUILD_SHARED_LIBS)
)
set_tests_properties(cuflash_attn_pytorch_comparison PROPERTIES
LABELS "integration;pytorch"
SKIP_RETURN_CODE 0
SKIP_RETURN_CODE 77
)
endif()
endif()
Expand Down
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -383,8 +383,6 @@ clang-tidy src/api/flash_attention_api.cu -- -Iinclude

📋 **Detailed guidelines**: See [CONTRIBUTING.md](CONTRIBUTING.md)

🤖 **For AI Contributors**: Read [AGENTS.md](AGENTS.md) for SDD workflow instructions.

---

## 📄 License
Expand All @@ -408,5 +406,5 @@ See [CHANGELOG.md](CHANGELOG.md) for detailed version history and updates.

<p align="center">
<sub>Built with ❤️ for efficient attention computation</sub><br>
<sub>Spec-Driven Development · CUDA C++ · Open Source</sub>
<sub>Lean Reference Implementation · CUDA C++ · Open Source</sub>
</p>
4 changes: 1 addition & 3 deletions README.zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -383,8 +383,6 @@ clang-tidy src/api/flash_attention_api.cu -- -Iinclude

📋 **详细指南**: 请参见 [CONTRIBUTING.md](CONTRIBUTING.md)

🤖 **AI 贡献者**: 阅读 [AGENTS.md](AGENTS.md) 了解 SDD 工作流说明。

---

## 📄 许可证
Expand All @@ -408,5 +406,5 @@ clang-tidy src/api/flash_attention_api.cu -- -Iinclude

<p align="center">
<sub>用 ❤️ 打造的高效注意力计算</sub><br>
<sub>规范驱动开发 · CUDA C++ · 开源</sub>
<sub>精简参考实现 · CUDA C++ · 开源</sub>
</p>
6 changes: 6 additions & 0 deletions cmake/run_package_smoke.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ endif()
if(DEFINED build_type AND NOT build_type STREQUAL "")
list(APPEND configure_args -DCMAKE_BUILD_TYPE=${build_type})
endif()
if(DEFINED c_compiler AND NOT c_compiler STREQUAL "")
list(APPEND configure_args -DCMAKE_C_COMPILER=${c_compiler})
endif()
if(DEFINED cxx_compiler AND NOT cxx_compiler STREQUAL "")
list(APPEND configure_args -DCMAKE_CXX_COMPILER=${cxx_compiler})
endif()
if(DEFINED cuda_architectures AND NOT cuda_architectures STREQUAL "")
list(APPEND configure_args -DCMAKE_CUDA_ARCHITECTURES=${cuda_architectures})
endif()
Expand Down
6 changes: 0 additions & 6 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,6 @@ function setLanguage(lang) {
</svg>
<span>Releases</span>
</a>
<a href="https://github.com/AICL-Lab/cuflash-attn/blob/master/CHANGELOG.md" class="quick-link">
<svg viewBox="0 0 24 24" width="18" height="18" fill="currentColor">
<path d="M14 2H6c-1.1 0-2 .9-2 2v16c0 1.1.9 2 2 2h12c1.1 0 2-.9 2-2V8l-6-6zm4 18H6V4h7v5h5v11z"/>
</svg>
<span>Changelog</span>
</a>
</div>
</div>

Expand Down
84 changes: 64 additions & 20 deletions src/backward/flash_attention_backward_typed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,9 @@ template __global__ void flash_attention_backward_dq_kernel<float, 64, 64, 32>(
template __global__ void flash_attention_backward_dq_kernel<float, 64, 64, 64>(
const float*, const float*, const float*, const float*, const float*, const float*, float*, int,
float, bool);
template __global__ void flash_attention_backward_dq_kernel<float, 32, 32, 64>(
const float*, const float*, const float*, const float*, const float*, const float*, float*, int,
float, bool);
template __global__ void flash_attention_backward_dq_kernel<float, 64, 64, 128>(
const float*, const float*, const float*, const float*, const float*, const float*, float*, int,
float, bool);
Expand All @@ -315,6 +318,9 @@ template __global__ void flash_attention_backward_dkdv_kernel<float, 64, 64, 32>
template __global__ void flash_attention_backward_dkdv_kernel<float, 64, 64, 64>(
const float*, const float*, const float*, const float*, const float*, const float*, float*,
float*, int, float, bool);
template __global__ void flash_attention_backward_dkdv_kernel<float, 32, 32, 64>(
const float*, const float*, const float*, const float*, const float*, const float*, float*,
float*, int, float, bool);
template __global__ void flash_attention_backward_dkdv_kernel<float, 64, 64, 128>(
const float*, const float*, const float*, const float*, const float*, const float*, float*,
float*, int, float, bool);
Expand All @@ -333,6 +339,9 @@ template __global__ void flash_attention_backward_dq_kernel<half, 64, 64, 32>(
template __global__ void flash_attention_backward_dq_kernel<half, 64, 64, 64>(
const half*, const half*, const half*, const half*, const half*, const float*, half*, int,
float, bool);
template __global__ void flash_attention_backward_dq_kernel<half, 32, 32, 64>(
const half*, const half*, const half*, const half*, const half*, const float*, half*, int,
float, bool);
template __global__ void flash_attention_backward_dq_kernel<half, 64, 64, 128>(
const half*, const half*, const half*, const half*, const half*, const float*, half*, int,
float, bool);
Expand All @@ -346,6 +355,9 @@ template __global__ void flash_attention_backward_dkdv_kernel<half, 64, 64, 32>(
template __global__ void flash_attention_backward_dkdv_kernel<half, 64, 64, 64>(
const half*, const half*, const half*, const half*, const half*, const float*, half*, half*,
int, float, bool);
template __global__ void flash_attention_backward_dkdv_kernel<half, 32, 32, 64>(
const half*, const half*, const half*, const half*, const half*, const float*, half*, half*,
int, float, bool);
template __global__ void flash_attention_backward_dkdv_kernel<half, 64, 64, 128>(
const half*, const half*, const half*, const half*, const half*, const float*, half*, half*,
int, float, bool);
Expand Down Expand Up @@ -380,6 +392,8 @@ FlashAttentionError launch_flash_attention_backward_typed<float>(
using Config = impl::BackwardTilingConfig;
constexpr int BLOCK_M = Config::BLOCK_M;
constexpr int BLOCK_N = Config::BLOCK_N;
constexpr int BLOCK_M_HD64 = Config::BLOCK_M_HD64;
constexpr int BLOCK_N_HD64 = Config::BLOCK_N_HD64;
constexpr int BLOCK_M_HD128 = Config::BLOCK_M_HD128;
constexpr int BLOCK_N_HD128 = Config::BLOCK_N_HD128;

Expand Down Expand Up @@ -416,10 +430,14 @@ FlashAttentionError launch_flash_attention_backward_typed<float>(

int num_q_blocks = (seq_len + BLOCK_M - 1) / BLOCK_M;
int num_kv_blocks = (seq_len + BLOCK_N - 1) / BLOCK_N;
int num_q_blocks_hd64 = (seq_len + BLOCK_M_HD64 - 1) / BLOCK_M_HD64;
int num_kv_blocks_hd64 = (seq_len + BLOCK_N_HD64 - 1) / BLOCK_N_HD64;
int num_q_blocks_hd128 = (seq_len + BLOCK_M_HD128 - 1) / BLOCK_M_HD128;
int num_kv_blocks_hd128 = (seq_len + BLOCK_N_HD128 - 1) / BLOCK_N_HD128;
dim3 dq_grid(num_q_blocks, batch_heads);
dim3 dkdv_grid(num_kv_blocks, batch_heads);
dim3 dq_grid_hd64(num_q_blocks_hd64, batch_heads);
dim3 dkdv_grid_hd64(num_kv_blocks_hd64, batch_heads);
dim3 dq_grid_hd128(num_q_blocks_hd128, batch_heads);
dim3 dkdv_grid_hd128(num_kv_blocks_hd128, batch_heads);
dim3 block(128);
Expand All @@ -432,6 +450,16 @@ FlashAttentionError launch_flash_attention_backward_typed<float>(
(BLOCK_N * head_dim + BLOCK_N * head_dim + BLOCK_M * head_dim + BLOCK_M * head_dim +
BLOCK_M * BLOCK_N + BLOCK_N * head_dim + BLOCK_N * head_dim + BLOCK_M + BLOCK_M) *
sizeof(float);
size_t dq_smem_size_hd64 =
(BLOCK_M_HD64 * head_dim + BLOCK_M_HD64 * head_dim + BLOCK_N_HD64 * head_dim +
BLOCK_N_HD64 * head_dim + BLOCK_M_HD64 * BLOCK_N_HD64 +
BLOCK_M_HD64 * head_dim + BLOCK_M_HD64 + BLOCK_M_HD64) *
sizeof(float);
size_t dkdv_smem_size_hd64 =
(BLOCK_N_HD64 * head_dim + BLOCK_N_HD64 * head_dim + BLOCK_M_HD64 * head_dim +
BLOCK_M_HD64 * head_dim + BLOCK_M_HD64 * BLOCK_N_HD64 +
BLOCK_N_HD64 * head_dim + BLOCK_N_HD64 * head_dim + BLOCK_M_HD64 + BLOCK_M_HD64) *
sizeof(float);
size_t dq_smem_size_hd128 =
(BLOCK_M_HD128 * head_dim + BLOCK_M_HD128 * head_dim + BLOCK_N_HD128 * head_dim +
BLOCK_N_HD128 * head_dim + BLOCK_M_HD128 * BLOCK_N_HD128 + BLOCK_M_HD128 * head_dim +
Expand Down Expand Up @@ -472,27 +500,27 @@ FlashAttentionError launch_flash_attention_backward_typed<float>(
} else if (head_dim == 64) {
status = prepare_dynamic_smem_launch(
reinterpret_cast<const void*>(
flash_attention_backward_dq_kernel<float, BLOCK_M, BLOCK_N, 64>),
dq_smem_size);
flash_attention_backward_dq_kernel<float, BLOCK_M_HD64, BLOCK_N_HD64, 64>),
dq_smem_size_hd64);
if (status != FlashAttentionError::SUCCESS)
return status;
status = prepare_dynamic_smem_launch(
reinterpret_cast<const void*>(
flash_attention_backward_dkdv_kernel<float, BLOCK_M, BLOCK_N, 64>),
dkdv_smem_size);
flash_attention_backward_dkdv_kernel<float, BLOCK_M_HD64, BLOCK_N_HD64, 64>),
dkdv_smem_size_hd64);
if (status != FlashAttentionError::SUCCESS)
return status;

flash_attention_backward_dq_kernel<float, BLOCK_M, BLOCK_N, 64>
<<<dq_grid, block, dq_smem_size, stream>>>(Q, K, V, L, dO, D, dQ, seq_len, scale,
causal);
flash_attention_backward_dq_kernel<float, BLOCK_M_HD64, BLOCK_N_HD64, 64>
<<<dq_grid_hd64, block, dq_smem_size_hd64, stream>>>(Q, K, V, L, dO, D, dQ, seq_len,
scale, causal);
err = cudaGetLastError();
if (err != cudaSuccess)
return FlashAttentionError::CUDA_ERROR;

flash_attention_backward_dkdv_kernel<float, BLOCK_M, BLOCK_N, 64>
<<<dkdv_grid, block, dkdv_smem_size, stream>>>(Q, K, V, L, dO, D, dK, dV, seq_len,
scale, causal);
flash_attention_backward_dkdv_kernel<float, BLOCK_M_HD64, BLOCK_N_HD64, 64>
<<<dkdv_grid_hd64, block, dkdv_smem_size_hd64, stream>>>(Q, K, V, L, dO, D, dK, dV,
seq_len, scale, causal);
} else if (head_dim == 128) {
status = prepare_dynamic_smem_launch(
reinterpret_cast<const void*>(
Expand Down Expand Up @@ -536,6 +564,8 @@ FlashAttentionError launch_flash_attention_backward_typed<half>(
using Config = impl::BackwardTilingConfig;
constexpr int BLOCK_M = Config::BLOCK_M;
constexpr int BLOCK_N = Config::BLOCK_N;
constexpr int BLOCK_M_HD64 = Config::BLOCK_M_HD64;
constexpr int BLOCK_N_HD64 = Config::BLOCK_N_HD64;
constexpr int BLOCK_M_HD128 = Config::BLOCK_M_HD128;
constexpr int BLOCK_N_HD128 = Config::BLOCK_N_HD128;

Expand Down Expand Up @@ -571,10 +601,14 @@ FlashAttentionError launch_flash_attention_backward_typed<half>(

int num_q_blocks = (seq_len + BLOCK_M - 1) / BLOCK_M;
int num_kv_blocks = (seq_len + BLOCK_N - 1) / BLOCK_N;
int num_q_blocks_hd64 = (seq_len + BLOCK_M_HD64 - 1) / BLOCK_M_HD64;
int num_kv_blocks_hd64 = (seq_len + BLOCK_N_HD64 - 1) / BLOCK_N_HD64;
int num_q_blocks_hd128 = (seq_len + BLOCK_M_HD128 - 1) / BLOCK_M_HD128;
int num_kv_blocks_hd128 = (seq_len + BLOCK_N_HD128 - 1) / BLOCK_N_HD128;
dim3 dq_grid(num_q_blocks, batch_heads);
dim3 dkdv_grid(num_kv_blocks, batch_heads);
dim3 dq_grid_hd64(num_q_blocks_hd64, batch_heads);
dim3 dkdv_grid_hd64(num_kv_blocks_hd64, batch_heads);
dim3 dq_grid_hd128(num_q_blocks_hd128, batch_heads);
dim3 dkdv_grid_hd128(num_kv_blocks_hd128, batch_heads);
dim3 block(128);
Expand All @@ -587,6 +621,16 @@ FlashAttentionError launch_flash_attention_backward_typed<half>(
(BLOCK_N * head_dim + BLOCK_N * head_dim + BLOCK_M * head_dim + BLOCK_M * head_dim +
BLOCK_M * BLOCK_N + BLOCK_N * head_dim + BLOCK_N * head_dim + BLOCK_M + BLOCK_M) *
sizeof(float);
size_t dq_smem_size_hd64 =
(BLOCK_M_HD64 * head_dim + BLOCK_M_HD64 * head_dim + BLOCK_N_HD64 * head_dim +
BLOCK_N_HD64 * head_dim + BLOCK_M_HD64 * BLOCK_N_HD64 +
BLOCK_M_HD64 * head_dim + BLOCK_M_HD64 + BLOCK_M_HD64) *
sizeof(float);
size_t dkdv_smem_size_hd64 =
(BLOCK_N_HD64 * head_dim + BLOCK_N_HD64 * head_dim + BLOCK_M_HD64 * head_dim +
BLOCK_M_HD64 * head_dim + BLOCK_M_HD64 * BLOCK_N_HD64 +
BLOCK_N_HD64 * head_dim + BLOCK_N_HD64 * head_dim + BLOCK_M_HD64 + BLOCK_M_HD64) *
sizeof(float);
size_t dq_smem_size_hd128 =
(BLOCK_M_HD128 * head_dim + BLOCK_M_HD128 * head_dim + BLOCK_N_HD128 * head_dim +
BLOCK_N_HD128 * head_dim + BLOCK_M_HD128 * BLOCK_N_HD128 + BLOCK_M_HD128 * head_dim +
Expand Down Expand Up @@ -627,27 +671,27 @@ FlashAttentionError launch_flash_attention_backward_typed<half>(
} else if (head_dim == 64) {
status = prepare_dynamic_smem_launch(
reinterpret_cast<const void*>(
flash_attention_backward_dq_kernel<half, BLOCK_M, BLOCK_N, 64>),
dq_smem_size);
flash_attention_backward_dq_kernel<half, BLOCK_M_HD64, BLOCK_N_HD64, 64>),
dq_smem_size_hd64);
if (status != FlashAttentionError::SUCCESS)
return status;
status = prepare_dynamic_smem_launch(
reinterpret_cast<const void*>(
flash_attention_backward_dkdv_kernel<half, BLOCK_M, BLOCK_N, 64>),
dkdv_smem_size);
flash_attention_backward_dkdv_kernel<half, BLOCK_M_HD64, BLOCK_N_HD64, 64>),
dkdv_smem_size_hd64);
if (status != FlashAttentionError::SUCCESS)
return status;

flash_attention_backward_dq_kernel<half, BLOCK_M, BLOCK_N, 64>
<<<dq_grid, block, dq_smem_size, stream>>>(Q, K, V, L, dO, D, dQ, seq_len, scale,
causal);
flash_attention_backward_dq_kernel<half, BLOCK_M_HD64, BLOCK_N_HD64, 64>
<<<dq_grid_hd64, block, dq_smem_size_hd64, stream>>>(Q, K, V, L, dO, D, dQ, seq_len,
scale, causal);
err = cudaGetLastError();
if (err != cudaSuccess)
return FlashAttentionError::CUDA_ERROR;

flash_attention_backward_dkdv_kernel<half, BLOCK_M, BLOCK_N, 64>
<<<dkdv_grid, block, dkdv_smem_size, stream>>>(Q, K, V, L, dO, D, dK, dV, seq_len,
scale, causal);
flash_attention_backward_dkdv_kernel<half, BLOCK_M_HD64, BLOCK_N_HD64, 64>
<<<dkdv_grid_hd64, block, dkdv_smem_size_hd64, stream>>>(Q, K, V, L, dO, D, dK, dV,
seq_len, scale, causal);
} else if (head_dim == 128) {
status = prepare_dynamic_smem_launch(
reinterpret_cast<const void*>(
Expand Down
6 changes: 5 additions & 1 deletion src/kernels/impl/tile_io.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -439,10 +439,14 @@ struct ForwardTilingConfig {
/// Tiling configuration for backward pass.
/// Uses smaller blocks to accommodate additional gradient tensors in shared memory.
struct BackwardTilingConfig {
// Standard block sizes for head_dim 32 and 64
// Standard block sizes for head_dim 32
static constexpr int BLOCK_M = 64; // Q block rows
static constexpr int BLOCK_N = 64; // K/V block rows

// Smaller blocks for head_dim 64 to stay within dynamic shared memory limits.
static constexpr int BLOCK_M_HD64 = 32;
static constexpr int BLOCK_N_HD64 = 32;

// Smaller blocks for head_dim 128 (more aggressive due to dQ, dK, dV)
static constexpr int BLOCK_M_HD128 = 16;
static constexpr int BLOCK_N_HD128 = 32;
Expand Down
20 changes: 20 additions & 0 deletions src/kernels/matmul.cu
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,16 @@ template FlashAttentionError matmul_AB_acc<64, 64, 32>(const float*, const float
template FlashAttentionError matmul_AtB<64, 64, 32>(const float*, const float*, float*, float,
cudaStream_t);

// 32x32x32 (head_dim=32, compact tiles)
template FlashAttentionError matmul_ABt<32, 32, 32>(const float*, const float*, float*, float,
cudaStream_t);
template FlashAttentionError matmul_AB<32, 32, 32>(const float*, const float*, float*, float,
cudaStream_t);
template FlashAttentionError matmul_AB_acc<32, 32, 32>(const float*, const float*, float*, float,
cudaStream_t);
template FlashAttentionError matmul_AtB<32, 32, 32>(const float*, const float*, float*, float,
cudaStream_t);

// 64x64x64 (head_dim=64)
template FlashAttentionError matmul_ABt<64, 64, 64>(const float*, const float*, float*, float,
cudaStream_t);
Expand Down Expand Up @@ -278,6 +288,16 @@ template FlashAttentionError matmul_AtB<64, 64, 128>(const float*, const float*,
cudaStream_t);

// Additional sizes for flexibility
// 32x64x32 variants
template FlashAttentionError matmul_ABt<32, 64, 32>(const float*, const float*, float*, float,
cudaStream_t);
template FlashAttentionError matmul_AB<32, 64, 32>(const float*, const float*, float*, float,
cudaStream_t);
template FlashAttentionError matmul_AB_acc<32, 64, 32>(const float*, const float*, float*, float,
cudaStream_t);
template FlashAttentionError matmul_AtB<32, 64, 32>(const float*, const float*, float*, float,
cudaStream_t);

// 32x64 variants
template FlashAttentionError matmul_ABt<32, 64, 64>(const float*, const float*, float*, float,
cudaStream_t);
Expand Down
Loading
Loading