diff --git a/CMakeLists.txt b/CMakeLists.txt index cf80e37b5..7c05e34d1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,6 +23,7 @@ option(BUILD_TESTS "Compile the tests" OFF) option(BUILD_SHARED_LIBS "Build shared libraries" ON) option(WITH_TENSOR_PARALLEL "Compile with NCCL and MPI backend" OFF) option(WITH_FLASH_ATTN "Compile with Flash Attention 2" OFF) +set(FLASH_ATTN_HDIMS "" CACHE STRING "Head dimensions to compile for flash attention (e.g. '32;64'). Empty means all.") option(ENABLE_ADDRESS_SANITIZER "ASAN" OFF) MESSAGE(STATUS "Compiler Id: ${CMAKE_CXX_COMPILER_ID}") @@ -606,74 +607,36 @@ if (WITH_CUDA) endif() if (WITH_FLASH_ATTN) add_definitions(-DCT2_WITH_FLASH_ATTN) - list(APPEND SOURCES - src/ops/flash-attention/flash_fwd_hdim32_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim32_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim64_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim64_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim96_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim96_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim128_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim128_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim160_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim160_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim192_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim192_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim224_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim224_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim256_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim256_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim32_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim32_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim64_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim64_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim96_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim96_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim128_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim128_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim160_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim160_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim192_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim192_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim224_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu - ) - set_source_files_properties( - src/ops/flash-attention/flash_fwd_hdim32_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim32_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim64_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim64_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim96_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim96_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim128_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim128_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim160_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim160_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim192_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim192_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim224_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim224_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim256_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim256_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim32_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim32_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim64_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim64_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim96_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim96_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim128_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim128_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim160_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim160_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim192_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim192_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim224_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu + set(_ALL_FLASH_HDIMS 32 64 96 128 160 192 224 256) + if(FLASH_ATTN_HDIMS) + set(_FLASH_HDIMS ${FLASH_ATTN_HDIMS}) + else() + set(_FLASH_HDIMS ${_ALL_FLASH_HDIMS}) + endif() + + message(STATUS "Flash attention head dimensions: ${_FLASH_HDIMS}") + + # Define which hdims are compiled so HEADDIM_SWITCH can limit instantiation + foreach(_hdim ${_FLASH_HDIMS}) + add_definitions(-DCT2_FLASH_ATTN_HDIM_${_hdim}) + endforeach() + if(FLASH_ATTN_HDIMS) + add_definitions(-DCT2_FLASH_ATTN_HDIMS_RESTRICTED) + endif() + + set(_FLASH_ATTN_SOURCES "") + foreach(_hdim ${_FLASH_HDIMS}) + list(APPEND _FLASH_ATTN_SOURCES + src/ops/flash-attention/flash_fwd_hdim${_hdim}_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim${_hdim}_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim${_hdim}_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim${_hdim}_fp16_sm80.cu + ) + endforeach() + + list(APPEND SOURCES ${_FLASH_ATTN_SOURCES}) + set_source_files_properties(${_FLASH_ATTN_SOURCES} PROPERTIES COMPILE_FLAGS "--use_fast_math") endif() set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE) diff --git a/include/ctranslate2/ops/flash-attention/static_switch.h b/include/ctranslate2/ops/flash-attention/static_switch.h index 7b38de2d0..e1905cb75 100644 --- a/include/ctranslate2/ops/flash-attention/static_switch.h +++ b/include/ctranslate2/ops/flash-attention/static_switch.h @@ -78,31 +78,98 @@ } \ }() +// When FLASH_ATTN_HDIMS is restricted via cmake, only instantiate selected +// head dimensions. Others throw at runtime instead of generating link-time +// symbol references. CT2_FLASH_ATTN_HDIM_N is defined per compiled hdim. +#define _HEADDIM_DISPATCH(DIM, ...) \ + constexpr static int kHeadDim = DIM; \ + return __VA_ARGS__(); + +#define _HEADDIM_UNSUPPORTED(DIM) \ + throw std::runtime_error( \ + "Flash attention head dim " #DIM " not compiled. " \ + "Rebuild CTranslate2 with FLASH_ATTN_HDIMS including " #DIM); + +#ifndef CT2_FLASH_ATTN_HDIM_32 +#define _HEADDIM_CASE_32(...) _HEADDIM_UNSUPPORTED(32) +#else +#define _HEADDIM_CASE_32(...) _HEADDIM_DISPATCH(32, __VA_ARGS__) +#endif +#ifndef CT2_FLASH_ATTN_HDIM_64 +#define _HEADDIM_CASE_64(...) _HEADDIM_UNSUPPORTED(64) +#else +#define _HEADDIM_CASE_64(...) _HEADDIM_DISPATCH(64, __VA_ARGS__) +#endif +#ifndef CT2_FLASH_ATTN_HDIM_96 +#define _HEADDIM_CASE_96(...) _HEADDIM_UNSUPPORTED(96) +#else +#define _HEADDIM_CASE_96(...) _HEADDIM_DISPATCH(96, __VA_ARGS__) +#endif +#ifndef CT2_FLASH_ATTN_HDIM_128 +#define _HEADDIM_CASE_128(...) _HEADDIM_UNSUPPORTED(128) +#else +#define _HEADDIM_CASE_128(...) _HEADDIM_DISPATCH(128, __VA_ARGS__) +#endif +#ifndef CT2_FLASH_ATTN_HDIM_160 +#define _HEADDIM_CASE_160(...) _HEADDIM_UNSUPPORTED(160) +#else +#define _HEADDIM_CASE_160(...) _HEADDIM_DISPATCH(160, __VA_ARGS__) +#endif +#ifndef CT2_FLASH_ATTN_HDIM_192 +#define _HEADDIM_CASE_192(...) _HEADDIM_UNSUPPORTED(192) +#else +#define _HEADDIM_CASE_192(...) _HEADDIM_DISPATCH(192, __VA_ARGS__) +#endif +#ifndef CT2_FLASH_ATTN_HDIM_224 +#define _HEADDIM_CASE_224(...) _HEADDIM_UNSUPPORTED(224) +#else +#define _HEADDIM_CASE_224(...) _HEADDIM_DISPATCH(224, __VA_ARGS__) +#endif +#ifndef CT2_FLASH_ATTN_HDIM_256 +#define _HEADDIM_CASE_256(...) _HEADDIM_UNSUPPORTED(256) +#else +#define _HEADDIM_CASE_256(...) _HEADDIM_DISPATCH(256, __VA_ARGS__) +#endif + +// When all hdims are compiled (no FLASH_ATTN_HDIMS set), all CT2_FLASH_ATTN_HDIM_* +// macros are undefined and _HEADDIM_CASE_* defaults to _HEADDIM_UNSUPPORTED. +// Fix: when not restricted, define all as dispatching. +#ifndef CT2_FLASH_ATTN_HDIMS_RESTRICTED +#undef _HEADDIM_CASE_32 +#undef _HEADDIM_CASE_64 +#undef _HEADDIM_CASE_96 +#undef _HEADDIM_CASE_128 +#undef _HEADDIM_CASE_160 +#undef _HEADDIM_CASE_192 +#undef _HEADDIM_CASE_224 +#undef _HEADDIM_CASE_256 +#define _HEADDIM_CASE_32(...) _HEADDIM_DISPATCH(32, __VA_ARGS__) +#define _HEADDIM_CASE_64(...) _HEADDIM_DISPATCH(64, __VA_ARGS__) +#define _HEADDIM_CASE_96(...) _HEADDIM_DISPATCH(96, __VA_ARGS__) +#define _HEADDIM_CASE_128(...) _HEADDIM_DISPATCH(128, __VA_ARGS__) +#define _HEADDIM_CASE_160(...) _HEADDIM_DISPATCH(160, __VA_ARGS__) +#define _HEADDIM_CASE_192(...) _HEADDIM_DISPATCH(192, __VA_ARGS__) +#define _HEADDIM_CASE_224(...) _HEADDIM_DISPATCH(224, __VA_ARGS__) +#define _HEADDIM_CASE_256(...) _HEADDIM_DISPATCH(256, __VA_ARGS__) +#endif + #define HEADDIM_SWITCH(HEADDIM, ...) \ [&] { \ if (HEADDIM <= 32) { \ - constexpr static int kHeadDim = 32; \ - return __VA_ARGS__(); \ + _HEADDIM_CASE_32(__VA_ARGS__) \ } else if (HEADDIM <= 64) { \ - constexpr static int kHeadDim = 64; \ - return __VA_ARGS__(); \ + _HEADDIM_CASE_64(__VA_ARGS__) \ } else if (HEADDIM <= 96) { \ - constexpr static int kHeadDim = 96; \ - return __VA_ARGS__(); \ + _HEADDIM_CASE_96(__VA_ARGS__) \ } else if (HEADDIM <= 128) { \ - constexpr static int kHeadDim = 128; \ - return __VA_ARGS__(); \ + _HEADDIM_CASE_128(__VA_ARGS__) \ } else if (HEADDIM <= 160) { \ - constexpr static int kHeadDim = 160; \ - return __VA_ARGS__(); \ + _HEADDIM_CASE_160(__VA_ARGS__) \ } else if (HEADDIM <= 192) { \ - constexpr static int kHeadDim = 192; \ - return __VA_ARGS__(); \ + _HEADDIM_CASE_192(__VA_ARGS__) \ } else if (HEADDIM <= 224) { \ - constexpr static int kHeadDim = 224; \ - return __VA_ARGS__(); \ + _HEADDIM_CASE_224(__VA_ARGS__) \ } else if (HEADDIM <= 256) { \ - constexpr static int kHeadDim = 256; \ - return __VA_ARGS__(); \ + _HEADDIM_CASE_256(__VA_ARGS__) \ } \ }()