diff --git a/.gitignore b/.gitignore index a667900..678d63e 100644 --- a/.gitignore +++ b/.gitignore @@ -77,8 +77,3 @@ venv/ # Benchmark outputs benchmark_results/ - -# Local AI preferences -CLAUDE.local.md -.claude/settings.local.json -.claude/skills/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 91d669a..50683e0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Simplified contributor guidance, Copilot instructions, and pull request metadata to match the real repository workflow. - Reduced GitHub Pages scope to product documentation; changelog history now lives only in the root `CHANGELOG.md`. - Removed docs-site spec mirrors, release-note mirrors, and AI planning artifacts that duplicated repository content. +- Removed stale README references to the deleted `AGENTS.md` workflow document and replaced leftover SDD branding with lean repository wording. +- Removed `.claude`/`CLAUDE.local.md` ignore rules that only served deleted AI tooling overlays. +- Simplified the GitHub Pages landing page links so the docs site no longer surfaces changelog navigation. +- Fixed CUDA preset validation on fresh Ubuntu environments by documenting and working through the required host-compiler/toolkit alignment. +- Fixed backward kernel dispatch for `head_dim=64` by using a smaller shared-memory tiling path that fits current CUDA limits. +- Fixed package-smoke consumption by exposing CUDA headers through the exported target and simplifying the downstream smoke project to a pure C++ consumer. +- Fixed the FP16 tile store test to validate half-rounding semantics instead of comparing against the original float with an unrealistically tight tolerance. +- Fixed the optional PyTorch comparison script to skip cleanly when `torch` is not installed. --- diff --git a/CMakeLists.txt b/CMakeLists.txt index 0b2d2ba..bd81218 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -47,6 +47,8 @@ add_library(cuflash_attn ${SOURCES}) target_link_libraries(cuflash_attn PUBLIC CUDA::cudart) target_include_directories(cuflash_attn PUBLIC $ + $ + $ $ ) # Private include directories for internal kernels @@ -229,6 +231,8 @@ if(BUILD_TESTING AND BUILD_SHARED_LIBS) -Dinstall_dir=${CUFLASH_PACKAGE_SMOKE_INSTALL_DIR} -Dgenerator=${CMAKE_GENERATOR} -Dbuild_type=${CMAKE_BUILD_TYPE} + -Dc_compiler=${CMAKE_C_COMPILER} + -Dcxx_compiler=${CMAKE_CXX_COMPILER} -Dcuda_architectures=${CMAKE_CUDA_ARCHITECTURES} -Dctest_command=${CMAKE_CTEST_COMMAND} -P ${CMAKE_SOURCE_DIR}/cmake/run_package_smoke.cmake @@ -244,7 +248,7 @@ if(BUILD_TESTING AND BUILD_SHARED_LIBS) ) set_tests_properties(cuflash_attn_pytorch_comparison PROPERTIES LABELS "integration;pytorch" - SKIP_RETURN_CODE 0 + SKIP_RETURN_CODE 77 ) endif() endif() diff --git a/README.md b/README.md index 8207870..e9df5da 100644 --- a/README.md +++ b/README.md @@ -383,8 +383,6 @@ clang-tidy src/api/flash_attention_api.cu -- -Iinclude 📋 **Detailed guidelines**: See [CONTRIBUTING.md](CONTRIBUTING.md) -🤖 **For AI Contributors**: Read [AGENTS.md](AGENTS.md) for SDD workflow instructions. - --- ## 📄 License @@ -408,5 +406,5 @@ See [CHANGELOG.md](CHANGELOG.md) for detailed version history and updates.

Built with ❤️ for efficient attention computation
- Spec-Driven Development · CUDA C++ · Open Source + Lean Reference Implementation · CUDA C++ · Open Source

diff --git a/README.zh-CN.md b/README.zh-CN.md index eac3c03..e326440 100644 --- a/README.zh-CN.md +++ b/README.zh-CN.md @@ -383,8 +383,6 @@ clang-tidy src/api/flash_attention_api.cu -- -Iinclude 📋 **详细指南**: 请参见 [CONTRIBUTING.md](CONTRIBUTING.md) -🤖 **AI 贡献者**: 阅读 [AGENTS.md](AGENTS.md) 了解 SDD 工作流说明。 - --- ## 📄 许可证 @@ -408,5 +406,5 @@ clang-tidy src/api/flash_attention_api.cu -- -Iinclude

