Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 30 additions & 67 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ option(BUILD_TESTS "Compile the tests" OFF)
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
option(WITH_TENSOR_PARALLEL "Compile with NCCL and MPI backend" OFF)
option(WITH_FLASH_ATTN "Compile with Flash Attention 2" OFF)
set(FLASH_ATTN_HDIMS "" CACHE STRING "Head dimensions to compile for flash attention (e.g. '32;64'). Empty means all.")
option(ENABLE_ADDRESS_SANITIZER "ASAN" OFF)

MESSAGE(STATUS "Compiler Id: ${CMAKE_CXX_COMPILER_ID}")
Expand Down Expand Up @@ -606,74 +607,36 @@ if (WITH_CUDA)
endif()
if (WITH_FLASH_ATTN)
add_definitions(-DCT2_WITH_FLASH_ATTN)
list(APPEND SOURCES
src/ops/flash-attention/flash_fwd_hdim32_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim32_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim64_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim64_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim96_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim96_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim128_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim128_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim160_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim160_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim192_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim192_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim224_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim224_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim256_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim256_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim32_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim32_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim64_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim64_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim96_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim96_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim128_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim128_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim160_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim160_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim192_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim192_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim224_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu
)

set_source_files_properties(
src/ops/flash-attention/flash_fwd_hdim32_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim32_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim64_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim64_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim96_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim96_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim128_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim128_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim160_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim160_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim192_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim192_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim224_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim224_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim256_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim256_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim32_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim32_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim64_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim64_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim96_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim96_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim128_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim128_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim160_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim160_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim192_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim192_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim224_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu
set(_ALL_FLASH_HDIMS 32 64 96 128 160 192 224 256)
if(FLASH_ATTN_HDIMS)
set(_FLASH_HDIMS ${FLASH_ATTN_HDIMS})
else()
set(_FLASH_HDIMS ${_ALL_FLASH_HDIMS})
endif()

message(STATUS "Flash attention head dimensions: ${_FLASH_HDIMS}")

# Define which hdims are compiled so HEADDIM_SWITCH can limit instantiation
foreach(_hdim ${_FLASH_HDIMS})
add_definitions(-DCT2_FLASH_ATTN_HDIM_${_hdim})
endforeach()
if(FLASH_ATTN_HDIMS)
add_definitions(-DCT2_FLASH_ATTN_HDIMS_RESTRICTED)
endif()

set(_FLASH_ATTN_SOURCES "")
foreach(_hdim ${_FLASH_HDIMS})
list(APPEND _FLASH_ATTN_SOURCES
src/ops/flash-attention/flash_fwd_hdim${_hdim}_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_hdim${_hdim}_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim${_hdim}_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim${_hdim}_fp16_sm80.cu
)
endforeach()

list(APPEND SOURCES ${_FLASH_ATTN_SOURCES})
set_source_files_properties(${_FLASH_ATTN_SOURCES}
PROPERTIES COMPILE_FLAGS "--use_fast_math")
endif()
set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE)
Expand Down
99 changes: 83 additions & 16 deletions include/ctranslate2/ops/flash-attention/static_switch.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,31 +78,98 @@
} \
}()

// When FLASH_ATTN_HDIMS is restricted via cmake, only instantiate selected
// head dimensions. Others throw at runtime instead of generating link-time
// symbol references. CT2_FLASH_ATTN_HDIM_N is defined per compiled hdim.
#define _HEADDIM_DISPATCH(DIM, ...) \
constexpr static int kHeadDim = DIM; \
return __VA_ARGS__();

