diff --git a/transformer_engine/common/ck_fused_attn/CMakeLists.txt b/transformer_engine/common/ck_fused_attn/CMakeLists.txt index 1c61596c0..beae0e9ce 100644 --- a/transformer_engine/common/ck_fused_attn/CMakeLists.txt +++ b/transformer_engine/common/ck_fused_attn/CMakeLists.txt @@ -32,8 +32,8 @@ message(STATUS "AITER V3_ASM_ARCHS: ${V3_ASM_ARCHS}") list(JOIN V3_ASM_ARCHS ";" V3_ASM_ARCHS_STR) if(DEFINED AITER_MHA_PATH) - message(STATUS "[AITER-PREBUILT] Using AITER_MHA_PATH=${AITER_MHA_PATH}") - # use pre-built libmha_fwd.so libmha_bwd.so + message(STATUS "[AITER-BUILD] Using AITER_MHA_PATH=${AITER_MHA_PATH}") + # use pre-built libraries set(__AITER_MHA_PATH ${AITER_MHA_PATH}) else() set(AITER_CACHE_VALID FALSE) @@ -46,10 +46,17 @@ else() if(NOT AITER_CACHE_VALID) # Try downloading prebuilt files if NVTE_AITER_PREBUILT_BASE_URL is set. download_aiter_prebuilt(AITER_PREBUILT_DOWNLOAD_SUCCESS) + if(AITER_PREBUILT_DOWNLOAD_SUCCESS) + is_aiter_cache_valid(AITER_DOWNLOAD_CACHE_VALID) + if(NOT AITER_DOWNLOAD_CACHE_VALID) + message(STATUS "[AITER-PREBUILT] Downloaded prebuilt cache invalid.") + set(AITER_PREBUILT_DOWNLOAD_SUCCESS FALSE) + endif() + endif() # If not downloaded, Fallback: Build from source if(NOT AITER_PREBUILT_DOWNLOAD_SUCCESS) - message(STATUS " [AITER-PREBUILT] Building aiter from source.") + message(STATUS " [AITER-BUILD] Building aiter from source.") execute_process( COMMAND bash ${CMAKE_CURRENT_LIST_DIR}/aiter_build.sh --aiter-dir ${__AITER_SOURCE_DIR} @@ -62,7 +69,7 @@ else() endif() endif() set(__AITER_MHA_PATH "${EXTRACT_DIR}") - message(STATUS "[AITER-PREBUILT] Using __AITER_MHA_PATH=${__AITER_MHA_PATH}") + message(STATUS "[AITER-BUILD] Using __AITER_MHA_PATH=${__AITER_MHA_PATH}") endif() set(ck_fused_attn_SOURCES) @@ -108,12 +115,15 @@ endif() target_include_directories(ck_fused_attn PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") target_include_directories(ck_fused_attn PRIVATE ${CK_INCLUDE_DIR} ${__CK_SOURCE_DIR}/example/ck_tile/01_fmha) target_include_directories(ck_fused_attn PRIVATE ${AITER_INCLUDE_DIR}) +target_link_options(ck_fused_attn PRIVATE -Wl,--exclude-libs,ALL) +set(__AITER_MHA_FWD_LIB "${__AITER_MHA_PATH}/libmha_fwd.a") +set(__AITER_MHA_BWD_LIB "${__AITER_MHA_PATH}/libmha_bwd.a") find_package(hip) -list(APPEND ck_fused_attn_LINKER_LIBS hip::host hip::device roctx64 ${__AITER_MHA_PATH}/libmha_fwd.so ${__AITER_MHA_PATH}/libmha_bwd.so) +list(APPEND ck_fused_attn_LINKER_LIBS hip::host hip::device roctx64 ${__AITER_MHA_FWD_LIB} ${__AITER_MHA_BWD_LIB}) + target_link_libraries(ck_fused_attn PUBLIC ${ck_fused_attn_LINKER_LIBS}) target_compile_options(ck_fused_attn PRIVATE ${CK_FUSED_ATTN_COMPILE_OPTIONS}) set_target_properties(ck_fused_attn PROPERTIES INSTALL_RPATH "$ORIGIN") -install(FILES ${__AITER_MHA_PATH}/libmha_fwd.so ${__AITER_MHA_PATH}/libmha_bwd.so DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib) install(TARGETS ck_fused_attn DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib) diff --git a/transformer_engine/common/ck_fused_attn/aiter_build.sh b/transformer_engine/common/ck_fused_attn/aiter_build.sh index 3ccf2979c..ec3b71071 100644 --- a/transformer_engine/common/ck_fused_attn/aiter_build.sh +++ b/transformer_engine/common/ck_fused_attn/aiter_build.sh @@ -32,7 +32,7 @@ while [[ $# -gt 0 ]]; do done if [[ -z "${AITER_DIR}" || -z "${AITER_TEST_DIR}" || -z "${GPU_ARCHS_VAL}" ]]; then - echo "[AITER-PREBUILT] --aiter-dir, --aiter-test-dir, and --gpu-archs are required." >&2 + echo "[AITER-BUILD] --aiter-dir, --aiter-test-dir, and --gpu-archs are required." >&2 exit 1 fi @@ -42,3 +42,48 @@ CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT="${CK_TILE_BF16_DEFAULT}" \ GPU_ARCHS="${GPU_ARCHS_VAL}" \ python3 "${AITER_TEST_DIR}/compile.py" +# Generate static archives from the built object files only if NVTE_AITER_STATIC_LINK=1 +if [[ "${NVTE_AITER_STATIC_LINK:-1}" -ne 1 ]]; then + exit 0 +fi + +# Check for ar and ranlib +AR_BIN="${AR:-$(command -v ar || true)}" +RANLIB_BIN="${RANLIB:-$(command -v ranlib || true)}" +if [[ -z "${AR_BIN}" ]]; then + echo "[AITER-BUILD] Could not find ar for static archive generation." >&2 + exit 1 +fi +if [[ -z "${RANLIB_BIN}" ]]; then + echo "[AITER-BUILD] Could not find ranlib for static archive generation." >&2 + exit 1 +fi + +# Create static archives for both forward and backward passes +for lib in fwd bwd; do + src_obj_dir="${AITER_DIR}/aiter/jit/build/libmha_${lib}/build" + out_archive="${AITER_TEST_DIR}/libmha_${lib}.a" + + if [[ ! -d "${src_obj_dir}" ]]; then + echo "[AITER-BUILD] Missing object directory: ${src_obj_dir}" >&2 + exit 1 + fi + + mapfile -d '' obj_files < <(find "${src_obj_dir}" -type f -name '*.o' -print0) + if [[ ${#obj_files[@]} -eq 0 ]]; then + echo "[AITER-BUILD] No object files found under ${src_obj_dir}" >&2 + exit 1 + fi + + total_objs=${#obj_files[@]} + + rm -f "${out_archive}" + "${AR_BIN}" q "${out_archive}" "${obj_files[@]}" + + if [[ -n "${RANLIB_BIN}" ]]; then + "${RANLIB_BIN}" "${out_archive}" + fi + + echo "[AITER-BUILD] Created static archive: ${out_archive} (${#obj_files[@]} objects)" +done + diff --git a/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake b/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake index a59605e00..6aa7a9d47 100644 --- a/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake +++ b/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake @@ -29,14 +29,20 @@ set(EXTRACT_DIR "${CACHE_ROOT}/${KEY}") # Validate existing cache path function(is_aiter_cache_valid CACHE_VALID) - if(EXISTS "${EXTRACT_DIR}/libmha_fwd.so" AND EXISTS "${EXTRACT_DIR}/libmha_bwd.so") + set(_AITER_CACHE_VALID TRUE) + + if(NOT (EXISTS "${EXTRACT_DIR}/libmha_fwd.a" AND EXISTS "${EXTRACT_DIR}/libmha_bwd.a")) + set(_AITER_CACHE_VALID FALSE) + endif() + + if(_AITER_CACHE_VALID) set(${CACHE_VALID} TRUE PARENT_SCOPE) message(STATUS "[AITER-PREBUILT] Found Cached build files at ${EXTRACT_DIR}") return() endif() # Cache is invalid/outdated - clean it - file(REMOVE_RECURSE "${CACHE_ROOT}") + file(REMOVE_RECURSE "${EXTRACT_DIR}") file(REMOVE_RECURSE "${CMAKE_BINARY_DIR}/_deps") endfunction() @@ -44,7 +50,10 @@ endfunction() function(cache_local_aiter_build SOURCE_DIR) file(MAKE_DIRECTORY "${EXTRACT_DIR}") message(STATUS "[AITER-PREBUILT] Caching locally built libs to ${EXTRACT_DIR}") - file(COPY "${SOURCE_DIR}/libmha_fwd.so" "${SOURCE_DIR}/libmha_bwd.so" DESTINATION "${EXTRACT_DIR}") + if(NOT EXISTS "${SOURCE_DIR}/libmha_fwd.a" OR NOT EXISTS "${SOURCE_DIR}/libmha_bwd.a") + message(FATAL_ERROR "Expected libmha_fwd.a and libmha_bwd.a under ${SOURCE_DIR}") + endif() + file(COPY "${SOURCE_DIR}/libmha_fwd.a" "${SOURCE_DIR}/libmha_bwd.a" DESTINATION "${EXTRACT_DIR}") endfunction() # Download prebuilt tgz file