Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ def setup_common_extension() -> CMakeExtension:
ck_path = Path(os.getenv("NVTE_CK_FUSED_ATTN_PATH"))
cmake_flags.append(f"-DAITER_MHA_PATH={ck_path}")

if os.getenv("NVTE_AITER_STATIC_LINK") is not None:
aiter_static_link = "ON" if int(os.getenv("NVTE_AITER_STATIC_LINK", "1")) else "OFF"
cmake_flags.append(f"-DAITER_STATIC_LINK={aiter_static_link}")

if int(os.getenv("NVTE_FUSED_ATTN_AOTRITON", "1"))==0 or int(os.getenv("NVTE_FUSED_ATTN", "1"))==0:
cmake_flags.append("-DUSE_FUSED_ATTN_AOTRITON=OFF")
elif os.getenv("NVTE_FUSED_ATTN_AOTRITON") or os.getenv("NVTE_FUSED_ATTN"):
Expand Down
29 changes: 23 additions & 6 deletions transformer_engine/common/ck_fused_attn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ project(ck_fused_attn LANGUAGES HIP CXX)


set(AITER_MHA_INSTALL_PREFIX "transformer_engine" CACHE STRING "aiter mha shared lib install prefix in TE")
option(AITER_STATIC_LINK "Statically link AITER MHA libs into ck_fused_attn" OFF)

