diff --git a/setup.py b/setup.py index 2f3e3c2ab..91e522f40 100644 --- a/setup.py +++ b/setup.py @@ -79,6 +79,10 @@ def setup_common_extension() -> CMakeExtension: ck_path = Path(os.getenv("NVTE_CK_FUSED_ATTN_PATH")) cmake_flags.append(f"-DAITER_MHA_PATH={ck_path}") + if os.getenv("NVTE_AITER_STATIC_LINK") is not None: + aiter_static_link = "ON" if int(os.getenv("NVTE_AITER_STATIC_LINK", "1")) else "OFF" + cmake_flags.append(f"-DAITER_STATIC_LINK={aiter_static_link}") + if int(os.getenv("NVTE_FUSED_ATTN_AOTRITON", "1"))==0 or int(os.getenv("NVTE_FUSED_ATTN", "1"))==0: cmake_flags.append("-DUSE_FUSED_ATTN_AOTRITON=OFF") elif os.getenv("NVTE_FUSED_ATTN_AOTRITON") or os.getenv("NVTE_FUSED_ATTN"): diff --git a/transformer_engine/common/ck_fused_attn/CMakeLists.txt b/transformer_engine/common/ck_fused_attn/CMakeLists.txt index 1c61596c0..cbe7ae7ef 100644 --- a/transformer_engine/common/ck_fused_attn/CMakeLists.txt +++ b/transformer_engine/common/ck_fused_attn/CMakeLists.txt @@ -7,6 +7,7 @@ project(ck_fused_attn LANGUAGES HIP CXX) set(AITER_MHA_INSTALL_PREFIX "transformer_engine" CACHE STRING "aiter mha shared lib install prefix in TE") +option(AITER_STATIC_LINK "Statically link AITER MHA libs into ck_fused_attn" OFF) set(__AITER_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/aiter") set(__AITER_TEST_DIR "${__AITER_SOURCE_DIR}/op_tests/cpp/mha") @@ -32,8 +33,8 @@ message(STATUS "AITER V3_ASM_ARCHS: ${V3_ASM_ARCHS}") list(JOIN V3_ASM_ARCHS ";" V3_ASM_ARCHS_STR) if(DEFINED AITER_MHA_PATH) - message(STATUS "[AITER-PREBUILT] Using AITER_MHA_PATH=${AITER_MHA_PATH}") - # use pre-built libmha_fwd.so libmha_bwd.so + message(STATUS "[AITER-BUILD] Using AITER_MHA_PATH=${AITER_MHA_PATH}") + # use pre-built libraries set(__AITER_MHA_PATH ${AITER_MHA_PATH}) else() set(AITER_CACHE_VALID FALSE) @@ -49,7 +50,7 @@ else() # If not downloaded, Fallback: Build from source if(NOT AITER_PREBUILT_DOWNLOAD_SUCCESS) - message(STATUS " [AITER-PREBUILT] Building aiter from source.") + message(STATUS " [AITER-BUILD] Building aiter from source.") execute_process( COMMAND bash ${CMAKE_CURRENT_LIST_DIR}/aiter_build.sh --aiter-dir ${__AITER_SOURCE_DIR} @@ -62,7 +63,7 @@ else() endif() endif() set(__AITER_MHA_PATH "${EXTRACT_DIR}") - message(STATUS "[AITER-PREBUILT] Using __AITER_MHA_PATH=${__AITER_MHA_PATH}") + message(STATUS "[AITER-BUILD] Using __AITER_MHA_PATH=${__AITER_MHA_PATH}") endif() set(ck_fused_attn_SOURCES) @@ -110,10 +111,26 @@ target_include_directories(ck_fused_attn PRIVATE ${CK_INCLUDE_DIR} ${__CK_SOURCE target_include_directories(ck_fused_attn PRIVATE ${AITER_INCLUDE_DIR}) find_package(hip) -list(APPEND ck_fused_attn_LINKER_LIBS hip::host hip::device roctx64 ${__AITER_MHA_PATH}/libmha_fwd.so ${__AITER_MHA_PATH}/libmha_bwd.so) +list(APPEND ck_fused_attn_LINKER_LIBS hip::host hip::device roctx64) + +if(AITER_STATIC_LINK) + set(_AITER_LIB_EXT "a") + target_link_options(ck_fused_attn PRIVATE -Wl,--exclude-libs,ALL) + message(STATUS "Statically linking AITER MHA libs into ck_fused_attn") +else() + set(_AITER_LIB_EXT "so") + message(STATUS "Using dynamic AITER MHA libs for ck_fused_attn") +endif() + +set(__AITER_MHA_FWD_LIB "${__AITER_MHA_PATH}/libmha_fwd.${_AITER_LIB_EXT}") +set(__AITER_MHA_BWD_LIB "${__AITER_MHA_PATH}/libmha_bwd.${_AITER_LIB_EXT}") +list(APPEND ck_fused_attn_LINKER_LIBS ${__AITER_MHA_FWD_LIB} ${__AITER_MHA_BWD_LIB}) + target_link_libraries(ck_fused_attn PUBLIC ${ck_fused_attn_LINKER_LIBS}) target_compile_options(ck_fused_attn PRIVATE ${CK_FUSED_ATTN_COMPILE_OPTIONS}) set_target_properties(ck_fused_attn PROPERTIES INSTALL_RPATH "$ORIGIN") -install(FILES ${__AITER_MHA_PATH}/libmha_fwd.so ${__AITER_MHA_PATH}/libmha_bwd.so DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib) +if(NOT AITER_STATIC_LINK) + install(FILES ${__AITER_MHA_PATH}/libmha_fwd.so ${__AITER_MHA_PATH}/libmha_bwd.so DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib) +endif() install(TARGETS ck_fused_attn DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib) diff --git a/transformer_engine/common/ck_fused_attn/aiter_build.sh b/transformer_engine/common/ck_fused_attn/aiter_build.sh index 3ccf2979c..c8b96afbb 100644 --- a/transformer_engine/common/ck_fused_attn/aiter_build.sh +++ b/transformer_engine/common/ck_fused_attn/aiter_build.sh @@ -32,13 +32,58 @@ while [[ $# -gt 0 ]]; do done if [[ -z "${AITER_DIR}" || -z "${AITER_TEST_DIR}" || -z "${GPU_ARCHS_VAL}" ]]; then - echo "[AITER-PREBUILT] --aiter-dir, --aiter-test-dir, and --gpu-archs are required." >&2 + echo "[AITER-BUILD] --aiter-dir, --aiter-test-dir, and --gpu-archs are required." >&2 exit 1 fi -rm -rf "${AITER_DIR}/aiter/jit/build" -AITER_LOG_MORE=1 \ -CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT="${CK_TILE_BF16_DEFAULT}" \ -GPU_ARCHS="${GPU_ARCHS_VAL}" \ -python3 "${AITER_TEST_DIR}/compile.py" +# rm -rf "${AITER_DIR}/aiter/jit/build" +# AITER_LOG_MORE=1 \ +# CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT="${CK_TILE_BF16_DEFAULT}" \ +# GPU_ARCHS="${GPU_ARCHS_VAL}" \ +# python3 "${AITER_TEST_DIR}/compile.py" + +# Generate static archives from the built object files only if NVTE_AITER_STATIC_LINK=1 +if [[ "${NVTE_AITER_STATIC_LINK:-1}" -ne 1 ]]; then + exit 0 +fi + +# Check for ar and ranlib +AR_BIN="${AR:-$(command -v ar || true)}" +RANLIB_BIN="${RANLIB:-$(command -v ranlib || true)}" +if [[ -z "${AR_BIN}" ]]; then + echo "[AITER-BUILD] Could not find ar for static archive generation." >&2 + exit 1 +fi +if [[ -z "${RANLIB_BIN}" ]]; then + echo "[AITER-BUILD] Could not find ranlib for static archive generation." >&2 + exit 1 +fi + +# Create static archives for both forward and backward passes +for lib in fwd bwd; do + src_obj_dir="${AITER_DIR}/aiter/jit/build/libmha_${lib}/build" + out_archive="${AITER_TEST_DIR}/libmha_${lib}.a" + + if [[ ! -d "${src_obj_dir}" ]]; then + echo "[AITER-BUILD] Missing object directory: ${src_obj_dir}" >&2 + exit 1 + fi + + mapfile -d '' obj_files < <(find "${src_obj_dir}" -type f -name '*.o' -print0) + if [[ ${#obj_files[@]} -eq 0 ]]; then + echo "[AITER-BUILD] No object files found under ${src_obj_dir}" >&2 + exit 1 + fi + + total_objs=${#obj_files[@]} + + rm -f "${out_archive}" + "${AR_BIN}" q "${out_archive}" "${obj_files[@]}" + + if [[ -n "${RANLIB_BIN}" ]]; then + "${RANLIB_BIN}" "${out_archive}" + fi + + echo "[AITER-BUILD] Created static archive: ${out_archive} (${#obj_files[@]} objects)" +done diff --git a/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake b/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake index a59605e00..b19a3fa72 100644 --- a/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake +++ b/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake @@ -29,14 +29,25 @@ set(EXTRACT_DIR "${CACHE_ROOT}/${KEY}") # Validate existing cache path function(is_aiter_cache_valid CACHE_VALID) - if(EXISTS "${EXTRACT_DIR}/libmha_fwd.so" AND EXISTS "${EXTRACT_DIR}/libmha_bwd.so") + set(_AITER_CACHE_VALID TRUE) + if(AITER_STATIC_LINK) + set(_AITER_LIB_EXT "a") + else() + set(_AITER_LIB_EXT "so") + endif() + + if(NOT (EXISTS "${EXTRACT_DIR}/libmha_fwd.${_AITER_LIB_EXT}" AND EXISTS "${EXTRACT_DIR}/libmha_bwd.${_AITER_LIB_EXT}")) + set(_AITER_CACHE_VALID FALSE) + endif() + + if(_AITER_CACHE_VALID) set(${CACHE_VALID} TRUE PARENT_SCOPE) message(STATUS "[AITER-PREBUILT] Found Cached build files at ${EXTRACT_DIR}") return() endif() # Cache is invalid/outdated - clean it - file(REMOVE_RECURSE "${CACHE_ROOT}") + file(REMOVE_RECURSE "${EXTRACT_DIR}") file(REMOVE_RECURSE "${CMAKE_BINARY_DIR}/_deps") endfunction() @@ -45,6 +56,12 @@ function(cache_local_aiter_build SOURCE_DIR) file(MAKE_DIRECTORY "${EXTRACT_DIR}") message(STATUS "[AITER-PREBUILT] Caching locally built libs to ${EXTRACT_DIR}") file(COPY "${SOURCE_DIR}/libmha_fwd.so" "${SOURCE_DIR}/libmha_bwd.so" DESTINATION "${EXTRACT_DIR}") + if(AITER_STATIC_LINK) + if(NOT EXISTS "${SOURCE_DIR}/libmha_fwd.a" OR NOT EXISTS "${SOURCE_DIR}/libmha_bwd.a") + message(FATAL_ERROR "Expected libmha_fwd.a and libmha_bwd.a under ${SOURCE_DIR} for static link mode") + endif() + file(COPY "${SOURCE_DIR}/libmha_fwd.a" "${SOURCE_DIR}/libmha_bwd.a" DESTINATION "${EXTRACT_DIR}") + endif() endfunction() # Download prebuilt tgz file