#define _HEADDIM_UNSUPPORTED(DIM) \
throw std::runtime_error( \
"Flash attention head dim " #DIM " not compiled. " \
"Rebuild CTranslate2 with FLASH_ATTN_HDIMS including " #DIM);

#ifndef CT2_FLASH_ATTN_HDIM_32
#define _HEADDIM_CASE_32(...) _HEADDIM_UNSUPPORTED(32)
#else
#define _HEADDIM_CASE_32(...) _HEADDIM_DISPATCH(32, __VA_ARGS__)
#endif
#ifndef CT2_FLASH_ATTN_HDIM_64
#define _HEADDIM_CASE_64(...) _HEADDIM_UNSUPPORTED(64)
#else
#define _HEADDIM_CASE_64(...) _HEADDIM_DISPATCH(64, __VA_ARGS__)
#endif
#ifndef CT2_FLASH_ATTN_HDIM_96
#define _HEADDIM_CASE_96(...) _HEADDIM_UNSUPPORTED(96)
#else
#define _HEADDIM_CASE_96(...) _HEADDIM_DISPATCH(96, __VA_ARGS__)
#endif
#ifndef CT2_FLASH_ATTN_HDIM_128
#define _HEADDIM_CASE_128(...) _HEADDIM_UNSUPPORTED(128)
#else
#define _HEADDIM_CASE_128(...) _HEADDIM_DISPATCH(128, __VA_ARGS__)
#endif
#ifndef CT2_FLASH_ATTN_HDIM_160
#define _HEADDIM_CASE_160(...) _HEADDIM_UNSUPPORTED(160)
#else
#define _HEADDIM_CASE_160(...) _HEADDIM_DISPATCH(160, __VA_ARGS__)
#endif
#ifndef CT2_FLASH_ATTN_HDIM_192
#define _HEADDIM_CASE_192(...) _HEADDIM_UNSUPPORTED(192)
#else
#define _HEADDIM_CASE_192(...) _HEADDIM_DISPATCH(192, __VA_ARGS__)
#endif
#ifndef CT2_FLASH_ATTN_HDIM_224
#define _HEADDIM_CASE_224(...) _HEADDIM_UNSUPPORTED(224)
#else
#define _HEADDIM_CASE_224(...) _HEADDIM_DISPATCH(224, __VA_ARGS__)
#endif
#ifndef CT2_FLASH_ATTN_HDIM_256
#define _HEADDIM_CASE_256(...) _HEADDIM_UNSUPPORTED(256)
#else
#define _HEADDIM_CASE_256(...) _HEADDIM_DISPATCH(256, __VA_ARGS__)
#endif

// When all hdims are compiled (no FLASH_ATTN_HDIMS set), all CT2_FLASH_ATTN_HDIM_*
// macros are undefined and _HEADDIM_CASE_* defaults to _HEADDIM_UNSUPPORTED.
// Fix: when not restricted, define all as dispatching.
#ifndef CT2_FLASH_ATTN_HDIMS_RESTRICTED
#undef _HEADDIM_CASE_32
#undef _HEADDIM_CASE_64
#undef _HEADDIM_CASE_96
#undef _HEADDIM_CASE_128
#undef _HEADDIM_CASE_160
#undef _HEADDIM_CASE_192
#undef _HEADDIM_CASE_224
#undef _HEADDIM_CASE_256
#define _HEADDIM_CASE_32(...) _HEADDIM_DISPATCH(32, __VA_ARGS__)
#define _HEADDIM_CASE_64(...) _HEADDIM_DISPATCH(64, __VA_ARGS__)
#define _HEADDIM_CASE_96(...) _HEADDIM_DISPATCH(96, __VA_ARGS__)
#define _HEADDIM_CASE_128(...) _HEADDIM_DISPATCH(128, __VA_ARGS__)
#define _HEADDIM_CASE_160(...) _HEADDIM_DISPATCH(160, __VA_ARGS__)
#define _HEADDIM_CASE_192(...) _HEADDIM_DISPATCH(192, __VA_ARGS__)
#define _HEADDIM_CASE_224(...) _HEADDIM_DISPATCH(224, __VA_ARGS__)
#define _HEADDIM_CASE_256(...) _HEADDIM_DISPATCH(256, __VA_ARGS__)
#endif

#define HEADDIM_SWITCH(HEADDIM, ...) \
[&] { \
if (HEADDIM <= 32) { \
constexpr static int kHeadDim = 32; \
return __VA_ARGS__(); \
_HEADDIM_CASE_32(__VA_ARGS__) \
} else if (HEADDIM <= 64) { \
constexpr static int kHeadDim = 64; \
return __VA_ARGS__(); \
_HEADDIM_CASE_64(__VA_ARGS__) \
} else if (HEADDIM <= 96) { \
constexpr static int kHeadDim = 96; \
return __VA_ARGS__(); \
_HEADDIM_CASE_96(__VA_ARGS__) \
} else if (HEADDIM <= 128) { \
constexpr static int kHeadDim = 128; \
return __VA_ARGS__(); \
_HEADDIM_CASE_128(__VA_ARGS__) \
} else if (HEADDIM <= 160) { \
constexpr static int kHeadDim = 160; \
return __VA_ARGS__(); \
_HEADDIM_CASE_160(__VA_ARGS__) \
} else if (HEADDIM <= 192) { \
constexpr static int kHeadDim = 192; \
return __VA_ARGS__(); \
_HEADDIM_CASE_192(__VA_ARGS__) \
} else if (HEADDIM <= 224) { \
constexpr static int kHeadDim = 224; \
return __VA_ARGS__(); \
_HEADDIM_CASE_224(__VA_ARGS__) \
} else if (HEADDIM <= 256) { \
constexpr static int kHeadDim = 256; \
return __VA_ARGS__(); \
_HEADDIM_CASE_256(__VA_ARGS__) \
} \
}()
Loading