From e8c7ee4eae1a536213c03b2b9c3cbcd411be5533 Mon Sep 17 00:00:00 2001 From: Yankui Wang Date: Sun, 1 Feb 2026 14:54:41 +0800 Subject: [PATCH] AMD Developer Challenge 2025: Distributed Inference Co-authored-by: Yingyi Hao --- README.md | 2 + dist-infer/.clang-format | 3 + dist-infer/.clangd | 3 + dist-infer/.gitignore | 11 + dist-infer/ag-gemm/CMakeLists.txt | 21 + dist-infer/ag-gemm/Makefile | 39 + dist-infer/ag-gemm/ag_gemm.cpp | 605 ++++++++++++ dist-infer/ag-gemm/benchmark.txt | 1 + dist-infer/ag-gemm/ck_gemm.h | 352 +++++++ dist-infer/ag-gemm/eval.py | 578 ++++++++++++ dist-infer/ag-gemm/perf_gemm.cc | 1074 +++++++++++++++++++++ dist-infer/ag-gemm/perf_gemm.h | 454 +++++++++ dist-infer/ag-gemm/reference.py | 70 ++ dist-infer/ag-gemm/submit.py | 93 ++ dist-infer/ag-gemm/task.py | 14 + dist-infer/ag-gemm/template.py | 372 ++++++++ dist-infer/ag-gemm/utils.py | 176 ++++ dist-infer/all2all/CMakeLists.txt | 20 + dist-infer/all2all/Makefile | 39 + dist-infer/all2all/all2all.cpp | 692 ++++++++++++++ dist-infer/all2all/benchmark.txt | 5 + dist-infer/all2all/eval.py | 580 ++++++++++++ dist-infer/all2all/reference.py | 285 ++++++ dist-infer/all2all/submit.py | 89 ++ dist-infer/all2all/task.py | 17 + dist-infer/all2all/template.py | 335 +++++++ dist-infer/all2all/utils.py | 176 ++++ dist-infer/compose.yml | 11 + dist-infer/gemm-rs/.gitignore | 12 + dist-infer/gemm-rs/CMakeLists.txt | 36 + dist-infer/gemm-rs/README.md | 18 + dist-infer/gemm-rs/benchmark_gemm.py | 121 +++ dist-infer/gemm-rs/benchmark_rs.py | 106 +++ dist-infer/gemm-rs/gen_submission.py | 32 + dist-infer/gemm-rs/local_test.py | 344 +++++++ dist-infer/gemm-rs/requirements.txt | 12 + dist-infer/gemm-rs/src/common.h | 55 ++ dist-infer/gemm-rs/src/gemm_rs.cc | 130 +++ dist-infer/gemm-rs/src/gemm_rs_kernel.cc | 648 +++++++++++++ dist-infer/gemm-rs/src/gemm_rs_kernel.h | 20 + dist-infer/gemm-rs/src/perf_gemm.cc | 1085 ++++++++++++++++++++++ dist-infer/gemm-rs/src/perf_gemm.h | 303 ++++++ dist-infer/gemm-rs/submit.py | 49 + dist-infer/gemm-rs/task.py | 14 + dist-infer/gemm-rs/task.yml | 85 ++ dist-infer/gemm-rs/template.py | 241 +++++ 46 files changed, 9428 insertions(+) create mode 100644 dist-infer/.clang-format create mode 100644 dist-infer/.clangd create mode 100644 dist-infer/.gitignore create mode 100644 dist-infer/ag-gemm/CMakeLists.txt create mode 100644 dist-infer/ag-gemm/Makefile create mode 100644 dist-infer/ag-gemm/ag_gemm.cpp create mode 100644 dist-infer/ag-gemm/benchmark.txt create mode 100644 dist-infer/ag-gemm/ck_gemm.h create mode 100644 dist-infer/ag-gemm/eval.py create mode 100644 dist-infer/ag-gemm/perf_gemm.cc create mode 100644 dist-infer/ag-gemm/perf_gemm.h create mode 100644 dist-infer/ag-gemm/reference.py create mode 100644 dist-infer/ag-gemm/submit.py create mode 100644 dist-infer/ag-gemm/task.py create mode 100644 dist-infer/ag-gemm/template.py create mode 100644 dist-infer/ag-gemm/utils.py create mode 100644 dist-infer/all2all/CMakeLists.txt create mode 100644 dist-infer/all2all/Makefile create mode 100644 dist-infer/all2all/all2all.cpp create mode 100644 dist-infer/all2all/benchmark.txt create mode 100644 dist-infer/all2all/eval.py create mode 100644 dist-infer/all2all/reference.py create mode 100644 dist-infer/all2all/submit.py create mode 100644 dist-infer/all2all/task.py create mode 100644 dist-infer/all2all/template.py create mode 100644 dist-infer/all2all/utils.py create mode 100644 dist-infer/compose.yml create mode 100644 dist-infer/gemm-rs/.gitignore create mode 100644 dist-infer/gemm-rs/CMakeLists.txt create mode 100644 dist-infer/gemm-rs/README.md create mode 100644 dist-infer/gemm-rs/benchmark_gemm.py create mode 100644 dist-infer/gemm-rs/benchmark_rs.py create mode 100644 dist-infer/gemm-rs/gen_submission.py create mode 100644 dist-infer/gemm-rs/local_test.py create mode 100644 dist-infer/gemm-rs/requirements.txt create mode 100644 dist-infer/gemm-rs/src/common.h create mode 100644 dist-infer/gemm-rs/src/gemm_rs.cc create mode 100644 dist-infer/gemm-rs/src/gemm_rs_kernel.cc create mode 100644 dist-infer/gemm-rs/src/gemm_rs_kernel.h create mode 100644 dist-infer/gemm-rs/src/perf_gemm.cc create mode 100644 dist-infer/gemm-rs/src/perf_gemm.h create mode 100644 dist-infer/gemm-rs/submit.py create mode 100644 dist-infer/gemm-rs/task.py create mode 100644 dist-infer/gemm-rs/task.yml create mode 100644 dist-infer/gemm-rs/template.py diff --git a/README.md b/README.md index 28d6dc3..c5c6eb6 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,9 @@ 🏆 Grand Prize Winner Project for AMD Developer Challenge 2025 +- [2026/02] All2All, GEMM-RS, and AG-GEMM kernels are now open-sourced for [AMD Developer Challenge 2025: Distributed Inference](https://amdchallenge2025.datamonsters.com/) ### News 🔥 + - [2025/11] [Easily Build and Share ROCm Kernels with Hugging Face](https://huggingface.co/blog/build-rocm-kernels) – Discover how to easily build and share RadeonFlow_Kernels using [Hugging Face's kernel-builder](https://github.com/huggingface/kernel-builder). - [2025/06] [AMD Developer Cloud](https://www.amd.com/en/developer/resources/cloud-access/amd-developer-cloud.html) now provides free AMD Instinct Accelerators, you can try out our project using their MI300X. diff --git a/dist-infer/.clang-format b/dist-infer/.clang-format new file mode 100644 index 0000000..04b00f6 --- /dev/null +++ b/dist-infer/.clang-format @@ -0,0 +1,3 @@ +IndentWidth: 4 +AlignAfterOpenBracket: BlockIndent +PackConstructorInitializers: CurrentLine diff --git a/dist-infer/.clangd b/dist-infer/.clangd new file mode 100644 index 0000000..e28adc6 --- /dev/null +++ b/dist-infer/.clangd @@ -0,0 +1,3 @@ +CompileFlags: + Add: ["-xhip", "--rocm-path=/opt/rocm"] + Remove: ["-x*"] diff --git a/dist-infer/.gitignore b/dist-infer/.gitignore new file mode 100644 index 0000000..a040d28 --- /dev/null +++ b/dist-infer/.gitignore @@ -0,0 +1,11 @@ +logs/ +build/ +.cache/ +.vscode-server/ +.popcorn.yaml +*.cu +*.hip +*.lock +submission.py +torch-build/ +__pycache__ \ No newline at end of file diff --git a/dist-infer/ag-gemm/CMakeLists.txt b/dist-infer/ag-gemm/CMakeLists.txt new file mode 100644 index 0000000..21e4bf3 --- /dev/null +++ b/dist-infer/ag-gemm/CMakeLists.txt @@ -0,0 +1,21 @@ +cmake_minimum_required(VERSION 3.21) +cmake_policy(VERSION 3.21.3...3.27) +set(PROJECT_NAME "ag_gemm") +project(${PROJECT_NAME} LANGUAGES HIP CXX) + +find_package(Python3 REQUIRED COMPONENTS Development Interpreter) +find_package(Torch CONFIG REQUIRED) +find_package(HIP CONFIG REQUIRED) + +# required for python binding +find_library(TORCH_PYTHON_LIBRARY torch_python PATH ${TORCH_INSTALL_PREFIX}/lib) + +add_library(${PROJECT_NAME} SHARED perf_gemm.cc) +set_source_files_properties(ag_gemm.cpp PROPERTIES LANGUAGE HIP) +target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_PYTHON_LIBRARY} ${TORCH_LIBRARIES} hip::device Python3::Python) +set_target_properties(${PROJECT_NAME} PROPERTIES PREFIX "") +target_compile_features(${PROJECT_NAME} PRIVATE cxx_std_20) + +# for host compile +target_compile_definitions(${PROJECT_NAME} PRIVATE -D__${CMAKE_HIP_ARCHITECTURES}__) +target_compile_options(${PROJECT_NAME} PRIVATE -save-temps) diff --git a/dist-infer/ag-gemm/Makefile b/dist-infer/ag-gemm/Makefile new file mode 100644 index 0000000..22cb558 --- /dev/null +++ b/dist-infer/ag-gemm/Makefile @@ -0,0 +1,39 @@ +.PHONY: all config build + +TARGET ?= gfx942 + +BUILD_TYPE ?= RelWithDebInfo +BUILD_DIR ?= build + +PYTHON_DIR ?= $(shell python -c "import site; print(site.getsitepackages()[0])") + +all: config build submit + +config: + PYTORCH_ROCM_ARCH=$(TARGET) cmake -B $(BUILD_DIR) . \ + -DCMAKE_BUILD_TYPE=$(BUILD_TYPE) \ + -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ + -DCMAKE_PREFIX_PATH="/opt/rocm;$(PYTHON_DIR)" \ + -DCMAKE_HIP_ARCHITECTURES=$(TARGET) \ + -DGPU_TARGETS=$(TARGET) \ + -DAMDGPU_TARGETS=$(TARGET) \ + -G Ninja + +build: + cmake --build $(BUILD_DIR) -j8 + +test: build + PYTHONPATH=$(PYTHONPATH):$(realpath tools):$(realpath $(BUILD_DIR)) python tools/smoke_test.py + +clean: + rm -r $(BUILD_DIR) + +local: + python submit.py local_test + POPCORN_GPUS=2 POPCORN_FD=2 python eval.py benchmark benchmark.txt + +submit: + python submit.py + +dis: + roc-obj -d build/ag_gemm.so diff --git a/dist-infer/ag-gemm/ag_gemm.cpp b/dist-infer/ag-gemm/ag_gemm.cpp new file mode 100644 index 0000000..70a011c --- /dev/null +++ b/dist-infer/ag-gemm/ag_gemm.cpp @@ -0,0 +1,605 @@ +// remove pytorch restriction +#undef __HIP_NO_HALF_OPERATORS__ +#undef __HIP_NO_HALF_CONVERSIONS__ +#include + +#include +#include +#include +#include + +#include +#include + +#include "perf_gemm.h" +#include +#include + +#include "rocwmma/rocwmma.hpp" +#include "rocwmma/rocwmma_coop.hpp" + +namespace mma = rocwmma; +using f16 = mma::float16_t; +using b16 = mma::bfloat16_t; +using f32 = mma::float32_t; +using i32 = mma::int32_t; +using i64 = mma::int64_t; + +#define USE_DBG 0 +#define USE_ASSERT 0 + +#define DO_PRAGMA_(x) _Pragma(#x) +#define DO_PRAGMA(x) DO_PRAGMA_(x) +#define UNROLL DO_PRAGMA(unroll) +#define UNROLL_N(n) DO_PRAGMA(unroll n) +#define STR(x) #x +#define TO_STR(x) STR(x) + +#define ASSERT(cond) \ + do { \ + if (USE_ASSERT && !(cond)) { \ + __assert_fail(#cond, __FILE__, __LINE__, __PRETTY_FUNCTION__); \ + } \ + } while (0) + +#define DBG(fmt, ...) \ + do { \ + if (USE_DBG && threadIdx.x % 64 == 0) { \ + printf(fmt "\n", ##__VA_ARGS__); \ + } \ + } while (0) + +#define HOST_DBG(fmt, ...) \ + do { \ + fprintf(stderr, fmt "\n", ##__VA_ARGS__); \ + } while (0) + +#define HIP_CHECK(call) \ + do { \ + hipError_t err = (call); \ + if (err != hipSuccess) { \ + fprintf( \ + stderr, "HIP error: %s (%d)\n at %s:%d\n", \ + hipGetErrorString(err), err, __FILE__, __LINE__ \ + ); \ + } \ + } while (0) + +constexpr i32 WORLD_SIZE = 2; +constexpr i32 WARP_SIZE = 64; +constexpr i32 BLOCK_SIZE = 512; +constexpr i32 NUM_SMS = 304; + +template constexpr T ceil_div(T a, T b) { return (a + b - 1) / b; } + +template struct vec_t { + using type = __attribute__((__vector_size__(N))) T; + static_assert(N % sizeof(T) == 0); + constexpr static i32 nelem = N / sizeof(T); + constexpr static i32 nelem_per_warp = WARP_SIZE * nelem; + + static __device__ void copy(T *dst, const T *src) { + auto val = + __builtin_nontemporal_load(reinterpret_cast(src)); + __builtin_nontemporal_store(val, reinterpret_cast(dst)); + } + + template + __device__ static inline void warp_copy(T *dst, const T *src) { + static_assert(N_ELEM % nelem_per_warp == 0); + const auto lane_id = threadIdx.x % WARP_SIZE; + UNROLL + for (int i = 0; i < N_ELEM / nelem_per_warp; i++) { + auto src_ptr = reinterpret_cast( + src + i * nelem_per_warp + lane_id * nelem + ); + auto val = __builtin_nontemporal_load(src_ptr); + auto dst_ptr = reinterpret_cast( + dst + i * nelem_per_warp + lane_id * nelem + ); + __builtin_nontemporal_store(val, dst_ptr); + } + } + + template using accum_type = type[N_ELEM / nelem_per_warp]; + + template + __device__ static inline void + warp_accum(accum_type &acc, const T *src, f32 weight) { + static_assert(N_ELEM % nelem_per_warp == 0); + const auto lane_id = threadIdx.x % WARP_SIZE; + UNROLL + for (int i = 0; i < N_ELEM / nelem_per_warp; i++) { + auto ptr = reinterpret_cast( + src + i * nelem_per_warp + lane_id * nelem + ); + auto val = __builtin_nontemporal_load(ptr); + UNROLL + for (int j = 0; j < nelem; j++) { + acc[i][j] += val[j] * weight; + } + } + } + + template + __device__ static inline void + warp_accum_store(T *dst, accum_type &acc) { + static_assert(N_ELEM % nelem_per_warp == 0); + const auto lane_id = threadIdx.x % WARP_SIZE; + UNROLL + for (int i = 0; i < N_ELEM / nelem_per_warp; i++) { + auto ptr = reinterpret_cast( + dst + i * nelem_per_warp + lane_id * nelem + ); + __builtin_nontemporal_store(acc[i], ptr); + } + } + + template + __device__ static inline void + warp_load(accum_type &acc, const T *src) { + static_assert(N_ELEM % nelem_per_warp == 0); + const auto lane_id = threadIdx.x % WARP_SIZE; + UNROLL + for (int i = 0; i < N_ELEM / nelem_per_warp; i++) { + auto src_ptr = reinterpret_cast( + src + i * nelem_per_warp + lane_id * nelem + ); + if constexpr (NT) { + acc[i] = __builtin_nontemporal_load(src_ptr); + } else { + acc[i] = *src_ptr; + } + } + } + + template + __device__ static inline void + warp_store(T *dst, const accum_type &acc) { + static_assert(N_ELEM % nelem_per_warp == 0); + const auto lane_id = threadIdx.x % WARP_SIZE; + UNROLL + for (int i = 0; i < N_ELEM / nelem_per_warp; i++) { + auto dst_ptr = reinterpret_cast( + dst + i * nelem_per_warp + lane_id * nelem + ); + if constexpr (NT) { + __builtin_nontemporal_store(acc[i], dst_ptr); + } else { + *dst_ptr = acc; + } + } + } +}; + +constexpr i32 MAX_M = 8192; +constexpr i32 MAX_N = 29568; +constexpr i32 MAX_K = 8192; +constexpr i32 MAX_M_LOCAL = MAX_M / WORLD_SIZE; + +constexpr i32 CHUNK_K = 512; +constexpr i32 CHUNK_M = 256; +constexpr i32 MAX_NUM_CHUNKS_K = ceil_div(MAX_K, CHUNK_K); +constexpr i32 MAX_NUM_CHUNKS_M = ceil_div(MAX_M, CHUNK_M); + +struct workspace_t { + i32 grid_barrier; +}; + +// global variables +struct ipc_mem_t { + // FIXME: reset signals + signal_t nvl_recv_signals; + i32 nvl_barrier[WORLD_SIZE]; +}; + +struct ipc_cache_t { + b16 nvl_recv_x[MAX_M * MAX_K]; +}; + +struct global_t { + // config + i32 rank; + i32 m, n, k; + + // buffers + ipc_mem_t *ipc_mems[WORLD_SIZE] = {}; + ipc_cache_t *ipc_caches[WORLD_SIZE] = {}; + workspace_t *workspace; +}; + +template __device__ inline void st_relaxed_sys(T *ptr, T val) { + __hip_atomic_store(ptr, val, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM); +} + +template __device__ inline T ld_relaxed_sys(T *ptr) { + return __hip_atomic_load(ptr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM); +} + +template __device__ inline void st_release_global(T *ptr, T val) { + __hip_atomic_store(ptr, val, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_AGENT); +} + +template __device__ inline T ld_acquire_global(T *ptr) { + return __hip_atomic_load(ptr, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_AGENT); +} + +__device__ __forceinline__ void syncwarp() { + __builtin_amdgcn_fence(__ATOMIC_RELEASE, "wavefront"); + __builtin_amdgcn_wave_barrier(); + __builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "wavefront"); +} + +template constexpr T const_min(T a, T b) { return a > b ? b : a; } + +template +__device__ inline void warp_copy_tile(b16 *(&dst)[NUM_DST], const b16 *src) { + static_assert(TILE_K % WARP_SIZE == 0); + constexpr i32 VEC_SIZE = TILE_K % 512 == 0 ? 16 : 2; + using cp_t = typename ck::vector_type::type; + cp_t regs[2]; + + const auto src_rsrc = ck::make_wave_buffer_resource_with_default_range(src); + ck::int32x4_t dst_rsrc[NUM_DST]; + UNROLL + for (int i = 0; i < NUM_DST; i++) { + dst_rsrc[i] = ck::make_wave_buffer_resource_with_default_range(dst[i]); + } + const auto lane_id = threadIdx.x % WARP_SIZE; + + auto load_row = [&](int reg_idx, int row_idx) { + const i32 soffset = row_idx * K * sizeof(b16); + const i32 voffset = lane_id * VEC_SIZE; + regs[reg_idx] = ck::amd_buffer_load_impl_raw< + VEC_SIZE, ck::AmdBufferCoherenceEnum::SYSTEM_NT0>( + src_rsrc, voffset, soffset + ); + }; + auto store_row = [&](int reg_idx, int row_idx) { + UNROLL + for (int i = 0; i < NUM_DST; i++) { + const i32 soffset = row_idx * K * sizeof(b16); + const i32 voffset = lane_id * VEC_SIZE; + ck::amd_buffer_store_impl_raw< + VEC_SIZE, ck::AmdBufferCoherenceEnum::SYSTEM_NT0>( + regs[reg_idx], dst_rsrc[i], voffset, soffset + ); + } + }; + + auto copy_loop_body = [&](int row_idx) { + load_row(0, row_idx + 2); + // sync, 1 ld, 8 st in flight + store_row(1, row_idx + 1); + load_row(1, row_idx + 3); + // sync, 1 ld, 8 st in flight + store_row(0, row_idx + 2); + }; + + constexpr i32 NUM_STAGES = 2; + constexpr i32 UNROLL_FACTOR = const_min(TILE_M / NUM_STAGES, 8); + constexpr i32 INNER_M = UNROLL_FACTOR * NUM_STAGES; + static_assert(TILE_M % NUM_STAGES == 0); + static_assert(TILE_M >= INNER_M); + + load_row(0, 0); + load_row(1, 1); + store_row(0, 0); + for (i32 i = 0; i < TILE_M - INNER_M; i += INNER_M) { + asm(";main loop begin"); + UNROLL + for (i32 j = 0; j < INNER_M; j += NUM_STAGES) { + copy_loop_body(i + j); + } + asm(";main loop end"); + } + UNROLL + for (i32 i = TILE_M - INNER_M; i < TILE_M - NUM_STAGES; i += NUM_STAGES) { + copy_loop_body(i); + } + store_row(1, TILE_M - 1); +} +struct send_args_t { + b16 *x; +}; + +// push save an extra copy compared with pull +template +__global__ void send_kernel(send_args_t args, global_t global) { + const auto num_sms = gridDim.x; + const auto num_warps = blockDim.x / WARP_SIZE; + const i32 num_global_warps = num_sms * num_warps; + + const auto sm_id = blockIdx.x; + // put soffset to sgpr + const auto warp_id = + __builtin_amdgcn_readfirstlane(threadIdx.x / WARP_SIZE); + const auto global_warp_id = sm_id * num_warps + warp_id; + const auto lane_id = threadIdx.x % WARP_SIZE; + + const auto rank = global.rank; + + static_assert(M % WORLD_SIZE == 0); + constexpr auto M_LOCAL = M / WORLD_SIZE; + + constexpr auto NUM_CHUNKS_M = ceil_div(M_LOCAL, CHUNK_M); + constexpr auto NUM_CHUNKS_K = ceil_div(K, CHUNK_K); + constexpr auto TAIL_CHUNK_M = M_LOCAL - (NUM_CHUNKS_M - 1) * CHUNK_M; + constexpr auto TAIL_CHUNK_K = K - (NUM_CHUNKS_K - 1) * CHUNK_K; + + for (int i = global_warp_id; i < NUM_CHUNKS_M * NUM_CHUNKS_K * WORLD_SIZE; + i += num_global_warps) { + const auto dst_rank = i % WORLD_SIZE; + const auto chunk_id = i / WORLD_SIZE; + // TODO: maybe k first + const auto chunk_k = chunk_id / NUM_CHUNKS_M; + const auto chunk_m = chunk_id % NUM_CHUNKS_M; + + const auto m_begin = chunk_m * CHUNK_M; + const auto k_begin = chunk_k * CHUNK_K; + const auto offset = m_begin * K + k_begin; + + const auto chunk_src = args.x + offset; + b16 *chunk_dst[1] = { + global.ipc_caches[dst_rank]->nvl_recv_x + rank * M_LOCAL * K + + offset + }; + + if (TAIL_CHUNK_M != CHUNK_M && chunk_m == NUM_CHUNKS_M - 1) { + if (TAIL_CHUNK_K != CHUNK_K && chunk_k == NUM_CHUNKS_K - 1) { + warp_copy_tile( + chunk_dst, chunk_src + ); + } else { + warp_copy_tile(chunk_dst, chunk_src); + } + } else { + if (TAIL_CHUNK_K != CHUNK_K && chunk_k == NUM_CHUNKS_K - 1) { + warp_copy_tile(chunk_dst, chunk_src); + } else { + warp_copy_tile(chunk_dst, chunk_src); + } + } + + if constexpr (!SYNC) { + // __builtin_amdgcn_fence(__ATOMIC_RELEASE, ""); + st_relaxed_sys( + &global.ipc_mems[dst_rank] + ->nvl_recv_signals[rank * NUM_CHUNKS_M + chunk_m][chunk_k], + 1 + ); + } + } + + if constexpr (!SYNC) { + return; + } + + __syncthreads(); + auto &grid_barrier = global.workspace->grid_barrier; + if (warp_id == 0 && lane_id == 0) { + __atomic_fetch_add(&grid_barrier, 1, __ATOMIC_RELAXED); + } + static_assert(WORLD_SIZE <= WARP_SIZE); + if (global_warp_id == 0 && lane_id < WORLD_SIZE) { + while (ld_acquire_global(&grid_barrier) != num_sms) + ; + st_relaxed_sys(&global.ipc_mems[lane_id]->nvl_barrier[rank], 1); + st_release_global(&grid_barrier, 0); + while (!ld_relaxed_sys(&global.ipc_mems[rank]->nvl_barrier[lane_id])) + ; + __builtin_amdgcn_wave_barrier(); + // safe to reset here because there would be a barrier after + // custom kernel + st_relaxed_sys(&global.ipc_mems[rank]->nvl_barrier[lane_id], 0); + } +} + +constexpr i64 pack_mnk(i32 m, i32 n, i32 k) { + return (i64(m) << 32) | (i64(n) << 16) | i64(k); +} + +// clang-format off +#define SWITCH_MNK(m, n, k, MACRO, ...) \ + switch (pack_mnk(m, n, k)) { \ + /*case pack_mnk(64, 2304, 7168): MACRO(64, 2304, 7168, ##__VA_ARGS__); break; \ + case pack_mnk(512, 1536, 4096): MACRO(512, 1536, 4096, ##__VA_ARGS__); break; \ + case pack_mnk(2048, 360, 2880): MACRO(2048, 360, 2880, ##__VA_ARGS__); break; \ + case pack_mnk(4096, 512, 4096): MACRO(4096, 512, 4096, ##__VA_ARGS__); break; \ + case pack_mnk(8192, 1792, 4096): MACRO(8192, 1792, 4096, ##__VA_ARGS__); break; \ + case pack_mnk(8192, 3696, 8192): MACRO(8192, 3696, 8192, ##__VA_ARGS__); break; \ + */case pack_mnk(2048, 3696, 8192): MACRO(2048, 3696, 8192, ##__VA_ARGS__); break; \ + default: ASSERT(false); \ + } + +// bm, bn, bk, wm, wn +#define SWITCH_GEMM_MNK(m, n, k, MACRO, ...) \ + switch (pack_mnk(m, n, k)) { \ + /*case pack_mnk(64, 2304, 7168): MACRO(64, 2304, 7168, 32, 64, 64, 2, 2, ##__VA_ARGS__); break; \ + case pack_mnk(512, 1536, 4096): MACRO(512, 1536, 4096, 32, 64, 64, 2, 2, ##__VA_ARGS__); break; \ + case pack_mnk(2048, 360, 2880): MACRO(2048, 360, 2880, 128, 128, 64, 2, 2, ##__VA_ARGS__); break; \ + case pack_mnk(4096, 512, 4096): MACRO(4096, 512, 4096, 256, 128, 64, 2, 2, ##__VA_ARGS__); break; \ + case pack_mnk(8192, 1792, 4096): MACRO(8192, 1792, 4096, 256, 224, 64, 2, 2, ##__VA_ARGS__); break; \ + case pack_mnk(8192, 3696, 8192): MACRO(8192, 3696, 8192, 256, 224, 64, 2, 2, ##__VA_ARGS__); break; \ + */case pack_mnk(2048, 3696, 8192): MACRO(2048, 3696, 8192, 256, 224, 64, 2, 2, ##__VA_ARGS__); break; \ + default: ASSERT(false); \ + } +// clang-format on + +class AgGemm { + private: + global_t global{}; + + public: + AgGemm(int rank, int m, int n, int k) { + global.rank = rank; + global.m = m; + global.n = n; + global.k = k; + } + + ~AgGemm() { + for (auto i = 0; i < WORLD_SIZE; i++) { + auto ipc_mem = global.ipc_mems[i]; + if (ipc_mem && i != global.rank) { + HIP_CHECK(hipIpcCloseMemHandle(ipc_mem)); + } + auto ipc_cache = global.ipc_caches[i]; + if (ipc_cache && i != global.rank) { + HIP_CHECK(hipIpcCloseMemHandle(ipc_cache)); + } + } + auto local_mem = global.ipc_mems[global.rank]; + if (local_mem) { + HIP_CHECK(hipFree(local_mem)); + } + auto local_cache = global.ipc_caches[global.rank]; + if (local_cache) { + HIP_CHECK(hipFree(local_cache)); + } + if (global.workspace) { + HIP_CHECK(hipFree(global.workspace)); + } + } + + auto get_ipc_handle() -> pybind11::bytearray { + void *ws; + HIP_CHECK(hipMalloc(&ws, sizeof(workspace_t))); + HIP_CHECK(hipMemset(ws, 0, sizeof(workspace_t))); + global.workspace = reinterpret_cast(ws); + + void *ptr; + HIP_CHECK(hipExtMallocWithFlags( + &ptr, sizeof(ipc_mem_t), hipDeviceMallocUncached + )); + HIP_CHECK(hipMemset(ptr, 0, sizeof(ipc_mem_t))); + global.ipc_mems[global.rank] = reinterpret_cast(ptr); + + HIP_CHECK(hipMalloc(&ptr, sizeof(ipc_cache_t))); + // HIP_CHECK(hipExtMallocWithFlags( + // &ptr, sizeof(ipc_cache_t), hipDeviceMallocUncached + // )); + HIP_CHECK(hipMemset(ptr, 0, sizeof(ipc_cache_t))); + global.ipc_caches[global.rank] = reinterpret_cast(ptr); + + std::vector handles(2); + HIP_CHECK(hipIpcGetMemHandle(&handles[0], global.ipc_mems[global.rank]) + ); + HIP_CHECK( + hipIpcGetMemHandle(&handles[1], global.ipc_caches[global.rank]) + ); + return { + reinterpret_cast(handles.data()), HIP_IPC_HANDLE_SIZE * 2 + }; + } + + auto init(const std::vector &ipc_handles) { + for (int i = 0; i < WORLD_SIZE; i++) { + if (i == global.rank) { + continue; + } + hipIpcMemHandle_t handle; + auto handle_buf = std::string(ipc_handles[i]); + ASSERT(handle_buf.size() == HIP_IPC_HANDLE_SIZE * 2); + auto handles = + reinterpret_cast(handle_buf.data()); + void *ptr; + HIP_CHECK(hipIpcOpenMemHandle( + &ptr, handles[0], hipIpcMemLazyEnablePeerAccess + )); + global.ipc_mems[i] = reinterpret_cast(ptr); + HIP_CHECK(hipIpcOpenMemHandle( + &ptr, handles[1], hipIpcMemLazyEnablePeerAccess + )); + global.ipc_caches[i] = reinterpret_cast(ptr); + } + } + + void send(torch::Tensor &x, bool sync) { + send_args_t args{ + .x = reinterpret_cast(x.contiguous().data_ptr()) + }; + auto stream = at::cuda::getCurrentHIPStream().stream(); + dim3 block(256, 1, 1); + // 1 SM ~= 8 GB/s, 16 SM ~= 40 GB/s + dim3 grid(16, 1, 1); + // clang-format off + #define LAUNCH_SEND(m, n, k, sync) send_kernel<<>>(args, global) + if (sync) { + SWITCH_MNK(global.m, global.n, global.k, LAUNCH_SEND, true) + } else { + SWITCH_MNK(global.m, global.n, global.k, LAUNCH_SEND, false) + } + // clang-format on + } + + auto get_x_full() { + auto x_full = torch::from_blob( + global.ipc_caches[global.rank]->nvl_recv_x, {global.m, global.k}, + torch::TensorOptions().dtype(torch::kBFloat16).device(torch::kCUDA) + ); + return x_full; + } + + auto get_signal() { + auto signal = torch::from_blob( + global.ipc_mems[global.rank]->nvl_recv_signals, {128, 128}, + torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA) + ); + return signal; + } + + void reset() { + auto stream = at::cuda::getCurrentHIPStream().stream(); + auto &signal = global.ipc_mems[global.rank]->nvl_recv_signals; + HIP_CHECK(hipMemsetAsync(&signal, 0, sizeof(signal), stream)); + } + + auto + perf_gemm(torch::Tensor &x, torch::Tensor &w, torch::Tensor &b, bool sync) { + auto m = x.size(0); + auto n = w.size(0); + auto k = w.size(1); + auto out = torch::empty({m, n}, x.options()); + + auto x_ptr = reinterpret_cast(x.const_data_ptr()); + auto w_ptr = reinterpret_cast(w.const_data_ptr()); + auto b_ptr = reinterpret_cast(b.const_data_ptr()); + auto o_ptr = reinterpret_cast(out.data_ptr()); + + constexpr i32 GEMM_THREADS = 256; + constexpr i32 GEMM_SMS = NUM_SMS; + + dim3 grid(GEMM_SMS); + dim3 block(GEMM_THREADS); + + auto stream = at::cuda::getCurrentHIPStream().stream(); + auto signal = &global.ipc_mems[global.rank]->nvl_recv_signals; + if (sync) { + signal = nullptr; + } + +// TODO: tune num_gemm_sms +// clang-format off + #define LAUNCH_PERF(m, n, k, bm, bn, bk, wm, wn) gemm_kernel<<>>(x_ptr, w_ptr, b_ptr, o_ptr, signal) + SWITCH_GEMM_MNK(m, n, k, LAUNCH_PERF) + // clang-format on + return out; + } +}; + +PYBIND11_MODULE(ag_gemm, m) { + py::class_(m, "AgGemm") + .def(py::init()) + .def("get_ipc_handle", &AgGemm::get_ipc_handle) + .def("init", &AgGemm::init) + .def("send", &AgGemm::send, py::arg("x"), py::arg("sync") = true) + .def("get_x_full", &AgGemm::get_x_full) + .def("reset", &AgGemm::reset) + .def( + "perf_gemm", &AgGemm::perf_gemm, py::arg("x"), py::arg("w"), + py::arg("b"), py::arg("sync") = true + ) + .def("get_signal", &AgGemm::get_signal); + m.def("ck_version", []() { return TO_STR(CK_COMMIT_ID); }); +} diff --git a/dist-infer/ag-gemm/benchmark.txt b/dist-infer/ag-gemm/benchmark.txt new file mode 100644 index 0000000..1c5442d --- /dev/null +++ b/dist-infer/ag-gemm/benchmark.txt @@ -0,0 +1 @@ +world_size: 2; m: 2048; n: 7392; k: 8192; has_bias: False; seed: 4406 \ No newline at end of file diff --git a/dist-infer/ag-gemm/ck_gemm.h b/dist-infer/ag-gemm/ck_gemm.h new file mode 100644 index 0000000..845659b --- /dev/null +++ b/dist-infer/ag-gemm/ck_gemm.h @@ -0,0 +1,352 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" +#include + +#include "ck/utility/common_header.hpp" +// __gfx9__ defined in the above header via ck.hpp +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/wrapper/layout.hpp" +#include "ck/wrapper/tensor.hpp" +#include "ck/wrapper/operations/copy.hpp" +#include "ck/wrapper/operations/gemm.hpp" +#include "ck/wrapper/utils/kernel_utils.hpp" +#include "ck/host_utility/device_prop.hpp" + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +__device__ auto ApplyPadding(const Layout& layout, const PaddingDims& padding_dims) +{ + if constexpr(DoPad) + { + return ck::wrapper::pad(layout, padding_dims); + } + else + { + return layout; + } +} + +template +__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a, + const void* p_b, + void* p_c, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const BlockShape tile_shape, + const ThreadLayout thread_layout) +{ + constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape); + constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape); + constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape); + constexpr auto K1 = GemmTraits::K1; + constexpr auto K0PerBlock = KPerBlock / K1; + const auto K0 = ck::math::integer_divide_ceil(K, K1); + + const auto tile_shape_k0_m_n_k1 = ck::make_tuple(K0PerBlock, MPerBlock, NPerBlock, K1); + // Create layouts for global memory + const auto a_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1)); + const auto b_global_layout = + ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1)); + const auto c_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1)); + // Apply padding + auto a_padded_global_layout = + ApplyPadding(a_global_layout, ck::make_tuple(MPerBlock, KPerBlock)); + auto b_padded_global_layout = + ApplyPadding(b_global_layout, ck::make_tuple(NPerBlock, KPerBlock)); + auto c_padded_global_layout = + ApplyPadding(c_global_layout, ck::make_tuple(MPerBlock, NPerBlock)); + // Reshape from M,K to K0,M,K1 + const auto reshaped_dims_idxs = + ck::make_tuple(ck::Number<1>{}, ck::make_tuple(ck::Number<0>{}, ck::Number<2>{})); + auto a_padded_unmerged_global_layout = + ck::wrapper::unmerge<1>(a_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs); + auto b_padded_unmerged_global_layout = + ck::wrapper::unmerge<1>(b_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs); + // Create tensors for global memory + auto a_global_tensor = ck::wrapper::make_tensor( + static_cast(p_a), a_padded_unmerged_global_layout); + auto b_global_tensor = ck::wrapper::make_tensor( + static_cast(p_b), b_padded_unmerged_global_layout); + auto c_global_tensor = ck::wrapper::make_tensor( + static_cast(p_c), c_padded_global_layout); + // Create layouts and tensors for lds memory. + constexpr auto a_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(K0PerBlock, MPerBlock, K1), + ck::make_tuple((MPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{})); + constexpr auto b_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(K0PerBlock, NPerBlock, K1), + ck::make_tuple((NPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{})); + + __shared__ DataType lds_a[ck::wrapper::size(a_tile_layout) + K0PerBlock]; + __shared__ DataType lds_b[ck::wrapper::size(b_tile_layout) + K0PerBlock]; + + auto a_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_a), a_tile_layout); + auto b_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_b), b_tile_layout); + + const auto block_idxs = ck::make_tuple(ck::wrapper::slice(), + static_cast(blockIdx.x), + static_cast(blockIdx.y), + ck::wrapper::slice()); + using DimAccessOrder = ck::Tuple, ck::Number<0>, ck::Number<2>>; + constexpr ck::index_t vector_dim = 2; + + // Create tile and partition for C global memory. Use specific gemm + // functions to get appropriate layouts. + auto c_global_local_tile = + ck::wrapper::make_local_tile(c_global_tensor, + tile_shape_k0_m_n_k1, + block_idxs, + make_tuple(ck::wrapper::slice(K0PerBlock), + ck::Number<1>{}, + ck::Number<1>{}, + ck::wrapper::slice(K1))); + auto c_global_local_partition = + ck::wrapper::make_blockwise_gemm_xdl_c_local_partition(c_global_local_tile); + // Define and clear c vgpr register + auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr(); + ck::wrapper::clear(c_vgpr_reg); + // Local partitions for lds memory + auto a_lds_tensor_local_partition = + ck::wrapper::make_local_partition(a_lds_tensor, thread_layout, threadIdx.x); + auto b_lds_tensor_local_partition = + ck::wrapper::make_local_partition(b_lds_tensor, thread_layout, threadIdx.x); + // Lamda to slice tensor, then create local tile and partition + auto make_global_partition = [&](auto tensor, auto projection, ck::index_t i) { + const auto k_slice = + ck::make_tuple(ck::wrapper::slice(i * K0PerBlock, (i + 1) * K0PerBlock), + ck::wrapper::slice(), + ck::wrapper::slice()); + auto local_tile = ck::wrapper::make_local_tile( + tensor(k_slice), tile_shape_k0_m_n_k1, block_idxs, projection); + return ck::wrapper::make_local_partition(local_tile, thread_layout, threadIdx.x); + }; + + auto a_global_local_partition = make_global_partition( + a_global_tensor, + make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}), + 0); + auto b_global_local_partition = make_global_partition( + b_global_tensor, + make_tuple(ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}), + 0); + + // (row-major vgpr layout) + auto a_vgpr_tensor = + ck::wrapper::make_register_tensor( + ck::wrapper::make_layout( + shape(a_global_local_partition), + ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) * + ck::wrapper::size<2>(a_global_local_partition), + ck::wrapper::size<2>(a_global_local_partition), + ck::Number<1>{}))); + auto b_vgpr_tensor = + ck::wrapper::make_register_tensor( + ck::wrapper::make_layout( + shape(b_global_local_partition), + ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) * + ck::wrapper::size<2>(a_global_local_partition), + ck::wrapper::size<2>(a_global_local_partition), + ck::Number<1>{}))); + // Copy first values to lds + ck::wrapper::copy(a_global_local_partition, + a_vgpr_tensor); + ck::wrapper::copy(b_global_local_partition, + b_vgpr_tensor); + ck::wrapper::copy(a_vgpr_tensor, + a_lds_tensor_local_partition); + ck::wrapper::copy(b_vgpr_tensor, + b_lds_tensor_local_partition); + // Pipeline loop + const ck::index_t num_loop = + __builtin_amdgcn_readfirstlane(ck::math::integer_divide_ceil(K, KPerBlock)); + // Skip if only tile should be processed + if(num_loop > 1) + { + ck::index_t i = 0; + do + { + auto a_global_local_partition_i = make_global_partition( + a_global_tensor, + make_tuple( + ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}), + i + 1); + auto b_global_local_partition_i = make_global_partition( + b_global_tensor, + make_tuple( + ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}), + i + 1); + // Copy data to A vgpr. + ck::wrapper::copy( + a_global_local_partition_i, a_vgpr_tensor); + // Synchronize. + ck::block_sync_lds(); + // Copy data to B vgpr. + ck::wrapper::copy( + b_global_local_partition_i, b_vgpr_tensor); + // Perform gemm. + ck::wrapper::blockwise_gemm_xdl( + a_lds_tensor, b_lds_tensor, c_vgpr_reg); + // Synchronize + ck::block_sync_lds(); + // Copy data to A and B lds tiles. + ck::wrapper::copy( + a_vgpr_tensor, a_lds_tensor_local_partition); + ck::wrapper::copy( + b_vgpr_tensor, b_lds_tensor_local_partition); + + ++i; + } while(i < (num_loop - 1)); + } + // Handle tail. + ck::block_sync_lds(); + ck::wrapper::blockwise_gemm_xdl( + a_lds_tensor, b_lds_tensor, c_vgpr_reg); + // Store data from C vgpr to C global memory. + ck::wrapper::copy(c_vgpr_reg, c_global_local_partition); +} + +template +void PerformGemm(const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const BlockShape& tile_shape, + const ThreadLayout& thread_layout, + DataType *A, DataType *B, DataType *C, + hipStream_t stream) +{ + const ck::index_t grid_size_x = + ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape)); + const ck::index_t grid_size_y = + ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape)); + + const auto kernel = + DeviceGemm; + auto grid = dim3(grid_size_x, grid_size_y, 1); + auto block = dim3(ck::wrapper::size(thread_layout)); + kernel<<>>(A, B, C, M, N, K, tile_shape, thread_layout); +} +#endif + +template using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ADataType = ck::bhalf_t; +using BDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = ck::bhalf_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = + ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3 +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, 8, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3 >; +// clang-format on + +inline void ck_gemm(torch::Tensor &x, torch::Tensor &w, torch::Tensor &out) { + auto stream = at::cuda::getCurrentHIPStream().stream(); + auto m = x.size(0), n = w.size(0), k = w.size(1); + + auto device_gemm = DeviceGemmInstance(); + auto invoker = device_gemm.MakeInvoker(); + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + // cast to bhalf_t to make pytorch happy + auto args = device_gemm.MakeArgument( + reinterpret_cast(x.contiguous().data_ptr()), + reinterpret_cast(w.contiguous().data_ptr()), + reinterpret_cast(out.contiguous().data_ptr()), m, n, k, + k, k, n, 1, a_element_op, b_element_op, c_element_op + ); + invoker.Run(args, StreamConfig{stream}); +}; + +// FIXME: wrong result +inline void ck_wrapper_gemm(torch::Tensor &x, torch::Tensor &w, torch::Tensor &out) { + auto stream = at::cuda::getCurrentHIPStream().stream(); + auto m = x.size(0), n = w.size(0), k = w.size(1); + auto a = reinterpret_cast(x.contiguous().data_ptr()); + auto b = reinterpret_cast(w.contiguous().data_ptr()); + auto c = reinterpret_cast(out.contiguous().data_ptr()); + + using DataType = ck::bhalf_t; + const auto thread_layout = ck::wrapper::make_layout( + ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}), + ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{}) + ); + const auto tile_shape = + ck::make_tuple(ck::Number<256>{}, ck::Number<128>{}, ck::Number<32>{}); + + PerformGemm< + DataType, ck::wrapper::BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1, + 8, false>(m, n, k, tile_shape, thread_layout, a, b, c, stream); +} \ No newline at end of file diff --git a/dist-infer/ag-gemm/eval.py b/dist-infer/ag-gemm/eval.py new file mode 100644 index 0000000..c0fac8f --- /dev/null +++ b/dist-infer/ag-gemm/eval.py @@ -0,0 +1,578 @@ +import base64 +import copy +import dataclasses +import multiprocessing +import re +import time +import os +import sys +import math +from pathlib import Path +from typing import Any, Optional + +import torch.cuda + +from utils import set_seed, clear_l2_cache + +try: + from task import TestSpec +except ImportError: + TestSpec = dict + +from reference import check_implementation, generate_input + + +class PopcornOutput: + def __init__(self, fd: int): + self.file = os.fdopen(fd, 'w') + # os.set_inheritable(fd, False) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.file.close() + + def print(self, *args, **kwargs): + print(*args, **kwargs, file=self.file, flush=True) + + def log(self, key, value): + self.print(f"{key}: {value}") + + +@dataclasses.dataclass +class TestCase: + args: dict + spec: str + + +def _combine(a: int, b: int) -> int: + # combine two integers into one: + # we need this to generate a secret seed based on the test-level seed and + # the global secret seed. + # the test-level seeds are public knowledge, and typically relatively small numbers, + # so we need to make sure they don't provide any useful info for the full seed. + # This Cantor construction ensures that if the secret seed is a large number, + # then so is the overall seed. + return int(a + (a + b) * (a + b + 1) // 2) + + +def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: + try: + content = Path(file_name).read_text() + except Exception as E: + print(f"Could not open test file`{file_name}`: {E}", file=sys.stderr) + exit(113) + + tests = [] + lines = content.splitlines() + match = r"\s*([a-zA-Z_]+):\s*([a-zA-Z]+|[+-]?[0-9]+)\s*" + for line in lines: + parts = line.split(";") + case = {} + for part in parts: + matched = re.match(match, part) + if not re.fullmatch(match, part): + print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) + exit(113) + key = matched[1] + val = matched[2] + try: + val = int(val) + except ValueError: + pass + + case[key] = val + tests.append(TestCase(spec=line, args=case)) + + if seed is not None: + for test in tests: + if "seed" in test.args: + test.args["seed"] = _combine(test.args["seed"], seed) + + return tests + + +@dataclasses.dataclass +class Stats: + runs: int + mean: float + std: float + err: float + best: float + worst: float + + +def calculate_stats(durations: list[int]): + """ + Calculate statistical data from a list of durations. + + @param durations: A list of durations in nanoseconds. + @return: A Stats object containing the number of runs, mean, standard deviation, error, best, and worst durations. + """ + runs = len(durations) + total = sum(durations) + best = min(durations) + worst = max(durations) + + avg = total / runs + variance = sum(map(lambda x: (x - avg) ** 2, durations)) + std = math.sqrt(variance / (runs - 1)) + err = std / math.sqrt(runs) + + return Stats(runs=runs, mean=avg, std=std, err=err, best=float(best), + worst=float(worst)) + + +def _clone_data(data, rank: int): + """ + Recursively goes through data and clones all tensors. + """ + if isinstance(data, tuple): + return tuple(_clone_data(x, rank) for x in data) + elif isinstance(data, list): + return [_clone_data(x, rank) for x in data] + elif isinstance(data, dict): + return {k: _clone_data(v, rank) for k, v in data.items()} + elif isinstance(data, torch.Tensor): + device = f"cuda:{rank}" + return data.clone().to(device) + else: + return data + + +def wrap_check_implementation(data, submission_output): + # Old version returned just a single string, new version + # returns (bool, str); this function ensures compatibility with old + # problem definitions. + result = check_implementation(data, submission_output) + if isinstance(result, tuple): + return result + else: + return not bool(result), result + + +def _run_single_test(test: TestCase): + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + data = generate_input(**test.args) + torch.cuda.synchronize() + submission_output = custom_kernel(_clone_data(data, 0)) + torch.cuda.synchronize() + return wrap_check_implementation(data, submission_output) + + +def _run_distributed_test(test: TestCase, rank: int): + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + import torch.distributed as dist + world_size = test.args["world_size"] + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "12356" + dist.init_process_group("nccl", init_method="env://", rank=rank, world_size=world_size, device_id=torch.device(f'cuda:{rank}')) + try: + data = generate_input(**test.args, rank=rank) + torch.cuda.synchronize() + submission_output = custom_kernel(_clone_data(data, rank)) + torch.cuda.synchronize() + return wrap_check_implementation(data, submission_output) + finally: + dist.destroy_process_group() + + +def run_multi_gpu_test(pool: multiprocessing.Pool, test: TestCase, world_size: int): + """ + Runs a single test in another process. + """ + rets = [] + # world_size is a mandatory argument for multi-gpu tests + for i in range(world_size): + rets.append( + pool.apply_async( + _run_distributed_test, + args=(test, i), + ) + ) + # 60 seconds should be more than enough, we want tests to be fast + rets = [el.get(60) for el in rets] + + correct = all(ret[0] for ret in rets) + error_messages = str.join("\n", [f"rank {rank} - {ret[1]}" for rank, ret in enumerate(rets) if not ret[0]]) + return correct, error_messages + + +def run_single_test(pool: multiprocessing.Pool, test: TestCase): + """ + Runs a single test in another process. + """ + world_size = test.args.get("world_size", None) + if world_size is None: + return pool.apply(_run_single_test, (test,)) + else: + return run_multi_gpu_test(pool, test, world_size) + + +def run_testing(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): + """ + Executes the actual test case code and checks for correctness. + + @param logger: A PopcornOutput object used for logging test results. + @param tests: A list of TestCase objects representing the test cases to be executed. + @return: An integer representing the exit status: 0 if all tests pass, otherwise 112. + """ + passed = True + logger.log("test-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"test.{idx}.spec", test.spec) + good, message = run_single_test(pool, test) + if not good: + logger.log(f"test.{idx}.status", "fail") + logger.log(f"test.{idx}.error", message) + passed = False + else: + logger.log(f"test.{idx}.status", "pass") + if message: + logger.log(f"test.{idx}.message", message) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def _run_single_benchmark(test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float) -> Stats | Any: + """ + Runs one benchmark. Do not call directly. + """ + from submission import custom_kernel + + durations = [] + # generate input data once + data = generate_input(**test.args) + check_copy = _clone_data(data, 0) + # first, one obligatory correctness check + output = custom_kernel(data) + good, message = wrap_check_implementation(check_copy, output) + if not good: + return message + + # now, do multiple timing runs without further correctness testing + # there is an upper bound of 100 runs, and a lower bound of 3 runs; + # otherwise, we repeat until we either measure at least 10 full seconds, + # or the relative error of the mean is below 1%. + + bm_start_time = time.perf_counter_ns() + for i in range(max_repeats): + if recheck: + # ensure we use a different seed for every benchmark + if "seed" in test.args: + test.args["seed"] += 13 + + data = generate_input(**test.args) + check_copy = _clone_data(data, 0) + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + clear_l2_cache() + + start_event.record() + output = custom_kernel(data) + end_event.record() + torch.cuda.synchronize() + duration = start_event.elapsed_time(end_event) * 1e6 # Convert ms to ns + + if recheck: + good, message = check_implementation(check_copy, output) + if not good: + return message + + del output + durations.append(duration) + + if i > 1: + total_bm_duration = time.perf_counter_ns() - bm_start_time + stats = calculate_stats(durations) + # stop if either + # a) relative error dips below 0.1% + # b) we exceed the total time limit for benchmarking the kernel + # c) we exceed 2 minutes of total wallclock time. + if stats.err / stats.mean < 0.001 or stats.mean * stats.runs > max_time_ns or total_bm_duration > 120e9: + break + + return calculate_stats(durations) + + +def _run_distributed_benchmark(test: TestCase, rank: int, recheck: bool, max_repeats: int, + max_time_ns: float) -> Stats | Any: + """ + Runs one distributed benchmark. Do not call directly. + """ + from submission import custom_kernel + import torch.distributed as dist + + world_size = test.args["world_size"] + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "12356" + dist.init_process_group("nccl", init_method="env://", rank=rank, world_size=world_size, device_id=torch.device(f'cuda:{rank}')) + + try: + durations = [] + # generate input data once + data = generate_input(**test.args, rank=rank) + check_copy = _clone_data(data, rank) + + # first, one obligatory correctness check + output = custom_kernel(_clone_data(data, rank)) + good, message = wrap_check_implementation(check_copy, output) + if not good: + return message + + # now, do multiple timing runs with proper distributed synchronization + bm_start_time = time.perf_counter_ns() + for i in range(max_repeats): + error_message = None + if recheck: + # ensure we use a different seed for every benchmark + if "seed" in test.args: + test.args["seed"] += 13 + + data = generate_input(**test.args, rank=rank) + check_copy = _clone_data(data, rank) + + # Synchronize all ranks before timing + clear_l2_cache() + torch.cuda.synchronize() + dist.barrier() + + # Use distributed timing - only rank 0 records the overall time + if rank == 0: + start_time = time.perf_counter_ns() + + # All ranks execute the kernel + output = custom_kernel(_clone_data(data, rank)) + + # Synchronize all ranks after kernel execution + torch.cuda.synchronize() + dist.barrier() + + if rank == 0: + end_time = time.perf_counter_ns() + duration = end_time - start_time # Already in nanoseconds + durations.append(duration) + + if recheck: + good, message = check_implementation(check_copy, output) + if not good: + error_message = message + + del output + + has_error = torch.tensor(1 if error_message is not None else 0, dtype=torch.int32, device=f'cuda:{rank}') + dist.all_reduce(has_error) + if has_error.item() > 0: + return error_message + + # Only rank 0 checks convergence criteria + if rank == 0 and i > 1: + total_bm_duration = time.perf_counter_ns() - bm_start_time + stats = calculate_stats(durations) + # stop if either + # a) relative error dips below 0.1% + # b) we exceed the total time limit for benchmarking the kernel + # c) we exceed 2 minutes of total wallclock time. + should_stop = (stats.err / stats.mean < 0.001 or + stats.mean * stats.runs > max_time_ns or + total_bm_duration > 120e9) + else: + should_stop = False + + # Broadcast stop decision to all ranks + stop_tensor = torch.tensor(should_stop, dtype=torch.bool, device=f'cuda:{rank}') + dist.broadcast(stop_tensor, 0) + + if stop_tensor.item(): + break + + # Only rank 0 returns meaningful stats + if rank == 0: + return calculate_stats(durations) + else: + # Non-zero ranks return a dummy stats object + return Stats(runs=len(durations), mean=0.0, std=0.0, err=0.0, best=0.0, worst=0.0) + + finally: + dist.destroy_process_group() + + +def run_multi_gpu_benchmark(pool: multiprocessing.Pool, test: TestCase, recheck: bool, max_repeats: int, + max_time_ns: float, world_size: int): + """ + Runs a multi-GPU benchmark across all ranks. + """ + rets = [] + for i in range(world_size): + rets.append( + pool.apply_async( + _run_distributed_benchmark, + args=(test, i, recheck, max_repeats, max_time_ns), + ) + ) + + # 120 seconds for benchmarking + we run a pre-benchmark test and want to leave some slack + rets = [el.get(timeout=180) for el in rets] + + # For multi-GPU benchmarking, only rank 0 has meaningful stats + failed_ranks = [] + rank_0_result = None + + for rank, ret in enumerate(rets): + if isinstance(ret, Stats): + if rank == 0: + rank_0_result = ret + else: + # ret is an error message + failed_ranks.append((rank, ret)) + + if failed_ranks: + error_messages = str.join("\n", [f"rank {rank} - {msg}" for rank, msg in failed_ranks]) + return error_messages + else: + return rank_0_result if rank_0_result else "No stats returned from rank 0" + + +def run_single_benchmark(pool: multiprocessing.Pool, test: TestCase, recheck: bool, max_repeats: int, + max_time_ns: float): + """ + For a particular test case, check correctness (if applicable) and grab runtime results. + + @param pool: Process on which the benchmark will be launched. + @param test: TestCase object. + @param recheck: Flag for whether to explicitly check functional correctness. + @param max_repeats: Number of trials to repeat. + @param max_time_ns: Timeout time in nanoseconds. + @return: A Stats object for this particular benchmark case or an error if the test fails. + """ + + world_size: Optional[int] = test.args.get("world_size", None) + if world_size is None: + return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) + else: + return run_multi_gpu_benchmark(pool, test, recheck, max_repeats, max_time_ns, world_size) + + +def run_benchmarking(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): + """ + Executes benchmarking code for a CUDA Kernel and logs runtimes. + + @param logger: A PopcornOutput object used for logging benchmark results. + @param pool: Process on which the benchmarks will be launched. + @param tests: A list of TestCase objects representing the test cases to be benchmarked. + @return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112. + """ + # warm up + run_single_benchmark(pool, tests[0], False, 100, 10e7) + + passed = True + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + result = run_single_benchmark(pool, test, False, 100, 10e9) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) + else: + passed = False + logger.log(f"benchmark.{idx}.status", "fail") + logger.log(f"benchmark.{idx}.error", result) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def run_single_profile(test: TestCase) -> str: + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + from torch.profiler import profile, record_function, ProfilerActivity + data = generate_input(**test.args) + torch.cuda.synchronize() + + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + submission_output = custom_kernel(_clone_data(data, 0)) + torch.cuda.synchronize() + return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) + + +def run_profiling(logger: PopcornOutput, tests: list[TestCase]): + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + report = run_single_profile(test) + logger.log(f"benchmark.{idx}.report", base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8")) + logger.log("check", "pass") + return 0 + + +def main(): + fd = os.getenv("POPCORN_FD") + if not fd: + return 111 + + if len(sys.argv) < 3: + return 2 + + mode = sys.argv[1] + seed = os.getenv("POPCORN_SEED") + os.unsetenv("POPCORN_SEED") + n_gpus = int(os.getenv("POPCORN_GPUS", "1")) + seed = int(seed) if seed else None + set_seed(seed or 42) + tests = get_test_cases(sys.argv[2], seed) + + with PopcornOutput(int(fd)) as logger: + import multiprocessing + mp_context = multiprocessing.get_context('spawn') + with mp_context.Pool(n_gpus) as pool: + if mode == "test": + return run_testing(logger, pool, tests) + if mode == "benchmark": + return run_benchmarking(logger, pool, tests) + + if mode == "leaderboard": + # warmup + run_single_benchmark(pool, tests[0], False, 100, 1e7) + logger.log("benchmark-count", len(tests)) + passed = True + for i in range(len(tests)): + result = run_single_benchmark(pool, tests[i], True, 100, 30e9) + logger.log(f"benchmark.{i}.spec", tests[i].spec) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log(f"benchmark.{i}.{field.name}", getattr(result, field.name)) + else: + passed = False + logger.log(f"benchmark.{i}.status", "fail") + logger.log(f"benchmark.{i}.error", str(result)) # TODO: Make sure result implements __str__? + break + + logger.log("check", "pass" if passed else "fail") + elif mode == "profile": + run_profiling(logger, tests) + else: + # invalid mode + return 2 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dist-infer/ag-gemm/perf_gemm.cc b/dist-infer/ag-gemm/perf_gemm.cc new file mode 100644 index 0000000..5b830b0 --- /dev/null +++ b/dist-infer/ag-gemm/perf_gemm.cc @@ -0,0 +1,1074 @@ +#include +#include +#include +#include +#include +#include +#include +#define FAST_UNSAFE_CAST + +#define FORCE_INLINE __attribute__((always_inline)) + +#define LOCAL_TEST + + +// Perf GEMM Common Begin + + +namespace roc_isa { + constexpr int AMDGCN_WAVEFRONT_SIZE = 64; +namespace issue_latency { + constexpr int v_mfma_f32_16x16x16_bf16 = 4; + constexpr int ds_read_b128 = 2 * 4; + constexpr int ds_write_b128 = 5 * 4; + constexpr int buffer_load_dwordx2 = 1 * 2; +} // namespace issue_latency + +// Trait for different GEMM size categories +enum class GemmSizeCategory { + LARGE, // 256x224, 224x256, 256x256 + MIDDLE, // 256x128, 128x256, 128x128 + SMALL // 128x64, 64x128, 128x32, 32x128, 64x32, 32x32 +}; + +template +struct GemmSizeTrait { + // Large GEMM + // (256x224) MFMA 224, DS_READ 30, DS_WRITE 15, BUFFER_LOAD 30 + // Middle GEMM + // (256X128) MFMA 128, DS_READ_24, DS_WRITE 12, BUFFER_LOAD 24 + // (128X128) MFMA 64, DS_READ 16, DS_WRITE 8, BUFFER_LOAD 16 + // Small GEMM + // (128x64) MFMA 32, DS_WRITE 6, DS_READ 12, BUFFER_LOAD 12 + // (128x32) MFMA 16, DS_WRITE 5, DS_READ 8 , BUFFER_LOAD 10 + // (64x32) MFMA 8, DS_WRITE 3, DS_READ 6 , BUFFER_LOAD 6 + static constexpr GemmSizeCategory category = + ((BM == 256 && BN == 224) || (BM == 224 && BN == 256) || (BM == 256 && BN == 256)) ? GemmSizeCategory::LARGE : + ((BM == 256 && BN == 128) || (BM == 128 && BN == 256) || (BM == 128 && BN == 128)) ? GemmSizeCategory::MIDDLE : + GemmSizeCategory::SMALL; +}; + +// Schedule configuration trait based on category +template +struct ScheduleConfig; + +template<> +struct ScheduleConfig { + // Stage 1: DS_WRITE(1) -> MFMA(2) -> VMEM(1) -> MFMA(3) + static constexpr int stage1_ds_write = 1; + static constexpr int stage1_mfma_before_vmem = 2; + static constexpr int stage1_vmem = 1; + static constexpr int stage1_mfma_after_vmem = 3; + + // Stage 2: MFMA(2) -> DS_READ(1) + static constexpr int stage2_mfma = 2; + static constexpr int stage2_ds_read = 1; +}; + +template<> +struct ScheduleConfig { + // Stage 1: DS_WRITE(1) -> MFMA(2) -> VMEM(1) -> MFMA(1) + static constexpr int stage1_ds_write = 1; + static constexpr int stage1_mfma_before_vmem = 2; + static constexpr int stage1_vmem = 1; + static constexpr int stage1_mfma_after_vmem = 1; + + // Stage 2: MFMA(1) -> DS_READ(2) + static constexpr int stage2_mfma = 1; + static constexpr int stage2_ds_read = 2; +}; + +template<> +struct ScheduleConfig { + // Stage 1: DS_WRITE(1) -> MFMA(1) -> VMEM(1) -> MFMA(1) + static constexpr int stage1_ds_write = 1; + static constexpr int stage1_mfma_before_vmem = 1; + static constexpr int stage1_vmem = 1; + static constexpr int stage1_mfma_after_vmem = 1; + + // Stage 2: MFMA(1) -> DS_READ(1) + static constexpr int stage2_mfma = 1; + static constexpr int stage2_ds_read = 1; +}; + +template +struct InstCalculator { + static constexpr int v_mfma_f32_16x16x16_bf16 = (BM * BN * BK) / (WARP_M * WARP_N) / (16*16*16); + // Compiler will merge two ds_{read,write}_b64 to ds_{read,write`}2st64_b64 + static constexpr int ds_read_b128_a = (BM * BK / WARP_M) / 64 / 8; + static constexpr int ds_read_b128_b = (BN * BK / WARP_N) / 64 / 8; + static constexpr int ds_read_b128 = ds_read_b128_a + ds_read_b128_b; + static constexpr int ds_write_b128_a = (BM * BK) / NUM_THREADS / 8; + static constexpr int ds_write_b128_b = (BN * BK) / NUM_THREADS / 8; + static constexpr int ds_write_b128 = ds_write_b128_a + ds_write_b128_b; + static constexpr int buffer_load_dwordx2_a = (BM * BK) / NUM_THREADS / 4; + static constexpr int buffer_load_dwordx2_b = (BN * BK) / NUM_THREADS / 4; + static constexpr int buffer_load_dwordx2 = buffer_load_dwordx2_a + buffer_load_dwordx2_b; + + // Get schedule configuration based on BM and BN + using size_trait = GemmSizeTrait; + using schedule_config = ScheduleConfig; +}; + +} // namespace roc_isa + +namespace test { + constexpr int BM = 128, BN = 128; + constexpr int WARP_M = 2, WARP_N = 2, NUM_THREADS = 256, BK = 64; + constexpr int MFMA_NUM = roc_isa::InstCalculator::v_mfma_f32_16x16x16_bf16; + constexpr int DS_READ_NUM = roc_isa::InstCalculator::ds_read_b128; + constexpr int DS_WRITE_NUM = roc_isa::InstCalculator::ds_write_b128; + constexpr int BUFFER_LOAD_NUM = roc_isa::InstCalculator::buffer_load_dwordx2; +} + + +using bfloat16_t = __bf16; + +__device__ __host__ FORCE_INLINE constexpr int ceil_div(int a, int b) { + return (a + b - 1) / b; +} + +template +__device__ __host__ FORCE_INLINE constexpr int exact_div() { + static_assert(a % b == 0); + return a / b; +} + +__device__ __host__ FORCE_INLINE constexpr int i_min(int a, int b) { + return a < b ? a : b; +} + +__device__ __host__ FORCE_INLINE constexpr int i_max(int a, int b) { + return a > b ? a : b; +} + + +template +struct PackN_t { + using t = __attribute__((vector_size(N * sizeof(dtype)))) dtype; + static constexpr auto n = N; + static constexpr auto H = N / 2; + union { + dtype x[N]; + t pack; + struct { dtype low[H], high[H]; }; + }; +}; + +using bf16x4_t = PackN_t; +using fp32x4_t = PackN_t; +using bf16x8_t = PackN_t; +using fp32x8_t = PackN_t; +using i32x4_t = PackN_t; + +#define FORCE_INLINE __attribute__((always_inline)) + + +__device__ ck::int32x4_t inline make_wave_buffer_resource(const void* ptr, uint32_t size = 0xffffffff) { + ck::int32x4_t res; + + // Pack the 64-bit pointer into two 32-bit integers + uint64_t ptr_val = reinterpret_cast(ptr); + res.x = static_cast(ptr_val); + res.y = static_cast(ptr_val >> 32); + + // Set buffer size and format + res.z = size; // Buffer size in bytes + res.w = 0x00020000; // hardcoded for gfx942 + + res.x = __builtin_amdgcn_readfirstlane(res.x); + res.y = __builtin_amdgcn_readfirstlane(res.y); + res.z = __builtin_amdgcn_readfirstlane(res.z); + res.w = __builtin_amdgcn_readfirstlane(res.w); + return res; +} + +__device__ FORCE_INLINE bfloat16_t fast_f32tob16(float f) { +#ifdef FAST_UNSAFE_CAST + union { + float fp32; + unsigned int u32; + } u = {f}; + u.u32 += 0x7FFF + ((u.u32 >> 16) & 1); + auto ret = u.u32 >> 16; + return reinterpret_cast(ret); +#else + return static_cast(f); +#endif +} + + +__device__ FORCE_INLINE float fast_b16tof32(bfloat16_t bf) { +#ifdef FAST_UNSAFE_CAST + union { + float fp32; + unsigned int u32; + } u; + u.u32 = (reinterpret_cast(bf)) << 16; + return u.fp32; +#else + return static_cast(bf); +#endif +} + +__device__ void block_sync_lds() { + __builtin_amdgcn_s_waitcnt(0xc07f); + __builtin_amdgcn_s_barrier(); +} + +__device__ void block_sync_gds() { + __builtin_amdgcn_s_waitcnt(0xf70); + __builtin_amdgcn_s_barrier(); +} + + +// M-dimension grouped version for better L2 cache locality when M > N +template +__device__ FORCE_INLINE void compute_tile_indices_m_grouped( + int tile_id, + int &tile_m_id, + int &tile_n_id +) { + if constexpr (GROUP_SIZE_M == 0) { + // No swizzle + tile_m_id = tile_id % num_tile_m; + tile_n_id = tile_id / num_tile_m; + } else { + // Swizzle pattern for better L2 cache locality + // Groups tiles in blocks of GROUP_SIZE_M x num_tile_n + constexpr int num_pid_in_group = GROUP_SIZE_M * num_tile_n; + + // Which group does this tile belong to? + const int group_id = tile_id / num_pid_in_group; + + // First M-dimension tile in this group + const int first_pid_m = group_id * GROUP_SIZE_M; + + // Actual group size (handling boundary case) + const int group_size_m = min(GROUP_SIZE_M, num_tile_m - first_pid_m); + + // Position within the group + const int idx_in_group = tile_id % num_pid_in_group; + + // Swizzled tile indices: alternate M then N within group + tile_m_id = first_pid_m + (idx_in_group % group_size_m); + tile_n_id = idx_in_group / group_size_m; + } +} + +// Perf GEMM Common End + +#define DO_PRAGMA_(x) _Pragma(#x) +#define DO_PRAGMA(x) DO_PRAGMA_(x) +#define UNROLL DO_PRAGMA(unroll) +#define UNROLL_N(n) DO_PRAGMA(unroll n) + +using b16 = bfloat16_t; + +constexpr int WARP_SIZE = 64; + +template constexpr T const_min(T a, T b) { return a > b ? b : a; } + +template +__device__ inline void warp_copy_tile(b16 *(&dst)[NUM_DST], const b16 *src) { + static_assert(TILE_K % WARP_SIZE == 0); + constexpr int VEC_SIZE = TILE_K / WARP_SIZE * sizeof(b16); + static_assert(VEC_SIZE == 16 || VEC_SIZE == 8 || VEC_SIZE == 4 || VEC_SIZE == 2); + using cp_t = typename ck::vector_type::type; + constexpr int NUM_STAGES = const_min(8, exact_div()); + cp_t regs[NUM_STAGES]; + + const auto src_rsrc = ck::make_wave_buffer_resource_with_default_range(src); + ck::int32x4_t dst_rsrc[NUM_DST]; + UNROLL + for (int i = 0; i < NUM_DST; i++) { + dst_rsrc[i] = ck::make_wave_buffer_resource_with_default_range(dst[i]); + } + const auto lane_id = threadIdx.x % WARP_SIZE; + + auto load_row = [&](int reg_idx, int row_idx) { + const int soffset = row_idx * K * sizeof(b16); + const int voffset = lane_id * VEC_SIZE; + regs[reg_idx] = ck::amd_buffer_load_impl_raw< + VEC_SIZE, ck::AmdBufferCoherenceEnum::SYSTEM_NT0>( + src_rsrc, voffset, soffset + ); + }; + auto store_row = [&](int reg_idx, int row_idx) { + UNROLL + for (int i = 0; i < NUM_DST; i++) { + const int soffset = row_idx * K * sizeof(b16); + const int voffset = lane_id * VEC_SIZE; + ck::amd_buffer_store_impl_raw< + VEC_SIZE, ck::AmdBufferCoherenceEnum::SYSTEM_NT0>( + regs[reg_idx], dst_rsrc[i], voffset, soffset + ); + } + }; + + /* + ld 0 + ld 1 + ld 2 + ld 3 + st 0 + ld 4 + st 1 + ld 5 + st 2 + + ld 6 + st 3 + ld 7 + st 4 + ld 8 + st 5 + ld 9 + st 6 + + ld 10 + st 7 + ld 11 + st 8 + st 9 + st 10 + st 11 + */ + static_assert(TILE_M % (NUM_STAGES * 2) == 0); + auto copy_prologue = [&](int row_idx=0) { + UNROLL + for (int i = 0; i < NUM_STAGES - 1; i++) { + load_row(i % NUM_STAGES, row_idx + i); + } + UNROLL + for (int i = 0; i < NUM_STAGES - 1; i++) { + load_row((i + NUM_STAGES - 1) % NUM_STAGES, row_idx + i + NUM_STAGES - 1); + store_row(i % NUM_STAGES, row_idx + i); + } + }; + // row_idx: 0, TILE_M - (NUM_STAGES * 2), NUM_STAGES + auto copy_loop_body = [&](int row_idx) { + UNROLL + for (int i = 0; i < NUM_STAGES; i++) { + auto load_off = (NUM_STAGES - 1) * 2; + auto store_off = (NUM_STAGES - 1); + load_row((i + load_off) % NUM_STAGES, row_idx + i + load_off); + store_row((i + store_off) % NUM_STAGES, row_idx + i + store_off); + } + }; + // row_idx: TILE_M - 2 + auto copy_epilogue = [&](int row_idx=TILE_M-2) { + UNROLL + for (int i = 0; i < 2; i++) { + auto load_off = row_idx; + auto store_off = row_idx - (NUM_STAGES - 1); + load_row((i + load_off) % NUM_STAGES, i + load_off); + store_row((i + store_off) % NUM_STAGES, i + store_off); + } + UNROLL + for (int i = 0; i < NUM_STAGES - 1; i++) { + auto store_off = row_idx - (NUM_STAGES - 1) + 2; + store_row((i + store_off) % NUM_STAGES, i + store_off); + } + }; + + copy_prologue(); + for(int i = 0; i < TILE_M - (NUM_STAGES * 2); i += NUM_STAGES) { + copy_loop_body(i); + } + copy_epilogue(); + + // auto copy_loop_body = [&](int row_idx) { + // load_row(0, row_idx + 2); + // // sync, 1 ld, 8 st in flight + // store_row(1, row_idx + 1); + // load_row(1, row_idx + 3); + // // sync, 1 ld, 8 st in flight + // store_row(0, row_idx + 2); + // }; + + // constexpr int UNROLL_FACTOR = const_min(TILE_M / NUM_STAGES, 8); + // constexpr int INNER_M = UNROLL_FACTOR * NUM_STAGES; + // static_assert(TILE_M % NUM_STAGES == 0); + // static_assert(TILE_M >= INNER_M); + + // load_row(0, 0); + // load_row(1, 1); + // store_row(0, 0); + // for (int i = 0; i < TILE_M - INNER_M; i += INNER_M) { + // asm(";main loop begin"); + // UNROLL + // for (int j = 0; j < INNER_M; j += NUM_STAGES) { + // copy_loop_body(i + j); + // } + // asm(";main loop end"); + // } + // UNROLL + // for (int i = TILE_M - INNER_M; i < TILE_M - NUM_STAGES; i += NUM_STAGES) { + // copy_loop_body(i); + // } + // store_row(1, TILE_M - 1); +} + +using signal_t = int[128][128]; +#ifdef LOCAL_TEST +constexpr int WORLD_SIZE = 2; +#else +constexpr int WORLD_SIZE = 8; +#endif +constexpr int NUM_SMS = 304; +constexpr int AG_SMS = 16; +constexpr int GEMM_SMS = NUM_SMS - AG_SMS; + +constexpr int CHUNK_K = 128; +constexpr int CHUNK_M = 64; + +struct workspace_t { + int grid_barrier; +}; + +constexpr int MAX_M = 8192; +constexpr int MAX_N = 29568; +constexpr int MAX_K = 8192; +constexpr int MAX_M_LOCAL = MAX_M / WORLD_SIZE; + +// global variables +struct ipc_mem_t { + // FIXME: reset signals + signal_t nvl_recv_signals; + signal_t nvl_consume_signals; + int nvl_barrier[WORLD_SIZE]; +}; + +struct ipc_cache_t { + b16 nvl_recv_x[MAX_M * MAX_K]; +}; + +struct global_t { + // config + int rank; + int m, n, k; + int next_signal; + + // buffers + ipc_mem_t *ipc_mems[WORLD_SIZE] = {}; + ipc_cache_t *ipc_caches[WORLD_SIZE] = {}; + workspace_t *workspace; +}; + +template __device__ inline void st_relaxed_sys(T *ptr, T val) { + __hip_atomic_store(ptr, val, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM); +} + +template __device__ inline T ld_relaxed_sys(T *ptr) { + return __hip_atomic_load(ptr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM); +} + +template __device__ inline void st_release_global(T *ptr, T val) { + __hip_atomic_store(ptr, val, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_AGENT); +} + +template __device__ inline T ld_acquire_global(T *ptr) { + return __hip_atomic_load(ptr, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_AGENT); +} + +template +__device__ inline void send_kernel(const int sm_id, const b16 *x, global_t global) { + const auto num_sms = NUM_AG_SMS; + const auto num_warps = blockDim.x / WARP_SIZE; + const int num_global_warps = num_sms * num_warps; + + // put soffset to sgpr + const auto warp_id = + __builtin_amdgcn_readfirstlane(threadIdx.x / WARP_SIZE); + const auto global_warp_id = sm_id * num_warps + warp_id; + const auto lane_id = threadIdx.x % WARP_SIZE; + + if (global_warp_id >= NUM_AG_WARPS) { + return; + } + + const auto rank = global.rank; + + static_assert(M % WORLD_SIZE == 0); + constexpr auto M_LOCAL = M / WORLD_SIZE; + + constexpr auto NUM_CHUNKS_M = ceil_div(M_LOCAL, CHUNK_M); + constexpr auto NUM_CHUNKS_K = ceil_div(K, CHUNK_K); + constexpr auto TAIL_CHUNK_M = M_LOCAL - (NUM_CHUNKS_M - 1) * CHUNK_M; + constexpr auto TAIL_CHUNK_K = K - (NUM_CHUNKS_K - 1) * CHUNK_K; + + for (int i = global_warp_id; i < NUM_CHUNKS_M * NUM_CHUNKS_K * WORLD_SIZE; + i += NUM_AG_WARPS) { + const auto dst_rank = i % WORLD_SIZE; + const auto chunk_id = i / WORLD_SIZE; + // TODO: maybe k first + const auto chunk_k = chunk_id / NUM_CHUNKS_M; + const auto chunk_m = chunk_id % NUM_CHUNKS_M; + + const auto m_begin = chunk_m * CHUNK_M; + const auto k_begin = chunk_k * CHUNK_K; + const auto offset = m_begin * K + k_begin; + + const auto chunk_src = x + offset; + b16 *chunk_dst[1] = { + global.ipc_caches[dst_rank]->nvl_recv_x + rank * M_LOCAL * K + + offset + }; + + if (TAIL_CHUNK_M != CHUNK_M && chunk_m == NUM_CHUNKS_M - 1) { + if (TAIL_CHUNK_K != CHUNK_K && chunk_k == NUM_CHUNKS_K - 1) { + warp_copy_tile( + chunk_dst, chunk_src + ); + } else { + warp_copy_tile(chunk_dst, chunk_src); + } + } else { + if (TAIL_CHUNK_K != CHUNK_K && chunk_k == NUM_CHUNKS_K - 1) { + warp_copy_tile(chunk_dst, chunk_src); + } else { + warp_copy_tile(chunk_dst, chunk_src); + } + } + + if (lane_id == 0) { + st_relaxed_sys( + &global.ipc_mems[dst_rank] + ->nvl_recv_signals[chunk_k][rank * NUM_CHUNKS_M + chunk_m], + global.next_signal + ); + + // if (lane_id == 0) { + // printf("m %d, k %d\n", chunk_m, chunk_k); + // } + + constexpr int NUM_PREFETCH = 1; + if (chunk_k >= NUM_PREFETCH) { + auto &consume = global.ipc_mems[dst_rank] + ->nvl_recv_signals[chunk_k - NUM_PREFETCH][rank * NUM_CHUNKS_M + chunk_m]; + while (global.next_signal != ld_relaxed_sys(&consume)) + __builtin_amdgcn_s_sleep(5); + } + } + } +} + +template +__launch_bounds__(NUM_THREADS) +__global__ void gemm_kernel( + const bfloat16_t *x, // M x K + const bfloat16_t *w, // N x K + const bfloat16_t *b, // N + bfloat16_t *c, // M x N + global_t global +) { + constexpr int NUM_CONCURRENT_CHUNK_K = 1; + constexpr int NUM_SEND_WARPS = ceil_div(M / WORLD_SIZE, CHUNK_M) * WORLD_SIZE * NUM_CONCURRENT_CHUNK_K; + constexpr int NUM_AG_SMS = ceil_div(NUM_SEND_WARPS, NUM_THREADS / WARP_SIZE); + constexpr int NUM_GEMM_SMS = NUM_SMS - NUM_AG_SMS; + if (blockIdx.x >= NUM_GEMM_SMS) { + // allgather + const int sm_id = blockIdx.x - NUM_GEMM_SMS; + send_kernel(sm_id, x, global); + return; + } + x = global.ipc_caches[global.rank]->nvl_recv_x; + auto signal = &global.ipc_mems[global.rank]->nvl_recv_signals; + + const int pid = __builtin_amdgcn_readfirstlane(blockIdx.x); + const int tid = threadIdx.x; + const int lane_id = __lane_id(); + const int warp_id = __builtin_amdgcn_readfirstlane(tid / roc_isa::AMDGCN_WAVEFRONT_SIZE); + __builtin_assume(pid >= 0 && pid < NUM_SMS); + __builtin_assume(tid >= 0 && tid < NUM_THREADS); + __builtin_assume(lane_id >= 0 && lane_id < 64); + + // gemm + + constexpr int COMM_K = CHUNK_K; + auto wait_signal = [&](int tile_m_id, int tile_k_id) { + constexpr int M_LOCAL = exact_div(); + constexpr int COMM_M = const_min(M_LOCAL, CHUNK_M); + + auto signal_k_id = tile_k_id / exact_div(); + + static_assert(M >= BM && M % BM == 0); + int signal_m_id_begin, signal_m_id_end; + if constexpr (COMM_M < BM) { + signal_m_id_begin = tile_m_id * exact_div(); + signal_m_id_end = signal_m_id_begin + exact_div(); + } else { + signal_m_id_begin = tile_m_id / exact_div(); + signal_m_id_end = signal_m_id_begin + 1; + } + + auto &sig = *signal; + + if(warp_id == 0 && lane_id < (signal_m_id_end - signal_m_id_begin)) { + // printf("wait m %d k %d\n", signal_m_id, signal_k_id); + while(global.next_signal != ld_relaxed_sys(&sig[signal_k_id][signal_m_id_begin + lane_id])) + __builtin_amdgcn_s_sleep(5); + if (tile_k_id % exact_div() == 0) { + auto &consume = global.ipc_mems[global.rank]->nvl_consume_signals[signal_k_id][signal_m_id_begin + lane_id]; + st_relaxed_sys(&consume, global.next_signal); + } + // __builtin_amdgcn_fence(__ATOMIC_ACQUIRE, ""); + } + __syncthreads(); + }; + + // Perf GEMM + constexpr int num_tile_m = ceil_div(M, BM); + constexpr int num_tile_n = ceil_div(N, BN); + constexpr int num_tiles = num_tile_m * num_tile_n * SPLIT_K; + // each split handles K_per_split + constexpr int K_per_split = exact_div(); + constexpr int num_tile_k = ceil_div(K_per_split, BK); + + using inst_nums = roc_isa::InstCalculator; + static_assert(BK % 4 == 0 && NUM_THREADS * 4 % BK == 0); + constexpr int WM = 16, WN = 16, WK = 16; + + constexpr int Frag_M = exact_div(); + constexpr int Frag_N = exact_div(); + constexpr int Frag_K = exact_div(); + const int warp_m = warp_id / WARP_N; + const int warp_n = warp_id % WARP_N; + using FragX = bf16x4_t; + using FragW = bf16x4_t; + using FragC = fp32x4_t; + __shared__ bfloat16_t s_x[BM][BK]; + __shared__ bfloat16_t s_w[BN][BK]; + bf16x4_t vgpr_x[ceil_div(BM * BK, NUM_THREADS * 4)]; + bf16x4_t vgpr_w[ceil_div(BN * BK, NUM_THREADS * 4)]; + + FragC frag_c[Frag_M][Frag_N]; + FragX frag_x[Frag_M][Frag_K]; + FragW frag_w[Frag_N][Frag_K]; + fp32x4_t out_fp32[Frag_M][Frag_N]; // AccVGPR -> VGPR Buffer + auto b_arr = ck::make_wave_buffer_resource(const_cast(b), N); + auto c_arr = ck::make_wave_buffer_resource(c, M * N); + auto x_arr = ck::make_wave_buffer_resource(const_cast(x), M * K); + auto w_arr = ck::make_wave_buffer_resource(const_cast(w), N * K); + + constexpr bool LOAD_BIAS = true; + constexpr int NUM_XCDS = 8; + for (int tile_id=pid; tile_id(tile_id / SPLIT_K, tile_m_id, tile_n_id); + int m = tile_m_id * BM; + int n = tile_n_id * BN; + + int k_offset = split_k_id * K_per_split * sizeof(bfloat16_t); + int v_offset = ((tid * 4 / BK) * K + (tid * 4 % BK)) * sizeof(bfloat16_t); + auto load_vgpr = [&](int k) FORCE_INLINE { + uint32_t src_addr_shift = ((K_per_split % BK == 0) || (k + tid * 4 % BK < K_per_split)) ? 0 : 0x80000000; + ck::static_for<0, sizeof(vgpr_x) / sizeof(vgpr_x[0]), 1>{}([&](auto t) { + int s_offset = ((m * K + k) + t * NUM_THREADS * 4 / BK * K) * sizeof(bfloat16_t); + vgpr_x[t] = __builtin_bit_cast(bf16x4_t, ck::amd_buffer_load_impl_raw( + x_arr, v_offset + src_addr_shift, s_offset + k_offset)); + }); + ck::static_for<0, sizeof(vgpr_w) / sizeof(vgpr_w[0]), 1>{}([&](auto t) { + int s_offset = ((n * K + k) + t * NUM_THREADS * 4 / BK * K) * sizeof(bfloat16_t); + vgpr_w[t] = __builtin_bit_cast(bf16x4_t, ck::amd_buffer_load_impl_raw( + w_arr, v_offset + src_addr_shift, s_offset + k_offset)); + }); + + }; + + + auto load_lds = [&]() FORCE_INLINE { + // diagonal swizzle, shape=[16, 64] dtype=bfloat16 + #pragma unroll + for (int t=0;t(&s_x[row0 + row1][col1]) = vgpr_x[t]; + } + #pragma unroll + for (int t=0;t(&s_w[row0 + row1][col1]) = vgpr_w[t]; + } + }; + + auto zero_all_frags = [&]() FORCE_INLINE { + ck::static_for<0, Frag_M, 1>{}([&](auto i) { + ck::static_for<0, Frag_N, 1>{}([&](auto j) { + ck::static_for<0, 4, 1>{}([&](auto t) { frag_c[i][j].x[t] = 0; }); + }); + }); + }; + + + + auto load_frag = [&](int tile_kk, const bfloat16_t *ptr, bf16x4_t &frag, bool permute) { + // ptr: [16][16] + const int row0 = lane_id % 16; + const int col0 = tile_kk * 16 + lane_id / 16 * 4; + const int col1 = (row0 * 4 + col0) % BK; + frag = *reinterpret_cast(&ptr[row0 * BK + col1]); + }; + + auto frags_load = [&]() { + #pragma unroll + for (int k=0;k(&s_x[row0 + row1][col1]); + } + } + #pragma unroll + for (int k=0;k(&s_w[row0 + row1][col1]); + } + } + }; + + auto frags_mfma = [&] { + #pragma unroll + for (int i=0;i VGPR + #pragma unroll + for (int i=0; i( + b_arr, b_v_offset, b_s_offset)); + #pragma unroll + for (int t = 0; t < 4; ++t) { + out_fp32[i][j].x[t] += static_cast(LOAD_BIAS ? b_vec.x[t] : 0); + } + } + } + + ck::static_for<0, Frag_M, 1>{}([&](auto i) { + ck::static_for<0, Frag_N, 1>{}([&](auto j) { + bf16x4_t c_out_bf16; + #pragma unroll + for (int t = 0; t < 4; ++t) c_out_bf16.x[t] = fast_f32tob16(out_fp32[i][j].x[t]); + int row = lane_id % 16; + int col = lane_id / 16 * 4; + uint32_t src_addr_shift = (N % BN == 0) || (n + (j + warp_n * Frag_N) * WN + col < N) ? 0 : 0x80000000; + int b_s_offset = (n + (j + warp_n * Frag_N) * WN) * sizeof(bfloat16_t); + int c_s_offset = b_s_offset + (m + (i + warp_m * Frag_M) * WM) * N * sizeof(bfloat16_t); + int c_v_offset = col * sizeof(bfloat16_t) + src_addr_shift + (row * N) * sizeof(bfloat16_t); + ck::amd_buffer_store_impl_raw(c_out_bf16.pack, c_arr, c_v_offset, c_s_offset); + }); + }); + } else { + // SPLIT_K > 1: store FP32 partials into workspace [split_id, M, N] (row-major floats) + // TODO: workspace + split_k_id * M * N + auto ws_arr = ck::make_wave_buffer_resource(nullptr, M * N); + ck::static_for<0, Frag_M, 1>{}([&](auto i) { + ck::static_for<0, Frag_N, 1>{}([&](auto j) { + int row = lane_id % 16; + int col = lane_id / 16 * 4; + uint32_t src_addr_shift = (N % BN == 0) || (n + (j + warp_n * Frag_N) * WN + col < N) ? 0 : 0x80000000; + int b_s_offset = (n + (j + warp_n * Frag_N) * WN) * sizeof(float); + int c_s_offset = b_s_offset + (m + (i + warp_m * Frag_M) * WM) * N * sizeof(float); + int c_v_offset = col * sizeof(float) + src_addr_shift + (row * N) * sizeof(float); + ck::amd_buffer_store_impl_raw(out_fp32[i][j].pack, ws_arr, c_v_offset, c_s_offset); + }); + }); + } + }; + + + + wait_signal(tile_m_id, 0); + load_vgpr(0); // GDS -> VGPR #0 + load_lds(); // VGPR -> LDS #0 + load_vgpr(1 * BK); // GDS -> VGPR #1 + zero_all_frags(); + // __builtin_amdgcn_s_waitcnt(0x70); + // __builtin_amdgcn_s_barrier(); + block_sync_lds(); + // release_signal(); + frags_load(); // LDS -> FRAG #0 + __builtin_amdgcn_sched_barrier(0); + for (int tile_k_id = 1; tile_k_id < (num_tile_k - 1); ++tile_k_id) { + block_sync_lds(); + // Stage 1 + load_lds(); // VGPR -> LDS #1 + if ((tile_k_id + 1) % exact_div() == 0) { + wait_signal(tile_m_id, tile_k_id + 1); + } + load_vgpr((tile_k_id + 1) * BK); // GDS -> VGPR #2(k+1) + frags_mfma(); // MFMA #0(k-1) + #pragma unroll + for (int k = 0; k < inst_nums::buffer_load_dwordx2; ++k) { + __builtin_amdgcn_sched_group_barrier(0x200, inst_nums::schedule_config::stage1_ds_write, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, inst_nums::schedule_config::stage1_mfma_before_vmem, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, inst_nums::schedule_config::stage1_vmem, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, inst_nums::schedule_config::stage1_mfma_after_vmem, 0); // MFMA + } + block_sync_lds(); + // Stage 2 + frags_load(); // LDS -> FRAG #1(k) + #pragma unroll + for (int k = 0; k < inst_nums::ds_read_b128; ++k) { + __builtin_amdgcn_sched_group_barrier(0x008, inst_nums::schedule_config::stage2_mfma, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, inst_nums::schedule_config::stage2_ds_read, 0); // DS read + } + __builtin_amdgcn_sched_barrier(0); + + } + frags_mfma(); // MFMA #1(n-2) + block_sync_lds(); + load_lds(); // VGPR -> LDS #2(n-1) + block_sync_lds(); + frags_load(); // LDS -> FRAG #2(n-1) + frags_mfma(); // MFMA #2(n-1) + + // __builtin_amdgcn_sched_barrier(0); + store_frags(); + // __syncthreads(); + // if (tid == 0) { + // __hip_atomic_store(&signal_arr[tile_m_id][tile_n_id], signal_val, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM); + // } + } + // block_sync_gds(); + // release_signal(); +} + +constexpr int64_t pack_mnk(int m, int n, int k) { + return (int64_t(m) << 32) | (int64_t(n) << 16) | int64_t(k); +} + +// TODO: fix 2880 ck=512 +// clang-format off +#ifdef LOCAL_TEST +#define SWITCH_GEMM_MNK(m, n, k, MACRO, ...) \ + switch (pack_mnk(m, n, k)) { \ + case pack_mnk(64 , 2304, 7168): MACRO(64 , 2304, 7168, 32, 32, 128, 2, 2, 1, 8, 64, 512, ##__VA_ARGS__); break; \ + case pack_mnk(128 , 1536, 4096): MACRO(128 , 1536, 4096, 64, 64, 128, 2, 2, 1, 32, 64, 512, ##__VA_ARGS__); break; \ + case pack_mnk(512 , 360 , 2880): MACRO(512 , 360 , 2880, 64, 64, 128, 2, 2, 1, 8, 64, 128, ##__VA_ARGS__); break; \ + case pack_mnk(1024, 512 , 4096): MACRO(1024, 512 , 4096, 128, 64, 128, 2, 2, 1, 40, 64, 512, ##__VA_ARGS__); break; \ + case pack_mnk(2048, 1792, 4096): MACRO(2048, 1792, 4096, 256, 224, 64, 2, 2, 1, 32, 64, 128, ##__VA_ARGS__); break; \ + case pack_mnk(2048, 3696, 8192): MACRO(2048, 3696, 8192, 256, 224, 64, 2, 2, 1, 32, 64, 128, ##__VA_ARGS__); break; \ + default: fprintf(stderr, "invalid mnk: %d %d %d", m, n, k); \ + } +#else +// bm, bn, bk, wm, wn +#define SWITCH_GEMM_MNK(m, n, k, MACRO, ...) \ + switch (pack_mnk(m, n, k)) { \ + case pack_mnk(64, 2304, 7168): MACRO(64, 2304, 7168, 32, 32, 128, 2, 2, 1, 8, 64, 512, ##__VA_ARGS__); break; \ + case pack_mnk(512, 1536, 4096): MACRO(512, 1536, 4096, 64, 64, 128, 2, 2, 1, 32, 64, 512, ##__VA_ARGS__); break; \ + case pack_mnk(2048, 360, 2880): MACRO(2048, 360, 2880, 64, 64, 128, 2, 2, 1, 8, 64, 128, ##__VA_ARGS__); break; \ + case pack_mnk(4096, 512, 4096): MACRO(4096, 512, 4096, 128, 64, 128, 2, 2, 1, 40, 64, 512, ##__VA_ARGS__); break; \ + case pack_mnk(8192, 1792, 4096): MACRO(8192, 1792, 4096, 256, 224, 64, 2, 2, 1, 32, 64, 128, ##__VA_ARGS__); break; \ + case pack_mnk(8192, 3696, 8192): MACRO(8192, 3696, 8192, 256, 224, 64, 2, 2, 1, 32, 64, 128, ##__VA_ARGS__); break; \ + default: fprintf(stderr, "invalid mnk"); \ + } +#endif +// clang-format on + +class AgGemm { + private: + global_t global{}; + + public: + AgGemm(int rank, int m, int n, int k) { + global.rank = rank; + global.m = m; + global.n = n; + global.k = k; + global.next_signal = 0; + } + + ~AgGemm() { + for (auto i = 0; i < WORLD_SIZE; i++) { + auto ipc_mem = global.ipc_mems[i]; + if (ipc_mem && i != global.rank) { + C10_HIP_CHECK(hipIpcCloseMemHandle(ipc_mem)); + } + auto ipc_cache = global.ipc_caches[i]; + if (ipc_cache && i != global.rank) { + C10_HIP_CHECK(hipIpcCloseMemHandle(ipc_cache)); + } + } + auto local_mem = global.ipc_mems[global.rank]; + if (local_mem) { + C10_HIP_CHECK(hipFree(local_mem)); + } + auto local_cache = global.ipc_caches[global.rank]; + if (local_cache) { + C10_HIP_CHECK(hipFree(local_cache)); + } + if (global.workspace) { + C10_HIP_CHECK(hipFree(global.workspace)); + } + } + + auto get_ipc_handle() -> pybind11::bytearray { + void *ws; + C10_HIP_CHECK(hipMalloc(&ws, sizeof(workspace_t))); + C10_HIP_CHECK(hipMemset(ws, 0, sizeof(workspace_t))); + global.workspace = reinterpret_cast(ws); + + void *ptr; + C10_HIP_CHECK(hipExtMallocWithFlags( + &ptr, sizeof(ipc_mem_t), hipDeviceMallocUncached + )); + C10_HIP_CHECK(hipMemset(ptr, 0, sizeof(ipc_mem_t))); + global.ipc_mems[global.rank] = reinterpret_cast(ptr); + + // C10_HIP_CHECK(hipMalloc(&ptr, sizeof(ipc_cache_t))); + C10_HIP_CHECK(hipExtMallocWithFlags( + &ptr, sizeof(ipc_cache_t), hipDeviceMallocUncached + )); + C10_HIP_CHECK(hipMemset(ptr, 0, sizeof(ipc_cache_t))); + global.ipc_caches[global.rank] = reinterpret_cast(ptr); + + std::vector handles(2); + C10_HIP_CHECK(hipIpcGetMemHandle(&handles[0], global.ipc_mems[global.rank]) + ); + C10_HIP_CHECK( + hipIpcGetMemHandle(&handles[1], global.ipc_caches[global.rank]) + ); + return { + reinterpret_cast(handles.data()), HIP_IPC_HANDLE_SIZE * 2 + }; + } + + auto init(const std::vector &ipc_handles) { + for (int i = 0; i < WORLD_SIZE; i++) { + if (i == global.rank) { + continue; + } + hipIpcMemHandle_t handle; + auto handle_buf = std::string(ipc_handles[i]); + auto handles = + reinterpret_cast(handle_buf.data()); + void *ptr; + C10_HIP_CHECK(hipIpcOpenMemHandle( + &ptr, handles[0], hipIpcMemLazyEnablePeerAccess + )); + global.ipc_mems[i] = reinterpret_cast(ptr); + C10_HIP_CHECK(hipIpcOpenMemHandle( + &ptr, handles[1], hipIpcMemLazyEnablePeerAccess + )); + global.ipc_caches[i] = reinterpret_cast(ptr); + } + } + + auto get_x_full() { + auto x_full = torch::from_blob( + global.ipc_caches[global.rank]->nvl_recv_x, {global.m, global.k}, + torch::TensorOptions().dtype(torch::kBFloat16).device(torch::kCUDA) + ); + return x_full; + } + + auto get_signal() { + auto signal = torch::from_blob( + global.ipc_mems[global.rank]->nvl_recv_signals, {128, 128}, + torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA) + ); + return signal; + } + + void reset() { + auto stream = at::cuda::getCurrentHIPStream().stream(); + auto &signal = global.ipc_mems[global.rank]->nvl_recv_signals; + C10_HIP_CHECK(hipMemsetAsync(&signal, 0, sizeof(signal), stream)); + } + + auto + perf_gemm(torch::Tensor &x, torch::Tensor &w, torch::Tensor &b) { + auto m_local = x.size(0); + const auto m = m_local * WORLD_SIZE; + auto n = w.size(0); + auto k = w.size(1); + auto out = torch::empty({m, n}, x.options()); + + auto x_ptr = reinterpret_cast(x.const_data_ptr()); + auto w_ptr = reinterpret_cast(w.const_data_ptr()); + auto b_ptr = reinterpret_cast(b.const_data_ptr()); + auto o_ptr = reinterpret_cast(out.data_ptr()); + + // udpate signal + global.next_signal++; + + constexpr int NUM_THREADS = 256; + + dim3 grid(NUM_SMS); + dim3 block(NUM_THREADS); + + auto stream = at::cuda::getCurrentHIPStream().stream(); + + // clang-format off + #define LAUNCH_PERF(m, n, k, bm, bn, bk, wm, wn, splitk, group_m, cm, ck) gemm_kernel<<>>(x_ptr, w_ptr, b_ptr, o_ptr, global) + SWITCH_GEMM_MNK(m, n, k, LAUNCH_PERF) + // clang-format on + return out; + } +}; + +PYBIND11_MODULE(ag_gemm, m) { + py::class_(m, "AgGemm") + .def(py::init()) + .def("get_ipc_handle", &AgGemm::get_ipc_handle) + .def("init", &AgGemm::init) + .def("get_x_full", &AgGemm::get_x_full) + .def("reset", &AgGemm::reset) + .def( + "perf_gemm", &AgGemm::perf_gemm, py::arg("x"), py::arg("w"), + py::arg("b") + ) + .def("get_signal", &AgGemm::get_signal); +} \ No newline at end of file diff --git a/dist-infer/ag-gemm/perf_gemm.h b/dist-infer/ag-gemm/perf_gemm.h new file mode 100644 index 0000000..bf72cea --- /dev/null +++ b/dist-infer/ag-gemm/perf_gemm.h @@ -0,0 +1,454 @@ +#include +#include +#include +#include +#include +#include +#include +#define FAST_UNSAFE_CAST +// #define SWIZZLE_XCD_PID +// #define SWIZZLE_L2_TILE + +#define FORCE_INLINE __attribute__((always_inline)) + + +namespace roc_isa { + constexpr int AMDGCN_WAVEFRONT_SIZE = 64; +namespace issue_latency { + constexpr int v_mfma_f32_16x16x16_bf16 = 4; + constexpr int ds_read_b128 = 2 * 4; + constexpr int ds_write_b128 = 5 * 4; + constexpr int buffer_load_dwordx2 = 1 * 2; +} // namespace issue_latency + +template +struct InstCalculator { + static constexpr int v_mfma_f32_16x16x16_bf16 = (BM * BN * BK) / (WARP_M * WARP_N) / (16*16*16); + // Compiler will merge two ds_{read,write}_b64 to ds_{read,write`}2st64_b64 + static constexpr int ds_read_b128_a = (BM * BK / WARP_M) / 64 / 8; + static constexpr int ds_read_b128_b = (BN * BK / WARP_N) / 64 / 8; + static constexpr int ds_read_b128 = ds_read_b128_a + ds_read_b128_b; + static constexpr int ds_write_b128_a = (BM * BK) / NUM_THREADS / 8; + static constexpr int ds_write_b128_b = (BN * BK) / NUM_THREADS / 8; + static constexpr int ds_write_b128 = ds_write_b128_a + ds_write_b128_b; + static constexpr int buffer_load_dwordx2_a = (BM * BK) / NUM_THREADS / 4; + static constexpr int buffer_load_dwordx2_b = (BN * BK) / NUM_THREADS / 4; + static constexpr int buffer_load_dwordx2 = buffer_load_dwordx2_a + buffer_load_dwordx2_b; + + __device__ static FORCE_INLINE void schedule_loop() { + if constexpr ((BM == 256 && BN == 224) || (BM == 224 && BN == 256)) { + // Large GEMM + // MFMA 224, DS_READ 30, DS_WRITE 15, BUFFER_LOAD 30 + #pragma unroll + for (int k = 0; k < buffer_load_dwordx2; ++k) { // 150 + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA + } + #pragma unroll + for (int k = 0; k < ds_read_b128; ++k) { // 60 + __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + } + } else if constexpr ((BM == 256 && BN == 128) || (BM == 128 && BN == 256) || (BM == 128 && BN == 128)) { + // Middle GEMM + // (256X128) MFMA 128, DS_READ_24, DS_WRITE 12, BUFFER_LOAD 24 + // (128X128) MFMA 64, DS_READ 16, DS_WRITE 8, BUFFER_LOAD 16 + #pragma unroll + for (int k = 0; k < buffer_load_dwordx2; ++k) { // 96, 48 + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + } + #pragma unroll + for (int k = 0; k < ds_read_b128; ++k) { // 24, 32 + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read + } + } else if constexpr ((BM == 128 && BN == 32) || (BM == 32 && BN == 128) + || (BM == 128 && BN == 64) || (BM == 64 && BN == 128) + || (BM == 64 && BN == 32) || (BM == 32 && BN == 32) + || (BM == 32 && BN == 32) + ) { + // Small GEMM + // (128x64) MFMA 32, DS_WRITE 6, DS_READ 12, BUFFER_LOAD 12 + // (128x32) MFMA 16, DS_WRITE 5, DS_READ 8 , BUFFER_LOAD 10 + // (64x32) MFMA 8, DS_WRITE 3, DS_READ 6 , BUFFER_LOAD 6 + #pragma unroll + for (int k = 0; k < buffer_load_dwordx2; ++k) { // 20 + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + } + #pragma unroll + for (int k = 0; k < ds_read_b128; ++k) { // 8 + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + } + } else { + static_assert((BM == 256 && BN == 224) || (BM == 224 && BN == 256) || + (BM == 256 && BN == 128) || (BM == 128 && BN == 256) || + (BM == 128 && BN == 128) || + (BM == 128 && BN == 64) || (BM == 64 && BN == 128) || + (BM == 128 && BN == 32) || (BM == 32 && BN == 128) || + (BM == 64 && BN == 32) || (BM == 32 && BN == 64) || (BM == 32 && BN == 32), + "Unsupported BM, BN"); + } + + } +}; + +} // namespace roc_isa + + + +using bfloat16_t = __bf16; + +__device__ __host__ FORCE_INLINE inline constexpr int ceil_div(int a, int b) { + return (a + b - 1) / b; +} + +template +__device__ __host__ FORCE_INLINE constexpr int exact_div() { + static_assert(a % b == 0); + return a / b; +} + + +template +struct PackN_t { + using t = __attribute__((vector_size(N * sizeof(dtype)))) dtype; + static constexpr auto n = N; + static constexpr auto H = N / 2; + union { + dtype x[N]; + t pack; + struct { dtype low[H], high[H]; }; + }; +}; + +using bf16x4_t = PackN_t; +using fp32x4_t = PackN_t; +using bf16x8_t = PackN_t; +using fp32x8_t = PackN_t; + + +__device__ FORCE_INLINE inline bfloat16_t fast_f32tob16(float f) { +#ifdef FAST_UNSAFE_CAST + union { + float fp32; + unsigned int u32; + } u = {f}; + u.u32 += 0x7FFF + ((u.u32 >> 16) & 1); + auto ret = u.u32 >> 16; + return reinterpret_cast(ret); +#else + return static_cast(f); +#endif +} + + +__device__ FORCE_INLINE inline float fast_b16tof32(bfloat16_t bf) { +#ifdef FAST_UNSAFE_CAST + union { + float fp32; + unsigned int u32; + } u; + u.u32 = (reinterpret_cast(bf)) << 16; + return u.fp32; +#else + return static_cast(bf); +#endif +} + + + +__device__ inline void block_sync_lds() { + __builtin_amdgcn_s_waitcnt(0xc07f); + __builtin_amdgcn_s_barrier(); +} + +template +__device__ __forceinline__ void compute_tile_indices( + int tile_id, + int &tile_m_id, + int &tile_n_id +) { + // Swizzle pattern for better L2 cache locality + // Groups tiles in blocks of GROUP_SIZE_M x num_tile_n + constexpr int num_pid_in_group = GROUP_SIZE_M * num_tile_n; + + // Which group does this tile belong to? + const int group_id = tile_id / num_pid_in_group; + + // First M-dimension tile in this group + const int first_pid_m = group_id * GROUP_SIZE_M; + + // Actual group size (handling boundary case) + const int group_size_m = min(GROUP_SIZE_M, num_tile_m - first_pid_m); + + // Position within the group + const int idx_in_group = tile_id % num_pid_in_group; + + // Swizzled tile indices: alternate M then N within group + tile_m_id = first_pid_m + (idx_in_group % group_size_m); + tile_n_id = idx_in_group / group_size_m; +} + +using signal_t = int[128][128]; + +template +__launch_bounds__(NUM_THREADS) +__global__ void gemm_kernel( + const bfloat16_t *x, // M x K + const bfloat16_t *w, // N x K + const bfloat16_t *b, // N + bfloat16_t *c, // M x N + signal_t *signal +) { + const int pid0 = blockIdx.x; + constexpr int GEMM_SMS = NUM_SMS; + constexpr int NUM_XCDS = 8; + constexpr int num_tile_m = ceil_div(M, BM); + constexpr int num_tile_n = ceil_div(N, BN); + constexpr int num_tile_k = ceil_div(K, BK); + constexpr int num_tiles = num_tile_m * num_tile_n; +#ifdef SWIZZLE_XCD_PID + const int pid = (pid0 % NUM_XCDS) * (NUM_GEMM_SMS / NUM_XCDS) + (pid0 / NUM_XCDS); +#else + const int pid = pid0; +#endif + using inst_nums = roc_isa::InstCalculator; + const int tid = threadIdx.x; + const int lane_id = __lane_id(); + __builtin_assume(pid >= 0 && pid < NUM_SMS); + __builtin_assume(tid >= 0 && tid < NUM_THREADS); + __builtin_assume(lane_id >= 0 && lane_id < 64); + // each thread load 4 elements + static_assert(BK % 4 == 0 && NUM_THREADS * 4 % BK == 0); + constexpr int WM = 16, WN = 16, WK = 16; + + constexpr int Frag_M = exact_div(); + constexpr int Frag_N = exact_div(); + constexpr int Frag_K = exact_div(); + const int warp_id = __builtin_amdgcn_readfirstlane(tid / roc_isa::AMDGCN_WAVEFRONT_SIZE); + const int warp_m = warp_id / WARP_N; + const int warp_n = warp_id % WARP_N; + using FragX = bf16x4_t; + using FragW = bf16x4_t; + using FragC = fp32x4_t; + __shared__ bfloat16_t s_x[BM][BK]; + __shared__ bfloat16_t s_w[BN][BK]; + bf16x4_t vgpr_x[ceil_div(BM * BK, NUM_THREADS * 4)]; + bf16x4_t vgpr_w[ceil_div(BN * BK, NUM_THREADS * 4)]; + + FragC frag_c[Frag_M][Frag_N]; + FragX frag_x[Frag_M][Frag_K]; + FragW frag_w[Frag_N][Frag_K]; + + auto load_vgpr = [&](int m, int n, int k) FORCE_INLINE { + auto x_arr = ck::make_wave_buffer_resource(const_cast(x), M * K); + auto w_arr = ck::make_wave_buffer_resource(const_cast(w), N * K); + int v_offset = ((tid * 4 / BK) * K + (tid * 4 % BK)) * sizeof(bfloat16_t); + uint32_t src_addr_shift = (K % BK == 0) || (k + tid * 4 % BK < K) ? 0 : 0x80000000; + ck::static_for<0, sizeof(vgpr_x) / sizeof(vgpr_x[0]), 1>{}([&](auto t) { + int s_offset = ((m * K + k) + t * NUM_THREADS * 4 / BK * K) * sizeof(bfloat16_t); + vgpr_x[t] = __builtin_bit_cast(bf16x4_t, ck::amd_buffer_load_impl_raw( + x_arr, v_offset + src_addr_shift, s_offset)); + }); + ck::static_for<0, sizeof(vgpr_w) / sizeof(vgpr_w[0]), 1>{}([&](auto t) { + int s_offset = ((n * K + k) + t * NUM_THREADS * 4 / BK * K) * sizeof(bfloat16_t); + vgpr_w[t] = __builtin_bit_cast(bf16x4_t, ck::amd_buffer_load_impl_raw( + w_arr, v_offset + src_addr_shift, s_offset)); + }); + + }; + + auto load_lds = [&]() FORCE_INLINE { + // diagonal swizzle, shape=[16, 64] dtype=bfloat16 + #pragma unroll + for (int t=0;t(&s_x[row0 + row1][col1]) = vgpr_x[t]; + } + #pragma unroll + for (int t=0;t(&s_w[row0 + row1][col1]) = vgpr_w[t]; + } + }; + + auto zero_all_frags = [&]() FORCE_INLINE { + ck::static_for<0, Frag_M, 1>{}([&](auto i) { + ck::static_for<0, Frag_N, 1>{}([&](auto j) { + ck::static_for<0, 4, 1>{}([&](auto t) { frag_c[i][j].x[t] = 0; }); + }); + }); + }; + + + auto frags_load = [&]() FORCE_INLINE { + ck::static_for<0, Frag_K, 1>{}([&](auto k) { + ck::static_for<0, Frag_M, 1>{}([&](auto i) { + const int row1 = (warp_m * Frag_M + i) * WM; + const int row0 = lane_id % 16; + const int col0 = k * 16 + lane_id / 16 * 4; + const int col1 = (row0 * 4 + col0) % BK; + frag_x[i][k] = *reinterpret_cast(&s_x[row0 + row1][col1]); + }); + }); + ck::static_for<0, Frag_K, 1>{}([&](auto k){ + ck::static_for<0, Frag_N, 1>{}([&](auto j) { + const int row1 = (warp_n * Frag_N + j) * WN; + const int row0 = lane_id % 16; + const int col0 = k * 16 + lane_id / 16 * 4; + const int col1 = (row0 * 4 + col0) % BK; + frag_w[j][k] = *reinterpret_cast(&s_w[row0 + row1][col1]); + }); + }); + }; + + auto frags_mfma = [&]() FORCE_INLINE { + ck::static_for<0, Frag_M, 1>{}([&](auto i) { + ck::static_for<0, Frag_N, 1>{}([&](auto j) { + ck::static_for<0, Frag_K, 1>{}([&](auto k) { + // a: [16][16], b: [16][16], c: [16][16] + // mfma requires a: row-major, b: col-major, out: col-major + // so we compute w^T * x^T = c^T so we can treat out as col-major + frag_c[i][j].pack = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(frag_w[j][k].pack, frag_x[i][k].pack, frag_c[i][j].pack, 0, 0, 0); + }); + }); + }); + }; + + + + auto store_frags = [&](int m, int n) FORCE_INLINE { + auto b_arr = ck::make_wave_buffer_resource(const_cast(b), N); + auto c_arr = ck::make_wave_buffer_resource(c, M * N); + fp32x4_t c_out[Frag_M][Frag_N]; + ck::static_for<0, Frag_M, 1>{}([&](auto i) { + ck::static_for<0, Frag_N, 1>{}([&](auto j) { + ck::static_for<0, 4, 1>{}([&](auto t) { + // v_accvgpr_read_b32 + c_out[i][j].x[t] = frag_c[i][j].x[t]; + }); + // c_out: [16][16] + int row = lane_id % 16; + int col = lane_id / 16 * 4; + uint32_t src_addr_shift = (N % BN == 0) || (n + (j + warp_n * Frag_N) * WN + col < N) ? 0 : 0x80000000; + // load b + int b_s_offset = (n + (j + warp_n * Frag_N) * WN) * sizeof(bfloat16_t); + int b_v_offset = col * sizeof(bfloat16_t) + src_addr_shift; + auto b_vec = __builtin_bit_cast(bf16x4_t, ck::amd_buffer_load_impl_raw( + b_arr, b_v_offset, b_s_offset)); + // compute c + bf16x4_t c_out_bf16; + #pragma unroll + for (int t = 0; t < 4; ++t) { + c_out_bf16.x[t] = fast_f32tob16(c_out[i][j].x[t] + b_vec.x[t]); + } + // write c + int c_s_offset = b_s_offset + (m + (i + warp_m * Frag_M) * WM) * N * sizeof(bfloat16_t); + int c_v_offset = b_v_offset + (row * N) * sizeof(bfloat16_t); + ck::amd_buffer_store_impl_raw(c_out_bf16.pack, c_arr, c_v_offset, c_s_offset); + }); + }); + }; + + auto wait_signal = [&](int tile_m_id, int tile_k_id) { + if (!signal) { + return; + } + + constexpr int M_LOCAL = M / 8; + constexpr int COMM_M = M_LOCAL < 256 ? M_LOCAL : 256; + constexpr int COMM_K = 512; + + auto signal_k_id = tile_k_id / exact_div(); + + static_assert(M >= BM && M % BM == 0); + int signal_m_id_begin, signal_m_id_end; + if constexpr (COMM_M < BM) { + signal_m_id_begin = tile_m_id * exact_div(); + signal_m_id_end = signal_m_id_begin + exact_div(); + } else { + signal_m_id_begin = tile_m_id / exact_div(); + signal_m_id_end = signal_m_id_begin + 1; + } + + auto &sig = *signal; + + if(warp_id == 0 && lane_id == 0) { + for (int signal_m_id = signal_m_id_begin; signal_m_id < signal_m_id_end; ++signal_m_id) { + while(!__hip_atomic_load(&sig[signal_m_id][signal_k_id], __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM)) + ; + } + // __builtin_amdgcn_fence(__ATOMIC_ACQUIRE, ""); + } + // __builtin_amdgcn_s_barrier(); + __syncthreads(); + __builtin_amdgcn_fence(__ATOMIC_ACQUIRE, ""); + }; + + + for (int tile_id=pid; tile_id(tile_id, tile_m_id, tile_n_id); +#else + int tile_m_id = tile_id / num_tile_n; + int tile_n_id = tile_id % num_tile_n; +#endif + int m = tile_m_id * BM; + int n = tile_n_id * BN; + wait_signal(tile_m_id, 0); + load_vgpr(m, n, 0); // GDS -> VGPR #0 + load_lds(); // VGPR -> LDS #0 + load_vgpr(m, n, 1 * BK); // GDS -> VGPR #1 + zero_all_frags(); + block_sync_lds(); + frags_load(); // LDS -> FRAG #0 + __builtin_amdgcn_sched_barrier(0); + // #pragma clang loop unroll_count(2) + // #pragma unroll 2 + // #pragma unroll + for (int tile_k_id = 1; tile_k_id < (num_tile_k - 1); ++tile_k_id) { + // asm volatile(R"( + // ; Main Loop Begin + // )" ::: "memory"); + block_sync_lds(); + // Stage 1 + load_lds(); // VGPR -> LDS #1 + if ((tile_k_id + 1) % 8 == 0) { + wait_signal(tile_m_id, tile_k_id + 1); + } + load_vgpr(m, n, (tile_k_id + 1) * BK); // GDS -> VGPR #2(k+1) + frags_mfma(); // MFMA #0(k-1) + block_sync_lds(); + // Stage 2 + frags_load(); // LDS -> FRAG #1(k) + inst_nums::schedule_loop(); + __builtin_amdgcn_sched_barrier(0); + // asm volatile(R"( + // ; Main Loop End + // )" ::: "memory"); + } + frags_mfma(); // MFMA #1(n-2) + block_sync_lds(); + load_lds(); // VGPR -> LDS #2(n-1) + block_sync_lds(); + frags_load(); // LDS -> FRAG #2(n-1) + frags_mfma(); // MFMA #2(n-1) + store_frags(m, n); + } +} diff --git a/dist-infer/ag-gemm/reference.py b/dist-infer/ag-gemm/reference.py new file mode 100644 index 0000000..63f8f1e --- /dev/null +++ b/dist-infer/ag-gemm/reference.py @@ -0,0 +1,70 @@ +from task import input_t, output_t +import torch + + +def generate_input(rank: int, world_size: int, m: int, n: int, k: int, has_bias: bool, seed: int) -> input_t: + """ + Generate random input and weights for the Allgather-Gemm operation. + + Returns: + Tuple of ( + input: torch.Tensor, # [local_M, k] + weight: torch.Tensor, # [local_N, K] + bias: Optional[torch.Tensor], # [local_N] or None + ) + """ + device = torch.device(f"cuda:{rank}") + gen = torch.Generator(device=device) + gen.manual_seed(seed + rank) + + assert m % world_size == 0, "m must be divisible by world_size" + assert n % world_size == 0, "n must be divisible by world_size" + local_m = m // world_size + local_n = n // world_size + + # Generate random inputs and weights + input = (torch.rand((local_m, k), dtype=torch.bfloat16, device=device, generator=gen) * 2 - 1) * 0.01 + weight = (torch.rand((local_n, k), dtype=torch.bfloat16, device=device, generator=gen) * 2 - 1) * 0.01 + + bias = None + if has_bias: + bias = (torch.rand((local_n,), dtype=torch.bfloat16, device=device, generator=gen) * 2 - 1) * 0.01 + return (input, weight, bias) + + +def ref_kernel(data: input_t) -> output_t: + """ + Reference kernel for AG-GEMM operation. + Args: + data: Tuple of (input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]) + - input: Local input tensor of shape [local_M, K]. + - weight: Weight tensor of shape [local_N, K]. + - bias: Optional bias tensor of shape [local_N] or None. + Returns: + output: Resulting tensor of shape [local_M * world_size, local_N]. + """ + input, weight, bias = data + local_M, K = input.shape + world_size = torch.distributed.get_world_size() + full_input = torch.empty((local_M * world_size, K), dtype=input.dtype, device=input.device) + # allgather + torch.distributed.all_gather_into_tensor(full_input, input) + # matmul + output = torch.matmul(full_input, weight.T) + + if bias is not None: + output = output + bias + + return output + +custom_kernel = ref_kernel + +def check_implementation(data: input_t, output: output_t): + expected = ref_kernel(data) + if output.device != expected.device: + return False, f"Output device mismatch: {output.device} != {expected.device}" + res = torch.allclose(output, expected, rtol=1e-2, atol=1e-2) + if not res: + return False, f"Output values mismatch, {output} != {expected}" + + return True, "" diff --git a/dist-infer/ag-gemm/submit.py b/dist-infer/ag-gemm/submit.py new file mode 100644 index 0000000..d0f333f --- /dev/null +++ b/dist-infer/ag-gemm/submit.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +import sys +import os +import subprocess +import datetime +import time +import re +import math + +def add_file(filename, template, placeholder): + with open(filename) as f: + code = f.read() + return template.replace("{{" + placeholder + "}}", code.replace("\\", "@")) + +def gen_submission(): + with open("template.py") as f: + template = f.read() + template = add_file("perf_gemm.cc", template, "") + template = add_file("ck_gemm.h", template, "CK") + template = add_file("perf_gemm.h", template, "PERF") + return template + +def extract_and_geom_mean(text: str): + pattern = re.compile(r'(\d+(?:\.\d+)?)\s*±\s*\d+(?:\.\d+)?') + values = [float(m.group(1)) for m in pattern.finditer(text)] + + if not values: + return None, [] + + log_sum = sum(math.log(v) for v in values) + geom_mean = math.exp(log_sum / len(values)) + return geom_mean, values + +def main(): + if len(sys.argv) > 1 and sys.argv[1] != "local_test": + pyfile = sys.argv[1] + else: + pyfile = "submission.py" + code = gen_submission() + if "local_test" not in sys.argv: + code = code.replace("#define LOCAL_TEST", "") + with open(pyfile, "w") as f: + f.write(code) + if "local_test" in sys.argv: + return + + timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M:%S-%f") + logfile = f"logs/ag-gemm-{timestamp}.log" + + os.makedirs("logs", exist_ok=True) + + print(f"submiting, log file: {logfile}") + + cmd = [ + "popcorn-cli", "submit", + "--gpu", "MI300x8", + "--leaderboard", "amd-ag-gemm", + "--mode", "leaderboard", + pyfile, + "-o", logfile, + ] + + start = time.time() + + timeout = 1800 + try: + subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, timeout=timeout, check=True) + except subprocess.TimeoutExpired: + print(f"Error: Command timed out after {timeout}s", file=sys.stderr) + sys.exit(1) + except subprocess.CalledProcessError as e: + print(f"Error: Command failed with exit code {e.returncode}", file=sys.stderr) + sys.exit(e.returncode) + + with open(logfile) as f: + output = f.read() + geom_mean, values = extract_and_geom_mean(output) + print(output) + + extra_log = "\n" + if geom_mean: + extra_log += f"geom_mean: {geom_mean:.2f}, values: {values}\n" + + print(extra_log, end="") + with open(logfile, "a") as f: + f.write(extra_log) + + end = time.time() + print(f"submit done, time cost: {end - start:.2f}s") + +if __name__ == "__main__": + main() + diff --git a/dist-infer/ag-gemm/task.py b/dist-infer/ag-gemm/task.py new file mode 100644 index 0000000..014c974 --- /dev/null +++ b/dist-infer/ag-gemm/task.py @@ -0,0 +1,14 @@ +from typing import TypedDict, TypeVar, Tuple, Optional +import torch + +input_t = TypeVar("input_t", bound=Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]) +output_t = TypeVar("output_t", bound=torch.Tensor) + + +class TestSpec(TypedDict): + world_size: int + m: int + n: int + k: int + has_bias: bool + seed: int \ No newline at end of file diff --git a/dist-infer/ag-gemm/template.py b/dist-infer/ag-gemm/template.py new file mode 100644 index 0000000..49e189e --- /dev/null +++ b/dist-infer/ag-gemm/template.py @@ -0,0 +1,372 @@ +import torch +import torch.distributed as dist +from torch.utils.cpp_extension import load + +from task import input_t, output_t + +CUDA_SRC = r""" +{{}} +""" + +CK_GEMM = r""" +{{CK}} +""" + +PERF_GEMM = r""" +{{PERF}} +""" + +import sys +import os +import time +from filelock import FileLock +from contextlib import contextmanager +import functools + +os.environ.update( + { + "CXX": "clang++", + "PYTORCH_ROCM_ARCH": "gfx942", + "HSA_XNACK": "0", + # "NCCL_DEBUG": "WARNING", + } +) + +# don't overwrite existing source file to avoid recompile among multiple ranks +lock_path = "ag_gemm-compile.lock" +with FileLock(lock_path): + with open("ag_gemm.cu", "w") as f: + f.write(CUDA_SRC.replace("@", chr(92))) + if not os.path.exists("ck_gemm.h"): + with open("ck_gemm.h", "w") as f: + f.write(CK_GEMM.replace("@", chr(92))) + if not os.path.exists("perf_gemm.h"): + with open("perf_gemm.h", "w") as f: + f.write(PERF_GEMM.replace("@", chr(92))) + os.makedirs("torch-build", exist_ok=True) + module = load( + name="ag_gemm", + sources=["ag_gemm.cu"], + build_directory="torch-build", + verbose=False, + extra_cuda_cflags=["--offload-arch=gfx942", "-std=c++20", "-O2"], + extra_cflags=["-O2"], + ) + + +def print0(out: str, all=False): + rank = dist.get_rank() + if rank == 0 or all: + print(f"[rank {rank}] {out}", file=sys.stderr) + + +def barrier(): + dist.barrier() + torch.cuda.current_stream().synchronize() + + +comm = None +should_udpate_comm = True +comm_stream = None + +orignal_init_pg = dist.init_process_group + + +def hooked_init_pg(*args, **kwargs): + global should_update_comm + should_update_comm = True + ret = orignal_init_pg(*args, **kwargs) + # print0(f"init pg: {args}, {kwargs}", True) + return ret + + +dist.init_process_group = hooked_init_pg + + +def all_get_comm(rank, world_size, m, n, k): + global comm, should_update_comm, comm_stream + config = (rank, m, n, k) + if should_update_comm: + should_update_comm = False + # clean up old comm + del comm + # always set device first to avoid using wrong gpu + torch.cuda.set_device(rank) + # create a new comm + print0(f"create new comm: {config}", True) + comm_stream = torch.cuda.Stream() + comm = module.AgGemm(*config) + ipc_handle = comm.get_ipc_handle() + ipc_handles = [None] * world_size + dist.all_gather_object(ipc_handles, ipc_handle) + comm.init(ipc_handles) + barrier() + return comm + + +from reference import ref_kernel + + +# 1e-2 1e-2 +def diff_allclose(ref, other, rtol, atol, max_print=10): + diff = torch.abs(ref - other) + mask = diff > (atol + rtol * torch.abs(ref)) + if mask.any(): + idx = mask.nonzero(as_tuple=False) + print0(f"{idx.shape[0]} elements mismatch", True) + for i in range(min(max_print, idx.shape[0])): + coord = tuple(idx[i].tolist()) + print0( + f" coord={coord}: ref={ref[coord].item()}, other={other[coord].item()}", + True, + ) + + +def all_assert(exp: bool, err: str): + gathered = [False] * dist.get_world_size() + dist.all_gather_object(gathered, exp) + assert all(gathered), err + + +@contextmanager +def host_timer(): + end = None + + def wait_for_time(): + if end is None: + return 0.0 + return (end - start) * 1000.0 + + try: + start = time.perf_counter() + yield wait_for_time + finally: + end = time.perf_counter() + + +def report_host_time(name=""): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + with host_timer() as t: + result = func(*args, **kwargs) + print0(f"{name}: {t():.3f}ms", True) + return result + + return wrapper + + return decorator + + +@contextmanager +def timer(): + end = None + + def wait_for_time(): + if end is None: + return 0.0 + end.synchronize() + return start.elapsed_time(end) + + try: + start = torch.cuda.Event(enable_timing=True) + start.record() + yield wait_for_time + finally: + end = torch.cuda.Event(enable_timing=True) + end.record() + + +def custom_kernel_test(data: input_t) -> output_t: + input, weight, bias = data + rank = dist.get_rank() + tp = dist.get_world_size() + m_local, k = input.shape + m = m_local * tp + n_local, k = weight.shape + n = n_local * tp + comm = all_get_comm(rank, tp, m, n, k) + + comm_stream = torch.cuda.Stream() + with torch.cuda.stream(comm_stream): + with timer() as t_comm: + comm.send(input) + chunk_size = 512 + n_chunks = (k + chunk_size - 1) // chunk_size + # clear signals + for i in range(n_chunks): + x_full = comm.wait(i) + comm_ms = t_comm() + print0(f"{comm_ms=:.3f}") + + ref_x_full = torch.empty((m, k), device="cuda", dtype=torch.bfloat16) + dist.all_gather_into_tensor(ref_x_full, input) + diff_allclose(ref_x_full, x_full, 1e-2, 1e-2) + # x_full = ref_x_full + + output = torch.matmul(x_full, weight.T) + + if bias is not None: + output = output + bias + return output + +from collections import defaultdict +logged = defaultdict(int) +def custom_kernel_sync(data: input_t) -> output_t: + input, weight, bias = data + rank = dist.get_rank() + tp = dist.get_world_size() + m_local, k = input.shape + m = m_local * tp + n, k = weight.shape + + comm = all_get_comm(rank, tp, m, n, k) + comm.send(input) + x_full = comm.get_x_full() + output = comm.perf_gemm(x_full, weight, bias) + with timer() as t_comm: + comm.send(input) + x_full = comm.get_x_full() + with timer() as t_gemm: + output = comm.perf_gemm(x_full, weight, bias) + x_full_cache = x_full.clone() + with timer() as t_gemm_cache: + output = comm.perf_gemm(x_full_cache, weight, bias) + if logged[(m, n, k)] < 3: + print0(f"{t_comm()=:.3f} {t_gemm()=:.3f} {t_gemm_cache()=:.3f}") + logged[(m, n, k)] += 1 + return output + +unique_tests = { + (64, 2880, 2880), + (64, 14336, 3584), + (512, 14336, 3584), + (512, 36864, 4608), + (2048, 7168, 4096), + (2048, 30720, 8192), + (4096, 2880, 2880), + (4096, 2048, 8192), + (8192, 14336, 3584), + (8192, 36864, 4608), + (8192, 28672, 8192), +} + +def custom_kernel_bench(data: input_t) -> output_t: + input, weight, bias = data + rank = dist.get_rank() + tp = dist.get_world_size() + m_local, k = input.shape + m = m_local * tp + n, k = weight.shape + + if (m, n * tp, k) in unique_tests: + return ref_kernel(data) + + comm = all_get_comm(rank, tp, m, n, k) + + output = comm.perf_gemm(input, weight, bias) + + # ref_x_full = torch.empty(m, k, device=input.device, dtype=input.dtype) + # dist.all_gather_into_tensor(ref_x_full, input) + # ref_output = torch.matmul(ref_x_full, weight.T) + bias + # diff_allclose(ref_output, output, 1e-2, 1e-2) + + return output + + +def custom_kernel_repeat(data: input_t, fun=custom_kernel_bench) -> output_t: + for _ in range(100): + ret = fun(data) + return ret + + +def timeit(fun, repeat=1, is_dist=True): + fun() # warmup + if is_dist: + barrier() + with timer() as t: + for _ in range(repeat): + fun() + if is_dist: + barrier() + return t() / repeat + + +def micro_benchmark(m: int, n: int, k: int): + rank = dist.get_rank() + tp = dist.get_world_size() + device = torch.device("cuda", rank) + dst_device = torch.device("cuda", (rank + 1) % tp) + x = torch.randn((m // tp, k), device=device, dtype=torch.bfloat16) + dst_x = torch.randn_like(x, device=dst_device) + w = torch.randn((n // tp, k), device=device, dtype=torch.bfloat16) + x_full = torch.empty((m, k), device=device, dtype=torch.bfloat16) + out = torch.empty((m, n // tp), device=device, dtype=torch.bfloat16) + ref_out = torch.empty((m, n // tp), device=device, dtype=torch.bfloat16) + + global should_update_comm + should_update_comm = True + comm = all_get_comm(rank, tp, m, n // tp, k) + + print0(f"{(m, n, k)=}") + + # make sure torch current stream is correct + torch.cuda.set_device(device) + + p2p_ce_ms = timeit(lambda: x.copy_(dst_x)) + bw_gb = (x.nbytes / (1 << 30)) / (p2p_ce_ms / 1e3) + print0(f" p2p-ce: {bw_gb=:.3f} {p2p_ce_ms=:.3f}") + + ag_ms = timeit(lambda: dist.all_gather_into_tensor(x_full, x)) + bw_gb = (x.nbytes / (1 << 30)) / (ag_ms / 1e3) + print0(f" ag: {bw_gb=:.3f} {ag_ms=:.3f}") + + my_ag_ms = timeit(lambda: comm.send(x, sync=False)) + bw_gb = (x.nbytes / (1 << 30)) / (my_ag_ms / 1e3) + print0(f" my-ag: {bw_gb=:.3f} {my_ag_ms=:.3f}") + my_x_full = comm.get_x_full() + # diff_allclose(x_full, my_x_full, 1e-2, 1e-2) + + if rank == 0: + with timer() as t_no_contention: + for i in range(100): + comm.send(x, sync=False) + no_contention_ms = t_no_contention() / 100 + print0(f" my-ag-no-contention: {no_contention_ms:.3f}") + + gemm_ms = timeit(lambda: torch.matmul(x_full, w.T)) + tflops = (2 * m * (n / tp) * k / 1e12) / (gemm_ms / 1e3) + print0(f" gemm: {tflops=:.1f} {gemm_ms=:.3f}") + + +def hw_benchmark(): + rank = dist.get_rank() + tp = dist.get_world_size() + device = torch.device("cuda", rank) + dst_device = torch.device("cuda", (rank + 1) % tp) + print0(f"hw bench") + if rank == 0: + x = torch.randn(1 << 30, device=device, dtype=torch.bfloat16) + dst_x = torch.randn_like(x, device=dst_device) + p2p_ms = timeit(lambda: x.copy_(dst_x), is_dist=False) + bw_gb = (x.nbytes / (1 << 30)) / (p2p_ms / 1e3) + print0(f" p2p-ce: {bw_gb=:.3f} {p2p_ms=:.3f}") + barrier() + + +should_run_mb = True + + +def empty_kernel(data: input_t) -> output_t: + global should_run_mb + if should_run_mb: + should_run_mb = False + hw_benchmark() + micro_benchmark(8192, 29568, 8192) + micro_benchmark(8192, 14336, 4096) + # micro_benchmark(64, 18432, 7168) + + return ref_kernel(data) + + +custom_kernel = custom_kernel_bench diff --git a/dist-infer/ag-gemm/utils.py b/dist-infer/ag-gemm/utils.py new file mode 100644 index 0000000..396c6bf --- /dev/null +++ b/dist-infer/ag-gemm/utils.py @@ -0,0 +1,176 @@ +import random +from typing import Tuple + +import numpy as np +import torch + + +def set_seed(seed=42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_device(use_cuda: bool = True) -> torch.device: + """Get the appropriate device (GPU or CPU).""" + if use_cuda: + if torch.cuda.is_available(): + return torch.device("cuda") + elif torch.backends.mps.is_available(): + return torch.device("mps") + else: + print("No compatible GPU found. Falling back to CPU.") + return torch.device("cpu") + + +# Adapted from https://github.com/linkedin/Liger-Kernel/blob/main/test/utils.py +@torch.no_grad() +def verbose_allclose( + received: torch.Tensor, expected: torch.Tensor, rtol=1e-05, atol=1e-08, max_print=5 +) -> Tuple[bool, list[str]]: + """ + Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. + + Parameters: + received (torch.Tensor): Tensor we actually got. + expected (torch.Tensor): Tensor we expected to receive. + rtol (float): Relative tolerance; relative to expected + atol (float): Absolute tolerance. + max_print (int): Maximum number of mismatched elements to print. + """ + # Check if the shapes of the tensors match + if received.shape != expected.shape: + return False, ["SIZE MISMATCH"] + + # Calculate the difference between the tensors + diff = torch.abs(received.to(torch.float32) - expected.to(torch.float32)) + + # Determine the tolerance + tolerance = atol + rtol * torch.abs(expected) + + # Find tolerance mismatched elements + tol_mismatched = diff > tolerance + + # Find nan mismatched elements + nan_mismatched = torch.logical_xor(torch.isnan(received), torch.isnan(expected)) + + # Find +inf mismatched elements + posinf_mismatched = torch.logical_xor( + torch.isposinf(received), torch.isposinf(expected) + ) + # Find -inf mismatched elements + neginf_mismatched = torch.logical_xor( + torch.isneginf(received), torch.isneginf(expected) + ) + + # Find all mismatched elements + mismatched = torch.logical_or( + torch.logical_or(tol_mismatched, nan_mismatched), + torch.logical_or(posinf_mismatched, neginf_mismatched), + ) + + mismatched_indices = torch.nonzero(mismatched) + + # Count the number of mismatched elements + num_mismatched = mismatched.count_nonzero().item() + + # Generate detailed information if there are mismatches + if num_mismatched >= 1: + mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] + + for index in mismatched_indices[:max_print]: + i = tuple(index.tolist()) + mismatch_details.append(f"ERROR at {i}: {received[i]} {expected[i]}") + if num_mismatched > max_print: + mismatch_details.append( + f"... and {num_mismatched - max_print} more mismatched elements." + ) + return False, mismatch_details + + return True, [f"Maximum error: {torch.max(diff)}"] + + +@torch.no_grad() +def verbose_allequal( + received: torch.Tensor, expected: torch.Tensor, max_print: int = 5 +) -> Tuple[bool, list[str]]: + """ + Assert that two tensors are element-wise perfectly equal, providing detailed information about mismatches. + + Parameters: + received (torch.Tensor): Tensor we actually got. + expected (torch.Tensor): Tensor we expected to receive. + max_print (int): Maximum number of mismatched elements to print. + + Returns: + Empty string if tensors are equal, otherwise detailed error information + """ + mismatched = torch.not_equal(received, expected) + mismatched_indices = torch.nonzero(mismatched) + + # Count the number of mismatched elements + num_mismatched = mismatched.count_nonzero().item() + + # Generate detailed information if there are mismatches + if num_mismatched >= 1: + mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] + + for index in mismatched_indices[:max_print]: + i = tuple(index.tolist()) + mismatch_details.append(f"ERROR at {i}: {received[i]} {expected[i]}") + if num_mismatched > max_print: + mismatch_details.append( + f"... and {num_mismatched - max_print} more mismatched elements." + ) + return False, mismatch_details + + return True, [] + + +def match_reference(data, output, reference: callable, rtol=1e-05, atol=1e-08): + """ + Convenient "default" implementation for tasks' `check_implementation` function. + """ + expected = reference(data) + good, reasons = verbose_allclose(output, expected, rtol=rtol, atol=atol) + + if len(reasons) > 0: + return good, "\\n".join(reasons) + + return good, "" + + +def make_match_reference(reference: callable, **kwargs): + def wrapped(data, output): + return match_reference(data, output, reference=reference, **kwargs) + + return wrapped + + +class DisableCuDNNTF32: + def __init__(self): + self.allow_tf32 = torch.backends.cudnn.allow_tf32 + self.deterministic = torch.backends.cudnn.deterministic + pass + + def __enter__(self): + torch.backends.cudnn.allow_tf32 = False + torch.backends.cudnn.deterministic = True + return self + + def __exit__(self, exc_type, exc_value, traceback): + torch.backends.cudnn.allow_tf32 = self.allow_tf32 + torch.backends.cudnn.deterministic = self.deterministic + + +def clear_l2_cache(): + # import cupy as cp + # cp.cuda.runtime.deviceSetLimit(cp.cuda.runtime.cudaLimitPersistingL2CacheSize, 0) + # create a large dummy tensor + dummy = torch.empty((32, 1024, 1024), dtype=torch.int64, device="cuda") + # write stuff to + dummy.fill_(42) + del dummy diff --git a/dist-infer/all2all/CMakeLists.txt b/dist-infer/all2all/CMakeLists.txt new file mode 100644 index 0000000..788be5c --- /dev/null +++ b/dist-infer/all2all/CMakeLists.txt @@ -0,0 +1,20 @@ +cmake_minimum_required(VERSION 3.21) +cmake_policy(VERSION 3.21.3...3.27) +set(PROJECT_NAME "all2all") +project(${PROJECT_NAME} LANGUAGES HIP CXX) + +find_package(Python3 REQUIRED COMPONENTS Development Interpreter) +find_package(Torch CONFIG REQUIRED) +find_package(HIP CONFIG REQUIRED) + +# required for python binding +find_library(TORCH_PYTHON_LIBRARY torch_python PATH ${TORCH_INSTALL_PREFIX}/lib) + +add_library(${PROJECT_NAME} SHARED all2all.cpp) +set_source_files_properties(all2all.cpp PROPERTIES LANGUAGE HIP) +target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_PYTHON_LIBRARY} ${TORCH_LIBRARIES} hip::device Python3::Python) +set_target_properties(${PROJECT_NAME} PROPERTIES PREFIX "") +target_compile_features(${PROJECT_NAME} PRIVATE cxx_std_20) + +# for host compile +target_compile_definitions(${PROJECT_NAME} PRIVATE -D__${CMAKE_HIP_ARCHITECTURES}__) diff --git a/dist-infer/all2all/Makefile b/dist-infer/all2all/Makefile new file mode 100644 index 0000000..068ebaa --- /dev/null +++ b/dist-infer/all2all/Makefile @@ -0,0 +1,39 @@ +.PHONY: all config build + +TARGET ?= gfx942 + +BUILD_TYPE ?= RelWithDebInfo +BUILD_DIR ?= build + +PYTHON_DIR ?= $(shell python -c "import site; print(site.getsitepackages()[0])") + +all: config build submit + +config: + PYTORCH_ROCM_ARCH=$(TARGET) cmake -B $(BUILD_DIR) . \ + -DCMAKE_BUILD_TYPE=$(BUILD_TYPE) \ + -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ + -DCMAKE_PREFIX_PATH="/opt/rocm;$(PYTHON_DIR)" \ + -DCMAKE_HIP_ARCHITECTURES=$(TARGET) \ + -DGPU_TARGETS=$(TARGET) \ + -DAMDGPU_TARGETS=$(TARGET) \ + -G Ninja + +build: + cmake --build $(BUILD_DIR) -j8 + +test: build + PYTHONPATH=$(PYTHONPATH):$(realpath tools):$(realpath $(BUILD_DIR)) python tools/smoke_test.py + +clean: + rm -r $(BUILD_DIR) + +local: + python submit.py local_test + POPCORN_GPUS=2 POPCORN_FD=2 python eval.py benchmark benchmark.txt + +submit: + python submit.py + +dis: + roc-obj -d build/all2all.so diff --git a/dist-infer/all2all/all2all.cpp b/dist-infer/all2all/all2all.cpp new file mode 100644 index 0000000..a740c8f --- /dev/null +++ b/dist-infer/all2all/all2all.cpp @@ -0,0 +1,692 @@ +#undef __HIP_NO_HALF_OPERATORS__ +#undef __HIP_NO_HALF_CONVERSIONS__ +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "rocwmma/rocwmma.hpp" +#include "rocwmma/rocwmma_coop.hpp" + +#define LOCAL_TEST + +namespace mma = rocwmma; +using f16 = mma::float16_t; +using f32 = mma::float32_t; +using i32 = mma::int32_t; +using i64 = mma::int64_t; +using f8 = mma::float8_fnuz_t; + +#define USE_DBG 0 +#define USE_ASSERT 0 + +#define PRAGMA_UNROLL _Pragma("unroll") + +#define ASSERT(cond) \ + do { \ + if (USE_ASSERT && !(cond)) { \ + __assert_fail(#cond, __FILE__, __LINE__, __PRETTY_FUNCTION__); \ + } \ + } while (0) + +#define DBG(fmt, ...) \ + do { \ + if (USE_DBG && threadIdx.x % 64 == 0) { \ + printf(fmt "\n", ##__VA_ARGS__); \ + } \ + } while (0) + +#define HOST_DBG(fmt, ...) \ + do { \ + fprintf(stderr, fmt "\n", ##__VA_ARGS__); \ + } while (0) + +#define HIP_CHECK(call) \ + do { \ + hipError_t err = (call); \ + if (err != hipSuccess) { \ + fprintf( \ + stderr, "HIP error: %s (%d)\n at %s:%d\n", \ + hipGetErrorString(err), err, __FILE__, __LINE__ \ + ); \ + } \ + } while (0) + + +#ifdef LOCAL_TEST +constexpr i32 WORLD_SIZE = 2; +constexpr i32 MAX_NUM_EXPERTS = 64; +#else +constexpr i32 WORLD_SIZE = 8; +constexpr i32 MAX_NUM_EXPERTS = 256; +#endif + +constexpr i32 MAX_TOPK = 8; +constexpr i32 MAX_HIDDEN_DIM = 7168; +constexpr i32 MAX_MAX_NUM_TOKENS = 256; + +constexpr i32 MAX_NUM_LOCAL_EXPERTS = MAX_NUM_EXPERTS / WORLD_SIZE; + +constexpr i32 WARP_SIZE = 64; + +// NOTE: 64, 1024 both result in poor performance +constexpr i32 BLOCK_SIZE = 256; +constexpr i32 NUM_SMS = 304; + +template constexpr T ceil_div(T a, T b) { return (a + b - 1) / b; } + +template struct vec_t { + using type = __attribute__((__vector_size__(N))) T; + static_assert(N % sizeof(T) == 0); + constexpr static i32 nelem = N / sizeof(T); + constexpr static i32 nelem_per_warp = WARP_SIZE * nelem; + + static __device__ void copy(T *dst, const T *src) { + auto val = + __builtin_nontemporal_load(reinterpret_cast(src)); + __builtin_nontemporal_store(val, reinterpret_cast(dst)); + } + + template + __device__ static inline void warp_mul(T *dst, const T *src, f16 weight) { + static_assert(N_ELEM % nelem_per_warp == 0); + const auto lane_id = threadIdx.x % WARP_SIZE; + PRAGMA_UNROLL + for (int i = 0; i < N_ELEM / nelem_per_warp; i++) { + auto src_ptr = reinterpret_cast( + src + i * nelem_per_warp + lane_id * nelem + ); + auto val = __builtin_nontemporal_load(src_ptr); + PRAGMA_UNROLL + for (int j = 0; j < nelem; j++) { + val[j] *= weight; + } + auto dst_ptr = reinterpret_cast( + dst + i * nelem_per_warp + lane_id * nelem + ); + __builtin_nontemporal_store(val, dst_ptr); + } + } + + template using accum_type = type[N_ELEM / (WARP_SIZE * nelem)]; + + template + __device__ static inline void + warp_accum(accum_type &acc, const T *src, f32 weight) { + static_assert(N_ELEM % (WARP_SIZE * nelem) == 0); + const auto lane_id = threadIdx.x % WARP_SIZE; + PRAGMA_UNROLL + for (int i = 0; i < N_ELEM / (WARP_SIZE * nelem); i++) { + auto ptr = reinterpret_cast( + src + i * (WARP_SIZE * nelem) + lane_id * nelem + ); + auto val = __builtin_nontemporal_load(ptr); + PRAGMA_UNROLL + for (int j = 0; j < nelem; j++) { + // order required to maintain precision + acc[i][j] += val[j] * weight; + } + } + } + + template + __device__ static inline void + warp_accum_store(T *dst, accum_type &acc) { + static_assert(N_ELEM % (WARP_SIZE * nelem) == 0); + const auto lane_id = threadIdx.x % WARP_SIZE; + PRAGMA_UNROLL + for (int i = 0; i < N_ELEM / (WARP_SIZE * nelem); i++) { + auto ptr = reinterpret_cast( + dst + i * (WARP_SIZE * nelem) + lane_id * nelem + ); + __builtin_nontemporal_store(acc[i], ptr); + } + } +}; + +struct dispatch_args_t { + i32 num_tokens; + // [num_tokens, num_experts] + i32 *topk_idx; + // [num_tokens, hidden_dim] + f16 *x; +}; + +struct combine_args_t { + i32 num_tokens; + // [num_tokens, num_topk] + i32 *topk_idx; + // [num_tokens, num_topk] + f32 *topk_weight; +}; + +struct workspace_t { + i32 grid_barrier; + i32 barrier_flag; + + // [max_num_tokens, topk] + i32 nvl_dst_idxs[MAX_MAX_NUM_TOKENS][MAX_TOPK]; + + // [max_num_tokens, hidden_dim] + f16 nvl_y[MAX_MAX_NUM_TOKENS][MAX_HIDDEN_DIM]; +}; + +// global variables +struct ipc_mem_t { + // [num_local_experts, max_num_tokens, hidden_dim] + f16 nvl_recv_x[MAX_NUM_LOCAL_EXPERTS][WORLD_SIZE * MAX_MAX_NUM_TOKENS] + [MAX_HIDDEN_DIM]; + // [num_local_experts] + i32 nvl_recv_count[MAX_NUM_LOCAL_EXPERTS]; + + // we need 2 barriers to do ping-pong + // so that next barrier would not overwrite previous one before all ranks + // pass the previous barrier + i32 nvl_barrier[2][WORLD_SIZE]; + + i32 nvl_signal[WORLD_SIZE]; + i32 barrier; +}; +struct global_t { + // config + i32 rank; + i32 num_experts; + i32 topk; + i32 hidden_dim; + i32 max_num_tokens; + + // buffers + ipc_mem_t *ipc_mems[WORLD_SIZE] = {}; + workspace_t *workspace; + + i32 next_signal; +}; + +template __device__ inline void st_relaxed_sys(T *ptr, T val) { + __hip_atomic_store(ptr, val, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM); +} + +template __device__ inline T ld_relaxed_sys(T *ptr) { + return __hip_atomic_load(ptr, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM); +} + +template __device__ inline void st_release_global(T *ptr, T val) { + __hip_atomic_store(ptr, val, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_AGENT); +} + +template __device__ inline T ld_acquire_global(T *ptr) { + return __hip_atomic_load(ptr, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_AGENT); +} + +__device__ __forceinline__ void syncwarp() { + __builtin_amdgcn_fence(__ATOMIC_RELEASE, "wavefront"); + __builtin_amdgcn_wave_barrier(); + __builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "wavefront"); +} + +// timeout 500 ms +__device__ inline void barrier( + const global_t &global, const i32 lane_id, const i32 warp_id, + const i32 global_warp_id, const i32 num_sms, const i32 barrier_idx, + const i64 timeout = 1 << 30 +) { + const auto rank = global.rank; + + // 0 is not a valid flag + constexpr i32 barrier_flag = 1; + + __syncthreads(); + + if (warp_id == 0 && lane_id == 0) { + // FIXME: look like relaxed is fine + __hip_atomic_fetch_add( + &global.workspace->grid_barrier, 1, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT + ); + } + // TODO: use the last atomic add warp instead of first global warp + static_assert(WORLD_SIZE < WARP_SIZE); + if (global_warp_id == 0) { + if (lane_id < WORLD_SIZE) { + auto start = clock64(); + // TODO: maybe relaxed + while (ld_acquire_global(&global.workspace->grid_barrier) != num_sms + ) { + if (clock64() - start > timeout) { + DBG("grid barrier timeout, expected %d, actual %d", num_sms, + global.workspace->grid_barrier); + } + } + // here we must use acquire/release to sync multiple l2s + st_release_global(&global.workspace->grid_barrier, 0); + // reset ping-pong nvl barrier flag + st_relaxed_sys( + &global.ipc_mems[rank]->nvl_barrier[!barrier_idx][lane_id], 0 + ); + // notify other ranks + st_relaxed_sys( + &global.ipc_mems[lane_id]->nvl_barrier[barrier_idx][rank], + barrier_flag + ); + } + } + // first warp of each SM waits + if (warp_id == 0 && lane_id < WORLD_SIZE) { + auto start = clock64(); + while (ld_relaxed_sys( + &global.ipc_mems[rank]->nvl_barrier[barrier_idx][lane_id] + ) != barrier_flag) { + if (clock64() - start > timeout) { + DBG("nvl barrier timeout, expected %d, actual %d", barrier_flag, + global.ipc_mems[rank]->nvl_barrier[barrier_idx][lane_id]); + } + } + } + __syncthreads(); +} + +__global__ __launch_bounds__(BLOCK_SIZE +) void dispatch_kernel(const dispatch_args_t args, const global_t global) { + const auto num_sms = NUM_SMS; + const auto num_warps = BLOCK_SIZE / WARP_SIZE; + const auto num_global_warps = num_sms * num_warps; + + const auto sm_id = blockIdx.x; + const auto warp_id = __builtin_amdgcn_readfirstlane(threadIdx.x / WARP_SIZE); + const auto global_warp_id = sm_id * num_warps + warp_id; + const auto lane_id = threadIdx.x % WARP_SIZE; + + const auto rank = global.rank; + const auto hidden_dim = global.hidden_dim; + const auto num_topk = global.topk; + + const auto num_local_experts = global.num_experts / WORLD_SIZE; + const auto num_experts = global.num_experts; + + const auto num_tokens = args.num_tokens; + const auto x = args.x; + const auto topk_idxs = args.topk_idx; + + // step 1: send data + // TODO: split token to chunks + for (int i = global_warp_id; i < num_tokens * num_topk; + i += num_global_warps) { + const auto token_idx = i / num_topk; + const auto topk_idx = i % num_topk; + const auto expert_idx = topk_idxs[token_idx * num_topk + topk_idx]; + const auto dst_rank = expert_idx / num_local_experts; + const auto dst_local_expert_idx = expert_idx % num_local_experts; + + // atomic count + // 20 us + auto dst_token_idx = 0; + if (lane_id == 0) { + dst_token_idx = atomicAdd( + &global.ipc_mems[dst_rank] + ->nvl_recv_count[dst_local_expert_idx], + 1 + ); + global.workspace->nvl_dst_idxs[token_idx][topk_idx] = dst_token_idx; + } + dst_token_idx = __shfl(dst_token_idx, 0); + + // copy x to dst_rank's nvl_recv_x + f16 *src_token = x + token_idx * hidden_dim; + f16 *dst_token = global.ipc_mems[dst_rank] + ->nvl_recv_x[dst_local_expert_idx][dst_token_idx]; + // NOTE: ~30 us + using cp_t = vec_t<16, f16>; + ASSERT(hidden_dim % cp_t::nelem == 0); + for (int j = lane_id * cp_t::nelem; j < hidden_dim; + j += WARP_SIZE * cp_t::nelem) { + cp_t::copy(dst_token + j, src_token + j); + } + } + + // TODO: maybe use per-dst-expert fine-grained sync + + // step 2: wait warps to finish + // NOTE: 40-90 us + // barrier(global, lane_id, warp_id, global_warp_id, num_sms, 0); + + // 1us + __syncthreads(); + if (warp_id == 0 && lane_id == 0) { + // FIXME: look like relaxed is fine + __hip_atomic_fetch_add( + &global.ipc_mems[rank]->barrier, 1, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT + ); + } + + if (global_warp_id == 0 && lane_id < WORLD_SIZE) { + while (ld_relaxed_sys(&global.ipc_mems[rank]->barrier) != num_sms) + ; + st_relaxed_sys(&global.ipc_mems[lane_id]->nvl_signal[rank], global.next_signal); + st_relaxed_sys(&global.ipc_mems[rank]->barrier, 0); + __builtin_amdgcn_wave_barrier(); + while (ld_relaxed_sys(&global.ipc_mems[rank]->nvl_signal[lane_id]) != global.next_signal) + ; + } + // no further code beyond barrier +} + +template +__global__ __launch_bounds__(BLOCK_SIZE +) void combine_kernel(combine_args_t args, global_t global) { + const auto num_sms = NUM_SMS; + const auto num_warps = BLOCK_SIZE / WARP_SIZE; + const auto num_global_warps = num_sms * num_warps; + + const auto sm_id = blockIdx.x; + const auto warp_id = __builtin_amdgcn_readfirstlane(threadIdx.x / WARP_SIZE); + const auto global_warp_id = sm_id * num_warps + warp_id; + const auto lane_id = threadIdx.x % WARP_SIZE; + + const auto rank = global.rank; + const auto num_topk = global.topk; + const auto num_local_experts = global.num_experts / WORLD_SIZE; + const auto num_experts = global.num_experts; + + const auto num_tokens = args.num_tokens; + const auto topk_idxs = args.topk_idx; + const auto topk_weight = args.topk_weight; + + constexpr i32 WARP_HIDDEN_DIM = 1024; + constexpr auto NUM_TOKEN_CHUNKS = ceil_div(HIDDEN_DIM, WARP_HIDDEN_DIM); + constexpr auto LAST_CHUNK_LEN = + HIDDEN_DIM - (NUM_TOKEN_CHUNKS - 1) * WARP_HIDDEN_DIM; + + // step 0: reset before barrier to avoid got overwritten by other ranks + // this is unsafe, need a barrier after combine + if (global_warp_id == 0) { + static_assert(WARP_SIZE >= MAX_NUM_LOCAL_EXPERTS); + if (lane_id < num_local_experts) { + st_relaxed_sys(&global.ipc_mems[rank]->nvl_recv_count[lane_id], 0); + } + } + + // barrier(global, lane_id, warp_id, global_warp_id, num_sms, 1); + + // step 1: send data + for (int i = global_warp_id; i < num_tokens * NUM_TOKEN_CHUNKS; + i += num_global_warps) { + const auto token_idx = i / NUM_TOKEN_CHUNKS; + const auto chunk_idx = i % NUM_TOKEN_CHUNKS; + + f16 *dst_chunk = + global.workspace->nvl_y[token_idx] + chunk_idx * WARP_HIDDEN_DIM; + + if (LAST_CHUNK_LEN < WARP_HIDDEN_DIM && + chunk_idx == NUM_TOKEN_CHUNKS - 1) { + // last unfull chunk + using cp_t = vec_t<2, f16>; + cp_t::accum_type acc = {}; + + for (int j = 0; j < num_topk; j++) { + const auto topk_idx = j; + const auto expert_idx = + topk_idxs[token_idx * num_topk + topk_idx]; + const auto dst_rank = expert_idx / num_local_experts; + const auto dst_local_expert_idx = + expert_idx % num_local_experts; + + const auto dst_token_idx = + global.workspace->nvl_dst_idxs[token_idx][topk_idx]; + + f32 weight = topk_weight[token_idx * num_topk + topk_idx]; + f16 *src_chunk = + global.ipc_mems[dst_rank] + ->nvl_recv_x[dst_local_expert_idx][dst_token_idx] + + chunk_idx * WARP_HIDDEN_DIM; + + cp_t::warp_accum(acc, src_chunk, weight * (dst_rank + 1)); + } + cp_t::warp_accum_store(dst_chunk, acc); + } else { + // full chunk + using cp_t = vec_t<16, f16>; + cp_t::accum_type acc = {}; + + for (int j = 0; j < num_topk; j++) { + const auto topk_idx = j; + const auto expert_idx = + topk_idxs[token_idx * num_topk + topk_idx]; + const auto dst_rank = expert_idx / num_local_experts; + const auto dst_local_expert_idx = + expert_idx % num_local_experts; + + const auto dst_token_idx = + global.workspace->nvl_dst_idxs[token_idx][topk_idx]; + + f32 weight = topk_weight[token_idx * num_topk + topk_idx]; + f16 *src_chunk = + global.ipc_mems[dst_rank] + ->nvl_recv_x[dst_local_expert_idx][dst_token_idx] + + chunk_idx * WARP_HIDDEN_DIM; + + cp_t::warp_accum(acc, src_chunk, weight * (dst_rank + 1)); + } + cp_t::warp_accum_store(dst_chunk, acc); + } + } +} + +template +__global__ __launch_bounds__(BLOCK_SIZE) void ffn_kernel(global_t global) { + const auto num_sms = gridDim.x; + const auto num_warps = blockDim.x / WARP_SIZE; + const i32 num_global_warps = num_sms * num_warps; + + const auto sm_id = blockIdx.x; + const auto warp_id = threadIdx.x / WARP_SIZE; + const auto global_warp_id = sm_id * num_warps + warp_id; + const auto lane_id = threadIdx.x % WARP_SIZE; + + const auto rank = global.rank; + const auto num_topk = global.topk; + const auto num_local_experts = global.num_experts / WORLD_SIZE; + const auto num_experts = global.num_experts; + + auto recv_x = global.ipc_mems[rank]->nvl_recv_x; + auto recv_count = global.ipc_mems[rank]->nvl_recv_count; + + // divide warps into num_local_experts groups, each group handles one + // local_expert_idx + const auto num_groups = num_local_experts; + const auto num_warps_per_group = ceil_div(num_global_warps, num_groups); + const i32 group_id = global_warp_id / num_warps_per_group; + const auto warp_id_per_group = global_warp_id % num_warps_per_group; + + const auto local_expert_idx = group_id; + const auto num_tokens = recv_count[local_expert_idx]; + + // consider the last group + const auto num_warps_in_group = std::min( + num_warps_per_group, num_global_warps - group_id * num_warps_per_group + ); + for (int i = warp_id_per_group; i < num_tokens; i += num_warps_in_group) { + // handle one token + using cp_t = vec_t<16, f16>; + constexpr auto LAST_HIDDEN_DIM = HIDDEN_DIM % cp_t::nelem_per_warp; + + auto token = recv_x[local_expert_idx][i]; + cp_t::warp_mul(token, token, rank + 1); + + // handle remaining part + if (LAST_HIDDEN_DIM > 0) { + auto token = + recv_x[local_expert_idx][i] + HIDDEN_DIM - LAST_HIDDEN_DIM; + vec_t<2, f16>::warp_mul(token, token, rank + 1); + } + } +} + +// clang-format off +#define SWITCH_HIDDEN(hidden, MACRO) \ + switch (hidden) { \ + case 2048: MACRO(2048); break; \ + case 2880: MACRO(2880); break; \ + case 4096: MACRO(4096); break; \ + case 6144: MACRO(6144); break; \ + case 7168: MACRO(7168); break; \ + } +// clang-format on + +class All2all { + private: + global_t global{}; + + public: + All2all( + int rank, int topk, int hidden_dim, int max_num_tokens, int num_experts + ) { + global.rank = rank; + global.topk = topk; + global.hidden_dim = hidden_dim; + global.max_num_tokens = max_num_tokens; + global.num_experts = num_experts; + + global.next_signal = 0; + } + + ~All2all() { + for (auto i = 0; i < WORLD_SIZE; i++) { + auto ipc_mem = global.ipc_mems[i]; + if (ipc_mem && i != global.rank) { + HIP_CHECK(hipIpcCloseMemHandle(ipc_mem)); + } + } + auto local_mem = global.ipc_mems[global.rank]; + if (local_mem) { + HIP_CHECK(hipFree(local_mem)); + } + if (global.workspace) { + HIP_CHECK(hipFree(global.workspace)); + } + } + + auto get_ipc_handle() -> pybind11::bytearray { + void *ws; + HIP_CHECK(hipMalloc(&ws, sizeof(workspace_t))); + HIP_CHECK(hipMemset(ws, 0, sizeof(workspace_t))); + global.workspace = reinterpret_cast(ws); + + void *ptr; + HIP_CHECK(hipExtMallocWithFlags( + &ptr, sizeof(ipc_mem_t), hipDeviceMallocUncached + )); + HIP_CHECK(hipMemset(ptr, 0, sizeof(ipc_mem_t))); + global.ipc_mems[global.rank] = reinterpret_cast(ptr); + + hipIpcMemHandle_t ipc_handle; + HIP_CHECK(hipIpcGetMemHandle(&ipc_handle, ptr)); + return {ipc_handle.reserved, HIP_IPC_HANDLE_SIZE}; + } + + auto init(const std::vector &ipc_handles) { + for (int i = 0; i < WORLD_SIZE; i++) { + if (i == global.rank) { + continue; + } + hipIpcMemHandle_t handle; + auto handle_buf = std::string(ipc_handles[i]); + ASSERT(handle_buf.size() == HIP_IPC_HANDLE_SIZE); + std::memcpy( + handle.reserved, handle_buf.data(), HIP_IPC_HANDLE_SIZE + ); + void *ptr; + HIP_CHECK( + hipIpcOpenMemHandle(&ptr, handle, hipIpcMemLazyEnablePeerAccess) + ); + global.ipc_mems[i] = reinterpret_cast(ptr); + } + } + + auto dispatch(torch::Tensor &topk_idx, torch::Tensor &x) { + dispatch_args_t args{ + .num_tokens = static_cast(topk_idx.size(0)), + .topk_idx = topk_idx.contiguous().data_ptr(), + // torch doesn't support data_ptr() + .x = reinterpret_cast(x.contiguous().data_ptr()), + }; + + global.next_signal++; + + auto stream = at::cuda::getCurrentHIPStream().stream(); + dim3 block(BLOCK_SIZE, 1, 1); + dim3 grid(NUM_SMS, 1, 1); + dispatch_kernel<<>>(args, global); + // HIP_CHECK(hipStreamSynchronize(global.stream)); + + auto recv_x = torch::from_blob( + global.ipc_mems[global.rank]->nvl_recv_x, + {global.num_experts / WORLD_SIZE, + global.max_num_tokens * WORLD_SIZE, global.hidden_dim}, + {WORLD_SIZE * MAX_MAX_NUM_TOKENS * MAX_HIDDEN_DIM, MAX_HIDDEN_DIM, 1 + }, + torch::TensorOptions().dtype(torch::kFloat16).device(torch::kCUDA) + ); + auto recv_count = torch::from_blob( + global.ipc_mems[global.rank]->nvl_recv_count, + {global.num_experts / WORLD_SIZE}, + torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA) + ); + return std::pair{recv_x, recv_count}; + } + + auto combine(torch::Tensor &topk_idx, torch::Tensor &topk_weight) { + combine_args_t args{ + .num_tokens = static_cast(topk_idx.size(0)), + .topk_idx = topk_idx.contiguous().data_ptr(), + .topk_weight = topk_weight.contiguous().data_ptr(), + }; + + auto stream = at::cuda::getCurrentHIPStream().stream(); + dim3 block(BLOCK_SIZE, 1, 1); + dim3 grid(NUM_SMS, 1, 1); + // clang-format off + #define LAUNCH_COMBINE(hidden) combine_kernel<(hidden)><<>>(args, global) + SWITCH_HIDDEN(global.hidden_dim, LAUNCH_COMBINE) + // clang-format on + // HIP_CHECK(hipStreamSynchronize(global.stream)); + + auto y = torch::from_blob( + global.workspace->nvl_y, {global.max_num_tokens, global.hidden_dim}, + {MAX_HIDDEN_DIM, 1}, + torch::TensorOptions().dtype(torch::kFloat16).device(torch::kCUDA) + ); + return y; + } + + // NOTE: 6-30 us + void ffn() { + auto stream = at::cuda::getCurrentHIPStream().stream(); + dim3 block(BLOCK_SIZE, 1, 1); + dim3 grid(NUM_SMS, 1, 1); + // clang-format off + #define LAUNCH_FFN(hidden) ffn_kernel<<>>(global) + SWITCH_HIDDEN(global.hidden_dim, LAUNCH_FFN) + // clang-format on + } +}; + +PYBIND11_MODULE(all2all, m) { + py::class_(m, "All2all") + .def(py::init()) + .def("get_ipc_handle", &All2all::get_ipc_handle) + .def("init", &All2all::init) + .def("dispatch", &All2all::dispatch) + .def("combine", &All2all::combine) + .def("ffn", &All2all::ffn); +} diff --git a/dist-infer/all2all/benchmark.txt b/dist-infer/all2all/benchmark.txt new file mode 100644 index 0000000..afa0d65 --- /dev/null +++ b/dist-infer/all2all/benchmark.txt @@ -0,0 +1,5 @@ +world_size: 2; num_experts: 2; experts_per_token: 2; hidden_dim: 6144; max_num_tokens: 16; seed: 6635 +world_size: 2; num_experts: 16; experts_per_token: 6; hidden_dim: 2048; max_num_tokens: 32; seed: 1234 +world_size: 2; num_experts: 32; experts_per_token: 4; hidden_dim: 2880; max_num_tokens: 128; seed: 51 +world_size: 2; num_experts: 32; experts_per_token: 8; hidden_dim: 4096; max_num_tokens: 256; seed: 175 +world_size: 2; num_experts: 64; experts_per_token: 8; hidden_dim: 7168; max_num_tokens: 256; seed: 4 \ No newline at end of file diff --git a/dist-infer/all2all/eval.py b/dist-infer/all2all/eval.py new file mode 100644 index 0000000..628e1b3 --- /dev/null +++ b/dist-infer/all2all/eval.py @@ -0,0 +1,580 @@ +import base64 +import copy +import dataclasses +import multiprocessing +import re +import time +import os +import sys +import math +from pathlib import Path +from typing import Any, Optional + +import torch.cuda + +from utils import set_seed, clear_l2_cache + +try: + from task import TestSpec +except ImportError: + TestSpec = dict + +from reference import check_implementation, generate_input + + +class PopcornOutput: + def __init__(self, fd: int): + self.file = os.fdopen(fd, 'w') + # os.set_inheritable(fd, False) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.file.close() + + def print(self, *args, **kwargs): + print(*args, **kwargs, file=self.file, flush=True) + + def log(self, key, value): + self.print(f"{key}: {value}") + + +@dataclasses.dataclass +class TestCase: + args: dict + spec: str + + +def _combine(a: int, b: int) -> int: + # combine two integers into one: + # we need this to generate a secret seed based on the test-level seed and + # the global secret seed. + # the test-level seeds are public knowledge, and typically relatively small numbers, + # so we need to make sure they don't provide any useful info for the full seed. + # This Cantor construction ensures that if the secret seed is a large number, + # then so is the overall seed. + return int(a + (a + b) * (a + b + 1) // 2) + + +def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: + try: + content = Path(file_name).read_text() + except Exception as E: + print(f"Could not open test file`{file_name}`: {E}", file=sys.stderr) + exit(113) + + tests = [] + lines = content.splitlines() + match = r"\s*([a-zA-Z_]+):\s*([a-zA-Z]+|[+-]?[0-9]+)\s*" + for line in lines: + parts = line.split(";") + case = {} + for part in parts: + matched = re.match(match, part) + if not re.fullmatch(match, part): + print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) + exit(113) + key = matched[1] + val = matched[2] + try: + val = int(val) + except ValueError: + pass + + case[key] = val + tests.append(TestCase(spec=line, args=case)) + + if seed is not None: + for test in tests: + if "seed" in test.args: + test.args["seed"] = _combine(test.args["seed"], seed) + + return tests + + +@dataclasses.dataclass +class Stats: + runs: int + mean: float + std: float + err: float + best: float + worst: float + + +def calculate_stats(durations: list[int]): + """ + Calculate statistical data from a list of durations. + + @param durations: A list of durations in nanoseconds. + @return: A Stats object containing the number of runs, mean, standard deviation, error, best, and worst durations. + """ + runs = len(durations) + total = sum(durations) + best = min(durations) + worst = max(durations) + + avg = total / runs + variance = sum(map(lambda x: (x - avg) ** 2, durations)) + std = math.sqrt(variance / (runs - 1)) + err = std / math.sqrt(runs) + + return Stats(runs=runs, mean=avg, std=std, err=err, best=float(best), + worst=float(worst)) + + +def _clone_data(data, rank: int): + """ + Recursively goes through data and clones all tensors. + """ + if isinstance(data, tuple): + return tuple(_clone_data(x, rank) for x in data) + elif isinstance(data, list): + return [_clone_data(x, rank) for x in data] + elif isinstance(data, dict): + return {k: _clone_data(v, rank) for k, v in data.items()} + elif isinstance(data, torch.Tensor): + device = f"cuda:{rank}" + return data.clone().to(device) + else: + return data + + +def wrap_check_implementation(data, submission_output): + # Old version returned just a single string, new version + # returns (bool, str); this function ensures compatibility with old + # problem definitions. + result = check_implementation(data, submission_output) + if isinstance(result, tuple): + return result + else: + return not bool(result), result + + +def _run_single_test(test: TestCase): + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + data = generate_input(**test.args) + torch.cuda.synchronize() + submission_output = custom_kernel(_clone_data(data, 0)) + torch.cuda.synchronize() + return wrap_check_implementation(data, submission_output) + + +def _run_distributed_test(test: TestCase, rank: int): + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + import torch.distributed as dist + world_size = test.args["world_size"] + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "12356" + dist.init_process_group("nccl", init_method="env://", rank=rank, world_size=world_size, device_id=torch.device(f'cuda:{rank}')) + try: + data = generate_input(**test.args, rank=rank) + torch.cuda.synchronize() + submission_output = custom_kernel(_clone_data(data, rank)) + torch.cuda.synchronize() + return wrap_check_implementation(data, submission_output) + finally: + dist.destroy_process_group() + + +def run_multi_gpu_test(pool: multiprocessing.Pool, test: TestCase, world_size: int): + """ + Runs a single test in another process. + """ + rets = [] + # world_size is a mandatory argument for multi-gpu tests + for i in range(world_size): + rets.append( + pool.apply_async( + _run_distributed_test, + args=(test, i), + ) + ) + # 60 seconds should be more than enough, we want tests to be fast + rets = [el.get(60) for el in rets] + + correct = all(ret[0] for ret in rets) + error_messages = str.join("\n", [f"rank {rank} - {ret[1]}" for rank, ret in enumerate(rets) if not ret[0]]) + return correct, error_messages + + +def run_single_test(pool: multiprocessing.Pool, test: TestCase): + """ + Runs a single test in another process. + """ + world_size = test.args.get("world_size", None) + if world_size is None: + return pool.apply(_run_single_test, (test,)) + else: + return run_multi_gpu_test(pool, test, world_size) + + +def run_testing(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): + """ + Executes the actual test case code and checks for correctness. + + @param logger: A PopcornOutput object used for logging test results. + @param tests: A list of TestCase objects representing the test cases to be executed. + @return: An integer representing the exit status: 0 if all tests pass, otherwise 112. + """ + passed = True + logger.log("test-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"test.{idx}.spec", test.spec) + good, message = run_single_test(pool, test) + if not good: + logger.log(f"test.{idx}.status", "fail") + logger.log(f"test.{idx}.error", message) + passed = False + else: + logger.log(f"test.{idx}.status", "pass") + if message: + logger.log(f"test.{idx}.message", message) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def _run_single_benchmark(test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float) -> Stats | Any: + """ + Runs one benchmark. Do not call directly. + """ + from submission import custom_kernel + + durations = [] + # generate input data once + data = generate_input(**test.args) + check_copy = _clone_data(data, 0) + # first, one obligatory correctness check + output = custom_kernel(data) + good, message = wrap_check_implementation(check_copy, output) + if not good: + return message + + # now, do multiple timing runs without further correctness testing + # there is an upper bound of 100 runs, and a lower bound of 3 runs; + # otherwise, we repeat until we either measure at least 10 full seconds, + # or the relative error of the mean is below 1%. + + bm_start_time = time.perf_counter_ns() + for i in range(max_repeats): + if recheck: + # ensure we use a different seed for every benchmark + if "seed" in test.args: + test.args["seed"] += 13 + + data = generate_input(**test.args) + check_copy = _clone_data(data, 0) + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + clear_l2_cache() + + start_event.record() + output = custom_kernel(data) + end_event.record() + torch.cuda.synchronize() + duration = start_event.elapsed_time(end_event) * 1e6 # Convert ms to ns + + if recheck: + good, message = check_implementation(check_copy, output) + if not good: + return message + + del output + durations.append(duration) + + if i > 1: + total_bm_duration = time.perf_counter_ns() - bm_start_time + stats = calculate_stats(durations) + # stop if either + # a) relative error dips below 0.1% + # b) we exceed the total time limit for benchmarking the kernel + # c) we exceed 2 minutes of total wallclock time. + if stats.err / stats.mean < 0.001 or stats.mean * stats.runs > max_time_ns or total_bm_duration > 120e9: + break + + return calculate_stats(durations) + + +def _run_distributed_benchmark(test: TestCase, rank: int, recheck: bool, max_repeats: int, + max_time_ns: float) -> Stats | Any: + """ + Runs one distributed benchmark. Do not call directly. + """ + from submission import custom_kernel + import torch.distributed as dist + + world_size = test.args["world_size"] + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "12356" + dist.init_process_group("nccl", init_method="env://", rank=rank, world_size=world_size, device_id=torch.device(f'cuda:{rank}')) + + try: + durations = [] + # generate input data once + data = generate_input(**test.args, rank=rank) + check_copy = _clone_data(data, rank) + + # first, one obligatory correctness check + output = custom_kernel(_clone_data(data, rank)) + good, message = wrap_check_implementation(check_copy, output) + if not good: + return message + + # now, do multiple timing runs with proper distributed synchronization + bm_start_time = time.perf_counter_ns() + for i in range(max_repeats): + error_message = None + if recheck: + # ensure we use a different seed for every benchmark + if "seed" in test.args: + test.args["seed"] += 13 + + data = generate_input(**test.args, rank=rank) + check_copy = _clone_data(data, rank) + + # Synchronize all ranks before timing + clear_l2_cache() + torch.cuda.synchronize() + dist.barrier() + + # Use distributed timing - only rank 0 records the overall time + if rank == 0: + start_time = time.perf_counter_ns() + + # All ranks execute the kernel + output = custom_kernel(_clone_data(data, rank)) + + # Synchronize all ranks after kernel execution + torch.cuda.synchronize() + dist.barrier() + + if rank == 0: + end_time = time.perf_counter_ns() + duration = end_time - start_time # Already in nanoseconds + durations.append(duration) + + if recheck: + good, message = check_implementation(check_copy, output) + if not good: + error_message = message + + del output + + has_error = torch.tensor(1 if error_message is not None else 0, dtype=torch.int32, device=f'cuda:{rank}') + dist.all_reduce(has_error) + if has_error.item() > 0: + return error_message + + # Only rank 0 checks convergence criteria + if rank == 0 and i > 1: + total_bm_duration = time.perf_counter_ns() - bm_start_time + stats = calculate_stats(durations) + # stop if either + # a) relative error dips below 0.1% + # b) we exceed the total time limit for benchmarking the kernel + # c) we exceed 2 minutes of total wallclock time. + should_stop = (stats.err / stats.mean < 0.001 or + stats.mean * stats.runs > max_time_ns or + total_bm_duration > 120e9) + else: + should_stop = False + + # Broadcast stop decision to all ranks + stop_tensor = torch.tensor(should_stop, dtype=torch.bool, device=f'cuda:{rank}') + dist.broadcast(stop_tensor, 0) + + if stop_tensor.item(): + break + + # Only rank 0 returns meaningful stats + if rank == 0: + # print(durations) + return calculate_stats(durations) + else: + # Non-zero ranks return a dummy stats object + return Stats(runs=len(durations), mean=0.0, std=0.0, err=0.0, best=0.0, worst=0.0) + + finally: + dist.destroy_process_group() + + +def run_multi_gpu_benchmark(pool: multiprocessing.Pool, test: TestCase, recheck: bool, max_repeats: int, + max_time_ns: float, world_size: int): + """ + Runs a multi-GPU benchmark across all ranks. + """ + rets = [] + for i in range(world_size): + rets.append( + pool.apply_async( + _run_distributed_benchmark, + args=(test, i, recheck, max_repeats, max_time_ns), + ) + ) + + # 120 seconds for benchmarking + we run a pre-benchmark test and want to leave some slack + rets = [el.get(timeout=180) for el in rets] + + # For multi-GPU benchmarking, only rank 0 has meaningful stats + failed_ranks = [] + rank_0_result = None + + for rank, ret in enumerate(rets): + if isinstance(ret, Stats): + if rank == 0: + rank_0_result = ret + else: + # ret is an error message + failed_ranks.append((rank, ret)) + + if failed_ranks: + error_messages = str.join("\n", [f"rank {rank} - {msg}" for rank, msg in failed_ranks]) + return error_messages + else: + return rank_0_result if rank_0_result else "No stats returned from rank 0" + + +def run_single_benchmark(pool: multiprocessing.Pool, test: TestCase, recheck: bool, max_repeats: int, + max_time_ns: float): + """ + For a particular test case, check correctness (if applicable) and grab runtime results. + + @param pool: Process on which the benchmark will be launched. + @param test: TestCase object. + @param recheck: Flag for whether to explicitly check functional correctness. + @param max_repeats: Number of trials to repeat. + @param max_time_ns: Timeout time in nanoseconds. + @return: A Stats object for this particular benchmark case or an error if the test fails. + """ + + world_size: Optional[int] = test.args.get("world_size", None) + if world_size is None: + return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) + else: + return run_multi_gpu_benchmark(pool, test, recheck, max_repeats, max_time_ns, world_size) + + +def run_benchmarking(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): + """ + Executes benchmarking code for a CUDA Kernel and logs runtimes. + + @param logger: A PopcornOutput object used for logging benchmark results. + @param pool: Process on which the benchmarks will be launched. + @param tests: A list of TestCase objects representing the test cases to be benchmarked. + @return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112. + """ + # warm up + run_single_benchmark(pool, tests[0], False, 100, 10e7) + + passed = True + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + result = run_single_benchmark(pool, test, False, 100, 10e9) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) + else: + passed = False + logger.log(f"benchmark.{idx}.status", "fail") + logger.log(f"benchmark.{idx}.error", result) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def run_single_profile(test: TestCase) -> str: + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + from torch.profiler import profile, record_function, ProfilerActivity + data = generate_input(**test.args) + torch.cuda.synchronize() + + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + submission_output = custom_kernel(_clone_data(data, 0)) + torch.cuda.synchronize() + return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) + + +def run_profiling(logger: PopcornOutput, tests: list[TestCase]): + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + report = run_single_profile(test) + logger.log(f"benchmark.{idx}.report", base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8")) + logger.log("check", "pass") + return 0 + + +def main(): + fd = os.getenv("POPCORN_FD") + if not fd: + return 111 + + if len(sys.argv) < 3: + return 2 + + mode = sys.argv[1] + seed = os.getenv("POPCORN_SEED") + os.unsetenv("POPCORN_SEED") + n_gpus = int(os.getenv("POPCORN_GPUS", "1")) + seed = int(seed) if seed else None + set_seed(seed or 42) + tests = get_test_cases(sys.argv[2], seed) + + with PopcornOutput(int(fd)) as logger: + import multiprocessing + mp_context = multiprocessing.get_context('spawn') + with mp_context.Pool(n_gpus) as pool: + if mode == "test": + return run_testing(logger, pool, tests) + if mode == "benchmark": + return run_benchmarking(logger, pool, tests) + + if mode == "leaderboard": + # warmup + run_single_benchmark(pool, tests[0], False, 100, 1e7) + logger.log("benchmark-count", len(tests)) + passed = True + for i in range(len(tests)): + result = run_single_benchmark(pool, tests[i], True, 100, 30e9) + logger.log(f"benchmark.{i}.spec", tests[i].spec) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log(f"benchmark.{i}.{field.name}", getattr(result, field.name)) + else: + passed = False + logger.log(f"benchmark.{i}.status", "fail") + logger.log(f"benchmark.{i}.error", str(result)) # TODO: Make sure result implements __str__? + break + + logger.log("check", "pass" if passed else "fail") + elif mode == "profile": + run_profiling(logger, tests) + else: + # invalid mode + return 2 + + +if __name__ == "__main__": + multiprocessing.set_start_method('spawn') + sys.exit(main()) diff --git a/dist-infer/all2all/reference.py b/dist-infer/all2all/reference.py new file mode 100644 index 0000000..ef8a2d7 --- /dev/null +++ b/dist-infer/all2all/reference.py @@ -0,0 +1,285 @@ +# pytorch_all2all.py +import os +import torch +import torch.distributed as dist +import dataclasses +from task import input_t, output_t + + +# ---------------- MoE config ---------------- +@dataclasses.dataclass +class MoEConfig: + num_experts: int + experts_per_token: int + hidden_dim: int + max_num_tokens: int + in_dtype: torch.dtype = torch.float16 + out_dtype: torch.dtype = torch.float16 + + +# ---------------- data per dp rank ---------------- +class RankTestData: + def __init__(self, cfg: MoEConfig, rng: torch.Generator, rank: int): + device = torch.device(f"cuda:{rank}") + self.num_tokens = int( + torch.randint( + 1, cfg.max_num_tokens, [1], generator=rng, device=device + ).item() + ) + # token expert map + self.indices = torch.empty( + self.num_tokens, cfg.experts_per_token, dtype=torch.int32, device=device + ) + for i in range(self.num_tokens): + perm = torch.randperm(cfg.num_experts, generator=rng, device=device) + self.indices[i] = perm[: cfg.experts_per_token] + # topk weights + self.weights = torch.rand( + self.num_tokens, + cfg.experts_per_token, + dtype=torch.float32, + generator=rng, + device=device, + ) + # dp tokens, input of dispatch + self.x = torch.randn( + self.num_tokens, + cfg.hidden_dim, + dtype=cfg.in_dtype, + generator=rng, + device=device, + ) + + +# ---------------- All2All pytorch impl ---------------- +class PyTorchAllToAll: + META_DIM = 5 # global_exp, src_rank, src_token, src_k, pad + + def __init__(self, cfg: MoEConfig, rank: int, world_size: int): + self.cfg = cfg + self.rank = rank + self.world_size = world_size + # num experts per rank + self.num_local_experts = cfg.num_experts // world_size + # max recv tokens per rank + self.max_recv = cfg.max_num_tokens * world_size + + # ---------- dispatch ---------- + def dispatch(self, dp_x: torch.Tensor, indices: torch.Tensor): + device = dp_x.device + cfg = self.cfg + + # ---------1. get counts of send and recv for each rank ----------- + # 1.1 token nums to send to each rank + send_counts = [0] * self.world_size + # 1.2 token id to send to each rank + token_map = [[] for _ in range(self.world_size)] + # 1.3 token meta data, need update for combine + meta_map = [[] for _ in range(self.world_size)] + for t, expert_list in enumerate(indices.tolist()): + for k, e in enumerate(expert_list): + dst_rank = e // self.num_local_experts + send_counts[dst_rank] += 1 + token_map[dst_rank].append(t) + meta_map[dst_rank].extend( + [e, self.rank, t, k, 0] + ) # srcGobalExpert, srcRank, srcIndex, expert index + + send_counts_t = torch.tensor(send_counts, dtype=torch.long, device=device) + # 1.3 token nums to recv from each rank + recv_counts_t = torch.empty(self.world_size, dtype=torch.long, device=device) + dist.all_to_all_single(recv_counts_t, send_counts_t) + # ---------2. send and recv buffer, order by tokens on each rank ---------- + send_buf = torch.cat([dp_x[idx_list] for idx_list in token_map], dim=0) + total_recv = int(recv_counts_t.sum().item()) + recv_buf = torch.empty( + total_recv, cfg.hidden_dim, dtype=cfg.in_dtype, device=device + ) + + # 2.1 meta buf for send and recv + send_meta = torch.tensor( + [v for sub in meta_map for v in sub], dtype=torch.int32, device=device + ).view(-1, self.META_DIM) + recv_meta = torch.empty( + total_recv, self.META_DIM, dtype=torch.int32, device=device + ) + # ---------3. dispatch send_buf to recv_buf by recv and send counts-------------- + dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts_t.tolist(), + input_split_sizes=send_counts_t.tolist(), + ) + + dist.all_to_all_single( + recv_meta.view(-1), + send_meta.view(-1), + output_split_sizes=[c * self.META_DIM for c in recv_counts_t.tolist()], + input_split_sizes=[c * self.META_DIM for c in send_counts_t.tolist()], + ) + recv_meta = recv_meta.view(-1, self.META_DIM) + # ---------4. define output tensor of dispatch ------------ + # 4.1 num tokens per expert + expert_num_tokens = torch.zeros( + self.num_local_experts, dtype=torch.int32, device=device + ) + # 4.2 token tensor on each expert + expert_x = torch.empty( + (self.num_local_experts, self.max_recv, cfg.hidden_dim), + dtype=cfg.in_dtype, + device=device, + ) + expert_meta = torch.empty( + (self.num_local_experts, self.max_recv, self.META_DIM), + dtype=torch.int32, + device=device, + ) + # ---------5. dispatch send_meta to recv_meta by recv and send counts------ + # ---------6. write tokens to each expert on each rank ------ + # 6.1 fetch the local expert id of corresponding token i + for i in range(total_recv): + global_eid = int(recv_meta[i, 0].item()) + local_eid = global_eid % self.num_local_experts + # output, store token buf and token meta and token nums of each expert + expert_x[local_eid, expert_num_tokens[local_eid]] = recv_buf[i] + expert_meta[local_eid, expert_num_tokens[local_eid]] = recv_meta[i] + expert_num_tokens[local_eid] += 1 + # 6.2 after dispatch, token nums and token and meta of token on expert + return expert_num_tokens, expert_x, expert_meta + + # ---------- combine ---------- + def combine( + self, + out_tokens: torch.Tensor, # output, (max num tokens, token dim) + weights: torch.Tensor, # topk weight + expert_meta: torch.Tensor, # input + expert_y: torch.Tensor, # input, (num_local_experts, max_num_tokens * num_dp, token_dim) + expert_num_tokens: torch.Tensor, + ): # input + device = out_tokens.device + cfg = self.cfg + + # 1. count send-back tokens in cur rank + send_counts = [0] * self.world_size + # 1.1 token that will send back + y_map = [[] for _ in range(self.world_size)] + # 1.2 meta info of each token that send back to its src rank + meta_map = [[] for _ in range(self.world_size)] + + # 2. traverse each token of each local expert of each rank, fill into send_counts and y_map and meta_map + for local_eid in range(self.num_local_experts): + cnt = int(expert_num_tokens[local_eid].item()) + for j in range(cnt): + # meta info token j of local eid + meta = expert_meta[local_eid, j] + dst_rank = int(meta[1].item()) + send_counts[dst_rank] += 1 + # token j and its meta that send back to dst rank/local eid + y_map[dst_rank].append(expert_y[local_eid, j].unsqueeze(0)) + meta_map[dst_rank].extend(meta.tolist()) + # token nums that cur rank plan to send to other ranks + send_counts_t = torch.tensor(send_counts, dtype=torch.long, device=device) + # token nums that will recv from other ranks + recv_counts_t = torch.empty(self.world_size, dtype=torch.long, device=device) + # call all2all to send send counts and recv recv_counts_t at each rank by all2all + dist.all_to_all_single(recv_counts_t, send_counts_t) + # 3.send buffers of each rank, that is, the tokens at its experts + y_map_tensors = [] + for sub_list in y_map: + if sub_list: + y_map_tensors.append(torch.cat(sub_list, dim=0)) + else: + y_map_tensors.append( + torch.empty((0, cfg.hidden_dim), dtype=cfg.out_dtype, device=device) + ) + send_buf = torch.cat(y_map_tensors, dim=0) + # 4. flatten send meta by tokens + send_meta = torch.tensor( + [v for sub in meta_map for v in sub], dtype=torch.int32, device=device + ).view(-1, self.META_DIM) + # 5. total recv tokens of cur rank + total_recv = int(recv_counts_t.sum().item()) + # 6. recv buffer of cur rank + recv_buf = torch.empty( + total_recv, cfg.hidden_dim, dtype=cfg.out_dtype, device=device + ) + recv_meta = torch.empty( + total_recv, self.META_DIM, dtype=torch.int32, device=device + ) + # 7. call all2all to send and recv for each rank + dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_counts_t.tolist(), + input_split_sizes=send_counts_t.tolist(), + ) + # 8. call all2all to send meta and recv meta for each rank + dist.all_to_all_single( + recv_meta.view(-1), + send_meta.view(-1), + output_split_sizes=[c * self.META_DIM for c in recv_counts_t.tolist()], + input_split_sizes=[c * self.META_DIM for c in send_counts_t.tolist()], + ) + # 9. restore recv meta + recv_meta = recv_meta.view(-1, self.META_DIM) + + # 10. write back tokens from recv buf, per meta info, and do weighted sum + for i in range(total_recv): + src_token = int(recv_meta[i, 2].item()) + src_k = int(recv_meta[i, 3].item()) + src_rank = int(recv_meta[i, 1].item()) + w = weights[src_token, src_k].to(torch.float32) + out_tokens[src_token] += recv_buf[i].to(torch.float32) * w + + return out_tokens + + +def generate_input( + num_experts, experts_per_token, hidden_dim, max_num_tokens, seed, rank, world_size +): + device = torch.device(f"cuda:{rank}") + gen = torch.Generator(device=device) + gen.manual_seed(seed + rank) + + cfg = MoEConfig( + num_experts=num_experts, + experts_per_token=experts_per_token, + hidden_dim=hidden_dim, + max_num_tokens=max_num_tokens, + in_dtype=torch.float16, + out_dtype=torch.float16, + ) + rank_data = RankTestData(cfg, gen, rank) + return cfg, rank_data, rank, world_size + + +def ref_kernel(data: input_t) -> output_t: + cfg, rank_data, rank, world_size = data + + ata = PyTorchAllToAll(cfg, rank, world_size) + + expert_num, expert_x, expert_meta = ata.dispatch(rank_data.x, rank_data.indices) + expert_y = expert_x.to(cfg.out_dtype) * (1 + rank) + y = torch.zeros( + cfg.max_num_tokens, + cfg.hidden_dim, + dtype=cfg.out_dtype, + device=rank_data.x.device, + ) + + ata.combine(y, rank_data.weights, expert_meta, expert_y, expert_num) + + return y[: rank_data.num_tokens] + +custom_kernel = ref_kernel + +def check_implementation(data: input_t, output: output_t): + expected = ref_kernel(data) + if output.device != expected.device: + return False, f"Output device mismatch: {output.device} != {expected.device}" + res = torch.allclose(output, expected, rtol=1e-2, atol=5e-3) + if not res: + return False, f"Output values mismatch, {output} != {expected}" + + return True, "" + diff --git a/dist-infer/all2all/submit.py b/dist-infer/all2all/submit.py new file mode 100644 index 0000000..f495b99 --- /dev/null +++ b/dist-infer/all2all/submit.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +import sys +import os +import subprocess +import datetime +import time +import re +import math + +def gen_submission(): + with open("template.py") as f: + template = f.read() + with open("all2all.cpp") as f: + code = f.read() + ret = template.replace("{{}}", code) + ret = ret.replace("\\", "@") + return ret + +def extract_and_geom_mean(text: str): + pattern = re.compile(r'(\d+(?:\.\d+)?)\s*±\s*\d+(?:\.\d+)?') + values = [float(m.group(1)) for m in pattern.finditer(text)] + + if not values: + return None, [] + + log_sum = sum(math.log(v) for v in values) + geom_mean = math.exp(log_sum / len(values)) + return geom_mean, values + +def main(): + if len(sys.argv) > 1 and sys.argv[1] != "local_test": + pyfile = sys.argv[1] + else: + pyfile = "submission.py" + code = gen_submission() + if "local_test" not in sys.argv: + code = code.replace("#define LOCAL_TEST", "") + with open(pyfile, "w") as f: + f.write(code) + if "local_test" in sys.argv: + return + + timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M:%S-%f") + logfile = f"logs/all2all-{timestamp}.log" + + os.makedirs("logs", exist_ok=True) + + print(f"submiting, log file: {logfile}") + + cmd = [ + "popcorn-cli", "submit", + "--gpu", "MI300x8", + "--leaderboard", "amd-all2all", + "--mode", "benchmark", + pyfile, + "-o", logfile, + ] + + start = time.time() + + timeout = 180 + try: + subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, timeout=timeout, check=True) + except subprocess.TimeoutExpired: + print(f"Error: Command timed out after {timeout}s", file=sys.stderr) + sys.exit(1) + except subprocess.CalledProcessError as e: + print(f"Error: Command failed with exit code {e.returncode}", file=sys.stderr) + sys.exit(e.returncode) + + with open(logfile) as f: + output = f.read() + geom_mean, values = extract_and_geom_mean(output) + print(output) + + extra_log = "\n" + if geom_mean: + extra_log += f"geom_mean: {geom_mean:.2f}, values: {values}\n" + + print(extra_log, end="") + with open(logfile, "a") as f: + f.write(extra_log) + + end = time.time() + print(f"submit done, time cost: {end - start:.2f}s") + +if __name__ == "__main__": + main() + diff --git a/dist-infer/all2all/task.py b/dist-infer/all2all/task.py new file mode 100644 index 0000000..ee2016a --- /dev/null +++ b/dist-infer/all2all/task.py @@ -0,0 +1,17 @@ +import torch +from typing import TypeVar, TypedDict, TYPE_CHECKING + +if TYPE_CHECKING: + from reference import MoEConfig, RankTestData + + +input_t = TypeVar("input_t", bound=tuple["MoEConfig", "RankTestData", int, int]) +output_t = TypeVar("output_t", bound=torch.Tensor) + + +class TestSpec(TypedDict): + num_experts: int + experts_per_token: int + hidden_dim: int + max_num_tokens: int + seed: int diff --git a/dist-infer/all2all/template.py b/dist-infer/all2all/template.py new file mode 100644 index 0000000..14bb524 --- /dev/null +++ b/dist-infer/all2all/template.py @@ -0,0 +1,335 @@ +import torch +import torch.distributed as dist +from torch.utils.cpp_extension import load + +from task import input_t, output_t + +CUDA_SRC = r""" +{{}} +""" + +import sys +import os +import time +from filelock import FileLock +from contextlib import contextmanager +import functools + +os.environ.update( + { + "CXX": "clang++", + "PYTORCH_ROCM_ARCH": "gfx942", + # "NCCL_P2P_DISABLE": "1", + # "NCCL_IB_DISABLE": "1", + # "NCCL_SHM_DISABLE": "1", + # "NCCL_DEBUG": "WARNING", + } +) + +# don't overwrite existing source file to avoid recompile among multiple ranks +lock_path = "all2all-compile.lock" +with FileLock(lock_path): + with open("all2all.cu", "w") as f: + f.write(CUDA_SRC.replace("@", chr(92))) + os.makedirs("torch-build", exist_ok=True) + module = load( + name="all2all", + sources=["all2all.cu"], + build_directory="torch-build", + verbose=False, + extra_cuda_cflags=["--offload-arch=gfx942", "-std=c++20", "-O2"], + extra_cflags=["-O2"], + ) + + +def print0(out: str, all=False): + rank = dist.get_rank() + if rank == 0 or all: + print(f"[rank {rank}] {out}", file=sys.stderr) + + +def barrier(): + dist.barrier() + torch.cuda.current_stream().synchronize() + + +comm = None +should_udpate_comm = True + +orignal_init_pg = dist.init_process_group + + +def hooked_init_pg(*args, **kwargs): + global should_update_comm + should_update_comm = True + ret = orignal_init_pg(*args, **kwargs) + # print0(f"init pg: {args}, {kwargs}", True) + return ret + + +dist.init_process_group = hooked_init_pg + + +def all_get_comm(rank, world_size, cfg): + global comm, should_update_comm + config = ( + rank, + cfg.experts_per_token, + cfg.hidden_dim, + cfg.max_num_tokens, + cfg.num_experts, + ) + if should_update_comm: + should_update_comm = False + # clean up old comm + del comm + # always set device first to avoid using wrong gpu + torch.cuda.set_device(dist.get_rank()) + # create a new comm + print0(f"create new comm: {config}", True) + comm = module.All2all(*config) + ipc_handle = comm.get_ipc_handle() + ipc_handles = [None] * world_size + dist.all_gather_object(ipc_handles, ipc_handle) + comm.init(ipc_handles) + barrier() + return comm + + +from reference import PyTorchAllToAll, ref_kernel + + +def diff_allclose(ref, other, rtol, atol, max_print=10): + diff = torch.abs(ref - other) + mask = diff > (atol + rtol * torch.abs(ref)) + if mask.any(): + idx = mask.nonzero(as_tuple=False) + print0(f"{idx.shape[0]} elements mismatch", True) + for i in range(min(max_print, idx.shape[0])): + coord = tuple(idx[i].tolist()) + print0( + f" coord={coord}: ref={ref[coord].item()}, other={other[coord].item()}", + True, + ) + + +def all_assert(exp: bool, err: str): + gathered = [False] * dist.get_world_size() + dist.all_gather_object(gathered, exp) + assert all(gathered), err + + +def sort_recv_x(recv_x: torch.Tensor, n_elem: int): + out = recv_x + # alphabetical sort + for d in reversed(range(recv_x.shape[2])): + indices = out[:, :, d].argsort(dim=-1, stable=True) + indices_exp = indices.unsqueeze(-1).expand(*out.shape) + out = out.gather(dim=1, index=indices_exp) + return out[:, :, :n_elem] + + +def check_dispatch(recv_x: torch.Tensor, ref: torch.Tensor, ref_recv_cnt: torch.Tensor): + recv_x = recv_x.clone() + ref = ref.clone() + + all_assert(ref.shape == recv_x.shape, "dispatch shape mismatch") + num_local_experts, max_num_tokens, hidden_dim = recv_x.shape + max_num_tokens /= 8 + + max_token = 0 + # mask out invalid token + for i in range(num_local_experts): + num_tokens = ref_recv_cnt[i].item() + max_token = max(max_token, num_tokens) + ref[i, num_tokens:] = 0 + recv_x[i, num_tokens:] = 0 + ref = ref[:, :max_token] + recv_x = recv_x[:, :max_token] + # only compare the first n_elem in each token + n_elem = hidden_dim + ref = sort_recv_x(ref, n_elem) + recv_x = sort_recv_x(recv_x, n_elem) + check = torch.allclose(ref, recv_x, rtol=1e-2, atol=5e-3) + if not check: + diff_allclose(ref, recv_x, 1e-2, 5e-3) + print0(f"{ref=}", True) + print0(f"{recv_x=}", True) + all_assert(check, "dispatch result mismatch") + + +def check_combine(y: torch.Tensor, ref_y: torch.Tensor, num_tokens: int): + y = y.clone() + ref_y = ref_y.clone() + + all_assert(y.shape == ref_y.shape, "combine shape mismatch") + _, hidden_dim = ref_y.shape + # only compare the first n_elem in each token + n_elem = hidden_dim + y = y[:num_tokens, :n_elem] + ref_y = ref_y[:num_tokens, :n_elem] + check = torch.allclose(y, ref_y, rtol=1e-2, atol=5e-3) + if not check: + diff_allclose(ref_y, y, 1e-2, 5e-3) + print0(f"{y=}", True) + print0(f"{ref_y=}", True) + all_assert(check, "combine result mismatch") + + +@contextmanager +def host_timer(): + end = None + + def wait_for_time(): + if end is None: + return 0.0 + return (end - start) * 1000.0 + + try: + start = time.time() + yield wait_for_time + finally: + end = time.time() + + +def report_host_time(name=""): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + with host_timer() as t: + result = func(*args, **kwargs) + print0(f"{name}: {t():.3f}ms", True) + return result + + return wrapper + + return decorator + + +@contextmanager +def timer(): + end = None + + def wait_for_time(): + if end is None: + return 0.0 + end.synchronize() + return start.elapsed_time(end) + + try: + start = torch.cuda.Event(enable_timing=True) + start.record() + yield wait_for_time + finally: + end = torch.cuda.Event(enable_timing=True) + end.record() + + +def custom_kernel_test(data: input_t, check=False) -> output_t: + cfg, rank_data, rank, world_size = data + # step 0: always set device first to avoid using wrong gpu + torch.cuda.set_device(rank) + # step 1: setup communicator + # rank could change in different testcase, so make sure to update rank in comm + if check: + ata = PyTorchAllToAll(cfg, rank, world_size) + with host_timer() as t_comm: + comm = all_get_comm(rank, world_size, cfg) + # step 2: dispatch + with timer() as t_dispatch: + recv_x, recv_count = comm.dispatch(rank_data.indices, rank_data.x) + if check: + expert_num, ref_recv_x, expert_meta = ata.dispatch( + rank_data.x, rank_data.indices + ) + torch.cuda.synchronize() + check_dispatch(recv_x, ref_recv_x, expert_num) + # step 3: ffn + with timer() as t_ffn: + comm.ffn() + if check: + ref_recv_x = ref_recv_x.to(cfg.out_dtype) * (1 + rank) + # step 4: combine + with timer() as t_combine: + y = comm.combine(rank_data.indices, rank_data.weights) + if check: + ref_y = torch.zeros( + (cfg.max_num_tokens, cfg.hidden_dim), + dtype=cfg.out_dtype, + device=rank_data.x.device, + ) + ata.combine(ref_y, rank_data.weights, expert_meta, ref_recv_x, expert_num) + torch.cuda.synchronize() + check_combine(y, ref_y, rank_data.num_tokens) + print0( + f"comm: {t_comm():.3f}ms, dispatch: {t_dispatch():.3f}ms, ffn: {t_ffn():.3f}ms, combine: {t_combine():.3f}ms", + True, + ) + return y[: rank_data.num_tokens] + + +# @report_host_time("custom_kernel") +def custom_kernel_bench(data: input_t) -> output_t: + cfg, rank_data, rank, world_size = data + # step 1: lazy set device and setup communicator + # rank could change in different testcase, so make sure to update rank in comm + comm = all_get_comm(rank, world_size, cfg) + # step 2: dispatch + recv_x, recv_count = comm.dispatch(rank_data.indices, rank_data.x) + # step 3: ffn & combine + y = comm.combine(rank_data.indices, rank_data.weights) + return y[: rank_data.num_tokens] + + +def custom_kernel_repeat(data: input_t, fun=custom_kernel_test) -> output_t: + for _ in range(10): + fun(data) + return fun(data) + + +def empty_kernel(data: input_t) -> output_t: + cfg, rank_data, rank, world_size = data + torch.cuda.set_device(rank) + + gathered = [None] * world_size + dist.all_gather_object(gathered, should_udpate_comm) + all_assert(all(gathered) or not any(gathered), "should_update_comm mismatch") + + comm = all_get_comm(rank, world_size, cfg) + return ref_kernel(data) + + +def custom_kernel_graph(data: input_t) -> output_t: + # warmup: set device, init comm + # avoid stream synchronizing in capture + torch.cuda.set_device(dist.get_rank()) + custom_kernel_bench(data) + torch.cuda.synchronize() + # capture in correct device + s = torch.cuda.Stream(device=dist.get_rank()) + # this is required, don't know why though + with torch.cuda.stream(s): + custom_kernel_bench(data) + torch.cuda.synchronize() + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g, stream=s): + custom_kernel_bench(data) + torch.cuda.set_device(dist.get_rank()) + # warmup graph + g.replay() + torch.cuda.synchronize() + with host_timer() as t_graph: + for _ in range(1): + g.replay() + torch.cuda.synchronize() + with host_timer() as t_normal: + for _ in range(1): + ret = custom_kernel_bench(data) + torch.cuda.synchronize() + # cuda graph is slightly slower + # normal: 0.241, graph: 0.277 + print0(f"normal: {t_normal():.3f}, graph: {t_graph():.3f}") + return ret + + +custom_kernel = custom_kernel_bench diff --git a/dist-infer/all2all/utils.py b/dist-infer/all2all/utils.py new file mode 100644 index 0000000..396c6bf --- /dev/null +++ b/dist-infer/all2all/utils.py @@ -0,0 +1,176 @@ +import random +from typing import Tuple + +import numpy as np +import torch + + +def set_seed(seed=42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_device(use_cuda: bool = True) -> torch.device: + """Get the appropriate device (GPU or CPU).""" + if use_cuda: + if torch.cuda.is_available(): + return torch.device("cuda") + elif torch.backends.mps.is_available(): + return torch.device("mps") + else: + print("No compatible GPU found. Falling back to CPU.") + return torch.device("cpu") + + +# Adapted from https://github.com/linkedin/Liger-Kernel/blob/main/test/utils.py +@torch.no_grad() +def verbose_allclose( + received: torch.Tensor, expected: torch.Tensor, rtol=1e-05, atol=1e-08, max_print=5 +) -> Tuple[bool, list[str]]: + """ + Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. + + Parameters: + received (torch.Tensor): Tensor we actually got. + expected (torch.Tensor): Tensor we expected to receive. + rtol (float): Relative tolerance; relative to expected + atol (float): Absolute tolerance. + max_print (int): Maximum number of mismatched elements to print. + """ + # Check if the shapes of the tensors match + if received.shape != expected.shape: + return False, ["SIZE MISMATCH"] + + # Calculate the difference between the tensors + diff = torch.abs(received.to(torch.float32) - expected.to(torch.float32)) + + # Determine the tolerance + tolerance = atol + rtol * torch.abs(expected) + + # Find tolerance mismatched elements + tol_mismatched = diff > tolerance + + # Find nan mismatched elements + nan_mismatched = torch.logical_xor(torch.isnan(received), torch.isnan(expected)) + + # Find +inf mismatched elements + posinf_mismatched = torch.logical_xor( + torch.isposinf(received), torch.isposinf(expected) + ) + # Find -inf mismatched elements + neginf_mismatched = torch.logical_xor( + torch.isneginf(received), torch.isneginf(expected) + ) + + # Find all mismatched elements + mismatched = torch.logical_or( + torch.logical_or(tol_mismatched, nan_mismatched), + torch.logical_or(posinf_mismatched, neginf_mismatched), + ) + + mismatched_indices = torch.nonzero(mismatched) + + # Count the number of mismatched elements + num_mismatched = mismatched.count_nonzero().item() + + # Generate detailed information if there are mismatches + if num_mismatched >= 1: + mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] + + for index in mismatched_indices[:max_print]: + i = tuple(index.tolist()) + mismatch_details.append(f"ERROR at {i}: {received[i]} {expected[i]}") + if num_mismatched > max_print: + mismatch_details.append( + f"... and {num_mismatched - max_print} more mismatched elements." + ) + return False, mismatch_details + + return True, [f"Maximum error: {torch.max(diff)}"] + + +@torch.no_grad() +def verbose_allequal( + received: torch.Tensor, expected: torch.Tensor, max_print: int = 5 +) -> Tuple[bool, list[str]]: + """ + Assert that two tensors are element-wise perfectly equal, providing detailed information about mismatches. + + Parameters: + received (torch.Tensor): Tensor we actually got. + expected (torch.Tensor): Tensor we expected to receive. + max_print (int): Maximum number of mismatched elements to print. + + Returns: + Empty string if tensors are equal, otherwise detailed error information + """ + mismatched = torch.not_equal(received, expected) + mismatched_indices = torch.nonzero(mismatched) + + # Count the number of mismatched elements + num_mismatched = mismatched.count_nonzero().item() + + # Generate detailed information if there are mismatches + if num_mismatched >= 1: + mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] + + for index in mismatched_indices[:max_print]: + i = tuple(index.tolist()) + mismatch_details.append(f"ERROR at {i}: {received[i]} {expected[i]}") + if num_mismatched > max_print: + mismatch_details.append( + f"... and {num_mismatched - max_print} more mismatched elements." + ) + return False, mismatch_details + + return True, [] + + +def match_reference(data, output, reference: callable, rtol=1e-05, atol=1e-08): + """ + Convenient "default" implementation for tasks' `check_implementation` function. + """ + expected = reference(data) + good, reasons = verbose_allclose(output, expected, rtol=rtol, atol=atol) + + if len(reasons) > 0: + return good, "\\n".join(reasons) + + return good, "" + + +def make_match_reference(reference: callable, **kwargs): + def wrapped(data, output): + return match_reference(data, output, reference=reference, **kwargs) + + return wrapped + + +class DisableCuDNNTF32: + def __init__(self): + self.allow_tf32 = torch.backends.cudnn.allow_tf32 + self.deterministic = torch.backends.cudnn.deterministic + pass + + def __enter__(self): + torch.backends.cudnn.allow_tf32 = False + torch.backends.cudnn.deterministic = True + return self + + def __exit__(self, exc_type, exc_value, traceback): + torch.backends.cudnn.allow_tf32 = self.allow_tf32 + torch.backends.cudnn.deterministic = self.deterministic + + +def clear_l2_cache(): + # import cupy as cp + # cp.cuda.runtime.deviceSetLimit(cp.cuda.runtime.cudaLimitPersistingL2CacheSize, 0) + # create a large dummy tensor + dummy = torch.empty((32, 1024, 1024), dtype=torch.int64, device="cuda") + # write stuff to + dummy.fill_(42) + del dummy diff --git a/dist-infer/compose.yml b/dist-infer/compose.yml new file mode 100644 index 0000000..639ccf1 --- /dev/null +++ b/dist-infer/compose.yml @@ -0,0 +1,11 @@ +name: default +services: + torch: + image: rocm/pytorch:rocm6.3.1_ubuntu22.04_py3.10_pytorch + volumes: + - .:/workspace + - .vscode-server:/root/.vscode-server + - .popcorn.yaml:/root/.popcorn.yaml + init: true + tty: true + command: sleep inf diff --git a/dist-infer/gemm-rs/.gitignore b/dist-infer/gemm-rs/.gitignore new file mode 100644 index 0000000..19ddece --- /dev/null +++ b/dist-infer/gemm-rs/.gitignore @@ -0,0 +1,12 @@ +/.cache +/.vscode +/build +/submission.py +/submit.sh +/logs +/venv +# profiling results +/workloads +*.csv + +__pycache__ \ No newline at end of file diff --git a/dist-infer/gemm-rs/CMakeLists.txt b/dist-infer/gemm-rs/CMakeLists.txt new file mode 100644 index 0000000..b48fce9 --- /dev/null +++ b/dist-infer/gemm-rs/CMakeLists.txt @@ -0,0 +1,36 @@ +cmake_minimum_required(VERSION 3.21) +cmake_policy(VERSION 3.21.3...3.27) +set(PROJECT_NAME "gemm_rs") +project(${PROJECT_NAME} LANGUAGES HIP CXX) +set(CMAKE_CXX_STANDARD 20) +find_package(Python3 REQUIRED COMPONENTS Development Interpreter) +find_package(Torch CONFIG REQUIRED) +find_package(HIP CONFIG REQUIRED) + +find_library(TORCH_PYTHON_LIBRARY torch_python PATH ${TORCH_INSTALL_PREFIX}/lib) +add_library(${PROJECT_NAME} SHARED src/gemm_rs.cc src/gemm_rs_kernel.cc) +set_source_files_properties(${PROJECT_NAME} PROPERTIES LANGUAGE HIP) +target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_PYTHON_LIBRARY} ${TORCH_LIBRARIES} hip::device Python3::Python) +target_compile_definitions(${PROJECT_NAME} PRIVATE -D__ENABLE_LOCAL_DEBUG__) +set_target_properties(${PROJECT_NAME} PROPERTIES PREFIX "") +target_compile_features(${PROJECT_NAME} PRIVATE cxx_std_20) + +# perf gemm +add_library(perf_gemm SHARED src/perf_gemm.cc) +set_source_files_properties(perf_gemm PROPERTIES LANGUAGE HIP) +target_link_libraries(perf_gemm PRIVATE ${TORCH_PYTHON_LIBRARY} ${TORCH_LIBRARIES} hip::device Python3::Python) +target_compile_definitions(perf_gemm PRIVATE -D__ENABLE_LOCAL_DEBUG__) +set_target_properties(perf_gemm PROPERTIES PREFIX "") +target_compile_features(perf_gemm PRIVATE cxx_std_20) +target_compile_options(perf_gemm PRIVATE "--save-temps") + +# for host compile +target_compile_definitions(${PROJECT_NAME} PRIVATE -D__${CMAKE_HIP_ARCHITECTURES}__) +# for debug +target_compile_options(${PROJECT_NAME} PRIVATE "--save-temps") + +# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -gline-tables-only") +# add_executable(test_basic_gemm src/test_basic_gemm.cc) +# target_link_libraries(test_basic_gemm PRIVATE ${TORCH_LIBRARIES} hip::device composable_kernel::device_other_operations composable_kernel::device_gemm_operations) +# add_executable(test_optimized_gemm src/test_optimized_gemm.cc) +# target_link_libraries(test_optimized_gemm PRIVATE ${TORCH_LIBRARIES} hip::device composable_kernel::device_other_operations composable_kernel::device_gemm_operations) diff --git a/dist-infer/gemm-rs/README.md b/dist-infer/gemm-rs/README.md new file mode 100644 index 0000000..e2fb39b --- /dev/null +++ b/dist-infer/gemm-rs/README.md @@ -0,0 +1,18 @@ +# AMD GEMM-Rs +## Build +```bash +cd /workspace/gemm-rs +source /opt/conda/bin/activate py_3.10 +export PATH="/opt/ompi/bin:/opt/ucx/bin:/opt/cache/bin:/opt/rocm/llvm/bin:/opt/rocm/opencl/bin:/opt/rocm/hip/bin:/opt/rocm/hcc/bin:/opt/rocm/bin:/opt/conda/envs/py_3.10/bin:/opt/conda/bin:$PATH" +export CMAKE_HIP_ARCHITECTURES=gfx942 +export PYTORCH_ROCM_ARCH=gfx942 + +# cmake -B build -S . -G Ninja -DTorch_DIR=/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/share/cmake/Torch/ -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DCMAKE_HIP_ARCHITECTURES=gfx942 -DAMDGPU_TARGETS=gfx942 -DCMAKE_BUILD_TYPE=Release +cmake -B build -S . -G Ninja -DTorch_DIR=/usr/local/lib/python3.12/dist-packages/torch/share/cmake/Torch/ -DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DCMAKE_HIP_ARCHITECTURES=gfx942 -DAMDGPU_TARGETS=gfx942 -DCMAKE_BUILD_TYPE=Release +cmake --build build +# Local Test (single node) +python benchmark_gemm.py +python benchmark_rs.py +# Local Test (multi nodes) +python local_test.py +``` diff --git a/dist-infer/gemm-rs/benchmark_gemm.py b/dist-infer/gemm-rs/benchmark_gemm.py new file mode 100644 index 0000000..4cbd89d --- /dev/null +++ b/dist-infer/gemm-rs/benchmark_gemm.py @@ -0,0 +1,121 @@ +import sys +import os +sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), 'build')) + +import perf_gemm + +import torch +import time + + +test_configs = [ + {"M": 8192, "N": 8192, "K": 3696, "name": "8192x8192x3696"}, #warmup GPU + {"M": 8192, "N": 3696, "K": 8192, "name": "8192x3696x8192"}, + + {"M": 64, "N": 7168, "K": 2304, "name": "64x7168x2304"}, + {"M": 512, "N": 4096, "K": 1536, "name": "512x4096x1536"}, + {"M": 2048, "N": 2880, "K": 360, "name": "2048x2880x360"}, + {"M": 8192, "N": 4096, "K": 1792, "name": "8192x4096x1792"}, + {"M": 8192, "N": 8192, "K": 3696, "name": "8192x8192x3696"}, + + {"M": 64, "N": 2304, "K": 7168, "name": "64x2304x7168"}, + {"M": 512, "N": 1536, "K": 4096, "name": "512x1536x4096"}, + {"M": 2048, "N": 360, "K": 2880, "name": "2048x360x2880"}, + {"M": 8192, "N": 1792, "K": 4096, "name": "8192x1792x4096"}, + {"M": 8192, "N": 3696, "K": 8192, "name": "8192x3696x8192"}, + + + +] + + +for config in test_configs: + M, N, K = config["M"], config["N"], config["K"] + + + mem_used = [] + def generate_data(): + torch.manual_seed(42) + x = torch.rand((M, K), dtype=torch.bfloat16, device='cuda') * 2 - 1 + x *= 0.01 + w = torch.rand((N, K), dtype=torch.bfloat16, device='cuda') * 2 - 1 + w *= 0.01 + b = torch.randn((N, ), dtype=torch.bfloat16, device='cuda') * 2 - 1 + b *= 0.01 + # b = torch.zeros((N, ), dtype=torch.bfloat16, device='cuda') + mem_used.append([x, w, b]) + return x, w, b + + x, w, b = generate_data() + + # torch.set_printoptions(threshold=torch.inf, linewidth=1000000000) + def ref_fn(x, w, b): + out = torch.matmul(x, w.T) + # mem_used.append(out) + return out + + def our_fn(x, w, b): + dummy_signal = torch.empty(x.shape[0], w.shape[0], dtype=torch.int32, device='cuda') + out = perf_gemm.launch_gemm(x, w, b, dummy_signal, 0) + # t = perf_gemm.__debug_get_workspace_tensor(M, N, 2) + # torch.cuda.synchronize() + # mem_used.append(out) + # print(out.data_ptr()) + # torch.cuda.synchronize() + # print(out.float()) + # print(t.sum(dim=0)) + # torch.set_printoptions(threshold=1000000000, linewidth=1000000000) + # print(torch.stack([t.sum(dim=0)[:100, 1], out.float()[:100, 1]])) + # print() + # torch.testing.assert_close(t.sum(dim=0), out.float(), atol=1e-1, rtol=1e-1) + # print("reduce correctness check") + # # return out + return out + + def clear_all_cache(): + z = torch.zeros(256 * 1024 * 1024, dtype=torch.uint8, device='cuda') + z.fill_(42) + mem_used.append(z) + + + def benchmark(name: str, fn: callable, repeat=100): + records = [] + for _ in range(repeat): + event_start = torch.cuda.Event(enable_timing=True) + event_end = torch.cuda.Event(enable_timing=True) + x, w, b = generate_data() + clear_all_cache() + torch.cuda.synchronize() + event_start.record() + fn(x, w, b) + event_end.record() + torch.cuda.synchronize() + records.append(event_start.elapsed_time(event_end) * 1e3) + tot = sum(records[2:]) / len(records[2:]) # skip first two warmup + + tflops = 2 * M * N * K / tot / 1e6 # Corrected TFLOPS calculation + print(f"{name}: {tot:.2f} μs {tflops:.2f} TFlops, [{', '.join('%.2f' % t for t in records[2:5])}]") + + + + print(f"==== {config['name']} ====") + + ref_t = ref_fn(x, w, b) + our_t = our_fn(x, w, b) + + + + # correctness + try: + torch.testing.assert_close(ref_t + b, our_t, atol=1e-1, rtol=1e-2) + except AssertionError as e: + print(our_t) + print(ref_t) + print("Error:", e) + + # performace + benchmark('ref_fn', ref_fn) + benchmark('our_fn', our_fn) + benchmark('ref_fn', ref_fn) + benchmark('our_fn', our_fn) + print("=" * 20, end='\n\n') diff --git a/dist-infer/gemm-rs/benchmark_rs.py b/dist-infer/gemm-rs/benchmark_rs.py new file mode 100644 index 0000000..afccb4d --- /dev/null +++ b/dist-infer/gemm-rs/benchmark_rs.py @@ -0,0 +1,106 @@ +import sys +import os +sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), 'build')) + +import perf_gemm + +import torch +import time + + +WORLD_SIZE = 8 + +test_configs = [ + # {"M": 64, "N": 7168, "name": "64x7168"}, + # {"M": 512, "N": 4096, "name": "512x4096"}, + # {"M": 2048, "N": 2880, "name": "2048x2880"}, + # {"M": 4096, "N": 4096, "name": "4096x4096"}, + {"M": 8192, "N": 4096, "name": "8192x4096"}, + # {"M": 8192, "N": 8192, "name": "8192x8192"}, +] + + +for config in test_configs: + M, N = config["M"], config["N"] + assert M % WORLD_SIZE == 0, f"M={M} must be divisible by WORLD_SIZE={WORLD_SIZE}" + + mem_used = [] + + def generate_data(): + c_list = [] + for i in range(WORLD_SIZE): + c = torch.randn((M, N), dtype=torch.bfloat16, device='cuda') * 2 - 1 + c_list.append(c) + mem_used.append(c) + return c_list + + def ref_fn(c_list, rank): + # Reference implementation: sum all tensors and extract the rank's slice + stacked = torch.stack(c_list, dim=0) # [WORLD_SIZE, M, N] + summed = torch.sum(stacked, dim=0) # [M, N] + M_per_rank = M // WORLD_SIZE + out = summed[rank * M_per_rank : (rank + 1) * M_per_rank, :] + mem_used.append(out) + return out + + def our_fn(c_list, rank): + dummy_signal = [torch.zeros(c_list[0].shape[0], c_list[0].shape[1], dtype=torch.int32, device='cuda')] * WORLD_SIZE + out = perf_gemm.launch_reduce_scatter(c_list, dummy_signal, 0, rank) + mem_used.append(out) + return out + + def clear_all_cache(): + z = torch.zeros(256 * 1024 * 1024, dtype=torch.uint8, device='cuda') + z.fill_(42) + mem_used.append(z) + + def benchmark(name: str, fn: callable, rank: int, repeat=5): + records = [] + for _ in range(repeat): + event_start = torch.cuda.Event(enable_timing=True) + event_end = torch.cuda.Event(enable_timing=True) + c_list = generate_data() + clear_all_cache() + torch.cuda.synchronize() + event_start.record() + fn(c_list, rank) + event_end.record() + torch.cuda.synchronize() + records.append(event_start.elapsed_time(event_end) * 1e3) + + tot = sum(records[2:]) / len(records[2:]) # skip first two warmup + + # Calculate bandwidth + # Read: WORLD_SIZE tensors of size M * N * 2 bytes (bfloat16) + # Write: 1 tensor of size (M / WORLD_SIZE) * N * 2 bytes + bytes_read = WORLD_SIZE * M * N * 2 + bytes_write = (M // WORLD_SIZE) * N * 2 + total_bytes = bytes_read + bytes_write + bandwidth_gb_s = total_bytes / tot / 1e3 # Convert to GB/s (μs -> s, bytes -> GB) + + print(f"{name}: {tot:.2f} μs {bandwidth_gb_s:.2f} GB/s, [{', '.join('%.2f' % t for t in records)}]") + + # Test for rank 0 + rank = 0 + c_list = generate_data() + + ref_t = ref_fn(c_list, rank) + our_t = our_fn(c_list, rank) + + print(f"==== {config['name']} (rank={rank}) ====") + + # Correctness check + try: + torch.testing.assert_close(ref_t, our_t, atol=1e-2, rtol=1e-2) + print("✓ Correctness check passed") + except AssertionError as e: + print(ref_t) + print(our_t) + print("✗ Correctness check failed:", e) + + # Performance benchmark + benchmark('ref_fn', ref_fn, rank) + benchmark('our_fn', our_fn, rank) + benchmark('ref_fn', ref_fn, rank) + benchmark('our_fn', our_fn, rank) + print("=" * 20, end='\n\n') diff --git a/dist-infer/gemm-rs/gen_submission.py b/dist-infer/gemm-rs/gen_submission.py new file mode 100644 index 0000000..2f68fae --- /dev/null +++ b/dist-infer/gemm-rs/gen_submission.py @@ -0,0 +1,32 @@ +import zlib +import base64 + +# 1. Read source files +with open('src/perf_gemm.cc', 'rb') as f: + kernel_cc = f.read() + +# 2. Concatenate C++ sources +kernel_cc = kernel_cc.replace(b'#include "gemm_rs_kernel.h"\n', b'') + +# 3. Compress and encode sources +encoded_cpp = base64.b64encode(zlib.compress(b'', level=9)) +encoded_cuda = base64.b64encode(zlib.compress(kernel_cc, level=9)) + +# 4. Read the template.py template +with open('template.py', 'r') as f: + submission_template = f.read() + +# 5. Format the final submission file content +submission_content = submission_template.replace( + 'CPP_WRAPPER = ""', + f'CPP_WRAPPER = zlib.decompress(base64.b64decode({encoded_cpp!r})).decode("utf-8")' +).replace( + 'CUDA_SRC = ""', + f'CUDA_SRC = zlib.decompress(base64.b64decode({encoded_cuda!r})).decode("utf-8")' +) + +# 6. Write the new submission.py +with open('submission.py', 'w') as f: + f.write(submission_content) + +print("submission.py has been generated successfully.") diff --git a/dist-infer/gemm-rs/local_test.py b/dist-infer/gemm-rs/local_test.py new file mode 100644 index 0000000..efa51af --- /dev/null +++ b/dist-infer/gemm-rs/local_test.py @@ -0,0 +1,344 @@ +import torch +import os +import sys +import numpy as np +import argparse +import statistics + +# Ensure the build directory is in the path +sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), 'build')) + +import gemm_rs + +# Test configurations based on the provided shapes from test_perf.py +test_configs = [ + {"M": 64, "N": 7168, "K": 2304, "bias": True, "name": "64x7168x2304"}, + {"M": 512, "N": 4096, "K": 1536, "bias": True, "name": "512x4096x1536"}, + {"M": 2048, "N": 2880, "K": 360, "bias": True, "name": "2048x2880x360"}, + {"M": 4096, "N": 4096, "K": 512, "bias": True, "name": "4096x4096x512"}, + {"M": 8192, "N": 4096, "K": 1792, "bias": True, "name": "8192x4096x1792"}, + {"M": 8192, "N": 8192, "K": 3696, "bias": True, "name": "8192x8192x3696"}, +] + +WORLD_SIZE = 8 + +# Initialize GemmRS object +rs = gemm_rs.GemmRS(0, WORLD_SIZE) + +def print_first_20_errors(expected, actual, test_name, rtol=1e-2, atol=1e-2): + """Print the first 20 error values when tensors don't match""" + expected_flat = expected.flatten() + actual_flat = actual.flatten() + + # Calculate tolerances + tolerance = atol + rtol * torch.abs(expected_flat) + diff = torch.abs(expected_flat - actual_flat) + error_mask = diff > tolerance + + error_indices = torch.nonzero(error_mask, as_tuple=False).flatten() + + if len(error_indices) > 0: + print(f" 📝 {test_name} Error Details:") + print(f" Total errors: {len(error_indices)} out of {len(expected_flat)} elements") + print(f" Error rate: {len(error_indices) / len(expected_flat) * 100:.2f}%") + print(f" 📋 First 20 errors:") + print(f" {'Index':<8} {'Expected':<12} {'Actual':<12} {'Diff':<12} {'RelErr%':<10}") + print(f" {'-'*8} {'-'*12} {'-'*12} {'-'*12} {'-'*10}") + + for i, idx in enumerate(error_indices[:20]): + idx = idx.item() + exp_val = expected_flat[idx].item() + act_val = actual_flat[idx].item() + diff_val = abs(exp_val - act_val) + rel_err = (diff_val / abs(exp_val)) * 100 if abs(exp_val) > 1e-10 else float('inf') + + print(f" {idx:<8} {exp_val:<12.6f} {act_val:<12.6f} {diff_val:<12.6f} {rel_err:<10.2f}") + +def elapsed_time(func: callable, repeat=100): + """Measures the elapsed time of a function.""" + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + start.record() + for _ in range(repeat): + func() + end.record() + + torch.cuda.synchronize() + return start.elapsed_time(end) / repeat + +def test_gemm(x, w, b): + """Tests the GEMM operation.""" + return rs.test_gemm(x, w, b) + +def test_rs(all_ranks, target_rank): + """Tests the RS (Reduce-Scatter) operation.""" + return rs.test_rs(all_ranks, target_rank) + +def test_gemm_rs(x, w, b, all_inputs, world_size, fake_rank): + """Tests the fused GEMM_RS operation.""" + return rs.test_gemm_rs(x, w, b, all_inputs, world_size, fake_rank) + +def main(): + parser = argparse.ArgumentParser(description="Run performance tests for GEMM, RS, and GEMM_RS.") + parser.add_argument('--gemm', action='store_true', help='Run only the GEMM test') + parser.add_argument('--rs', action='store_true', help='Run only the RS test') + parser.add_argument('--gemm-rs', action='store_true', help='Run only the GEMM_RS test') + args = parser.parse_args() + + # If no specific test is selected, run all of them + run_all = not (args.gemm or args.rs or args.gemm_rs) + + # Store execution times for geometric mean calculation + times_rocblas_gemm = [] + times_our_gemm = [] + times_manual_rs = [] + times_our_rs = [] + times_baseline_gemm_rs = [] + times_our_gemm_rs = [] + + # Track correctness statistics + total_tests = 0 + gemm_matches = 0 + rs_matches = 0 + gemm_rs_matches = 0 + gemm_total = 0 + rs_total = 0 + gemm_rs_total = 0 + + # Main testing loop + for config in test_configs: + M, N, K = config["M"], config["N"], config["K"] + M_per_rank = M // WORLD_SIZE + print(f"\n🧪 Testing {config['name']} (M={M}, N={N}, K={K})") + + # Create generator for reproducible results + gen = torch.Generator(device="cuda") + gen.manual_seed(42) + + # Generate input data + x = (torch.rand((M, K), dtype=torch.bfloat16, device="cuda", generator=gen) * 2 - 1) * 0.01 + w = (torch.rand((N, K), dtype=torch.bfloat16, device="cuda", generator=gen) * 2 - 1) * 0.01 + b = (torch.rand((N,), dtype=torch.bfloat16, device="cuda", generator=gen) * 2 - 1) * 0.01 if config["bias"] else None + + if run_all or args.gemm: + # --- rocBLAS Baseline (GEMM) --- + rocblas_result = torch.matmul(x, w.T) + if b is not None: + rocblas_result += b + + # --- Our GEMM Kernel --- + try: + # First check correctness + our_gemm_result = test_gemm(x, w, b) + + # Verify results before timing + try: + torch.testing.assert_close(our_gemm_result, rocblas_result, rtol=1e-2, atol=1e-2) + print(f" ✅ GEMM Results match (within tolerance)") + gemm_matches += 1 + + # Only run performance tests if correctness check passes + func_rocblas_gemm = lambda: torch.matmul(x, w.T) + elapsed_time(func_rocblas_gemm) # Warmup + t_rocblas_gemm = elapsed_time(func_rocblas_gemm) + times_rocblas_gemm.append(t_rocblas_gemm) + rocblas_gemm_tflops = 2 * M * N * K / t_rocblas_gemm / 1e9 + print(f" - rocBLAS GEMM: {rocblas_gemm_tflops:.2f} TFLOPS ({t_rocblas_gemm * 1000:.1f} μs)") + + func_gemm = lambda: test_gemm(x, w, b) + elapsed_time(func_gemm) # Warmup + t_gemm = elapsed_time(func_gemm) + times_our_gemm.append(t_gemm) + gemm_tflops = 2 * M * N * K / t_gemm / 1e9 + gemm_speedup = t_rocblas_gemm / t_gemm + print(f" - Our GEMM: {gemm_tflops:.2f} TFLOPS ({t_gemm * 1000:.1f} μs) (Speedup: {gemm_speedup:.2f}x)") + + except AssertionError as e: + print(f" ❌ GEMM Results do not match: {str(e)}") + print_first_20_errors(rocblas_result, our_gemm_result, "GEMM") + gemm_total += 1 + except Exception as e: + print(f" - Our GEMM: FAILED - {str(e)}") + + if run_all or args.rs: + torch.manual_seed(42) + # --- RS (Reduce-Scatter) Test --- + # Generate input for RS: simulate WORLD_SIZE ranks each with M_per_rank rows + M_per_rank = M // WORLD_SIZE + all_ranks = [ + torch.rand(M, N, dtype=torch.bfloat16, device="cuda", generator=gen) + for _ in range(WORLD_SIZE) + ] + + # Baseline for RS: manual reduce-scatter for rank 0 + def manual_rs(): + expected = torch.zeros((M_per_rank, N), dtype=torch.float32, device='cuda') + for rank in range(WORLD_SIZE): + expected += all_ranks[rank][M_per_rank * 0: M_per_rank * (0 + 1), :].float() + return expected.to(torch.bfloat16) + + manual_rs_result = manual_rs() + + # Our RS kernel (test for rank 0) + try: + # First check correctness + our_rs_result = test_rs(all_ranks, 0) + + # Verify results before timing + try: + torch.testing.assert_close(our_rs_result, manual_rs_result, rtol=1e-2, atol=1e-2) + print(f" ✅ RS Results match (within tolerance)") + rs_matches += 1 + + # Only run performance tests if correctness check passes + func_manual_rs = manual_rs + elapsed_time(func_manual_rs) # Warmup + t_manual_rs = elapsed_time(func_manual_rs) + times_manual_rs.append(t_manual_rs) + # TFLOPS for RS is based on (WORLD_SIZE - 1) additions for each element in the slice + rs_tflops_manual = (WORLD_SIZE - 1) * M_per_rank * N / t_manual_rs / 1e9 + print(f" - Manual RS: {rs_tflops_manual:.2f} TFLOPS ({t_manual_rs * 1000:.1f} μs)") + + func_rs = lambda: test_rs(all_ranks, 0) + elapsed_time(func_rs) # Warmup + t_rs = elapsed_time(func_rs) + times_our_rs.append(t_rs) + rs_tflops = (WORLD_SIZE - 1) * M_per_rank * N / t_rs / 1e9 + rs_speedup = t_manual_rs / t_rs + print(f" - Our RS: {rs_tflops:.2f} TFLOPS ({t_rs * 1000:.1f} μs) (Speedup: {rs_speedup:.2f}x)") + + except AssertionError as e: + print(f" ❌ RS Results do not match: {str(e)}") + print_first_20_errors(manual_rs_result, our_rs_result, "RS") + rs_total += 1 + except Exception as e: + print(f" - Our RS: FAILED - {str(e)}") + + if run_all or args.gemm_rs: + # --- Fused GEMM_RS Test --- + # GEMM_RS does GEMM first, then RS on the result + # Generate input for GEMM_RS: simulate WORLD_SIZE ranks each contributing to GEMM + + # Create input matrices for each rank (simulating distributed GEMM inputs) + torch.manual_seed(42) + gen_gemm_rs = torch.Generator(device="cuda") + gen_gemm_rs.manual_seed(42) + + # Generate GEMM inputs for all ranks with same range as main x, w, b + all_inputs = [] + for rank in range(WORLD_SIZE): + x_rank = (torch.rand((M, K), dtype=torch.bfloat16, device="cuda", generator=gen_gemm_rs) * 2 - 1) * 0.01 + w_rank = (torch.rand((N, K), dtype=torch.bfloat16, device="cuda", generator=gen_gemm_rs) * 2 - 1) * 0.01 + b_rank = (torch.rand((N,), dtype=torch.bfloat16, device="cuda", generator=gen_gemm_rs) * 2 - 1) * 0.01 if config["bias"] else None + + # Compute GEMM + BIAS for this rank + gemm_result = torch.matmul(x_rank, w_rank.T) + if b_rank is not None: + gemm_result += b_rank + all_inputs.append(gemm_result) + + x = (torch.rand((M,K), dtype=torch.bfloat16, device="cuda", generator=gen_gemm_rs) * 2 - 1) * 0.01 + w = (torch.rand((N,K), dtype=torch.bfloat16, device="cuda", generator=gen_gemm_rs) * 2 - 1) * 0.01 + b = (torch.rand((N,), dtype=torch.bfloat16, device="cuda", generator=gen_gemm_rs) * 2 - 1) * 0.01 if config["bias"] else None + # Create GEMM results from all ranks for the RS operation + + # Baseline: Manual GEMM + RS for rank 0 + def manual_gemm_rs(): + # The baseline is just doing RS on the pre-computed GEMM results for rank 0's slice + M_per_rank = M // WORLD_SIZE + start_idx = M_per_rank * 0 # rank 0 + end_idx = M_per_rank * (0 + 1) + + rs_result = torch.zeros((M_per_rank, N), dtype=torch.float32, device='cuda') + for rank in range(WORLD_SIZE): + if rank == 0: # rank 0 + res = torch.matmul(x, w.T)[start_idx:end_idx] + if b is not None: + res += b + rs_result += res.float() + else: + rs_result += all_inputs[rank][start_idx:end_idx, :].float() + return rs_result.to(torch.bfloat16) + baseline_gemm_rs_result = manual_gemm_rs() + + torch.cuda.synchronize() + try: + # First check correctness + our_gemm_rs_result = test_gemm_rs(x, w, b, all_inputs, WORLD_SIZE, 0) + # Verify results before timing + try: + torch.testing.assert_close(our_gemm_rs_result, baseline_gemm_rs_result, rtol=1e-2, atol=1e-2) + print(f" ✅ GEMM_RS Results match (within tolerance)") + gemm_rs_matches += 1 + + # Only run performance tests if correctness check passes + func_baseline_gemm_rs = manual_gemm_rs + elapsed_time(func_baseline_gemm_rs) # Warmup + t_baseline_gemm_rs = elapsed_time(func_baseline_gemm_rs) + times_baseline_gemm_rs.append(t_baseline_gemm_rs) + + # TFLOPS calculation: GEMM operations + RS operations + gemm_ops = WORLD_SIZE * 2 * M * N * K # GEMM for all ranks + rs_ops = (WORLD_SIZE - 1) * (M // WORLD_SIZE) * N # RS operations + total_ops = gemm_ops + rs_ops + gemm_rs_tflops_baseline = total_ops / t_baseline_gemm_rs / 1e9 + print(f" - Baseline GEMM+RS: {gemm_rs_tflops_baseline:.2f} TFLOPS ({t_baseline_gemm_rs * 1000:.1f} μs)") + + # Reset for our kernel test + torch.cuda.synchronize() + + func_gemm_rs = lambda: test_gemm_rs(x, w, b, all_inputs, WORLD_SIZE, 0) + elapsed_time(func_gemm_rs) # Warmup + t_gemm_rs = elapsed_time(func_gemm_rs) + times_our_gemm_rs.append(t_gemm_rs) + gemm_rs_tflops = total_ops / t_gemm_rs / 1e9 + gemm_rs_speedup = t_baseline_gemm_rs / t_gemm_rs + print(f" - Our GEMM_RS: {gemm_rs_tflops:.2f} TFLOPS ({t_gemm_rs * 1000:.1f} μs) (Speedup: {gemm_rs_speedup:.2f}x)") + + except AssertionError as e: + print(f" ❌ GEMM_RS Results do not match: {str(e)}") + print_first_20_errors(baseline_gemm_rs_result, our_gemm_rs_result, "GEMM_RS") + gemm_rs_total += 1 + except Exception as e: + print(f" - Our GEMM_RS: FAILED - {str(e)}") + + print(f"\n{'='*60}") + print("📊 All tests completed.") + + # Display correctness statistics + print(f"\n🎯 Correctness Statistics:") + if gemm_total > 0: + print(f" - GEMM matches: {gemm_matches}/{gemm_total}") + if rs_total > 0: + print(f" - RS matches: {rs_matches}/{rs_total}") + if gemm_rs_total > 0: + print(f" - GEMM_RS matches: {gemm_rs_matches}/{gemm_rs_total}") + + # Calculate and display geometric means + print("\n📈 Geometric Mean of Execution Times:") + if times_rocblas_gemm: + geom_mean_rocblas = statistics.geometric_mean(times_rocblas_gemm) + print(f" - rocBLAS GEMM: {geom_mean_rocblas * 1000:.1f} μs") + if times_our_gemm: + geom_mean_our_gemm = statistics.geometric_mean(times_our_gemm) + print(f" - Our GEMM: {geom_mean_our_gemm * 1000:.1f} μs") + if times_manual_rs: + geom_mean_manual_rs = statistics.geometric_mean(times_manual_rs) + print(f" - Manual RS: {geom_mean_manual_rs * 1000:.1f} μs") + if times_our_rs: + geom_mean_our_rs = statistics.geometric_mean(times_our_rs) + print(f" - Our RS: {geom_mean_our_rs * 1000:.1f} μs") + if times_baseline_gemm_rs: + geom_mean_baseline = statistics.geometric_mean(times_baseline_gemm_rs) + print(f" - Baseline GEMM+RS: {geom_mean_baseline * 1000:.1f} μs") + if times_our_gemm_rs: + geom_mean_our_gemm_rs = statistics.geometric_mean(times_our_gemm_rs) + print(f" - Our GEMM_RS: {geom_mean_our_gemm_rs * 1000:.1f} μs") + + print(f"{'='*60}") + +if __name__ == "__main__": + main() + main() \ No newline at end of file diff --git a/dist-infer/gemm-rs/requirements.txt b/dist-infer/gemm-rs/requirements.txt new file mode 100644 index 0000000..0ff6ac7 --- /dev/null +++ b/dist-infer/gemm-rs/requirements.txt @@ -0,0 +1,12 @@ +astunparse==1.6.2 +colorlover +dash>=1.12.0 +matplotlib +pandas>=1.4.3 +pymongo +tabulate +tqdm +dash-svg +dash-bootstrap-components +kaleido==0.2.1 +plotille \ No newline at end of file diff --git a/dist-infer/gemm-rs/src/common.h b/dist-infer/gemm-rs/src/common.h new file mode 100644 index 0000000..dea6616 --- /dev/null +++ b/dist-infer/gemm-rs/src/common.h @@ -0,0 +1,55 @@ +#pragma once + + +#define FORCE_INLINE __attribute__((always_inline)) + +namespace perf_gemm { + +constexpr int NUM_XCDS = 8; +constexpr int AMDGCN_WAVEFRONT_SIZE = 64; + +using bfloat16_t = __bf16; + +template +struct PackN_t { + using t = __attribute__((vector_size(N * sizeof(dtype)))) dtype; + static constexpr auto n = N; + static constexpr auto H = N / 2; + union { + dtype x[N]; + t pack; + struct { dtype low[H], high[H]; }; + }; +}; + +using bf16x4_t = PackN_t; +using fp32x4_t = PackN_t; + + +__device__ FORCE_INLINE constexpr bfloat16_t fast_f32tob16(float f) { +#ifdef FAST_UNSAFE_CAST + union { + float fp32; + unsigned int u32; + } u = {f}; + u.u32 += 0x7FFF + ((u.u32 >> 16) & 1); + auto ret = u.u32 >> 16; + return reinterpret_cast(ret); +#else + return static_cast(f); +#endif +} + + +__device__ __host__ FORCE_INLINE constexpr int ceil_div(int a, int b) { + return (a + b - 1) / b; +} + +template +__device__ __host__ FORCE_INLINE constexpr int exact_div() { + static_assert(a % b == 0); + return a / b; +} + + +} // namespace gemm_rs \ No newline at end of file diff --git a/dist-infer/gemm-rs/src/gemm_rs.cc b/dist-infer/gemm-rs/src/gemm_rs.cc new file mode 100644 index 0000000..5e2fe80 --- /dev/null +++ b/dist-infer/gemm-rs/src/gemm_rs.cc @@ -0,0 +1,130 @@ +#include "gemm_rs_kernel.h" +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace gemm_rs { + + +class GemmRS { +private: + int rank_; + int world_size_; + void *ipc_mems_[MAX_WORLD_SIZE]; + void *sig_buf_[MAX_WORLD_SIZE]; + + void check_device() { + int device; + C10_HIP_CHECK(hipGetDevice(&device)); + TORCH_CHECK(device == rank_); + } + +public: + GemmRS(int rank, int world_size): rank_(rank), world_size_(world_size) { + // C10_HIP_CHECK(hipExtMallocWithFlags(&ipc_mems_[rank_], MAX_IPC_MEM_SIZE, hipDeviceMallocUncached)); + C10_HIP_CHECK(hipMalloc(&ipc_mems_[rank_], MAX_IPC_MEM_SIZE)); + C10_HIP_CHECK(hipMemset(ipc_mems_[rank_], 0, MAX_IPC_MEM_SIZE)); + sig_buf_[rank_] = reinterpret_cast(ipc_mems_[rank_]) + MAX_IPC_MEM_SIZE - SIGNAL_BUF_SIZE; + } + ~GemmRS() {} + + pybind11::bytearray get_ipc_handle() { + check_device(); + hipIpcMemHandle_t ipc_handle; + C10_HIP_CHECK(hipIpcGetMemHandle(&ipc_handle, ipc_mems_[rank_])); + return {ipc_handle.reserved, HIP_IPC_HANDLE_SIZE}; + } + + auto init_dist(const std::vector &ipc_handles) { + int world_size = ipc_handles.size(); + for (int i = 0; i < world_size; i++) { + if (i == rank_) continue; + hipIpcMemHandle_t handle; + auto handle_buf = std::string(ipc_handles[i]); + TORCH_CHECK(handle_buf.size() == HIP_IPC_HANDLE_SIZE); + std::memcpy(handle.reserved, handle_buf.data(), HIP_IPC_HANDLE_SIZE); + C10_HIP_CHECK(hipIpcOpenMemHandle(&ipc_mems_[i], handle, hipIpcMemLazyEnablePeerAccess)); + sig_buf_[i] = reinterpret_cast(ipc_mems_[i]) + MAX_IPC_MEM_SIZE - SIGNAL_BUF_SIZE; // last for signal + + } + + } + + + torch::Tensor gemm_rs(const torch::Tensor& input, const torch::Tensor& weight, const c10::optional& bias) { + int M = input.size(0); + int K = input.size(1); + int N = weight.size(0); + TORCH_CHECK(K == weight.size(1), "Incompatible GEMM size"); + auto stream = at::cuda::getCurrentHIPStream(); + auto out = torch::empty({M / world_size_, N}, input.options()); + TORCH_CHECK(M * N * sizeof(at::BFloat16) + SIGNAL_BUF_SIZE <= MAX_IPC_MEM_SIZE, "Input size exceeds MAX_IPC_MEM_SIZE"); + launch_gemm_rs_dist(input.const_data_ptr(), weight.const_data_ptr(), bias ? bias->const_data_ptr() : nullptr, ipc_mems_, sig_buf_, out.mutable_data_ptr(), rank_, world_size_, M, N, K, stream); + return out; + } + + torch::Tensor test_gemm(const torch::Tensor& input, const torch::Tensor& weight, const torch::Tensor &bias) { + int M = input.size(0); + int K = input.size(1); + int N = weight.size(0); + TORCH_CHECK(K == weight.size(1), "Incompatible GEMM size"); + TORCH_CHECK(M * N * sizeof(at::BFloat16) + SIGNAL_BUF_SIZE <= MAX_IPC_MEM_SIZE, "Input size exceeds MAX_IPC_MEM_SIZE"); + auto stream = at::cuda::getCurrentHIPStream(); + // C10_HIP_CHECK(hipMemsetAsync(ipc_mems_[rank_], 0, M * N * sizeof(at::BFloat16), stream)); + launch_gemm(input.const_data_ptr(), weight.const_data_ptr(), bias.const_data_ptr(), ipc_mems_[rank_], M, N, K, stream); + auto t = torch::from_blob(ipc_mems_[rank_], {M, N}, input.options().dtype(at::kBFloat16)); + // C10_HIP_CHECK(hipStreamSynchronize(stream)); + return t; + } + + torch::Tensor test_rs(const std::vector& inputs, int fake_rank) { + int M = inputs[0].size(0); + int N = inputs[0].size(1); + auto stream = at::cuda::getCurrentHIPStream(); + TORCH_CHECK(inputs.size() <= MAX_WORLD_SIZE, "inputs size exceeds MAX_WORLD_SIZE"); + std::vector input_ptrs; + for (const auto& t : inputs) { + TORCH_CHECK(t.size(0) == M && t.size(1) == N, "All inputs must have the same shape"); + input_ptrs.push_back(t.const_data_ptr()); + } + int world_size = inputs.size(); + auto out = torch::empty({M / world_size, N}, inputs[0].options()); + launch_rs(input_ptrs.data(), out.mutable_data_ptr(), fake_rank, world_size, M, N, stream); + return out; + } + + torch::Tensor test_gemm_rs(const torch::Tensor& input, const torch::Tensor& weight, std::optional bias, const std::vector& all_inputs, int world_size, int fake_rank) { + int M = input.size(0); + int K = input.size(1); + int N = weight.size(0); + TORCH_CHECK(K == weight.size(1), "Incompatible GEMM size"); + auto stream = at::cuda::getCurrentHIPStream(); + auto out = torch::empty({M / world_size, N}, input.options()); + std::vector rs_bufs(world_size); + for (int i = 0; i < world_size; i++) { + rs_bufs[i] = all_inputs[i].mutable_data_ptr(); + } + TORCH_CHECK(M * N * sizeof(at::BFloat16) + SIGNAL_BUF_SIZE <= MAX_IPC_MEM_SIZE, "Input size exceeds MAX_IPC_MEM_SIZE"); + launch_gemm_rs(input.const_data_ptr(), weight.const_data_ptr(), bias ? bias->const_data_ptr() : nullptr, rs_bufs.data(), sig_buf_[rank_], out.mutable_data_ptr(), fake_rank, world_size, M, N, K, stream); + return out; + } +}; + +} // namespace gemm_rs + +PYBIND11_MODULE(gemm_rs, m) { + pybind11::class_(m, "GemmRS") + .def(pybind11::init(), py::arg("rank"), py::arg("world_size")) + .def("get_ipc_handle", &gemm_rs::GemmRS::get_ipc_handle) + .def("init_dist", &gemm_rs::GemmRS::init_dist) + .def("gemm_rs", &gemm_rs::GemmRS::gemm_rs) + .def("test_gemm", &gemm_rs::GemmRS::test_gemm) + .def("test_rs", &gemm_rs::GemmRS::test_rs) + .def("test_gemm_rs", &gemm_rs::GemmRS::test_gemm_rs); +} \ No newline at end of file diff --git a/dist-infer/gemm-rs/src/gemm_rs_kernel.cc b/dist-infer/gemm-rs/src/gemm_rs_kernel.cc new file mode 100644 index 0000000..dfe92f3 --- /dev/null +++ b/dist-infer/gemm-rs/src/gemm_rs_kernel.cc @@ -0,0 +1,648 @@ +#include "gemm_rs_kernel.h" +#include +#include +#include +#include +#include +#include +#include +#define FAST_UNSAFE_CAST +// #define SWIZZLE_XCD_PID +// #define SWIZZLE_L2_TILE +// #define FORCE_LOAD_BIAS +#define FORCE_INLINE __attribute__((always_inline)) + + +namespace gemm_rs { + + +constexpr int NUM_XCDS = 8; +constexpr int AMDGCN_WAVEFRONT_SIZE = 64; + +using bfloat16_t = __bf16; + +template +struct PackN_t { + using t = __attribute__((vector_size(N * sizeof(dtype)))) dtype; + static constexpr auto n = N; + static constexpr auto H = N / 2; + union { + dtype x[N]; + t pack; + struct { dtype low[H], high[H]; }; + }; +}; + +using bf16x4_t = PackN_t; +using fp32x4_t = PackN_t; + + +__device__ FORCE_INLINE bfloat16_t fast_f32tob16(float f) { +#ifdef FAST_UNSAFE_CAST + union { + float fp32; + unsigned int u32; + } u = {f}; + u.u32 += 0x7FFF + ((u.u32 >> 16) & 1); + auto ret = u.u32 >> 16; + return reinterpret_cast(ret); +#else + return static_cast(f); +#endif +} + + +__device__ FORCE_INLINE float fast_b16tof32(bfloat16_t bf) { +#ifdef FAST_UNSAFE_CAST + union { + float fp32; + unsigned int u32; + } u; + u.u32 = reinterpret_cast(bf) << 16; + return u.fp32; +#else + return static_cast(bf); +#endif +} + + +__device__ __host__ FORCE_INLINE constexpr int ceil_div(int a, int b) { + return (a + b - 1) / b; +} + +template +__device__ __host__ FORCE_INLINE constexpr int exact_div() { + static_assert(a % b == 0); + return a / b; +} + + +__device__ FORCE_INLINE inline void block_sync_lds() { + __builtin_amdgcn_s_waitcnt(0xc07f); + __builtin_amdgcn_s_barrier(); +} + + +template +__device__ FORCE_INLINE inline void compute_tile_indices( + int tile_id, + int &tile_m_id, + int &tile_n_id +) { + static_assert(GROUP_SIZE_M % 8 == 0); + if constexpr (GROUP_SIZE_M > 0) { + // Swizzle pattern for better L2 cache locality + // Groups tiles in blocks of GROUP_SIZE_M x num_tile_n + constexpr int num_pid_in_group = GROUP_SIZE_M * num_tile_n; + + // Which group does this tile belong to? + const int group_id = tile_id / num_pid_in_group; + + // First M-dimension tile in this group + const int first_pid_m = group_id * GROUP_SIZE_M; + + // Actual group size (handling boundary case) + const int group_size_m = min(GROUP_SIZE_M, num_tile_m - first_pid_m); + + // Position within the group + const int idx_in_group = tile_id % num_pid_in_group; + + // Swizzled tile indices: alternate M then N within group + tile_m_id = first_pid_m + (idx_in_group % group_size_m); + tile_n_id = idx_in_group / group_size_m; + } else { + tile_m_id = tile_id / num_tile_n; + tile_n_id = tile_id % num_tile_n; + } +} + +template +__device__ int FORCE_INLINE remap_xcd_pid(int pid) { + if constexpr(REMAP_XCD) { + return (pid % NUM_XCDS) * (NUM_GEMM_SMS / NUM_XCDS) + (pid / NUM_XCDS); + } else { + return pid; + } +} + + +template +struct EpilogueSignal { + FORCE_INLINE __device__ void operator()(int tid, int tile_m_id, int tile_n_id, int *signal_ptr, int round_trip) const { + if (tid == 0) { + auto signal_arr = reinterpret_cast(signal_ptr); + __hip_atomic_store(&signal_arr[tile_m_id][tile_n_id], round_trip, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_SYSTEM); + } + } +}; + + +template +__launch_bounds__(NUM_THREADS) +__global__ void gemm_kernel( + const bfloat16_t *x, // [M, K] + const bfloat16_t *w, // [N, K] + const bfloat16_t *b, // [N] or nullptr + std::array c, // WORLD_SIZE * [M, N] + std::array signal, // WORLD_SIZE * [M / BM][N / BN] + bfloat16_t *out, // [M / WORLD_SIZE, N] + int rank, + int round_trip +) { + const int tid = threadIdx.x; + const int pid = blockIdx.x; + __builtin_assume(tid >= 0 && tid < NUM_THREADS); + __builtin_assume(pid >= 0 && pid < (NUM_GEMM_SMS + NUM_RS_SMS)); + const int lane_id = __lane_id(); + __builtin_assume(lane_id >= 0 && lane_id < 64); + + + if (pid < NUM_GEMM_SMS) { /* GEMM */ + + auto *c_ptr = c[rank]; + auto *signal_ptr = signal[rank]; + constexpr int v_mfma_f32_16x16x16_bf16 = (BM * BN * BK) / (WARP_M * WARP_N) / (16*16*16); + // Compiler will merge two ds_{read,write}_b64 to ds_{read,write`}2st64_b64 + constexpr int ds_read_b128_a = (BM * BK / WARP_M) / 64 / 8; + constexpr int ds_read_b128_b = (BN * BK / WARP_N) / 64 / 8; + constexpr int ds_read_b128 = ds_read_b128_a + ds_read_b128_b; + constexpr int ds_write_b128_a = (BM * BK) / NUM_THREADS / 8; + constexpr int ds_write_b128_b = (BN * BK) / NUM_THREADS / 8; + constexpr int ds_write_b128 = ds_write_b128_a + ds_write_b128_b; + constexpr int buffer_load_dwordx2_a = (BM * BK) / NUM_THREADS / 4; + constexpr int buffer_load_dwordx2_b = (BN * BK) / NUM_THREADS / 4; + constexpr int buffer_load_dwordx2 = buffer_load_dwordx2_a + buffer_load_dwordx2_b; + const int pid0 = blockIdx.x; + constexpr int NUM_XCDS = 8; + constexpr int num_tile_m = ceil_div(M, BM); + constexpr int num_tile_n = ceil_div(N, BN); + constexpr int num_tile_k = ceil_div(K, BK); + constexpr int num_tiles = num_tile_m * num_tile_n; + #ifdef SWIZZLE_XCD_PID + const int pid = (pid0 % NUM_XCDS) * (NUM_GEMM_SMS / NUM_XCDS) + (pid0 / NUM_XCDS); + #else + const int pid = pid0; + #endif + + const int tid = threadIdx.x; + const int lane_id = __lane_id(); + __builtin_assume(pid >= 0 && pid < NUM_GEMM_SMS + NUM_RS_SMS); + __builtin_assume(tid >= 0 && tid < NUM_THREADS); + __builtin_assume(lane_id >= 0 && lane_id < 64); + // each thread load 4 elements + static_assert(BK % 4 == 0 && NUM_THREADS * 4 % BK == 0); + constexpr int WM = 16, WN = 16, WK = 16; + + constexpr int Frag_M = exact_div(); + constexpr int Frag_N = exact_div(); + constexpr int Frag_K = exact_div(); + const int warp_id = __builtin_amdgcn_readfirstlane(tid / 64); + // const int warp_id = 0; + const int warp_m = warp_id / WARP_N; + const int warp_n = warp_id % WARP_N; + using FragX = bf16x4_t; + using FragW = bf16x4_t; + using FragC = fp32x4_t; + __shared__ bfloat16_t s_x[BM][BK]; + __shared__ bfloat16_t s_w[BN][BK]; + bf16x4_t vgpr_x[ceil_div(BM * BK, NUM_THREADS * 4)]; + bf16x4_t vgpr_w[ceil_div(BN * BK, NUM_THREADS * 4)]; + + FragC frag_c[Frag_M][Frag_N]; + FragX frag_x[Frag_M][Frag_K]; + FragW frag_w[Frag_N][Frag_K]; + + auto load_vgpr = [&](int m, int n, int k) FORCE_INLINE { + auto x_arr = ck::make_wave_buffer_resource(const_cast(x), M * K); + auto w_arr = ck::make_wave_buffer_resource(const_cast(w), N * K); + int v_offset = ((tid * 4 / BK) * K + (tid * 4 % BK)) * sizeof(bfloat16_t); + uint32_t src_addr_shift = (K % BK == 0) || (k + tid * 4 % BK < K) ? 0 : 0x80000000; + ck::static_for<0, sizeof(vgpr_x) / sizeof(vgpr_x[0]), 1>{}([&](auto t) { + int s_offset = ((m * K + k) + t * NUM_THREADS * 4 / BK * K) * sizeof(bfloat16_t); + vgpr_x[t] = __builtin_bit_cast(bf16x4_t, ck::amd_buffer_load_impl_raw( + x_arr, v_offset + src_addr_shift, s_offset)); + }); + ck::static_for<0, sizeof(vgpr_w) / sizeof(vgpr_w[0]), 1>{}([&](auto t) { + int s_offset = ((n * K + k) + t * NUM_THREADS * 4 / BK * K) * sizeof(bfloat16_t); + vgpr_w[t] = __builtin_bit_cast(bf16x4_t, ck::amd_buffer_load_impl_raw( + w_arr, v_offset + src_addr_shift, s_offset)); + }); + + }; + + auto load_lds = [&]() FORCE_INLINE { + // diagonal swizzle, shape=[16, 64] dtype=bfloat16 + #pragma unroll + for (int t=0;t(&s_x[row0 + row1][col1]) = vgpr_x[t]; + } + #pragma unroll + for (int t=0;t(&s_w[row0 + row1][col1]) = vgpr_w[t]; + } + }; + + auto zero_all_frags = [&]() FORCE_INLINE { + ck::static_for<0, Frag_M, 1>{}([&](auto i) { + ck::static_for<0, Frag_N, 1>{}([&](auto j) { + ck::static_for<0, 4, 1>{}([&](auto t) { frag_c[i][j].x[t] = 0; }); + }); + }); + }; + + + auto frags_load = [&]() FORCE_INLINE { + ck::static_for<0, Frag_K, 1>{}([&](auto k) { + ck::static_for<0, Frag_M, 1>{}([&](auto i) { + const int row1 = (warp_m * Frag_M + i) * WM; + const int row0 = lane_id % 16; + const int col0 = k * 16 + lane_id / 16 * 4; + const int col1 = (row0 * 4 + col0) % BK; + frag_x[i][k] = *reinterpret_cast(&s_x[row0 + row1][col1]); + }); + ck::static_for<0, Frag_N, 1>{}([&](auto j) { + const int row1 = (warp_n * Frag_N + j) * WN; + const int row0 = lane_id % 16; + const int col0 = k * 16 + lane_id / 16 * 4; + const int col1 = (row0 * 4 + col0) % BK; + frag_w[j][k] = *reinterpret_cast(&s_w[row0 + row1][col1]); + }); + }); + }; + + auto frags_mfma = [&]() FORCE_INLINE { + ck::static_for<0, Frag_M, 1>{}([&](auto i) { + ck::static_for<0, Frag_N, 1>{}([&](auto j) { + ck::static_for<0, Frag_K, 1>{}([&](auto k) { + // a: [16][16], b: [16][16], c: [16][16] + // mfma requires a: row-major, b: col-major, out: col-major + // so we compute w^T * x^T = c^T so we can treat out as col-major + frag_c[i][j].pack = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(frag_w[j][k].pack, frag_x[i][k].pack, frag_c[i][j].pack, 0, 0, 0); + }); + }); + }); + }; + + + auto store_frags = [&](int m, int n) FORCE_INLINE { + auto b_arr = ck::make_wave_buffer_resource(const_cast(b), N); + auto c_arr = ck::make_wave_buffer_resource(c_ptr, M * N); + fp32x4_t c_out[Frag_M][Frag_N]; + ck::static_for<0, Frag_M, 1>{}([&](auto i) { + ck::static_for<0, Frag_N, 1>{}([&](auto j) { + ck::static_for<0, 4, 1>{}([&](auto t) { + // v_accvgpr_read_b32 + c_out[i][j].x[t] = frag_c[i][j].x[t]; + }); + // c_out: [16][16] + int row = lane_id % 16; + int col = lane_id / 16 * 4; + uint32_t src_addr_shift = (N % BN == 0) || (n + (j + warp_n * Frag_N) * WN + col < N) ? 0 : 0x80000000; + // load b + int b_s_offset = (n + (j + warp_n * Frag_N) * WN) * sizeof(bfloat16_t); + int b_v_offset = col * sizeof(bfloat16_t) + src_addr_shift; + auto b_vec = __builtin_bit_cast(bf16x4_t, ck::amd_buffer_load_impl_raw( + b_arr, b_v_offset, b_s_offset)); + // compute c + bf16x4_t c_out_bf16; + #pragma unroll + for (int t = 0; t < 4; ++t) { + c_out_bf16.x[t] = fast_f32tob16(c_out[i][j].x[t] + b_vec.x[t]); + } + // write c + int c_s_offset = b_s_offset + (m + (i + warp_m * Frag_M) * WM) * N * sizeof(bfloat16_t); + int c_v_offset = b_v_offset + (row * N) * sizeof(bfloat16_t); + ck::amd_buffer_store_impl_raw(c_out_bf16.pack, c_arr, c_v_offset, c_s_offset); + }); + }); + }; + + + + for (int tile_id=pid; tile_id(tile_id, tile_m_id, tile_n_id); + #else + // int tile_m_id = tile_id / num_tile_n; + // int tile_n_id = tile_id % num_tile_n; + int tile_m_id = tile_id % num_tile_m; + int tile_n_id = tile_id / num_tile_m; + #endif + int m = tile_m_id * BM; + int n = tile_n_id * BN; + load_vgpr(m, n, 0); // GDS -> VGPR #0 + load_lds(); // VGPR -> LDS #0 + load_vgpr(m, n, 1 * BK); // GDS -> VGPR #1 + zero_all_frags(); + block_sync_lds(); + frags_load(); // LDS -> FRAG #0 + __builtin_amdgcn_sched_barrier(0); + // #pragma clang loop unroll_count(2) + // #pragma unroll 2 + // #pragma unroll + for (int tile_k_id = 1; tile_k_id < (num_tile_k - 1); ++tile_k_id) { + asm volatile(R"( + ; Main Loop Begin + )" ::: "memory"); + block_sync_lds(); + // Stage 1 + load_lds(); // VGPR -> LDS #1 + load_vgpr(m, n, (tile_k_id + 1) * BK); // GDS -> VGPR #2(k+1) + frags_mfma(); // MFMA #0(k-1) + // 120 + #pragma unroll + for (int k = 0; k < buffer_load_dwordx2 / 2; ++k) { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA + } + + block_sync_lds(); + // Stage 2 + frags_load(); // LDS -> FRAG #1(k) + // 60 + #pragma unroll + for (int k = 0; k < ds_read_b128; ++k) { + __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + } + __builtin_amdgcn_sched_barrier(0); + asm volatile(R"( + ; Main Loop End + )" ::: "memory"); + } + frags_mfma(); // MFMA #1(n-2) + block_sync_lds(); + load_lds(); // VGPR -> LDS #2(n-1) + block_sync_lds(); + frags_load(); // LDS -> FRAG #2(n-1) + frags_mfma(); // MFMA #2(n-1) + store_frags(m, n); + + __builtin_amdgcn_s_barrier(); + if (tid == 0) { + __hip_atomic_store(&signal_ptr[tile_m_id * num_tile_n + tile_n_id], round_trip, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_SYSTEM); + } + } + } else { /* Reduce Scatter */ + constexpr int M_per_rank = exact_div(); + static_assert(BN % 8 == 0); + constexpr int NUM_TILES_M = ceil_div(M_per_rank, BM); + constexpr int NUM_TILES_N = ceil_div(N, BN); + constexpr int TOTAL_TILES = NUM_TILES_M * NUM_TILES_N; + constexpr int ELEMENTS_PER_THREAD = 8; + struct rs_vec_t { bfloat16_t data[ELEMENTS_PER_THREAD]; }; + static_assert(N % ELEMENTS_PER_THREAD == 0); + const int rs_pid = pid - NUM_GEMM_SMS; + for (int tile_id = rs_pid; tile_id < TOTAL_TILES; tile_id += NUM_RS_SMS) { + int tile_m = tile_id % NUM_TILES_M; + int tile_n = tile_id / NUM_TILES_M; + float accum[BM * BN / NUM_THREADS] = {}; + static_assert((sizeof(accum) / sizeof(float)) % ELEMENTS_PER_THREAD == 0); + #pragma clang loop unroll_count(4) + for (int r = 0; r < WORLD_SIZE; r++) { + int swizzle_rank = (rs_pid + r) % WORLD_SIZE; + const int M_begin = rank * M_per_rank; + const int M_end = (rank + 1) * M_per_rank; + // since tile may be not well aligned, we have to wait at most two signal + // we should wait signal[signal_m0 .. signal_m1][signal_n] + const int tile_row_begin = M_begin + tile_m * BM; + const int tile_row_end = min(tile_row_begin + BM, M_end); + int signal_m0 = tile_row_begin / BM; + int signal_m1 = (tile_row_end - 1) / BM; + int signal_n = tile_n; + int signal_m = signal_m0 + tid; + if (signal_m <= signal_m1) { + auto *signal_arr = reinterpret_cast(signal[swizzle_rank]); + while (__hip_atomic_load(&signal_arr[signal_m][signal_n], __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_SYSTEM) != round_trip) {} // spin wait + } + __syncthreads(); + auto input_arr = ck::make_wave_buffer_resource(c[swizzle_rank] + M_begin * N, M_per_rank * N); + // auto input_arr = reinterpret_cast(c[swizzle_rank] + M_begin * N); + #pragma unroll + for (int t = 0; t < sizeof(accum) / sizeof(float); t+=ELEMENTS_PER_THREAD) { + int i = (t * NUM_THREADS + tid * ELEMENTS_PER_THREAD) / BN; + int j = (t * NUM_THREADS + tid * ELEMENTS_PER_THREAD) % BN; + int global_i = tile_m * BM + i; + int global_j = tile_n * BN + j; + auto vec = ck::bit_cast(ck::amd_buffer_load_impl_raw( + input_arr, (global_i * N + global_j) * sizeof(bfloat16_t), 0)); + #pragma unroll + for (int k = 0; k < ELEMENTS_PER_THREAD; ++k) accum[t + k] += fast_b16tof32(vec.data[k]); + } + } + // auto out_arr = reinterpret_cast(out); + auto out_arr = ck::make_wave_buffer_resource(out, M_per_rank * N); + #pragma unroll + for (int t = 0; t < sizeof(accum) / sizeof(float); t+=ELEMENTS_PER_THREAD) { + int i = (t * NUM_THREADS + tid * ELEMENTS_PER_THREAD) / BN; + int j = (t * NUM_THREADS + tid * ELEMENTS_PER_THREAD) % BN; + int global_i = tile_m * BM + i; + int global_j = tile_n * BN + j; + rs_vec_t vec; + #pragma unroll + for (int k = 0; k < ELEMENTS_PER_THREAD; ++k) vec.data[k] = fast_f32tob16(accum[t + k]); + using r_t = typename ck::vector_type::type; + ck::amd_buffer_store_impl_raw(ck::bit_cast(vec), + out_arr, (global_i * N + global_j) * sizeof(bfloat16_t), 0 + ); + } + } + } +} + +constexpr long IntKey(int M, int N, int K, int WORLD_SIZE, bool bias) { + union { + struct {uint16_t M, N, K, WORLD_SIZE;} data; + long key; + } u { .data={ + static_cast(M) , + static_cast(N), + static_cast(K), + static_cast(WORLD_SIZE)}}; + return u.key; +} + + +constexpr int WORLD_SIZE = 8; +constexpr int MI300X_NUM_SMS = 304; +constexpr int RS_SMS = 16; +constexpr int GEMM_SMS = MI300X_NUM_SMS - RS_SMS; + +// Simuate IntraNode environment for test +class GEMMFactory { +private: + std::unordered_map> gemm_map_; + std::unordered_map> rs_map_; + std::unordered_map> gemm_rs_map_; + int *dummy_signal_buf_; + bfloat16_t *zero_bias_buf_; + int round_trip_ = 0; + + template + inline void RegisterGEMM() { + static_assert(M * N * sizeof(bfloat16_t) <= MAX_IPC_MEM_SIZE, "C size exceeds MAX_IPC_MEM_SIZE"); + constexpr int NUM_THREADS = WARP_M * WARP_N * AMDGCN_WAVEFRONT_SIZE; + TORCH_CHECK(gemm_map_.count(IntKey(M, N, K, 8, 0)) == 0); + gemm_map_[IntKey(M, N, K, 8, 0)] = [&](const bfloat16_t *x, const bfloat16_t *w, const bfloat16_t *b, bfloat16_t *c, hipStream_t stream) { + std::array signal = {dummy_signal_buf_, nullptr}; + std::array c_ptrs = {c, nullptr}; + hipLaunchKernelGGL(HIP_KERNEL_NAME(gemm_kernel), + dim3(304), dim3(NUM_THREADS), 0, stream, + x, w, b, c_ptrs, signal, nullptr, 0, 0 + ); + }; +#ifdef __ENABLE_LOCAL_DEBUG__ + TORCH_CHECK(rs_map_.count(IntKey(M, N, 0, WORLD_SIZE, false)) == 0); + rs_map_[IntKey(M, N, 0, WORLD_SIZE, false)] = [&](const bfloat16_t *const *inputs, bfloat16_t *out, int *sig[], int rank, hipStream_t stream) mutable { + std::array signal; + std::array c_ptrs; + for (int i = 0; i < WORLD_SIZE; i++) { + signal[i] = sig[i]; + c_ptrs[i] = const_cast(inputs[i]); + } + int round_trip = ++round_trip_; + hipLaunchKernelGGL(HIP_KERNEL_NAME(gemm_kernel), + dim3(MI300X_NUM_SMS), dim3(NUM_THREADS), 0, stream, + nullptr, nullptr, nullptr, c_ptrs, signal, out, rank, round_trip + ); + }; +#endif + TORCH_CHECK(gemm_rs_map_.count(IntKey(M, N, K, WORLD_SIZE, false)) == 0); + gemm_rs_map_[IntKey(M, N, K, WORLD_SIZE, false)] = [&](const bfloat16_t *x, const bfloat16_t *w, const bfloat16_t *b, bfloat16_t *c[], int *sig[], bfloat16_t *out, int rank, hipStream_t stream) mutable { + std::array signal; + std::array c_ptrs; + for (int i = 0; i < WORLD_SIZE; i++) { + signal[i] = sig[i]; + c_ptrs[i] = c[i]; + } + // TODO: sync all rank + // C10_HIP_CHECK(hipDeviceSynchronize()); // make sure memset is done before kernel launch + + hipLaunchKernelGGL(HIP_KERNEL_NAME(gemm_kernel), + dim3(MI300X_NUM_SMS), dim3(NUM_THREADS), 0, stream, + x, w, b, c_ptrs, signal, out, rank, ++round_trip_ + ); + }; + } + + GEMMFactory() { + C10_HIP_CHECK(hipExtMallocWithFlags(reinterpret_cast(&dummy_signal_buf_), SIGNAL_BUF_SIZE, hipDeviceMallocUncached)); // dummy buffer for signal + C10_HIP_CHECK(hipMemset(dummy_signal_buf_, 0, SIGNAL_BUF_SIZE)); +#ifdef FORCE_LOAD_BIAS + C10_HIP_CHECK(hipMalloc(reinterpret_cast(&zero_bias_buf_), 8192 * sizeof(bfloat16_t))); + C10_HIP_CHECK(hipMemset(zero_bias_buf_, 0, 8192 * sizeof(bfloat16_t))); +#endif + // RegisterGEMM<8192, 8192, 2048>(); + // RegisterGEMM<4096, 4096, 4096>(); + // RegisterGEMM<4096, 4096, 2048>(); + + // int M, int N, int K, int BM, int BN, int BK, int WARP_M, int WARP_N, int GROUP_SIZE_M + RegisterGEMM<64, 7168, 2304, 32, 128, 64, 2, 2, 0 >(); // 64 x 7168 x (18432/8) + RegisterGEMM<512, 4096, 1536, 64, 128, 64, 2, 2, 8 >(); // 512 x 4096 x (12288/8) + RegisterGEMM<2048, 2880, 360, 128, 128, 64, 2, 2, 8 >(); // 2048 x 2880 x (2880/8) + RegisterGEMM<4096, 4096, 512, 128, 256, 64, 2, 2, 8 >(); // 4096 x 4096 x (4096/8) + RegisterGEMM<8192, 4096, 1792, 256, 224, 64, 2, 2, 16>(); // 8192 x 4096 x (14336/8) + RegisterGEMM<8192, 8192, 3696, 224, 256, 64, 2, 2, 16>(); // 8192 x 8192 x (29568/8) + } + +public: + static GEMMFactory* get() { + static std::unique_ptr instance; + if (!instance) { + instance = std::unique_ptr(new GEMMFactory()); + } + return instance.get(); + } + + void *get_completed_signal_for_next(hipStream_t stream) { + int round_trip = round_trip_ + 1; + C10_HIP_CHECK(hipMemsetD32Async(dummy_signal_buf_, round_trip, SIGNAL_BUF_SIZE / sizeof(int), stream)); + return dummy_signal_buf_; + } + + void launch_gemm(const bfloat16_t *x_ptr, const bfloat16_t *w_ptr, const bfloat16_t *b_ptr, bfloat16_t *c_ptr, int M, int N, int K, hipStream_t stream) { + auto key = IntKey(M, N, K, 8, 0); + auto it = gemm_map_.find(key); + TORCH_CHECK(it != gemm_map_.end(), "Unsupported GEMM size: ", M, "x", N, "x", K); + auto func = it->second; + func(x_ptr, w_ptr, b_ptr, c_ptr, stream); + } + + void launch_rs(const bfloat16_t *inputs[], bfloat16_t *out, int *signal_buf[], int rank, int world_size, int M, int N, hipStream_t stream) { + auto key = IntKey(M, N, 0, world_size, false); // K is not used in RS, set to a dummy value + auto it = rs_map_.find(key); + TORCH_CHECK(it != rs_map_.end(), "Unsupported RS size: ", M, "x", N); + auto func = it->second; + func(inputs, out, signal_buf, rank, stream); + } + + void launch_gemm_rs(const bfloat16_t *x_ptr, const bfloat16_t *w_ptr, const bfloat16_t *b_ptr, bfloat16_t *c_ptr[], int *signal_buf[], bfloat16_t *out, int rank, int M, int N, int K, int world_size, hipStream_t stream) { + bool bias = b_ptr != nullptr; + auto key = IntKey(M, N, K, world_size, bias); + auto it = gemm_rs_map_.find(key); + TORCH_CHECK(it != gemm_rs_map_.end(), "Unsupported GEMM+RS size: ", M, "x", N, "x", K, "-", bias); + auto func = it->second; + func(x_ptr, w_ptr, b_ptr, c_ptr, signal_buf, out, rank, stream); + } + +}; + + + +void launch_gemm(const void *x, const void *w, const void *b, void *out, int M, int N, int K, hipStream_t stream) { + auto *x_ptr = reinterpret_cast(x); + auto *w_ptr = reinterpret_cast(w); + auto *b_ptr = reinterpret_cast(b); + auto *c_ptr = reinterpret_cast(out); + GEMMFactory::get()->launch_gemm(x_ptr, w_ptr, b_ptr, c_ptr, M, N, K, stream); +} + +void launch_rs(const void *inputs[], void *output, int rank, int world_size, int M, int N, hipStream_t stream) { + auto *inputs_ptr = reinterpret_cast(inputs); + auto *output_ptr = reinterpret_cast(output); + std::vector signal_buf(world_size, reinterpret_cast(GEMMFactory::get()->get_completed_signal_for_next(stream))); + GEMMFactory::get()->launch_rs(inputs_ptr, output_ptr, signal_buf.data(), rank, world_size, M, N, stream); +} + +void launch_gemm_rs_dist(const void *x, const void *w,const void *b, void *rs_buf[], void *sig_buf[], void *output, int rank, int world_size, int M, int N, int K, hipStream_t stream) { + auto * x_ptr = reinterpret_cast(x); + auto * w_ptr = reinterpret_cast(w); + auto * b_ptr = reinterpret_cast(b); + auto * out_ptr = reinterpret_cast(output); + auto *rs_buf_ptr = reinterpret_cast(const_cast(rs_buf)); + auto *sig_buf_ptr = reinterpret_cast(const_cast(sig_buf)); + GEMMFactory::get()->launch_gemm_rs(x_ptr, w_ptr, b_ptr, rs_buf_ptr, sig_buf_ptr, out_ptr, rank, M, N, K, world_size, stream); +} + +void launch_gemm_rs(const void *x, const void *w,const void *b, void *rs_buf[], void *sig_buf, void *output, int rank, int world_size, int M, int N, int K, hipStream_t stream) { + std::vector signal_buf(world_size); + for (int i = 0; i < world_size; i++) { + if (i == rank) { + signal_buf[i] = sig_buf; + } else { + signal_buf[i] = GEMMFactory::get()->get_completed_signal_for_next(stream); + } + + } + launch_gemm_rs_dist(x, w, b, rs_buf, signal_buf.data(), output, rank, world_size, M, N, K, stream); +} + +} // namespace gemm_rs \ No newline at end of file diff --git a/dist-infer/gemm-rs/src/gemm_rs_kernel.h b/dist-infer/gemm-rs/src/gemm_rs_kernel.h new file mode 100644 index 0000000..0b60550 --- /dev/null +++ b/dist-infer/gemm-rs/src/gemm_rs_kernel.h @@ -0,0 +1,20 @@ +#pragma once +#include + + +namespace gemm_rs { + +constexpr int MAX_WORLD_SIZE = 8; +constexpr size_t MAX_IPC_MEM_SIZE = 256 * (1UL << 20); // 256MB +constexpr size_t SIGNAL_BUF_SIZE = 1 * (1UL << 20); // 1MB + +void launch_gemm(const void *x, const void *w, const void *b, void *out, int M, int N, int K, hipStream_t stream); + +void launch_rs(const void *inputs[], void *output, int rank, int world_size, int M, int N, hipStream_t stream); + +void launch_gemm_rs(const void *x, const void *w, const void *b, void *rs_buf[], void *sig_buf, void *output, int rank, int world_size, int M, int N, int K, hipStream_t stream); + +void launch_gemm_rs_dist(const void *x, const void *w, const void *b, void *rs_buf[], void *sig_buf[], void *output, int rank, int world_size, int M, int N, int K, hipStream_t stream); + + +} // namespace gemm_rs \ No newline at end of file diff --git a/dist-infer/gemm-rs/src/perf_gemm.cc b/dist-infer/gemm-rs/src/perf_gemm.cc new file mode 100644 index 0000000..bc4523c --- /dev/null +++ b/dist-infer/gemm-rs/src/perf_gemm.cc @@ -0,0 +1,1085 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#define FAST_UNSAFE_CAST +// #define DEBUG_SIGNAL_CYCLE +// #define SWIZZLE_XCD_PID +// #define SWIZZLE_L2_TILE + +#define FORCE_INLINE __attribute__((always_inline)) + +namespace roc_isa { + constexpr int AMDGCN_WAVEFRONT_SIZE = 64; +namespace issue_latency { + constexpr int v_mfma_f32_16x16x16_bf16 = 4; + constexpr int ds_read_b128 = 2 * 4; + constexpr int ds_write_b128 = 5 * 4; + constexpr int buffer_load_dwordx2 = 1 * 2; +} // namespace issue_latency + +// Trait for different GEMM size categories +enum class GemmSizeCategory { + LARGE, // 256x224, 224x256, 256x256 + MIDDLE, // 256x128, 128x256, 128x128 + SMALL // 128x64, 64x128, 128x32, 32x128, 64x32, 32x32 +}; + +template +struct GemmSizeTrait { + // Large GEMM + // (256x224) MFMA 224, DS_READ 30, DS_WRITE 15, BUFFER_LOAD 30 + // Middle GEMM + // (256X128) MFMA 128, DS_READ_24, DS_WRITE 12, BUFFER_LOAD 24 + // (128X128) MFMA 64, DS_READ 16, DS_WRITE 8, BUFFER_LOAD 16 + // Small GEMM + // (128x64) MFMA 32, DS_WRITE 6, DS_READ 12, BUFFER_LOAD 12 + // (128x32) MFMA 16, DS_WRITE 5, DS_READ 8 , BUFFER_LOAD 10 + // (64x32) MFMA 8, DS_WRITE 3, DS_READ 6 , BUFFER_LOAD 6 + static constexpr GemmSizeCategory category = + ((BM == 256 && BN == 224) || (BM == 224 && BN == 256) || (BM == 256 && BN == 256)) ? GemmSizeCategory::LARGE : + ((BM == 256 && BN == 128) || (BM == 128 && BN == 256) || (BM == 128 && BN == 128)) ? GemmSizeCategory::MIDDLE : + GemmSizeCategory::SMALL; +}; + +// Schedule configuration trait based on category +template +struct ScheduleConfig; + +template<> +struct ScheduleConfig { + // Stage 1: DS_WRITE(1) -> MFMA(2) -> VMEM(1) -> MFMA(3) + static constexpr int stage1_ds_write = 1; + static constexpr int stage1_mfma_before_vmem = 2; + static constexpr int stage1_vmem = 1; + static constexpr int stage1_mfma_after_vmem = 3; + + // Stage 2: MFMA(2) -> DS_READ(1) + static constexpr int stage2_mfma = 2; + static constexpr int stage2_ds_read = 1; +}; + +template<> +struct ScheduleConfig { + // Stage 1: DS_WRITE(1) -> MFMA(2) -> VMEM(1) -> MFMA(1) + static constexpr int stage1_ds_write = 1; + static constexpr int stage1_mfma_before_vmem = 2; + static constexpr int stage1_vmem = 1; + static constexpr int stage1_mfma_after_vmem = 1; + + // Stage 2: MFMA(1) -> DS_READ(2) + static constexpr int stage2_mfma = 1; + static constexpr int stage2_ds_read = 2; +}; + +template<> +struct ScheduleConfig { + // Stage 1: DS_WRITE(1) -> MFMA(1) -> VMEM(1) -> MFMA(1) + static constexpr int stage1_ds_write = 1; + static constexpr int stage1_mfma_before_vmem = 1; + static constexpr int stage1_vmem = 1; + static constexpr int stage1_mfma_after_vmem = 1; + + // Stage 2: MFMA(1) -> DS_READ(1) + static constexpr int stage2_mfma = 1; + static constexpr int stage2_ds_read = 1; +}; + +template +struct InstCalculator { + static constexpr int v_mfma_f32_16x16x16_bf16 = (BM * BN * BK) / (WARP_M * WARP_N) / (16*16*16); + // Compiler will merge two ds_{read,write}_b64 to ds_{read,write`}2st64_b64 + static constexpr int ds_read_b128_a = (BM * BK / WARP_M) / 64 / 8; + static constexpr int ds_read_b128_b = (BN * BK / WARP_N) / 64 / 8; + static constexpr int ds_read_b128 = ds_read_b128_a + ds_read_b128_b; + static constexpr int ds_write_b128_a = (BM * BK) / NUM_THREADS / 8; + static constexpr int ds_write_b128_b = (BN * BK) / NUM_THREADS / 8; + static constexpr int ds_write_b128 = ds_write_b128_a + ds_write_b128_b; + static constexpr int buffer_load_dwordx2_a = (BM * BK) / NUM_THREADS / 4; + static constexpr int buffer_load_dwordx2_b = (BN * BK) / NUM_THREADS / 4; + static constexpr int buffer_load_dwordx2 = buffer_load_dwordx2_a + buffer_load_dwordx2_b; + + // Get schedule configuration based on BM and BN + using size_trait = GemmSizeTrait; + using schedule_config = ScheduleConfig; +}; + +} // namespace roc_isa + +namespace test { + constexpr int BM = 128, BN = 128; + constexpr int WARP_M = 2, WARP_N = 2, NUM_THREADS = 256, BK = 64; + constexpr int MFMA_NUM = roc_isa::InstCalculator::v_mfma_f32_16x16x16_bf16; + constexpr int DS_READ_NUM = roc_isa::InstCalculator::ds_read_b128; + constexpr int DS_WRITE_NUM = roc_isa::InstCalculator::ds_write_b128; + constexpr int BUFFER_LOAD_NUM = roc_isa::InstCalculator::buffer_load_dwordx2; +} + + +using bfloat16_t = __bf16; + +__device__ __host__ FORCE_INLINE constexpr int ceil_div(int a, int b) { + return (a + b - 1) / b; +} + +template +__device__ __host__ FORCE_INLINE constexpr int exact_div() { + static_assert(a % b == 0); + return a / b; +} + +__device__ __host__ FORCE_INLINE constexpr int i_min(int a, int b) { + return a < b ? a : b; +} + +__device__ __host__ FORCE_INLINE constexpr int i_max(int a, int b) { + return a > b ? a : b; +} + + +template +struct PackN_t { + using t = __attribute__((vector_size(N * sizeof(dtype)))) dtype; + static constexpr auto n = N; + static constexpr auto H = N / 2; + union { + dtype x[N]; + t pack; + struct { dtype low[H], high[H]; }; + }; +}; + +using bf16x4_t = PackN_t; +using fp32x4_t = PackN_t; +using bf16x8_t = PackN_t; +using fp32x8_t = PackN_t; +using i32x4_t = PackN_t; + +#define FORCE_INLINE __attribute__((always_inline)) + + +__device__ ck::int32x4_t inline make_wave_buffer_resource(const void* ptr, uint32_t size = 0xffffffff) { + ck::int32x4_t res; + + // Pack the 64-bit pointer into two 32-bit integers + uint64_t ptr_val = reinterpret_cast(ptr); + res.x = static_cast(ptr_val); + res.y = static_cast(ptr_val >> 32); + + // Set buffer size and format + res.z = size; // Buffer size in bytes + res.w = 0x00020000; // hardcoded for gfx942 + + res.x = __builtin_amdgcn_readfirstlane(res.x); + res.y = __builtin_amdgcn_readfirstlane(res.y); + res.z = __builtin_amdgcn_readfirstlane(res.z); + res.w = __builtin_amdgcn_readfirstlane(res.w); + return res; +} + +__device__ FORCE_INLINE bfloat16_t fast_f32tob16(float f) { +#ifdef FAST_UNSAFE_CAST + union { + float fp32; + unsigned int u32; + } u = {f}; + u.u32 += 0x7FFF + ((u.u32 >> 16) & 1); + auto ret = u.u32 >> 16; + return reinterpret_cast(ret); +#else + return static_cast(f); +#endif +} + + +__device__ FORCE_INLINE float fast_b16tof32(bfloat16_t bf) { +#ifdef FAST_UNSAFE_CAST + union { + float fp32; + unsigned int u32; + } u; + u.u32 = (reinterpret_cast(bf)) << 16; + return u.fp32; +#else + return static_cast(bf); +#endif +} + +__device__ void block_sync_lds() { + __builtin_amdgcn_s_waitcnt(0xc07f); + __builtin_amdgcn_s_barrier(); +} + +__device__ void block_sync_gds() { + __builtin_amdgcn_s_waitcnt(0xf70); + __builtin_amdgcn_s_barrier(); +} + +template +__device__ FORCE_INLINE void compute_tile_indices( + int tile_id, + int &tile_m_id, + int &tile_n_id +) { + if constexpr (GROUP_SIZE_N == 0) { + // No swizzle + tile_m_id = tile_id / num_tile_n; + tile_n_id = tile_id % num_tile_n; + } else { + // Swizzle pattern for better L2 cache locality + // Groups tiles in blocks of num_tile_m x GROUP_SIZE_N + constexpr int num_pid_in_group = num_tile_m * GROUP_SIZE_N; + + // Which group does this tile belong to? + const int group_id = tile_id / num_pid_in_group; + + // First N-dimension tile in this group + const int first_pid_n = group_id * GROUP_SIZE_N; + + // Actual group size (handling boundary case) + const int group_size_n = min(GROUP_SIZE_N, num_tile_n - first_pid_n); + + // Position within the group + const int idx_in_group = tile_id % num_pid_in_group; + + // Swizzled tile indices: alternate N then M within group + tile_n_id = first_pid_n + (idx_in_group % group_size_n); + tile_m_id = idx_in_group / group_size_n; + } +} + +// M-dimension grouped version for better L2 cache locality when M > N +template +__device__ FORCE_INLINE void compute_tile_indices_m_grouped( + int tile_id, + int &tile_m_id, + int &tile_n_id +) { + if constexpr (GROUP_SIZE_M == 0) { + // No swizzle + tile_m_id = tile_id % num_tile_m; + tile_n_id = tile_id / num_tile_m; + } else { + // Swizzle pattern for better L2 cache locality + // Groups tiles in blocks of GROUP_SIZE_M x num_tile_n + constexpr int num_pid_in_group = GROUP_SIZE_M * num_tile_n; + + // Which group does this tile belong to? + const int group_id = tile_id / num_pid_in_group; + + // First M-dimension tile in this group + const int first_pid_m = group_id * GROUP_SIZE_M; + + // Actual group size (handling boundary case) + const int group_size_m = min(GROUP_SIZE_M, num_tile_m - first_pid_m); + + // Position within the group + const int idx_in_group = tile_id % num_pid_in_group; + + // Swizzled tile indices: alternate M then N within group + tile_m_id = first_pid_m + (idx_in_group % group_size_m); + tile_n_id = idx_in_group / group_size_m; + } +} + +const int WORLD_SIZE = 8; +constexpr size_t MAX_IPC_MEM_SIZE = 256 * (1UL << 20); // 256MB +constexpr size_t SIGNAL_BUF_SIZE = 1 * (1UL << 20); // 1MB +constexpr int NUM_THREADS = 256; +constexpr int NUM_SMS = 304; +// Fused kernel: pid < NUM_GEMM_SMS runs GEMM, next NUM_RS_SMS pids run ReduceScatter +template< + int M, int N, int K, bool LOAD_BIAS, + int BM, int BN, int BK, + int NUM_SMS, int NUM_GEMM_SMS, int NUM_RS_SMS, + int NUM_THREADS, int WARP_M, int WARP_N, int GROUP_SIZE_N, + int SPLIT_K = 1 +> +__launch_bounds__(NUM_THREADS) +__global__ void fused_gemm_rs_kernel( + const bfloat16_t *x, // GEMM: M x K + const bfloat16_t *w, // GEMM: N x K + const bfloat16_t *b, // GEMM: N + const std::array c_all, // unified C buffers for all ranks (GEMM writes c_all[rank], RS reads all) + const std::array signal_all, // unified signal buffers for all ranks (GEMM writes signal_all[rank], RS reads all) + int signal_val, + bfloat16_t *rs_out, // RS out: [M / WORLD_SIZE, N] + float *workspace, // temporary FP32 workspace: [SPLIT_K, M, N] + int rank +) { + static_assert(SPLIT_K >= 1, "SPLIT_K must be >= 1"); + static_assert(K % SPLIT_K == 0, "K must be divisible by SPLIT_K"); + + const int pid = __builtin_amdgcn_readfirstlane(blockIdx.x); + const int tid = threadIdx.x; + const int lane_id = __lane_id(); + const int warp_id = __builtin_amdgcn_readfirstlane(tid / roc_isa::AMDGCN_WAVEFRONT_SIZE); + __builtin_assume(pid >= 0 && pid < NUM_SMS); + __builtin_assume(tid >= 0 && tid < NUM_THREADS); + __builtin_assume(lane_id >= 0 && lane_id < 64); + + if (pid < NUM_GEMM_SMS) { + // GEMM Kernel + constexpr int num_tile_m = ceil_div(M, BM); + constexpr int num_tile_n = ceil_div(N, BN); + constexpr int num_tiles = num_tile_m * num_tile_n * SPLIT_K; + // each split handles K_per_split + constexpr int K_per_split = exact_div(); + constexpr int num_tile_k = ceil_div(K_per_split, BK); + + using inst_nums = roc_isa::InstCalculator; + static_assert(BK % 4 == 0 && NUM_THREADS * 4 % BK == 0); + constexpr int WM = 16, WN = 16, WK = 16; + + constexpr int Frag_M = exact_div(); + constexpr int Frag_N = exact_div(); + constexpr int Frag_K = exact_div(); + const int warp_m = warp_id / WARP_N; + const int warp_n = warp_id % WARP_N; + using FragX = bf16x4_t; + using FragW = bf16x4_t; + using FragC = fp32x4_t; + __shared__ bfloat16_t s_x[BM][BK]; + __shared__ bfloat16_t s_w[BN][BK]; + bf16x4_t vgpr_x[ceil_div(BM * BK, NUM_THREADS * 4)]; + bf16x4_t vgpr_w[ceil_div(BN * BK, NUM_THREADS * 4)]; + + FragC frag_c[Frag_M][Frag_N]; + FragX frag_x[Frag_M][Frag_K]; + FragW frag_w[Frag_N][Frag_K]; + fp32x4_t out_fp32[Frag_M][Frag_N]; // AccVGPR -> VGPR Buffer + auto b_arr = ck::make_wave_buffer_resource(const_cast(b), N); + auto c_arr = ck::make_wave_buffer_resource(c_all[rank], M * N); + auto x_arr = ck::make_wave_buffer_resource(const_cast(x), M * K); + auto w_arr = ck::make_wave_buffer_resource(const_cast(w), N * K); + auto *signal_arr = reinterpret_cast(signal_all[rank]); + + + // TODO: optimized use ck::amd_buffer_store + int last_tile_m_id = num_tile_m, last_tile_n_id = num_tile_n; + auto release_signal = [&]() FORCE_INLINE { + if (tid == 0) { + __hip_atomic_store(&signal_arr[last_tile_m_id][last_tile_n_id], signal_val, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM); + } + + }; + + constexpr int NUM_XCDS = 8; + for (int tile_id=pid; tile_id(tile_id / SPLIT_K, tile_m_id, tile_n_id); + int m = tile_m_id * BM; + int n = tile_n_id * BN; + + int k_offset = split_k_id * K_per_split * sizeof(bfloat16_t); + int v_offset = ((tid * 4 / BK) * K + (tid * 4 % BK)) * sizeof(bfloat16_t); + auto load_vgpr = [&](int k) FORCE_INLINE { + uint32_t src_addr_shift = ((K_per_split % BK == 0) || (k + tid * 4 % BK < K_per_split)) ? 0 : 0x80000000; + ck::static_for<0, sizeof(vgpr_x) / sizeof(vgpr_x[0]), 1>{}([&](auto t) { + int s_offset = ((m * K + k) + t * NUM_THREADS * 4 / BK * K) * sizeof(bfloat16_t); + vgpr_x[t] = __builtin_bit_cast(bf16x4_t, ck::amd_buffer_load_impl_raw( + x_arr, v_offset + src_addr_shift, s_offset + k_offset)); + }); + ck::static_for<0, sizeof(vgpr_w) / sizeof(vgpr_w[0]), 1>{}([&](auto t) { + int s_offset = ((n * K + k) + t * NUM_THREADS * 4 / BK * K) * sizeof(bfloat16_t); + vgpr_w[t] = __builtin_bit_cast(bf16x4_t, ck::amd_buffer_load_impl_raw( + w_arr, v_offset + src_addr_shift, s_offset + k_offset)); + }); + + }; + + + auto load_lds = [&]() FORCE_INLINE { + // diagonal swizzle, shape=[16, 64] dtype=bfloat16 + #pragma unroll + for (int t=0;t(&s_x[row0 + row1][col1]) = vgpr_x[t]; + } + #pragma unroll + for (int t=0;t(&s_w[row0 + row1][col1]) = vgpr_w[t]; + } + }; + + auto zero_all_frags = [&]() FORCE_INLINE { + ck::static_for<0, Frag_M, 1>{}([&](auto i) { + ck::static_for<0, Frag_N, 1>{}([&](auto j) { + ck::static_for<0, 4, 1>{}([&](auto t) { frag_c[i][j].x[t] = 0; }); + }); + }); + }; + + + + auto frags_load = [&]() { + #pragma unroll + for (int k=0;k(&s_x[row0 + row1][col1]); + } + } + #pragma unroll + for (int k=0;k(&s_w[row0 + row1][col1]); + } + } + }; + + auto frags_mfma = [&] { + #pragma unroll + for (int i=0;i VGPR + #pragma unroll + for (int i=0; i( + b_arr, b_v_offset, b_s_offset)); + #pragma unroll + for (int t = 0; t < 4; ++t) { + out_fp32[i][j].x[t] += static_cast(LOAD_BIAS ? b_vec.x[t] : 0); + } + } + } + + ck::static_for<0, Frag_M, 1>{}([&](auto i) { + ck::static_for<0, Frag_N, 1>{}([&](auto j) { + bf16x4_t c_out_bf16; + #pragma unroll + for (int t = 0; t < 4; ++t) c_out_bf16.x[t] = fast_f32tob16(out_fp32[i][j].x[t]); + int row = lane_id % 16; + int col = lane_id / 16 * 4; + uint32_t src_addr_shift = (N % BN == 0) || (n + (j + warp_n * Frag_N) * WN + col < N) ? 0 : 0x80000000; + int b_s_offset = (n + (j + warp_n * Frag_N) * WN) * sizeof(bfloat16_t); + int c_s_offset = b_s_offset + (m + (i + warp_m * Frag_M) * WM) * N * sizeof(bfloat16_t); + int c_v_offset = col * sizeof(bfloat16_t) + src_addr_shift + (row * N) * sizeof(bfloat16_t); + ck::amd_buffer_store_impl_raw(c_out_bf16.pack, c_arr, c_v_offset, c_s_offset); + }); + }); + } else { + // SPLIT_K > 1: store FP32 partials into workspace [split_id, M, N] (row-major floats) + auto ws_arr = ck::make_wave_buffer_resource(workspace + split_k_id * M * N, M * N); + ck::static_for<0, Frag_M, 1>{}([&](auto i) { + ck::static_for<0, Frag_N, 1>{}([&](auto j) { + int row = lane_id % 16; + int col = lane_id / 16 * 4; + uint32_t src_addr_shift = (N % BN == 0) || (n + (j + warp_n * Frag_N) * WN + col < N) ? 0 : 0x80000000; + int b_s_offset = (n + (j + warp_n * Frag_N) * WN) * sizeof(float); + int c_s_offset = b_s_offset + (m + (i + warp_m * Frag_M) * WM) * N * sizeof(float); + int c_v_offset = col * sizeof(float) + src_addr_shift + (row * N) * sizeof(float); + ck::amd_buffer_store_impl_raw(out_fp32[i][j].pack, ws_arr, c_v_offset, c_s_offset); + }); + }); + } + }; + + + + + load_vgpr(0); // GDS -> VGPR #0 + load_lds(); // VGPR -> LDS #0 + load_vgpr(1 * BK); // GDS -> VGPR #1 + zero_all_frags(); + // __builtin_amdgcn_s_waitcnt(0x70); + // __builtin_amdgcn_s_barrier(); + block_sync_lds(); + // release_signal(); + frags_load(); // LDS -> FRAG #0 + __builtin_amdgcn_sched_barrier(0); + for (int tile_k_id = 1; tile_k_id < (num_tile_k - 1); ++tile_k_id) { + block_sync_lds(); + // Stage 1 + load_lds(); // VGPR -> LDS #1 + load_vgpr((tile_k_id + 1) * BK); // GDS -> VGPR #2(k+1) + frags_mfma(); // MFMA #0(k-1) + #pragma unroll + for (int k = 0; k < inst_nums::buffer_load_dwordx2; ++k) { + __builtin_amdgcn_sched_group_barrier(0x200, inst_nums::schedule_config::stage1_ds_write, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, inst_nums::schedule_config::stage1_mfma_before_vmem, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, inst_nums::schedule_config::stage1_vmem, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, inst_nums::schedule_config::stage1_mfma_after_vmem, 0); // MFMA + } + block_sync_lds(); + // Stage 2 + frags_load(); // LDS -> FRAG #1(k) + #pragma unroll + for (int k = 0; k < inst_nums::ds_read_b128; ++k) { + __builtin_amdgcn_sched_group_barrier(0x008, inst_nums::schedule_config::stage2_mfma, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, inst_nums::schedule_config::stage2_ds_read, 0); // DS read + } + __builtin_amdgcn_sched_barrier(0); + + } + frags_mfma(); // MFMA #1(n-2) + block_sync_lds(); + load_lds(); // VGPR -> LDS #2(n-1) + block_sync_lds(); + frags_load(); // LDS -> FRAG #2(n-1) + frags_mfma(); // MFMA #2(n-1) + + // __builtin_amdgcn_sched_barrier(0); + store_frags(); + __syncthreads(); + if (tid == 0) { + __hip_atomic_store(&signal_arr[tile_m_id][tile_n_id], signal_val, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM); + } + } + // block_sync_gds(); + // release_signal(); + } else if (pid >= NUM_SMS - NUM_RS_SMS) { + const int pid_rs = pid - (NUM_SMS - NUM_RS_SMS); + const int initial_tile_id = pid_rs * WARP_M * WARP_N + warp_id; + const int batch_tile = WARP_M * WARP_N * NUM_RS_SMS; + constexpr int M_per_rank = M / WORLD_SIZE; + constexpr int num_tile_m = ceil_div(M_per_rank, BM); + constexpr int num_tile_n = ceil_div(N, BN); + i32x4_t swizzle_c[WORLD_SIZE]; // SGPR + int* swizzle_signal = signal_all[(lane_id / 2) % WORLD_SIZE]; // VGPR + // TODO: benchmark PIPELINE_STAGES + constexpr int TOTAL_STAGES = BM * BN / 64 / 8; + constexpr int PIPELINE_STAGES = i_min(4, i_max(TOTAL_STAGES / 2, 1)); + bf16x8_t reg_c_buf[PIPELINE_STAGES][WORLD_SIZE]; + fp32x8_t reg_c_sum[PIPELINE_STAGES]; + ck::static_for<0, WORLD_SIZE, 1>{}([&](auto k) { + int r = __builtin_amdgcn_readfirstlane((pid_rs + k) % WORLD_SIZE); + swizzle_c[k].pack = ck::make_wave_buffer_resource(c_all[r] + rank * M_per_rank * N, M_per_rank * N); + }); + + for (int tile_id=initial_tile_id; tile_id(tile_id, tile_m_id, tile_n_id); + + int signal_id_m_begin = (rank * M_per_rank + tile_m_id * BM) / BM; + int signal_id_m_close = std::min(rank * M_per_rank + M_per_rank - 1, rank * M_per_rank + tile_m_id * BM + BM - 1) / BM; + int signal_id_n = tile_n_id; + +#ifdef DEBUG_SIGNAL_CYCLE + long long begin; + if (M == 8192 && N == 8192 && rank == 0 && tid == 0) { + begin = clock64(); + } +#endif + + // __builtin_amdgcn_sched_barrier(0); + // __builtin_amdgcn_s_setprio(0); + // __builtin_amdgcn_sched_barrier(0); + + if (lane_id < WORLD_SIZE * 2) { + const int signal_id = lane_id % 2; + const int target_rank = lane_id / 2; + if (signal_id_m_begin + signal_id <= signal_id_m_close) { + auto *signal_arr = reinterpret_cast(swizzle_signal); + while (__hip_atomic_load(&signal_arr[signal_id_m_begin + signal_id][signal_id_n], __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM) != signal_val) { + // __builtin_amdgcn_s_sleep(4); + } + } + } + + // __builtin_amdgcn_sched_barrier(0); + // __builtin_amdgcn_s_setprio(1); + // __builtin_amdgcn_sched_barrier(0); + // __threadfence_system(); + + // #pragma unroll + // for (int r = 0; r < WORLD_SIZE; ++r) { + // // int v_offset = signal_id_m_begin + lane_id <= signal_id_m_close ? lane_id * N * sizeof(int) : 0x80000000; + // // int s_offset = (signal_id_m_begin * N + signal_id_n) * sizeof(int); + // // while (!(v_offset == 0x80000000 || ck::amd_buffer_load_impl_raw(swizzle_signal[r].pack, v_offset, s_offset) == signal_val)) { + + // // } + // if (signal_id_m_begin + lane_id <= signal_id_m_close) { + // auto *signal_arr = reinterpret_cast(signal_all[r]); + // while (__hip_atomic_load(&signal_arr[signal_id_m_begin + lane_id][signal_id_n], __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM) != signal_val) {} + // } + // } +#ifdef DEBUG_SIGNAL_CYCLE + if (tid == 0 && M == 8192 && N == 8192 && rank == 0) { + auto duration = clock64() - begin; + printf("rank0 wait signal: %lld cycles\n", duration); + } +#endif + // __threadfence_system(); + + auto buf_load = [&](int t) FORCE_INLINE { + int v_offset = ((lane_id * 8 / BN) * N + (lane_id * 8 % BN)) * sizeof(bfloat16_t); + int s_offset = ((t * 8 * 64 / BN + tile_m_id * BM) * N + (tile_n_id * BN)) * sizeof(bfloat16_t); + uint32_t src_addr_shift = ((N % BN == 0) || ((tile_n_id * BN) + (lane_id * 8 % BN) < N)) ? 0 : 0x80000000; + int stage = t % PIPELINE_STAGES; + ck::static_for<0, WORLD_SIZE, 1>{}([&](auto r) { + reg_c_buf[stage][r] = __builtin_bit_cast(bf16x8_t, ck::amd_buffer_load_impl_raw( + swizzle_c[r].pack, v_offset + src_addr_shift, s_offset)); + }); + }; + + auto compute_store = [&](int t) FORCE_INLINE { + int stage = t % PIPELINE_STAGES; + int v_offset = ((lane_id * 8 / BN) * N + (lane_id * 8 % BN)) * sizeof(bfloat16_t); + int s_offset = ((t * 8 * 64 / BN + tile_m_id * BM) * N + (tile_n_id * BN)) * sizeof(bfloat16_t); + uint32_t src_addr_shift = ((N % BN == 0) || ((tile_n_id * BN) + (lane_id * 8 % BN) < N)) ? 0 : 0x80000000; + ck::static_for<0, WORLD_SIZE, 1>{}([&](auto r) { + ck::static_for<0, 8, 1>{}([&](auto k) { + if constexpr(r == 0) reg_c_sum[stage].x[k] = reg_c_buf[stage][r].x[k]; + else reg_c_sum[stage].x[k] += reg_c_buf[stage][r].x[k]; + }); + }); + auto out_arr = ck::make_wave_buffer_resource(rs_out, M_per_rank * N); + bf16x8_t out_val; + ck::static_for<0, 8, 1>{}([&](auto k) { out_val.x[k] = static_cast(reg_c_sum[stage].x[k]); }); + ck::amd_buffer_store_impl_raw(out_val.pack, out_arr, v_offset + src_addr_shift, s_offset); + }; + + + + #pragma unroll + for (int t = 0; t < PIPELINE_STAGES; ++t) { + buf_load(t); + } + if constexpr (PIPELINE_STAGES < TOTAL_STAGES) { + #pragma unroll PIPELINE_STAGES + for (int t = 0; t < TOTAL_STAGES - PIPELINE_STAGES; t++) { + compute_store(t); + buf_load(t + PIPELINE_STAGES); + } + } + #pragma unroll + for (int t = TOTAL_STAGES - PIPELINE_STAGES; t < TOTAL_STAGES; t++) { + compute_store(t); + } + } + + } +} + +template +__launch_bounds__(BLOCK_SIZE) +__global__ void reduce_kernel(const float *workspace, const bfloat16_t *b, bfloat16_t *c) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int idx = tid * 4; + if (idx >= M * N) return; + static_assert(N % 4 == 0, "N must be multiple of 4"); + auto *ws_arr = reinterpret_cast(workspace); // [SPLITK, M, N] + auto *b_arr = reinterpret_cast(b); // [N] + auto *c_arr = reinterpret_cast(c); // [M, N] + fp32x4_t sum = {}; + int row = idx / N, col = idx % N; + #pragma unroll + for (auto k = 0; k < SPLITK_K; ++k) { + fp32x4_t data = *reinterpret_cast(&ws_arr[k][row][col]); + ck::static_for<0, 4, 1>{}([&](auto i) { sum.x[i] += data.x[i]; }); + } + if constexpr (LOAD_BIAS) { + bf16x4_t bias = *reinterpret_cast(&b_arr[col]); + ck::static_for<0, 4, 1>{}([&](auto i) { sum.x[i] += static_cast(bias.x[i]); }); + } + bf16x4_t out; + ck::static_for<0, 4, 1>{}([&](auto i) { out.x[i] = fast_f32tob16(sum.x[i]); }); + *reinterpret_cast(&c_arr[row][col]) = *reinterpret_cast(&out); +} + +template< + int M, int N, int K, int LOAD_BIAS, + int BM, int BN, int BK, + int NUM_SMS, int NUM_GEMM_SMS, int NUM_RS_SMS, + int NUM_THREADS, int WARP_M, int WARP_N, int GROUP_SIZE_N, + int SPLIT_K = 1 +> +void kernel_launcher( + const bfloat16_t* x, const bfloat16_t* w, const bfloat16_t* b, + const std::array& c_all, + const std::array& signal_all, + int signal_val, bfloat16_t* rs_out, float* workspace, int rank) +{ + dim3 grid(NUM_SMS); + dim3 block(NUM_THREADS); + auto stream = at::cuda::getCurrentHIPStream().stream(); + if constexpr (SPLIT_K == 1) { + hipLaunchKernelGGL( + HIP_KERNEL_NAME(fused_gemm_rs_kernel), + grid, block, 0, stream, + x, w, b, c_all, signal_all, signal_val, rs_out, + /*workspace*/ static_cast(nullptr), + rank); + } else { + hipLaunchKernelGGL( + HIP_KERNEL_NAME(fused_gemm_rs_kernel), + grid, block, 0, stream, + x, w, b, c_all, signal_all, signal_val, rs_out, + workspace, + rank); + constexpr int BLOCK_SIZE = 512; + dim3 block_reduce(BLOCK_SIZE); + dim3 grid_reduce(exact_div()); + hipLaunchKernelGGL( + HIP_KERNEL_NAME(reduce_kernel), + grid, block, 0, stream, + workspace, b, c_all[rank]); + } +} + +#ifndef __PERF_GEMM_HEADER__ + +using KernelFn = void (*)(const bfloat16_t* x, const bfloat16_t* w, const bfloat16_t* b, + const std::array& c_all, + const std::array& signal_all, + int signal_val, bfloat16_t* rs_out, float* workspace, int rank); + +struct KernelRegistery { + std::unordered_map gemm_map; + std::unordered_map rs_map; + std::unordered_map fused_full_map; + + void* workspace_ptr = nullptr; + size_t workspace_size = 0; +}; + +union ShapeKey { + struct { uint16_t M, N, K, PAD = 0; } data; + int64_t key; + static_assert(sizeof(data) == sizeof(key)); +}; + + + +KernelRegistery& get_kernel_registry() { + static KernelRegistery registry; + constexpr bool WB = true; // w/ bias + constexpr bool WO = false; // w/o bias +#ifndef __GPUMODE_BENCHMARK__ +#define REGISTER_KERNELS(M,N,K,LOAD_BIAS,BM,BN,BK,WARP_M,WARP_N,GROUP_SIZE_N,NUM_RS_SMS,SPLIT_K) \ + registry.rs_map[ShapeKey{M,N,0}.key] = kernel_launcher; \ + registry.gemm_map[ShapeKey{M,N,K}.key] = kernel_launcher; \ + registry.fused_full_map[ShapeKey{M,N,K}.key] = kernel_launcher; +// #define REGISTER_KERNELS(M,N,K,LOAD_BIAS,BM,BN,BK,WARP_M,WARP_N,GROUP_SIZE_N,NUM_RS_SMS,SPLIT_K) \ +// registry.gemm_map[ShapeKey{M,N,K}.key] = kernel_launcher; + + + if (__builtin_expect(registry.workspace_ptr == nullptr, false)) { + // minimal registration updated with SPLIT_K (use 1 for now) + // REGISTER_KERNELS(64, 7168, 2304, WB, 32, 64, 64, 2, 2, 16, 8, 1); + // REGISTER_KERNELS(512, 4096, 1536, WB, 64, 128, 64, 2, 2, 16, 8, 1); + // REGISTER_KERNELS(2048, 2880, 360, WB, 128, 128, 64, 2, 2, 16, 8, 1); + // REGISTER_KERNELS(4096, 4096, 512, WB, 224, 256, 64, 2, 2, 16, 8, 1); + // REGISTER_KERNELS(8192, 4096, 1792, WB, 224, 256, 64, 2, 2, 8, 8, 1); + // REGISTER_KERNELS(8192, 8192, 3696, WB, 224, 256, 64, 2, 2, 8, 8, 1); + + // AG + // REGISTER_KERNELS(64, 2304, 7168, WB, 64, 64, 64, 2, 2, 0, 4, 16); + // REGISTER_KERNELS(512, 1536, 4096, WB, 64, 128, 64, 2, 2, 16, 4, 1); + // REGISTER_KERNELS(2048, 360, 2880, WB, 128, 64, 64, 2, 2, 16, 4, 1); + // REGISTER_KERNELS(4096, 512, 4096, WB, 256, 128, 64, 2, 2, 16, 4, 1); + // REGISTER_KERNELS(8192, 1792, 4096, WB, 256, 224, 64, 2, 2, 32, 4, 1); + // REGISTER_KERNELS(8192, 3696, 8192, WB, 256, 224, 64, 2, 2, 32, 0, 1); + + + // REGISTER_KERNELS(8192, 8192, 3696, WB, 256, 224, 64, 2, 2, 16, 8, 1); // 619.00 TFLOPS + + REGISTER_KERNELS(64 , 7168, 2304, WB, 32 , 64 , 128, 2, 2, 8 , 32, 1 ); // 90.92 TFLOPS + REGISTER_KERNELS(512 , 4096, 1536, WB, 128, 64 , 128, 2, 2, 8 , 48, 1 ); // 194.13 TFLOPS + REGISTER_KERNELS(2048, 2880, 360 , WB, 128, 256, 64, 2, 2, 8 , 48, 1 ); // 150.06 TFLOPS + REGISTER_KERNELS(4096, 4096, 512 , WB, 256, 128, 64, 2, 2, 16, 48, 1 ); // 240.03 TFLOPS + REGISTER_KERNELS(8192, 4096, 1792, WB, 224, 256, 64, 2, 2, 16, 32, 1 ); // 491.16 TFLOPS + REGISTER_KERNELS(8192, 8192, 3696, WB, 224, 256, 64, 2, 2, 0 , 8, 1 ); // 620.40 TFLOPS + + // REGISTER_KERNELS(64 , 2304, 7168, WB, 32 , 64 , 128, 2, 2, 0 , 8, 4); // 51.87 TFLOPS + // REGISTER_KERNELS(512 , 1536, 4096, WB, 64 , 64 , 128, 2, 2, 32, 8, 1); // 166.24 TFLOPS + // REGISTER_KERNELS(2048, 360 , 2880, WB, 64 , 64 , 128, 2, 2, 8 , 8, 1); // 161.53 TFLOPS + // REGISTER_KERNELS(4096, 512 , 4096, WB, 128, 64 , 128, 2, 2, 40, 8, 1); // 253.96 TFLOPS + // REGISTER_KERNELS(8192, 1792, 4096, WB, 256, 224, 64, 2, 2, 8 , 8, 1); // 494.63 TFLOPS + // REGISTER_KERNELS(8192, 3696, 8192, WB, 256, 224, 64, 2, 2, 32, 8, 1); // 577.45 TFLOPS + + constexpr size_t PREALLOC_WORKSPACE = 2 * 1024UL * 1024UL * 1024UL; // 2GB + registry.workspace_size = PREALLOC_WORKSPACE; + C10_HIP_CHECK(hipMalloc(®istry.workspace_ptr, registry.workspace_size)); + C10_HIP_CHECK(hipMemset(registry.workspace_ptr, 0, registry.workspace_size)); + } + +#undef REGISTER_KERNELS + +#else +// #define REGISTER_KERNELS(M,N,K,LOAD_BIAS,BM,BN,BK,WARP_M,WARP_N,GROUP_SIZE_N,NUM_RS_SMS,SPLIT_K) \ +// registry.fused_full_map[ShapeKey{M,N,K}.key] = kernel_launcher; + +#define REGISTER_KERNELS(M,N,K,LOAD_BIAS,BM,BN,BK,WARP_M,WARP_N,GROUP_SIZE_N,NUM_RS_SMS,SPLIT_K) \ + registry.rs_map[ShapeKey{M,N,0}.key] = kernel_launcher; \ + registry.gemm_map[ShapeKey{M,N,K}.key] = kernel_launcher; \ + registry.fused_full_map[ShapeKey{M,N,K}.key] = kernel_launcher; + + if (__builtin_expect(registry.gemm_map.empty(), false)) { + // // Online Benchmark parameters here + REGISTER_KERNELS(64 , 7168, 2304, WB, 32 , 64, 128, 2, 2, 8 , 48, 1 ); // 90.92 TFLOPS + REGISTER_KERNELS(512 , 4096, 1536, WB, 128, 64, 128, 2, 2, 8 , 48, 1 ); // 194.13 TFLOPS + REGISTER_KERNELS(2048, 2880, 360 , WB, 128, 256, 64, 2, 2, 8 , 48, 1 ); // 150.06 TFLOPS + REGISTER_KERNELS(4096, 4096, 512 , WB, 256, 128, 64, 2, 2, 16, 48, 1 ); // 240.03 TFLOPS + REGISTER_KERNELS(8192, 4096, 1792, WB, 224, 256, 64, 2, 2, 16, 32, 1 ); // 491.16 TFLOPS + REGISTER_KERNELS(8192, 8192, 3696, WB, 224, 256, 64, 2, 2, 0 , 8, 1 ); // 620.40 TFLOPS + // // New shapes + REGISTER_KERNELS(64, 2880, 360, WB, 32, 64, 64, 2, 2, 8, 16, 1); + REGISTER_KERNELS(64, 3584, 1792, WB, 32, 64, 64, 2, 2, 8, 16, 1); + REGISTER_KERNELS(512, 3584, 1792, WB, 64, 128, 64, 2, 2, 8, 16, 1); + REGISTER_KERNELS(512, 4608, 4608, WB, 64, 128, 64, 2, 2, 8, 16, 1); + REGISTER_KERNELS(2048, 4096, 896, WB, 128, 128, 64, 2, 2, 8, 16, 1); + REGISTER_KERNELS(2048, 8192, 3840, WB, 128, 256, 64, 2, 2, 8, 8, 1); + REGISTER_KERNELS(4096, 2880, 360, WB, 256, 128, 64, 2, 2, 8, 16, 1); + REGISTER_KERNELS(4096, 8192, 256, WB, 224, 256, 64, 2, 2, 8, 8, 1); + REGISTER_KERNELS(8192, 3584, 1792, WB, 224, 256, 64, 2, 2, 8, 8, 1); + REGISTER_KERNELS(8192, 4608, 4608, WB, 224, 256, 64, 2, 2, 8, 8, 1); + REGISTER_KERNELS(8192, 8192, 3584, WB, 224, 256, 64, 2, 2, 8, 8, 1); + + + } +#endif + + + + return registry; +} + + + +torch::Tensor launch_gemm(torch::Tensor &x, torch::Tensor &w, torch::Tensor &b, torch::Tensor &signal, int signal_val, std::optional out_opt) { + auto M = x.size(0); + auto N = w.size(0); + auto K = x.size(1); + auto out = out_opt ? *out_opt : torch::empty({M, N}, x.options()); + auto ®istry = get_kernel_registry(); + auto kernel_it = registry.gemm_map.find(ShapeKey{static_cast(M),static_cast(N),static_cast(K)}.key); + TORCH_CHECK(kernel_it != registry.gemm_map.end(), "Unsupported GEMM size: ", M, "x", N, "x", K); + int device = 0; + C10_HIP_CHECK(hipGetDevice(&device)); + int rank = device; + + std::array c_ptrs{}; + c_ptrs[rank] = reinterpret_cast(out.data_ptr()); + std::array signal_ptrs{}; + signal_ptrs[rank] = reinterpret_cast(signal.data_ptr()); + + kernel_it->second( + reinterpret_cast(x.const_data_ptr()), + reinterpret_cast(w.const_data_ptr()), + reinterpret_cast(b.const_data_ptr()), + c_ptrs, signal_ptrs, signal_val, + reinterpret_cast(out.data_ptr()), + static_cast(registry.workspace_ptr), + rank); + return out; +} + +torch::Tensor launch_reduce_scatter(std::array &c, std::array &signal, int signal_val, int rank) { + auto M = c[0].size(0); + auto N = c[0].size(1); + auto out = torch::empty({M / WORLD_SIZE, N}, c[0].options()); + auto ®istry = get_kernel_registry(); + auto key = ShapeKey{static_cast(M),static_cast(N),0}.key; + auto it = registry.rs_map.find(key); + TORCH_CHECK(it != registry.rs_map.end(), "Unsupported ReduceScatter size: ", M, "x", N); + + std::array c_ptrs{}; + std::array signal_ptrs{}; + for (int i = 0; i < WORLD_SIZE; i++) { + c_ptrs[i] = reinterpret_cast(c[i].data_ptr()); + signal_ptrs[i] = reinterpret_cast(signal[i].data_ptr()); + } + + it->second(/*x*/static_cast(nullptr), + /*w*/static_cast(nullptr), + /*b*/static_cast(nullptr), + c_ptrs, signal_ptrs, signal_val, + reinterpret_cast(out.data_ptr()), + static_cast(registry.workspace_ptr), + rank); + return out; +} + +torch::Tensor launch_fused( + torch::Tensor &x, + torch::Tensor &w, + torch::Tensor &b, + std::array &c, // unified C buffers for all ranks + std::array &signal, // unified signal buffers for all ranks + int signal_val, + int rank +) { + dim3 grid(NUM_SMS); + dim3 block(NUM_THREADS); + + auto M = x.size(0); + auto N = w.size(0); + auto K = x.size(1); + auto rs_out = torch::empty({M / WORLD_SIZE, N}, c.front().options()); + auto ®istry = get_kernel_registry(); + auto key = ShapeKey{static_cast(M),static_cast(N),static_cast(K)}.key; + auto it = registry.fused_full_map.find(key); + std::array c_ptrs{}; + std::array signal_ptrs{}; + for (int i = 0; i < WORLD_SIZE; i++) { + c_ptrs[i] = reinterpret_cast(c[i].data_ptr()); + signal_ptrs[i] = reinterpret_cast(signal[i].data_ptr()); + } + it->second( + reinterpret_cast(x.const_data_ptr()), + reinterpret_cast(w.const_data_ptr()), + reinterpret_cast(b.const_data_ptr()), + c_ptrs, signal_ptrs, signal_val, + reinterpret_cast(rs_out.data_ptr()), + static_cast(registry.workspace_ptr), + rank); + return rs_out; +} + + +class GemmRS { +private: + int rank_; + int world_size_; + void *ipc_mems_[WORLD_SIZE]; + void *sig_buf_[WORLD_SIZE]; + + void check_device() { + int device; + C10_HIP_CHECK(hipGetDevice(&device)); + TORCH_CHECK(device == rank_); + } + +public: + GemmRS(int rank, int world_size): rank_(rank), world_size_(world_size) { + C10_HIP_CHECK(hipExtMallocWithFlags(&ipc_mems_[rank_], MAX_IPC_MEM_SIZE, hipDeviceMallocUncached)); + // C10_HIP_CHECK(hipMalloc(&ipc_mems_[rank_], MAX_IPC_MEM_SIZE)); + C10_HIP_CHECK(hipMemset(ipc_mems_[rank_], 0, MAX_IPC_MEM_SIZE)); + sig_buf_[rank_] = reinterpret_cast(ipc_mems_[rank_]) + MAX_IPC_MEM_SIZE - SIGNAL_BUF_SIZE; + TORCH_CHECK(world_size_ == WORLD_SIZE, "Only support world_size = ", WORLD_SIZE); + } + ~GemmRS() {} + + pybind11::bytearray get_ipc_handle() { + check_device(); + hipIpcMemHandle_t ipc_handle; + C10_HIP_CHECK(hipIpcGetMemHandle(&ipc_handle, ipc_mems_[rank_])); + return {ipc_handle.reserved, HIP_IPC_HANDLE_SIZE}; + } + + auto init_dist(const std::vector &ipc_handles) { + check_device(); + int world_size = ipc_handles.size(); + TORCH_CHECK(world_size == WORLD_SIZE, "Mismatched world size"); + for (int i = 0; i < world_size; i++) { + if (i == rank_) continue; + hipIpcMemHandle_t handle; + auto handle_buf = std::string(ipc_handles[i]); + TORCH_CHECK(handle_buf.size() == HIP_IPC_HANDLE_SIZE); + std::memcpy(handle.reserved, handle_buf.data(), HIP_IPC_HANDLE_SIZE); + C10_HIP_CHECK(hipIpcOpenMemHandle(&ipc_mems_[i], handle, hipIpcMemLazyEnablePeerAccess)); + sig_buf_[i] = reinterpret_cast(ipc_mems_[i]) + MAX_IPC_MEM_SIZE - SIGNAL_BUF_SIZE; // last for signal + } + } + + std::array get_c_tensors(int M, int N) { + check_device(); + TORCH_CHECK(M % world_size_ == 0, "M must be divisible by world_size"); + + std::array c; + for (int i = 0; i < world_size_; i++) { + TORCH_CHECK(ipc_mems_[i] != nullptr, "IPC memory not initialized"); + c[i] = torch::from_blob(ipc_mems_[i], {M, N}, torch::TensorOptions().dtype(torch::kBFloat16).device(torch::kCUDA, rank_)); + } + return c; + } + + std::array get_signal_tensors() { + check_device(); + std::array signal; + for (int i = 0; i < world_size_; i++) { + TORCH_CHECK(sig_buf_[i] != nullptr, "Signal buffer not initialized"); + signal[i] = torch::from_blob(sig_buf_[i], {1}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA, rank_)); + } + return signal; + } + +}; + + +PYBIND11_MODULE(perf_gemm, m) { + m.def("launch_gemm", &launch_gemm, "Launch GEMM kernel", + pybind11::arg("x"), pybind11::arg("w"), pybind11::arg("b"), pybind11::arg("signal"), pybind11::arg("signal_val"), pybind11::arg("out") = std::nullopt); + m.def("launch_reduce_scatter", &launch_reduce_scatter, "Launch ReduceScatter kernel", + pybind11::arg("c"), pybind11::arg("signal"), pybind11::arg("signal_val"), pybind11::arg("rank")); + m.def("launch_fused", &launch_fused, "Launch fused GEMM+RS kernel", + pybind11::arg("x"), pybind11::arg("w"), pybind11::arg("b"), + pybind11::arg("c"), pybind11::arg("signal"), + pybind11::arg("signal_val"), pybind11::arg("rank")); + m.def("__debug_get_workspace_tensor", [](int M, int N, int split_k){ + auto ®istry = get_kernel_registry(); + TORCH_CHECK(registry.workspace_ptr != nullptr, "Workspace not initialized"); + size_t required_size = static_cast(M) * static_cast(N) * sizeof(float) * split_k; + TORCH_CHECK(required_size <= registry.workspace_size, "Requested workspace size ", required_size, " exceeds preallocated size ", registry.workspace_size); + return torch::from_blob(registry.workspace_ptr, {static_cast(split_k), M, N}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)); + }); + pybind11::class_(m, "GemmRS") + .def(pybind11::init(), pybind11::arg("rank"), pybind11::arg("world_size")) + .def("get_ipc_handle", &GemmRS::get_ipc_handle, "Get IPC handle for the current rank") + .def("init_dist", &GemmRS::init_dist, "Initialize distributed GemmRS with IPC handles") + .def("get_c_tensors", &GemmRS::get_c_tensors, "Get tensors for C matrices") + .def("get_signal_tensors", &GemmRS::get_signal_tensors, "Get tensors for signal buffers"); +} + +#endif // __PERF_GEMM_HEADER__ diff --git a/dist-infer/gemm-rs/src/perf_gemm.h b/dist-infer/gemm-rs/src/perf_gemm.h new file mode 100644 index 0000000..5d35966 --- /dev/null +++ b/dist-infer/gemm-rs/src/perf_gemm.h @@ -0,0 +1,303 @@ +#pragma once +#include "common.h" +#include +#include +#include +#include + + +namespace perf_gemm { + + +__device__ FORCE_INLINE inline void block_sync_lds() { + __builtin_amdgcn_s_waitcnt(0xc07f); + __builtin_amdgcn_s_barrier(); +} + + +template +__device__ FORCE_INLINE inline void compute_tile_indices( + int tile_id, + int &tile_m_id, + int &tile_n_id +) { + static_assert(GROUP_SIZE_M % 8 == 0); + if constexpr (GROUP_SIZE_M > 0) { + // Swizzle pattern for better L2 cache locality + // Groups tiles in blocks of GROUP_SIZE_M x num_tile_n + constexpr int num_pid_in_group = GROUP_SIZE_M * num_tile_n; + + // Which group does this tile belong to? + const int group_id = tile_id / num_pid_in_group; + + // First M-dimension tile in this group + const int first_pid_m = group_id * GROUP_SIZE_M; + + // Actual group size (handling boundary case) + const int group_size_m = min(GROUP_SIZE_M, num_tile_m - first_pid_m); + + // Position within the group + const int idx_in_group = tile_id % num_pid_in_group; + + // Swizzled tile indices: alternate M then N within group + tile_m_id = first_pid_m + (idx_in_group % group_size_m); + tile_n_id = idx_in_group / group_size_m; + } else { + tile_m_id = tile_id / num_tile_n; + tile_n_id = tile_id % num_tile_n; + } +} + +template +int __device__ remap_xcd_pid(int pid) { + if constexpr(REMAP_XCD) { + return (pid % NUM_XCDS) * (NUM_GEMM_SMS / NUM_XCDS) + (pid / NUM_XCDS); + } else { + return pid; + } +} + + +struct EpilogueNOP { + void operator()(int tid, int tile_m, int tile_n, void *signal_ptr) const {} +}; + +template< + int M, int N, int K, + int BM, int BN, int BK, + int NUM_SMS, int NUM_GEMM_SMS, int NUM_THREADS, + int WARP_M, int WARP_N, + int GROUP_SIZE_M, bool REMAP_XCD, + typename Epilogue = EpilogueNOP +> +__device__ void FORCE_INLINE gemm_kernel( + const bfloat16_t *x, // M x K + const bfloat16_t *w, // N x K + const bfloat16_t *b, // N + bfloat16_t *c, // M x N + int *signal_ptr, + int signal_val +) { + constexpr int num_tile_m = ceil_div(M, BM); + constexpr int num_tile_n = ceil_div(N, BN); + constexpr int num_tile_k = ceil_div(K, BK); + constexpr int num_tiles = num_tile_m * num_tile_n; + const int pid = remap_xcd_pid(threadIdx.x); + const int tid = threadIdx.x; + const int lane_id = __lane_id(); + __builtin_assume(pid >= 0 && pid < NUM_SMS); + __builtin_assume(tid >= 0 && tid < NUM_THREADS); + __builtin_assume(lane_id >= 0 && lane_id < 64); + // each thread load 4 elements + static_assert(BK % 4 == 0 && NUM_THREADS * 4 % BK == 0); + constexpr int WM = 16, WN = 16, WK = 16; + + constexpr int Frag_M = exact_div(); + constexpr int Frag_N = exact_div(); + constexpr int Frag_K = exact_div(); + const int warp_id = __builtin_amdgcn_readfirstlane(tid / AMDGCN_WAVEFRONT_SIZE); + const int warp_m = warp_id / WARP_N; + const int warp_n = warp_id % WARP_N; + using FragX = bf16x4_t; + using FragW = bf16x4_t; + using FragC = fp32x4_t; + __shared__ bfloat16_t s_x[BM][BK]; + __shared__ bfloat16_t s_w[BN][BK]; + bf16x4_t vgpr_x[ceil_div(BM * BK, NUM_THREADS * 4)]; + bf16x4_t vgpr_w[ceil_div(BN * BK, NUM_THREADS * 4)]; + + FragC frag_c[Frag_M][Frag_N]; + FragX frag_x[Frag_M][Frag_K]; + FragW frag_w[Frag_N][Frag_K]; + + + constexpr int v_mfma_f32_16x16x16_bf16 = (BM * BN * BK) / (WARP_M * WARP_N) / (16*16*16); + // Compiler will merge two ds_{read,write}_b64 to ds_{read,write`}2st64_b64 + constexpr int ds_read_b128_a = (BM * BK / WARP_M) / 64 / 8; + constexpr int ds_read_b128_b = (BN * BK / WARP_N) / 64 / 8; + constexpr int ds_read_b128 = ds_read_b128_a + ds_read_b128_b; + constexpr int ds_write_b128_a = (BM * BK) / NUM_THREADS / 8; + constexpr int ds_write_b128_b = (BN * BK) / NUM_THREADS / 8; + constexpr int ds_write_b128 = ds_write_b128_a + ds_write_b128_b; + constexpr int buffer_load_dwordx2_a = (BM * BK) / NUM_THREADS / 4; + constexpr int buffer_load_dwordx2_b = (BN * BK) / NUM_THREADS / 4; + constexpr int buffer_load_dwordx2 = buffer_load_dwordx2_a + buffer_load_dwordx2_b; + + + auto load_vgpr = [&](int m, int n, int k) FORCE_INLINE { + auto x_arr = ck::make_wave_buffer_resource(const_cast(x), M * K); + auto w_arr = ck::make_wave_buffer_resource(const_cast(w), N * K); + int v_offset = ((tid * 4 / BK) * K + (tid * 4 % BK)) * sizeof(bfloat16_t); + uint32_t src_addr_shift = (K % BK == 0) || (k + tid * 4 % BK < K) ? 0 : 0x80000000; + ck::static_for<0, sizeof(vgpr_x) / sizeof(vgpr_x[0]), 1>{}([&](auto t) { + int s_offset = ((m * K + k) + t * NUM_THREADS * 4 / BK * K) * sizeof(bfloat16_t); + vgpr_x[t] = __builtin_bit_cast(bf16x4_t, ck::amd_buffer_load_impl_raw( + x_arr, v_offset + src_addr_shift, s_offset)); + }); + ck::static_for<0, sizeof(vgpr_w) / sizeof(vgpr_w[0]), 1>{}([&](auto t) { + int s_offset = ((n * K + k) + t * NUM_THREADS * 4 / BK * K) * sizeof(bfloat16_t); + vgpr_w[t] = __builtin_bit_cast(bf16x4_t, ck::amd_buffer_load_impl_raw( + w_arr, v_offset + src_addr_shift, s_offset)); + }); + + }; + + auto load_lds = [&]() FORCE_INLINE { + // diagonal swizzle, shape=[16, 64] dtype=bfloat16 + #pragma unroll + for (int t=0;t(&s_x[row0 + row1][col1]) = vgpr_x[t]; + } + #pragma unroll + for (int t=0;t(&s_w[row0 + row1][col1]) = vgpr_w[t]; + } + }; + + auto zero_all_frags = [&]() FORCE_INLINE { + ck::static_for<0, Frag_M, 1>{}([&](auto i) { + ck::static_for<0, Frag_N, 1>{}([&](auto j) { + ck::static_for<0, 4, 1>{}([&](auto t) { frag_c[i][j].x[t] = 0; }); + }); + }); + }; + + + auto frags_load = [&]() FORCE_INLINE { + ck::static_for<0, Frag_K, 1>{}([&](auto k) { + ck::static_for<0, Frag_M, 1>{}([&](auto i) { + const int row1 = (warp_m * Frag_M + i) * WM; + const int row0 = lane_id % 16; + const int col0 = k * 16 + lane_id / 16 * 4; + const int col1 = (row0 * 4 + col0) % BK; + frag_x[i][k] = *reinterpret_cast(&s_x[row0 + row1][col1]); + }); + ck::static_for<0, Frag_N, 1>{}([&](auto j) { + const int row1 = (warp_n * Frag_N + j) * WN; + const int row0 = lane_id % 16; + const int col0 = k * 16 + lane_id / 16 * 4; + const int col1 = (row0 * 4 + col0) % BK; + frag_w[j][k] = *reinterpret_cast(&s_w[row0 + row1][col1]); + }); + }); + }; + + auto frags_mfma = [&]() FORCE_INLINE { + ck::static_for<0, Frag_M, 1>{}([&](auto i) { + ck::static_for<0, Frag_N, 1>{}([&](auto j) { + ck::static_for<0, Frag_K, 1>{}([&](auto k) { + // a: [16][16], b: [16][16], c: [16][16] + // mfma requires a: row-major, b: col-major, out: col-major + // so we compute w^T * x^T = c^T so we can treat out as col-major + frag_c[i][j].pack = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(frag_w[j][k].pack, frag_x[i][k].pack, frag_c[i][j].pack, 0, 0, 0); + }); + }); + }); + }; + + + + auto store_frags = [&](int m, int n) FORCE_INLINE { + auto b_arr = ck::make_wave_buffer_resource(const_cast(b), N); + auto c_arr = ck::make_wave_buffer_resource(c, M * N); + fp32x4_t c_out[Frag_M][Frag_N]; + ck::static_for<0, Frag_M, 1>{}([&](auto i) { + ck::static_for<0, Frag_N, 1>{}([&](auto j) { + ck::static_for<0, 4, 1>{}([&](auto t) { + // v_accvgpr_read_b32 + c_out[i][j].x[t] = frag_c[i][j].x[t]; + }); + // c_out: [16][16] + int row = lane_id % 16; + int col = lane_id / 16 * 4; + uint32_t src_addr_shift = (N % BN == 0) || (n + (j + warp_n * Frag_N) * WN + col < N) ? 0 : 0x80000000; + // load b + int b_s_offset = (n + (j + warp_n * Frag_N) * WN) * sizeof(bfloat16_t); + int b_v_offset = col * sizeof(bfloat16_t) + src_addr_shift; + auto b_vec = __builtin_bit_cast(bf16x4_t, ck::amd_buffer_load_impl_raw( + b_arr, b_v_offset, b_s_offset)); + // compute c + bf16x4_t c_out_bf16; + #pragma unroll + for (int t = 0; t < 4; ++t) { + c_out_bf16.x[t] = fast_f32tob16(c_out[i][j].x[t] + b_vec.x[t]); + } + // write c + int c_s_offset = b_s_offset + (m + (i + warp_m * Frag_M) * WM) * N * sizeof(bfloat16_t); + int c_v_offset = b_v_offset + (row * N) * sizeof(bfloat16_t); + ck::amd_buffer_store_impl_raw(c_out_bf16.pack, c_arr, c_v_offset, c_s_offset); + }); + }); + }; + + + + for (int tile_id=pid; tile_id(tile_id, tile_m_id, tile_n_id); + int m = tile_m_id * BM; + int n = tile_n_id * BN; + load_vgpr(m, n, 0); // GDS -> VGPR #0 + load_lds(); // VGPR -> LDS #0 + load_vgpr(m, n, 1 * BK); // GDS -> VGPR #1 + zero_all_frags(); + block_sync_lds(); + frags_load(); // LDS -> FRAG #0 + __builtin_amdgcn_sched_barrier(0); + // #pragma clang loop unroll_count(2) + // #pragma unroll 2 + // #pragma unroll + for (int tile_k_id = 1; tile_k_id < (num_tile_k - 1); ++tile_k_id) { + asm volatile(R"( + ; Main Loop Begin + )" ::: "memory"); + block_sync_lds(); + // Stage 1 + load_lds(); // VGPR -> LDS #1 + load_vgpr(m, n, (tile_k_id + 1) * BK); // GDS -> VGPR #2(k+1) + frags_mfma(); // MFMA #0(k-1) + // 120 + #pragma unroll + for (int k = 0; k < buffer_load_dwordx2 / 2; ++k) { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA + } + + block_sync_lds(); + // Stage 2 + frags_load(); // LDS -> FRAG #1(k) + // 60 + #pragma unroll + for (int k = 0; k < ds_read_b128; ++k) { + __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + } + __builtin_amdgcn_sched_barrier(0); + asm volatile(R"( + ; Main Loop End + )" ::: "memory"); + } + frags_mfma(); // MFMA #1(n-2) + block_sync_lds(); + load_lds(); // VGPR -> LDS #2(n-1) + block_sync_lds(); + frags_load(); // LDS -> FRAG #2(n-1) + frags_mfma(); // MFMA #2(n-1) + store_frags(m, n); + // Epilogue{}(tid, tile_m_id, tile_n_id, signal_ptr, signal_val); + } +} + +} // namespace perf_gemm \ No newline at end of file diff --git a/dist-infer/gemm-rs/submit.py b/dist-infer/gemm-rs/submit.py new file mode 100644 index 0000000..febfc2f --- /dev/null +++ b/dist-infer/gemm-rs/submit.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +import sys +import os +import subprocess +import datetime +import time + +def main(): + timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M:%S-%f") + logfile = f"logs/gemm-rs-{timestamp}.log" + + os.makedirs("logs", exist_ok=True) + + print(f"submiting, log file: {logfile}") + + cmd = [ + "popcorn-cli", "submit", + "--gpu", "MI300x8", + "--leaderboard", "amd-gemm-rs", + "--mode", "benchmark", + "submission.py", + "-o", logfile, + ] + + start = time.time() + + timeout = 180 + try: + # Use default stdio; remove invalid stdout=subprocess.STDOUT + subprocess.run(cmd, timeout=timeout, check=True) + except subprocess.TimeoutExpired: + print(f"Error: Command timed out after {timeout}s", file=sys.stderr) + sys.exit(1) + except FileNotFoundError: + print(f"Error: Command not found: {cmd[0]}", file=sys.stderr) + sys.exit(1) + except subprocess.CalledProcessError as e: + print(f"Error: Command failed with exit code {e.returncode}", file=sys.stderr) + sys.exit(e.returncode) + + with open(logfile) as f: + print(f.read()) + + end = time.time() + print(f"submit done, time cost: {end - start:.2f}s") + +if __name__ == "__main__": + main() + diff --git a/dist-infer/gemm-rs/task.py b/dist-infer/gemm-rs/task.py new file mode 100644 index 0000000..1de3edd --- /dev/null +++ b/dist-infer/gemm-rs/task.py @@ -0,0 +1,14 @@ +from typing import TypedDict, TypeVar, Tuple, Optional +import torch + +input_t = TypeVar("input_t", bound=Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]) +output_t = TypeVar("output_t", bound=torch.Tensor) + + +class TestSpec(TypedDict): + world_size: int + m: int + n: int + k: int + has_bias: bool + seed: int diff --git a/dist-infer/gemm-rs/task.yml b/dist-infer/gemm-rs/task.yml new file mode 100644 index 0000000..6eac274 --- /dev/null +++ b/dist-infer/gemm-rs/task.yml @@ -0,0 +1,85 @@ +# name: gemm-rs + +files: + - {"name": "submission.py", "source": "@SUBMISSION@"} + - {"name": "task.py", "source": "task.py"} + - {"name": "utils.py", "source": "../utils.py"} + - {"name": "reference.py", "source": "reference.py"} + - {"name": "eval.py", "source": "../eval.py"} + +lang: "py" +multi_gpu: true + +description: | + Implement a Gemm-ReduceScatter kernel on a single MI300X node. + + Gemm-ReduceScatter is a technique that combines the ReduceScatter + communication pattern with General Matrix Multiplication (GEMM) to optimize + the performance of transformer models on GPUs. It is particularly useful for + handling large models that exceed the memory capacity of a single GPU by + distributing the model across multiple GPUs and efficiently scattering the + results of matrix multiplications. + + Your task: + - Implement the Gemm-RS kernel to perform matrix multiplications in a + distributed manner, leveraging the ReduceScatter operation to distribute + data across multiple GPUs. + - Ensure that the implementation is optimized for the MI300X architecture, + taking advantage of its specific hardware features for maximum performance. + + Input: + - `data`: Tuple of (input: torch.Tensor, weights: torch.Tensor, + bias: Optional, None or torch.Tensor) + - input: Local input tensor of shape [M, local_K]. + - weight: Weight tensor of shape [N, local_K]. + - bias: bias tensor of shape [N] or None. + + Output: + - Tuple containing: + - output: Resulting tensor of shape [M // world_size, N] + + The ranking criteria is the geometric mean of the benchmark results. + + For the grand price, your kernel will be evaluated against the speed of light + analysis and AMD implementations, the solution closest to the speed of light + and AMD implementations will be awarded the grand price. + ``` + The speed of light analysis is: + m n k has_bias time[us] + 64 7168 18432 False 6.46 + 512 4096 12288 True 8.19 + 2048 2880 2880 True 23.04 + 4096 4096 4096 False 65.54 + 8192 4096 14336 True 131.07 + 8192 8192 29568 False 379.43 + ``` +config: + main: "eval.py" + +templates: + Python: "submission.py" + +ranking_by: "geom" +ranked_timeout: 360 # just in case + +tests: + - {"world_size": 8, "m": 64, "n": 2880, "k": 2880, "has_bias": True, "seed": 2035} + - {"world_size": 8, "m": 64, "n": 3584, "k": 14336, "has_bias": True, "seed": 13} + - {"world_size": 8, "m": 512, "n": 3584, "k": 14336, "has_bias": True, "seed": 4297} + - {"world_size": 8, "m": 512, "n": 4608, "k": 36864, "has_bias": False, "seed": 1597} + - {"world_size": 8, "m": 2048, "n": 4096, "k": 7168, "has_bias": False, "seed": 716} + - {"world_size": 8, "m": 2048, "n": 8192, "k": 30720, "has_bias": False, "seed": 20201} + - {"world_size": 8, "m": 4096, "n": 2880, "k": 2880, "has_bias": True, "seed": 136} + - {"world_size": 8, "m": 4096, "n": 8192, "k": 2048, "has_bias": True, "seed": 138} + - {"world_size": 8, "m": 8192, "n": 3584, "k": 14336, "has_bias": True, "seed": 748} + - {"world_size": 8, "m": 8192, "n": 4608, "k": 36864, "has_bias": True, "seed": 4422} + - {"world_size": 8, "m": 8192, "n": 8192, "k": 28672, "has_bias": False, "seed": 1536} + + +benchmarks: + - {"world_size": 8, "m": 64, "n": 7168, "k": 18432, "has_bias": False, "seed": 1234} + - {"world_size": 8, "m": 512, "n": 4096, "k": 12288, "has_bias": True, "seed": 663} + - {"world_size": 8, "m": 2048, "n": 2880, "k": 2880, "has_bias": True, "seed": 166} + - {"world_size": 8, "m": 4096, "n": 4096, "k": 4096, "has_bias": False, "seed": 1371} + - {"world_size": 8, "m": 8192, "n": 4096, "k": 14336, "has_bias": True, "seed": 7168} + - {"world_size": 8, "m": 8192, "n": 8192, "k": 29568, "has_bias": False, "seed": 42} diff --git a/dist-infer/gemm-rs/template.py b/dist-infer/gemm-rs/template.py new file mode 100644 index 0000000..79c9e40 --- /dev/null +++ b/dist-infer/gemm-rs/template.py @@ -0,0 +1,241 @@ +#!POPCORN leaderboard amd-gemm-rs +import sys +import time +from task import input_t, output_t +from typing import Optional +import torch +import torch.distributed as dist +from torch.utils.cpp_extension import load_inline +import zlib +import base64 +import os +os.environ.update( + { + "HSA_XNACK": "0", + "CXX": "clang++", + "PYTORCH_ROCM_ARCH": "gfx942", + } +) + + +CPP_WRAPPER = "" +CUDA_SRC = "" + +module = load_inline( + name='perf_gemm', + cpp_sources=[CPP_WRAPPER], + cuda_sources=[CUDA_SRC], + verbose=True, + extra_cuda_cflags=["--offload-arch=gfx942", "-std=c++20", "-U__HIP_NO_HALF_OPERATORS__", "-U__HIP_NO_HALF_CONVERSIONS__", "-D__GPUMODE_BENCHMARK__"], + extra_cflags=["-Ofast", "-ffast-math", "-march=native", "-funroll-loops", "-fomit-frame-pointer"], +) + +comm = None +should_udpate_comm = True +round_trip = 0 + +orignal_init_pg = dist.init_process_group + +def hooked_init_pg(*args, **kwargs): + global should_update_comm + should_update_comm = True + ret = orignal_init_pg(*args, **kwargs) + # print0(f"init pg: {args}, {kwargs}", True) + return ret + +def dist_print(*args, **kwargs): + if not dist.is_initialized() or dist.get_rank() == 0: + print(*args, **kwargs, flush=True) + +def dist_print_err(*args, **kwargs): + if not dist.is_initialized() or dist.get_rank() == 0: + print(*args, **kwargs, flush=True, file=sys.stderr) + + + +def all_get_comm(): + global comm, should_update_comm, round_trip + rank = dist.get_rank() + torch.cuda.set_device(rank) + if should_update_comm: + should_update_comm = False + round_trip = 0 + # create a new comm + world_size = dist.get_world_size() + dist_print(f"create new comm: {rank} {world_size}") + del comm + comm = module.GemmRS(rank, world_size) + ipc_handle = comm.get_ipc_handle() + ipc_handles = [None] * world_size + dist.all_gather_object(ipc_handles, ipc_handle) + comm.init_dist(ipc_handles) + dist.barrier() + torch.cuda.synchronize() + round_trip += 1 + return comm, round_trip + +dist.init_process_group = hooked_init_pg + + +def ref_kernel(data: input_t) -> output_t: + """ + Reference kernel for Gemm-ReduceScatter operation. + + Args: + data: Tuple of (input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]) + - input: Local input tensor of shape [M, local_K]. + - weight: Weight tensor of shape [N, local_K]. + - bias: Optional bias tensor of shape [N] or None. + Returns: + Tuple containing: + - output: Resulting tensor of shape [M // world_size, N]. + """ + input, weight, bias = data + M, local_K = input.shape + N = weight.shape[0] + world_size = torch.distributed.get_world_size() + # matmul + output = torch.matmul(input, weight.T) + if bias is not None: + output = output + bias + # reduce scatter + rs_output = torch.empty((M // world_size, N), dtype=output.dtype, device=input.device) + torch.distributed.reduce_scatter_tensor(rs_output, output) + return rs_output + +def ref_kernel_debug(data: input_t) -> output_t: + input, weight, bias = data + M, local_K = input.shape + N = weight.shape[0] + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + t0 = time.perf_counter() + torch.cuda.synchronize() + output = torch.matmul(input, weight.T) + if bias is not None: + output = output + bias + torch.cuda.synchronize() + t1 = time.perf_counter() + torch.distributed.barrier() + rs_output = torch.empty((M // world_size, N), dtype=output.dtype, device=input.device) + torch.distributed.reduce_scatter_tensor(rs_output, output) + torch.cuda.synchronize() + t2 = time.perf_counter() + comm, round_trip = all_get_comm() + if round_trip < 8: + dist_print_err(f"rank {rank} gemm time: {((t1 - t0)*1e6):.2f}μs, perf: {((M*N*local_K*2)/(t1 - t0) * 1e-12):.2f}TFlops/s") + dist_print_err(f"shape: {M}x{N}x{local_K}, time: {((t2 - t1)*1e6):.2f}μs, perf: {((M*N*2)/(t2 - t1) * 1e-9):.2f}GB/s") + return rs_output + +def gemm_then_rs_kernel_debug(data: input_t) -> output_t: + input, weight, bias = data + M, local_K = input.shape + N = weight.shape[0] + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + comm, round_trip = all_get_comm() + ipc_tensor = comm.get_c_tensors(M, N) + signal_tensor = comm.get_signal_tensors() + t0 = time.perf_counter() + torch.cuda.synchronize() + output = module.launch_gemm(input, weight, bias, signal_tensor[rank], round_trip, ipc_tensor[rank]) + torch.cuda.synchronize() + t1 = time.perf_counter() + if round_trip < 8: + dist_print_err(f"shape: {M}x{N}x{local_K}, gemm time: {((t1 - t0)*1e6):.2f}μs, perf: {((M*N*local_K*2)/(t1 - t0) * 1e-12):.2f}TFlops/s") + torch.distributed.barrier() + + output = module.launch_reduce_scatter(ipc_tensor, signal_tensor, round_trip, rank) + torch.cuda.synchronize() + t2 = time.perf_counter() + if round_trip < 8: + dist_print_err(f"shape: {M}x{N}x{local_K}, rs time: {((t2 - t1)*1e6):.2f}μs, perf: {((M*N*2)/(t2 - t1) * 1e-9):.2f}GB/s") + return output + rs_output = ref_kernel(data) + if not torch.allclose(output, rs_output, atol=1e-2, rtol=1e-2): + dist_print_err("mismatch in gemm_then_rs_kernel") + dist_print_err("output:", output) + dist_print_err("rs_output:", rs_output) + diff = torch.abs(output - rs_output) + dist_print_err("mismatch count:", torch.sum(diff > 1e-2).item()) + dist_print_err("diff:", diff) + dist_print_err("max diff:", torch.max(diff).item()) + return rs_output + +def gemm_then_rs_kernel(data: input_t) -> output_t: + input, weight, bias = data + M, local_K = input.shape + N = weight.shape[0] + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + comm, round_trip = all_get_comm() + ipc_tensor = comm.get_c_tensors(M, N) + signal_tensor = comm.get_signal_tensors() + output = module.launch_gemm(input, weight, bias, signal_tensor[rank], round_trip, ipc_tensor[rank]) + output = module.launch_reduce_scatter(ipc_tensor, signal_tensor, round_trip, rank) + return output + + +def gemm_rs_kernel(data: input_t) -> output_t: + input, weight, bias = data + M, _ = input.shape + N = weight.shape[0] + rank = torch.distributed.get_rank() + comm, round_trip = all_get_comm() + ipc_tensor = comm.get_c_tensors(M, N) + signal_tensor = comm.get_signal_tensors() + rs_output = module.launch_fused(input, weight, bias, ipc_tensor, signal_tensor, round_trip, rank) + # dist_print_err(f"gemm_rs_kernel {input.shape}x{weight.shape} rank {rank} round_trip {round_trip}") + return rs_output + +def gemm_rs_kernel_debug(data: input_t) -> output_t: + input, weight, bias = data + M, local_K = input.shape + N = weight.shape[0] + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + comm, round_trip = all_get_comm() + ipc_tensor = comm.get_c_tensors(M, N) + signal_tensor = comm.get_signal_tensors() + t0 = time.perf_counter() + torch.cuda.synchronize() + rs_output = module.launch_fused(input, weight, bias, ipc_tensor, signal_tensor, round_trip, rank) + torch.cuda.synchronize() + t1 = time.perf_counter() + torch.distributed.barrier() + if round_trip < 8: + dist_print_err(f"shape {M}x{local_K}x{N} ? {torch.all(bias == 0).item()} fused time: {((t1 - t0)*1e6):.2f}μs, perf: {((M*N*local_K*2)/(t1 - t0) * 1e-12):.2f}TFlops/s") + rs_output_ref = ref_kernel(data) + if not torch.allclose(rs_output, rs_output_ref, atol=1e-2, rtol=1e-2): + dist_print_err("mismatch in gemm_rs_kernel") + dist_print_err("rs_output:", rs_output) + dist_print_err("rs_output_ref:", rs_output_ref) + diff = torch.abs(rs_output - rs_output_ref) + dist_print_err("mismatch count:", torch.sum(diff > 1e-2).item()) + dist_print_err("diff:", diff) + dist_print_err("max diff:", torch.max(diff).item()) + return rs_output_ref + + +def check_implementation(data: input_t) -> output_t: + rf_res = gemm_rs_kernel(data) + ref_res = ref_kernel(data) + x_shape = data[0].shape + w_shape = data[1].shape + + if not torch.allclose(rf_res, ref_res, atol=1e-2, rtol=1e-2): + dist_print_err(f"{x_shape[0]}x{x_shape[1]}x{w_shape[0]} mismatch") + dist_print_err("rf_res:", rf_res) + dist_print_err("ref_res:", ref_res) + diff = torch.abs(rf_res - ref_res) + dist_print_err("mismatch count:", torch.sum(diff > 1e-2).item()) + dist_print_err("diff:", diff) + dist_print_err("max diff:", torch.max(diff).item()) + return ref_res + +# custom_kernel = check_implementation +# custom_kernel = gemm_rs_kernel +custom_kernel = gemm_then_rs_kernel_debug +# custom_kernel = ref_kernel +# custom_kernel = ref_kernel_debug +# custom_kernel = gemm_rs_kernel_debug \ No newline at end of file