set(__AITER_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/aiter")
set(__AITER_TEST_DIR "${__AITER_SOURCE_DIR}/op_tests/cpp/mha")
Expand All @@ -32,8 +33,8 @@ message(STATUS "AITER V3_ASM_ARCHS: ${V3_ASM_ARCHS}")
list(JOIN V3_ASM_ARCHS ";" V3_ASM_ARCHS_STR)

if(DEFINED AITER_MHA_PATH)
message(STATUS "[AITER-PREBUILT] Using AITER_MHA_PATH=${AITER_MHA_PATH}")
# use pre-built libmha_fwd.so libmha_bwd.so
message(STATUS "[AITER-BUILD] Using AITER_MHA_PATH=${AITER_MHA_PATH}")
# use pre-built libraries
set(__AITER_MHA_PATH ${AITER_MHA_PATH})
else()
set(AITER_CACHE_VALID FALSE)
Expand All @@ -49,7 +50,7 @@ else()

# If not downloaded, Fallback: Build from source
if(NOT AITER_PREBUILT_DOWNLOAD_SUCCESS)
message(STATUS " [AITER-PREBUILT] Building aiter from source.")
message(STATUS " [AITER-BUILD] Building aiter from source.")
execute_process(
COMMAND bash ${CMAKE_CURRENT_LIST_DIR}/aiter_build.sh
--aiter-dir ${__AITER_SOURCE_DIR}
Expand All @@ -62,7 +63,7 @@ else()
endif()
endif()
set(__AITER_MHA_PATH "${EXTRACT_DIR}")
message(STATUS "[AITER-PREBUILT] Using __AITER_MHA_PATH=${__AITER_MHA_PATH}")
message(STATUS "[AITER-BUILD] Using __AITER_MHA_PATH=${__AITER_MHA_PATH}")
endif()

set(ck_fused_attn_SOURCES)
Expand Down Expand Up @@ -110,10 +111,26 @@ target_include_directories(ck_fused_attn PRIVATE ${CK_INCLUDE_DIR} ${__CK_SOURCE
target_include_directories(ck_fused_attn PRIVATE ${AITER_INCLUDE_DIR})

find_package(hip)
list(APPEND ck_fused_attn_LINKER_LIBS hip::host hip::device roctx64 ${__AITER_MHA_PATH}/libmha_fwd.so ${__AITER_MHA_PATH}/libmha_bwd.so)
list(APPEND ck_fused_attn_LINKER_LIBS hip::host hip::device roctx64)

if(AITER_STATIC_LINK)
set(_AITER_LIB_EXT "a")
target_link_options(ck_fused_attn PRIVATE -Wl,--exclude-libs,ALL)
message(STATUS "Statically linking AITER MHA libs into ck_fused_attn")
else()
set(_AITER_LIB_EXT "so")
message(STATUS "Using dynamic AITER MHA libs for ck_fused_attn")
endif()

set(__AITER_MHA_FWD_LIB "${__AITER_MHA_PATH}/libmha_fwd.${_AITER_LIB_EXT}")
set(__AITER_MHA_BWD_LIB "${__AITER_MHA_PATH}/libmha_bwd.${_AITER_LIB_EXT}")
list(APPEND ck_fused_attn_LINKER_LIBS ${__AITER_MHA_FWD_LIB} ${__AITER_MHA_BWD_LIB})

target_link_libraries(ck_fused_attn PUBLIC ${ck_fused_attn_LINKER_LIBS})
target_compile_options(ck_fused_attn PRIVATE ${CK_FUSED_ATTN_COMPILE_OPTIONS})
set_target_properties(ck_fused_attn PROPERTIES INSTALL_RPATH "$ORIGIN")

install(FILES ${__AITER_MHA_PATH}/libmha_fwd.so ${__AITER_MHA_PATH}/libmha_bwd.so DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib)
if(NOT AITER_STATIC_LINK)
install(FILES ${__AITER_MHA_PATH}/libmha_fwd.so ${__AITER_MHA_PATH}/libmha_bwd.so DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib)
endif()
install(TARGETS ck_fused_attn DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib)
57 changes: 51 additions & 6 deletions transformer_engine/common/ck_fused_attn/aiter_build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,58 @@ while [[ $# -gt 0 ]]; do
done

if [[ -z "${AITER_DIR}" || -z "${AITER_TEST_DIR}" || -z "${GPU_ARCHS_VAL}" ]]; then
echo "[AITER-PREBUILT] --aiter-dir, --aiter-test-dir, and --gpu-archs are required." >&2
echo "[AITER-BUILD] --aiter-dir, --aiter-test-dir, and --gpu-archs are required." >&2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency it should probably also be changed in CMakeLists.txt

exit 1
fi

rm -rf "${AITER_DIR}/aiter/jit/build"
AITER_LOG_MORE=1 \
CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT="${CK_TILE_BF16_DEFAULT}" \
GPU_ARCHS="${GPU_ARCHS_VAL}" \
python3 "${AITER_TEST_DIR}/compile.py"
# rm -rf "${AITER_DIR}/aiter/jit/build"
# AITER_LOG_MORE=1 \
# CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT="${CK_TILE_BF16_DEFAULT}" \
# GPU_ARCHS="${GPU_ARCHS_VAL}" \
# python3 "${AITER_TEST_DIR}/compile.py"

# Generate static archives from the built object files only if NVTE_AITER_STATIC_LINK=1
if [[ "${NVTE_AITER_STATIC_LINK:-1}" -ne 1 ]]; then
exit 0
fi

# Check for ar and ranlib
AR_BIN="${AR:-$(command -v ar || true)}"
RANLIB_BIN="${RANLIB:-$(command -v ranlib || true)}"
if [[ -z "${AR_BIN}" ]]; then
echo "[AITER-BUILD] Could not find ar for static archive generation." >&2
exit 1
fi
if [[ -z "${RANLIB_BIN}" ]]; then
echo "[AITER-BUILD] Could not find ranlib for static archive generation." >&2
exit 1
fi

# Create static archives for both forward and backward passes
for lib in fwd bwd; do
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it is built and used by us, will it be more efficient to make single lib?

src_obj_dir="${AITER_DIR}/aiter/jit/build/libmha_${lib}/build"
out_archive="${AITER_TEST_DIR}/libmha_${lib}.a"

if [[ ! -d "${src_obj_dir}" ]]; then
echo "[AITER-BUILD] Missing object directory: ${src_obj_dir}" >&2
exit 1
fi

mapfile -d '' obj_files < <(find "${src_obj_dir}" -type f -name '*.o' -print0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why so complicated approach, just do what build systems do -> call ar cq @file_list

if [[ ${#obj_files[@]} -eq 0 ]]; then
echo "[AITER-BUILD] No object files found under ${src_obj_dir}" >&2
exit 1
fi

total_objs=${#obj_files[@]}

rm -f "${out_archive}"
"${AR_BIN}" q "${out_archive}" "${obj_files[@]}"

if [[ -n "${RANLIB_BIN}" ]]; then
"${RANLIB_BIN}" "${out_archive}"
fi

echo "[AITER-BUILD] Created static archive: ${out_archive} (${#obj_files[@]} objects)"
done

21 changes: 19 additions & 2 deletions transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,25 @@ set(EXTRACT_DIR "${CACHE_ROOT}/${KEY}")

# Validate existing cache path
function(is_aiter_cache_valid CACHE_VALID)
if(EXISTS "${EXTRACT_DIR}/libmha_fwd.so" AND EXISTS "${EXTRACT_DIR}/libmha_bwd.so")
set(_AITER_CACHE_VALID TRUE)
if(AITER_STATIC_LINK)
set(_AITER_LIB_EXT "a")
else()
set(_AITER_LIB_EXT "so")
endif()

if(NOT (EXISTS "${EXTRACT_DIR}/libmha_fwd.${_AITER_LIB_EXT}" AND EXISTS "${EXTRACT_DIR}/libmha_bwd.${_AITER_LIB_EXT}"))
set(_AITER_CACHE_VALID FALSE)
endif()

if(_AITER_CACHE_VALID)
set(${CACHE_VALID} TRUE PARENT_SCOPE)
message(STATUS "[AITER-PREBUILT] Found Cached build files at ${EXTRACT_DIR}")
return()
endif()

# Cache is invalid/outdated - clean it
file(REMOVE_RECURSE "${CACHE_ROOT}")
file(REMOVE_RECURSE "${EXTRACT_DIR}")
file(REMOVE_RECURSE "${CMAKE_BINARY_DIR}/_deps")
endfunction()

Expand All @@ -45,6 +56,12 @@ function(cache_local_aiter_build SOURCE_DIR)
file(MAKE_DIRECTORY "${EXTRACT_DIR}")
message(STATUS "[AITER-PREBUILT] Caching locally built libs to ${EXTRACT_DIR}")
file(COPY "${SOURCE_DIR}/libmha_fwd.so" "${SOURCE_DIR}/libmha_bwd.so" DESTINATION "${EXTRACT_DIR}")
if(AITER_STATIC_LINK)
if(NOT EXISTS "${SOURCE_DIR}/libmha_fwd.a" OR NOT EXISTS "${SOURCE_DIR}/libmha_bwd.a")
message(FATAL_ERROR "Expected libmha_fwd.a and libmha_bwd.a under ${SOURCE_DIR} for static link mode")
endif()
file(COPY "${SOURCE_DIR}/libmha_fwd.a" "${SOURCE_DIR}/libmha_bwd.a" DESTINATION "${EXTRACT_DIR}")
endif()
endfunction()

# Download prebuilt tgz file
Expand Down