用 ❤️ 打造的高效注意力计算
- 规范驱动开发 · CUDA C++ · 开源 + 精简参考实现 · CUDA C++ · 开源

diff --git a/cmake/run_package_smoke.cmake b/cmake/run_package_smoke.cmake index c2ab187..ddc846e 100644 --- a/cmake/run_package_smoke.cmake +++ b/cmake/run_package_smoke.cmake @@ -36,6 +36,12 @@ endif() if(DEFINED build_type AND NOT build_type STREQUAL "") list(APPEND configure_args -DCMAKE_BUILD_TYPE=${build_type}) endif() +if(DEFINED c_compiler AND NOT c_compiler STREQUAL "") + list(APPEND configure_args -DCMAKE_C_COMPILER=${c_compiler}) +endif() +if(DEFINED cxx_compiler AND NOT cxx_compiler STREQUAL "") + list(APPEND configure_args -DCMAKE_CXX_COMPILER=${cxx_compiler}) +endif() if(DEFINED cuda_architectures AND NOT cuda_architectures STREQUAL "") list(APPEND configure_args -DCMAKE_CUDA_ARCHITECTURES=${cuda_architectures}) endif() diff --git a/docs/index.md b/docs/index.md index 965ce32..0cea1d6 100644 --- a/docs/index.md +++ b/docs/index.md @@ -127,12 +127,6 @@ function setLanguage(lang) { Releases - - - - - Changelog - diff --git a/src/backward/flash_attention_backward_typed.cu b/src/backward/flash_attention_backward_typed.cu index fa71c70..be12d7f 100644 --- a/src/backward/flash_attention_backward_typed.cu +++ b/src/backward/flash_attention_backward_typed.cu @@ -302,6 +302,9 @@ template __global__ void flash_attention_backward_dq_kernel( template __global__ void flash_attention_backward_dq_kernel( const float*, const float*, const float*, const float*, const float*, const float*, float*, int, float, bool); +template __global__ void flash_attention_backward_dq_kernel( + const float*, const float*, const float*, const float*, const float*, const float*, float*, int, + float, bool); template __global__ void flash_attention_backward_dq_kernel( const float*, const float*, const float*, const float*, const float*, const float*, float*, int, float, bool); @@ -315,6 +318,9 @@ template __global__ void flash_attention_backward_dkdv_kernel template __global__ void flash_attention_backward_dkdv_kernel( const float*, const float*, const float*, const float*, const float*, const float*, float*, float*, int, float, bool); +template __global__ void flash_attention_backward_dkdv_kernel( + const float*, const float*, const float*, const float*, const float*, const float*, float*, + float*, int, float, bool); template __global__ void flash_attention_backward_dkdv_kernel( const float*, const float*, const float*, const float*, const float*, const float*, float*, float*, int, float, bool); @@ -333,6 +339,9 @@ template __global__ void flash_attention_backward_dq_kernel( template __global__ void flash_attention_backward_dq_kernel( const half*, const half*, const half*, const half*, const half*, const float*, half*, int, float, bool); +template __global__ void flash_attention_backward_dq_kernel( + const half*, const half*, const half*, const half*, const half*, const float*, half*, int, + float, bool); template __global__ void flash_attention_backward_dq_kernel( const half*, const half*, const half*, const half*, const half*, const float*, half*, int, float, bool); @@ -346,6 +355,9 @@ template __global__ void flash_attention_backward_dkdv_kernel( template __global__ void flash_attention_backward_dkdv_kernel( const half*, const half*, const half*, const half*, const half*, const float*, half*, half*, int, float, bool); +template __global__ void flash_attention_backward_dkdv_kernel( + const half*, const half*, const half*, const half*, const half*, const float*, half*, half*, + int, float, bool); template __global__ void flash_attention_backward_dkdv_kernel( const half*, const half*, const half*, const half*, const half*, const float*, half*, half*, int, float, bool); @@ -380,6 +392,8 @@ FlashAttentionError launch_flash_attention_backward_typed( using Config = impl::BackwardTilingConfig; constexpr int BLOCK_M = Config::BLOCK_M; constexpr int BLOCK_N = Config::BLOCK_N; + constexpr int BLOCK_M_HD64 = Config::BLOCK_M_HD64; + constexpr int BLOCK_N_HD64 = Config::BLOCK_N_HD64; constexpr int BLOCK_M_HD128 = Config::BLOCK_M_HD128; constexpr int BLOCK_N_HD128 = Config::BLOCK_N_HD128; @@ -416,10 +430,14 @@ FlashAttentionError launch_flash_attention_backward_typed( int num_q_blocks = (seq_len + BLOCK_M - 1) / BLOCK_M; int num_kv_blocks = (seq_len + BLOCK_N - 1) / BLOCK_N; + int num_q_blocks_hd64 = (seq_len + BLOCK_M_HD64 - 1) / BLOCK_M_HD64; + int num_kv_blocks_hd64 = (seq_len + BLOCK_N_HD64 - 1) / BLOCK_N_HD64; int num_q_blocks_hd128 = (seq_len + BLOCK_M_HD128 - 1) / BLOCK_M_HD128; int num_kv_blocks_hd128 = (seq_len + BLOCK_N_HD128 - 1) / BLOCK_N_HD128; dim3 dq_grid(num_q_blocks, batch_heads); dim3 dkdv_grid(num_kv_blocks, batch_heads); + dim3 dq_grid_hd64(num_q_blocks_hd64, batch_heads); + dim3 dkdv_grid_hd64(num_kv_blocks_hd64, batch_heads); dim3 dq_grid_hd128(num_q_blocks_hd128, batch_heads); dim3 dkdv_grid_hd128(num_kv_blocks_hd128, batch_heads); dim3 block(128); @@ -432,6 +450,16 @@ FlashAttentionError launch_flash_attention_backward_typed( (BLOCK_N * head_dim + BLOCK_N * head_dim + BLOCK_M * head_dim + BLOCK_M * head_dim + BLOCK_M * BLOCK_N + BLOCK_N * head_dim + BLOCK_N * head_dim + BLOCK_M + BLOCK_M) * sizeof(float); + size_t dq_smem_size_hd64 = + (BLOCK_M_HD64 * head_dim + BLOCK_M_HD64 * head_dim + BLOCK_N_HD64 * head_dim + + BLOCK_N_HD64 * head_dim + BLOCK_M_HD64 * BLOCK_N_HD64 + + BLOCK_M_HD64 * head_dim + BLOCK_M_HD64 + BLOCK_M_HD64) * + sizeof(float); + size_t dkdv_smem_size_hd64 = + (BLOCK_N_HD64 * head_dim + BLOCK_N_HD64 * head_dim + BLOCK_M_HD64 * head_dim + + BLOCK_M_HD64 * head_dim + BLOCK_M_HD64 * BLOCK_N_HD64 + + BLOCK_N_HD64 * head_dim + BLOCK_N_HD64 * head_dim + BLOCK_M_HD64 + BLOCK_M_HD64) * + sizeof(float); size_t dq_smem_size_hd128 = (BLOCK_M_HD128 * head_dim + BLOCK_M_HD128 * head_dim + BLOCK_N_HD128 * head_dim + BLOCK_N_HD128 * head_dim + BLOCK_M_HD128 * BLOCK_N_HD128 + BLOCK_M_HD128 * head_dim + @@ -472,27 +500,27 @@ FlashAttentionError launch_flash_attention_backward_typed( } else if (head_dim == 64) { status = prepare_dynamic_smem_launch( reinterpret_cast( - flash_attention_backward_dq_kernel), - dq_smem_size); + flash_attention_backward_dq_kernel), + dq_smem_size_hd64); if (status != FlashAttentionError::SUCCESS) return status; status = prepare_dynamic_smem_launch( reinterpret_cast( - flash_attention_backward_dkdv_kernel), - dkdv_smem_size); + flash_attention_backward_dkdv_kernel), + dkdv_smem_size_hd64); if (status != FlashAttentionError::SUCCESS) return status; - flash_attention_backward_dq_kernel - <<>>(Q, K, V, L, dO, D, dQ, seq_len, scale, - causal); + flash_attention_backward_dq_kernel + <<>>(Q, K, V, L, dO, D, dQ, seq_len, + scale, causal); err = cudaGetLastError(); if (err != cudaSuccess) return FlashAttentionError::CUDA_ERROR; - flash_attention_backward_dkdv_kernel - <<>>(Q, K, V, L, dO, D, dK, dV, seq_len, - scale, causal); + flash_attention_backward_dkdv_kernel + <<>>(Q, K, V, L, dO, D, dK, dV, + seq_len, scale, causal); } else if (head_dim == 128) { status = prepare_dynamic_smem_launch( reinterpret_cast( @@ -536,6 +564,8 @@ FlashAttentionError launch_flash_attention_backward_typed( using Config = impl::BackwardTilingConfig; constexpr int BLOCK_M = Config::BLOCK_M; constexpr int BLOCK_N = Config::BLOCK_N; + constexpr int BLOCK_M_HD64 = Config::BLOCK_M_HD64; + constexpr int BLOCK_N_HD64 = Config::BLOCK_N_HD64; constexpr int BLOCK_M_HD128 = Config::BLOCK_M_HD128; constexpr int BLOCK_N_HD128 = Config::BLOCK_N_HD128; @@ -571,10 +601,14 @@ FlashAttentionError launch_flash_attention_backward_typed( int num_q_blocks = (seq_len + BLOCK_M - 1) / BLOCK_M; int num_kv_blocks = (seq_len + BLOCK_N - 1) / BLOCK_N; + int num_q_blocks_hd64 = (seq_len + BLOCK_M_HD64 - 1) / BLOCK_M_HD64; + int num_kv_blocks_hd64 = (seq_len + BLOCK_N_HD64 - 1) / BLOCK_N_HD64; int num_q_blocks_hd128 = (seq_len + BLOCK_M_HD128 - 1) / BLOCK_M_HD128; int num_kv_blocks_hd128 = (seq_len + BLOCK_N_HD128 - 1) / BLOCK_N_HD128; dim3 dq_grid(num_q_blocks, batch_heads); dim3 dkdv_grid(num_kv_blocks, batch_heads); + dim3 dq_grid_hd64(num_q_blocks_hd64, batch_heads); + dim3 dkdv_grid_hd64(num_kv_blocks_hd64, batch_heads); dim3 dq_grid_hd128(num_q_blocks_hd128, batch_heads); dim3 dkdv_grid_hd128(num_kv_blocks_hd128, batch_heads); dim3 block(128); @@ -587,6 +621,16 @@ FlashAttentionError launch_flash_attention_backward_typed( (BLOCK_N * head_dim + BLOCK_N * head_dim + BLOCK_M * head_dim + BLOCK_M * head_dim + BLOCK_M * BLOCK_N + BLOCK_N * head_dim + BLOCK_N * head_dim + BLOCK_M + BLOCK_M) * sizeof(float); + size_t dq_smem_size_hd64 = + (BLOCK_M_HD64 * head_dim + BLOCK_M_HD64 * head_dim + BLOCK_N_HD64 * head_dim + + BLOCK_N_HD64 * head_dim + BLOCK_M_HD64 * BLOCK_N_HD64 + + BLOCK_M_HD64 * head_dim + BLOCK_M_HD64 + BLOCK_M_HD64) * + sizeof(float); + size_t dkdv_smem_size_hd64 = + (BLOCK_N_HD64 * head_dim + BLOCK_N_HD64 * head_dim + BLOCK_M_HD64 * head_dim + + BLOCK_M_HD64 * head_dim + BLOCK_M_HD64 * BLOCK_N_HD64 + + BLOCK_N_HD64 * head_dim + BLOCK_N_HD64 * head_dim + BLOCK_M_HD64 + BLOCK_M_HD64) * + sizeof(float); size_t dq_smem_size_hd128 = (BLOCK_M_HD128 * head_dim + BLOCK_M_HD128 * head_dim + BLOCK_N_HD128 * head_dim + BLOCK_N_HD128 * head_dim + BLOCK_M_HD128 * BLOCK_N_HD128 + BLOCK_M_HD128 * head_dim + @@ -627,27 +671,27 @@ FlashAttentionError launch_flash_attention_backward_typed( } else if (head_dim == 64) { status = prepare_dynamic_smem_launch( reinterpret_cast( - flash_attention_backward_dq_kernel), - dq_smem_size); + flash_attention_backward_dq_kernel), + dq_smem_size_hd64); if (status != FlashAttentionError::SUCCESS) return status; status = prepare_dynamic_smem_launch( reinterpret_cast( - flash_attention_backward_dkdv_kernel), - dkdv_smem_size); + flash_attention_backward_dkdv_kernel), + dkdv_smem_size_hd64); if (status != FlashAttentionError::SUCCESS) return status; - flash_attention_backward_dq_kernel - <<>>(Q, K, V, L, dO, D, dQ, seq_len, scale, - causal); + flash_attention_backward_dq_kernel + <<>>(Q, K, V, L, dO, D, dQ, seq_len, + scale, causal); err = cudaGetLastError(); if (err != cudaSuccess) return FlashAttentionError::CUDA_ERROR; - flash_attention_backward_dkdv_kernel - <<>>(Q, K, V, L, dO, D, dK, dV, seq_len, - scale, causal); + flash_attention_backward_dkdv_kernel + <<>>(Q, K, V, L, dO, D, dK, dV, + seq_len, scale, causal); } else if (head_dim == 128) { status = prepare_dynamic_smem_launch( reinterpret_cast( diff --git a/src/kernels/impl/tile_io.cuh b/src/kernels/impl/tile_io.cuh index d9ab4ad..b29315e 100644 --- a/src/kernels/impl/tile_io.cuh +++ b/src/kernels/impl/tile_io.cuh @@ -439,10 +439,14 @@ struct ForwardTilingConfig { /// Tiling configuration for backward pass. /// Uses smaller blocks to accommodate additional gradient tensors in shared memory. struct BackwardTilingConfig { - // Standard block sizes for head_dim 32 and 64 + // Standard block sizes for head_dim 32 static constexpr int BLOCK_M = 64; // Q block rows static constexpr int BLOCK_N = 64; // K/V block rows + // Smaller blocks for head_dim 64 to stay within dynamic shared memory limits. + static constexpr int BLOCK_M_HD64 = 32; + static constexpr int BLOCK_N_HD64 = 32; + // Smaller blocks for head_dim 128 (more aggressive due to dQ, dK, dV) static constexpr int BLOCK_M_HD128 = 16; static constexpr int BLOCK_N_HD128 = 32; diff --git a/src/kernels/matmul.cu b/src/kernels/matmul.cu index 6e95bfe..42e3d83 100644 --- a/src/kernels/matmul.cu +++ b/src/kernels/matmul.cu @@ -247,6 +247,16 @@ template FlashAttentionError matmul_AB_acc<64, 64, 32>(const float*, const float template FlashAttentionError matmul_AtB<64, 64, 32>(const float*, const float*, float*, float, cudaStream_t); +// 32x32x32 (head_dim=32, compact tiles) +template FlashAttentionError matmul_ABt<32, 32, 32>(const float*, const float*, float*, float, + cudaStream_t); +template FlashAttentionError matmul_AB<32, 32, 32>(const float*, const float*, float*, float, + cudaStream_t); +template FlashAttentionError matmul_AB_acc<32, 32, 32>(const float*, const float*, float*, float, + cudaStream_t); +template FlashAttentionError matmul_AtB<32, 32, 32>(const float*, const float*, float*, float, + cudaStream_t); + // 64x64x64 (head_dim=64) template FlashAttentionError matmul_ABt<64, 64, 64>(const float*, const float*, float*, float, cudaStream_t); @@ -278,6 +288,16 @@ template FlashAttentionError matmul_AtB<64, 64, 128>(const float*, const float*, cudaStream_t); // Additional sizes for flexibility +// 32x64x32 variants +template FlashAttentionError matmul_ABt<32, 64, 32>(const float*, const float*, float*, float, + cudaStream_t); +template FlashAttentionError matmul_AB<32, 64, 32>(const float*, const float*, float*, float, + cudaStream_t); +template FlashAttentionError matmul_AB_acc<32, 64, 32>(const float*, const float*, float*, float, + cudaStream_t); +template FlashAttentionError matmul_AtB<32, 64, 32>(const float*, const float*, float*, float, + cudaStream_t); + // 32x64 variants template FlashAttentionError matmul_ABt<32, 64, 64>(const float*, const float*, float*, float, cudaStream_t); diff --git a/tests/integration/test_pytorch_comparison.py b/tests/integration/test_pytorch_comparison.py index 43921e7..728bfdc 100644 --- a/tests/integration/test_pytorch_comparison.py +++ b/tests/integration/test_pytorch_comparison.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +from __future__ import annotations + """ PyTorch Comparison Tests for CuFlash-Attn Feature: cuflash-attn @@ -15,12 +17,21 @@ from typing import Optional, Sequence, Tuple import numpy as np -import torch -import torch.nn.functional as F + +try: + import torch + import torch.nn.functional as F +except ModuleNotFoundError as import_error: + torch = None + F = None + TORCH_IMPORT_ERROR = import_error +else: + TORCH_IMPORT_ERROR = None SUCCESS = 0 UNSUPPORTED_DTYPE = 7 +SKIP = 77 class CuFlashLibrary: @@ -104,8 +115,12 @@ def _script_dir() -> str: return os.path.dirname(os.path.abspath(__file__)) +def _repo_root() -> str: + return os.path.abspath(os.path.join(_script_dir(), "..", "..")) + + def _candidate_library_paths() -> Sequence[str]: - root = os.path.abspath(os.path.join(_script_dir(), "..")) + root = _repo_root() build_dir = os.path.join(root, "build") candidates = [] @@ -504,13 +519,17 @@ def main(): print("CuFlash-Attn PyTorch Comparison Tests") print("=" * 60) + if TORCH_IMPORT_ERROR is not None: + print(f"PyTorch not available, skipping tests: {TORCH_IMPORT_ERROR}") + return SKIP + if not torch.cuda.is_available(): print("CUDA not available, skipping tests") - return + return SKIP library = load_library() if library is None: - return + return SKIP print(f"Using CUDA device: {torch.cuda.get_device_name(0)}") print() @@ -539,7 +558,8 @@ def main(): print("=" * 60) print(f"Results: {passed} passed, {failed} failed") print("=" * 60) + return 1 if failed else 0 if __name__ == "__main__": - main() + raise SystemExit(main()) diff --git a/tests/package_smoke/CMakeLists.txt b/tests/package_smoke/CMakeLists.txt index c7b10c9..740a109 100644 --- a/tests/package_smoke/CMakeLists.txt +++ b/tests/package_smoke/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required(VERSION 3.18) -project(cuflash_attn_package_smoke LANGUAGES CXX CUDA) +project(cuflash_attn_package_smoke LANGUAGES CXX) find_package(cuflash_attn REQUIRED CONFIG) diff --git a/tests/unit/test_tile_io.cu b/tests/unit/test_tile_io.cu index 175c573..e078e98 100644 --- a/tests/unit/test_tile_io.cu +++ b/tests/unit/test_tile_io.cu @@ -277,8 +277,9 @@ TEST_F(TileIOTest, StoreTileFP16_32x128) { // Verify conversion and storage for (int i = 0; i < BLOCK_ROWS * BLOCK_COLS; i++) { - float expected = __half2float(h_dst[i]); - EXPECT_NEAR(expected, h_tile[i], 1e-3f) << "Mismatch at index " << i; + float expected = __half2float(__float2half(h_tile[i])); + float actual = __half2float(h_dst[i]); + EXPECT_FLOAT_EQ(actual, expected) << "Mismatch at index " << i; } cudaFree(d_tile);