From 909dbdb1361d5ffcab2396948db244c5c9cbcb91 Mon Sep 17 00:00:00 2001 From: Richard Chamberlain Date: Thu, 6 Nov 2025 09:09:38 -0800 Subject: [PATCH 01/22] Merging code from IFU branch. --- README.md | 10 +- csrc/config.hpp | 4 +- csrc/deep_ep.cpp | 1269 ++++++++++++++++++++++---------- csrc/deep_ep.hpp | 240 ++++-- csrc/kernels/api.cuh | 346 ++++++--- csrc/kernels/configs.cuh | 15 + csrc/kernels/exception.cuh | 2 +- csrc/kernels/internode.cu | 160 ++-- csrc/kernels/internode_ll.cu | 176 +++-- csrc/kernels/intranode.cu | 108 ++- csrc/kernels/launch.cuh | 12 + csrc/kernels/runtime.cu | 21 +- csrc/kernels/shmem_wrapper.cuh | 2 +- csrc/kernels/utils.cuh | 73 +- deep_ep/__init__.py | 2 +- deep_ep/buffer.py | 327 +++++--- deep_ep/utils.py | 47 +- setup.py | 99 ++- tests/test_internode.py | 252 +++++-- tests/test_intranode.py | 184 +++-- tests/test_low_latency.py | 362 ++++++--- tests/utils.py | 115 ++- third-party/README.md | 6 - 23 files changed, 2772 insertions(+), 1060 deletions(-) diff --git a/README.md b/README.md index e56889b..75eda73 100644 --- a/README.md +++ b/README.md @@ -31,23 +31,15 @@ DeepEP (AMD version) depends on [rocSHMEM](https://github.com/ROCm/rocSHMEM). Pl git clone https://github.com/ROCm/DeepEP cd DeepEP - -# To use DeepEP with MPI, please proceed with these commands -# Export OMPI dir in the next command (e.g., it's $BUILD_DIR/ompi in third-party/README.md) +# export OMPI dir in the next command (e.g., it's $BUILD_DIR/ompi in third-party/README.md) export OMPI_DIR= python3 setup.py --variant rocm build develop -# To use DeepEP without MPI, please make sure rocSHMEM was built with this flag -DUSE_EXTERNAL_MPI=OFF -# Then install DeepEP using this command -python3 setup.py --variant rocm --disable-mpi build develop - # Run test cases # NOTES: you may modify the `init_dist` function in `tests/utils.py` # according to your own cluster settings, and launch into multiple nodes python3 tests/test_intranode.py python3 tests/test_internode.py -# Set the required ROCSHMEM heap size (for example, for DeepSeek models) -export ROCSHMEM_HEAP_SIZE=2147483648 python3 tests/test_low_latency.py ``` diff --git a/csrc/config.hpp b/csrc/config.hpp index 83c60fe..9acf674 100644 --- a/csrc/config.hpp +++ b/csrc/config.hpp @@ -7,13 +7,13 @@ namespace deep_ep { template -dtype_t cell_div(dtype_t a, dtype_t b) { +dtype_t ceil_div(dtype_t a, dtype_t b) { return (a + b - 1) / b; } template dtype_t align(dtype_t a, dtype_t b) { - return cell_div(a, b) * b; + return ceil_div(a, b) * b; } struct Config { diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index aacd712..ed36325 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -11,27 +11,36 @@ #include "kernels/api.cuh" #include "kernels/configs.cuh" -int get_env_with_default_value(const std::string& env_path, const std::string& default_value) { - const char* value = std::getenv(env_path.c_str()); - std::string value_str = (value != nullptr) ? std::string(value) : default_value; - return std::stoi(value_str); -} - namespace deep_ep { -Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode): - rank(rank), num_ranks(num_ranks), - num_nvl_bytes(num_nvl_bytes), num_rdma_bytes(num_rdma_bytes), - low_latency_mode(low_latency_mode), - comm_stream(at::cuda::getStreamFromPool(true)) { - // Task fifo memory - int64_t fifo_bytes = sizeof(int) * NUM_MAX_FIFO_SLOTS; - int64_t buffer_ptr_bytes = sizeof(void*) * NUM_MAX_NVL_PEERS; - int64_t task_ptr_bytes = sizeof(int*) * NUM_MAX_NVL_PEERS; +Buffer::Buffer(int rank, + int num_ranks, + int64_t num_nvl_bytes, + int64_t num_rdma_bytes, + bool low_latency_mode, + bool explicitly_destroy, + bool enable_shrink) + : rank(rank), + num_ranks(num_ranks), + num_nvl_bytes(num_nvl_bytes), + num_rdma_bytes(num_rdma_bytes), + enable_shrink(enable_shrink), + low_latency_mode(low_latency_mode), + explicitly_destroy(explicitly_destroy), + comm_stream(at::cuda::getStreamFromPool(true)) { + // Metadata memory + int64_t barrier_signal_bytes = NUM_MAX_FIFO_SLOTS * sizeof(int); + int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(void*); + int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int*); // Common checks - EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and (num_nvl_bytes <= std::numeric_limits::max() or num_rdma_bytes == 0)); - EP_HOST_ASSERT(num_rdma_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and (low_latency_mode or num_rdma_bytes <= std::numeric_limits::max())); + EP_STATIC_ASSERT(NUM_BUFFER_ALIGNMENT_BYTES % sizeof(int4) == 0, "Invalid alignment"); + EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and + (num_nvl_bytes <= std::numeric_limits::max() or num_rdma_bytes == 0)); + EP_HOST_ASSERT(num_rdma_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and + (low_latency_mode or num_rdma_bytes <= std::numeric_limits::max())); + EP_HOST_ASSERT(num_nvl_bytes / sizeof(int4) < std::numeric_limits::max()); + EP_HOST_ASSERT(num_rdma_bytes / sizeof(int4) < std::numeric_limits::max()); EP_HOST_ASSERT(0 <= rank and rank < num_ranks and (num_ranks <= NUM_MAX_NVL_PEERS * NUM_MAX_RDMA_PEERS or low_latency_mode)); EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0); if (num_rdma_bytes > 0) @@ -41,32 +50,36 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_ CUDA_CHECK(cudaGetDevice(&device_id)); rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS), num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); +#ifdef DISABLE_NVSHMEM + EP_HOST_ASSERT(num_rdma_ranks == 1 and not low_latency_mode and "NVSHMEM is disabled during compilation"); +#endif // Get device info cudaDeviceProp device_prop = {}; CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_id)); -#ifdef USE_ROCM - sscanf(device_prop.gcnArchName, "gfx%d", &gfx); - EP_HOST_ASSERT(gfx >= 942); -#endif + num_device_sms = device_prop.multiProcessorCount; + + // Number of per-channel bytes cannot be large + EP_HOST_ASSERT(ceil_div(num_nvl_bytes, num_device_sms / 2) < std::numeric_limits::max()); + EP_HOST_ASSERT(ceil_div(num_rdma_bytes, num_device_sms / 2) < std::numeric_limits::max()); if (num_nvl_bytes > 0) { - // Local IPC: alloc local memory and set local IPC handle + // Local IPC: alloc local memory and set local IPC handles #ifdef USE_ROCM - CUDA_CHECK(hipExtMallocWithFlags(&buffer_ptrs[nvl_rank], num_nvl_bytes + fifo_bytes + buffer_ptr_bytes + task_ptr_bytes, hipDeviceMallocUncached)); + CUDA_CHECK(hipExtMallocWithFlags(&buffer_ptrs[nvl_rank], num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes, hipDeviceMallocUncached)); #else - CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + fifo_bytes + buffer_ptr_bytes + task_ptr_bytes)); + CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes)); #endif CUDA_CHECK(cudaIpcGetMemHandle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank])); - buffer_ptrs_gpu = reinterpret_cast(reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + fifo_bytes); + buffer_ptrs_gpu = reinterpret_cast(reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes); - // Set task fifo - EP_HOST_ASSERT(NUM_MAX_FIFO_SLOTS % num_nvl_ranks == 0); - task_fifo_ptrs[nvl_rank] = reinterpret_cast(reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes); - task_fifo_ptrs_gpu = reinterpret_cast(reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + fifo_bytes + buffer_ptr_bytes); + // Set barrier signals + barrier_signal_ptrs[nvl_rank] = reinterpret_cast(reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes); + barrier_signal_ptrs_gpu = + reinterpret_cast(reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes); // No need to synchronize, will do a full device sync during `sync` - CUDA_CHECK(cudaMemsetAsync(task_fifo_ptrs[nvl_rank], 0, fifo_bytes, comm_stream)); + CUDA_CHECK(cudaMemsetAsync(barrier_signal_ptrs[nvl_rank], 0, barrier_signal_bytes, comm_stream)); } // Create 32 MiB workspace @@ -85,7 +98,7 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_ // MoE expert-level counter CUDA_CHECK(cudaMallocHost(&moe_recv_expert_counter, sizeof(int) * NUM_MAX_LOCAL_EXPERTS, cudaHostAllocMapped)); CUDA_CHECK(cudaHostGetDevicePointer(reinterpret_cast(&moe_recv_expert_counter_mapped), const_cast(moe_recv_expert_counter), 0)); - for (int i = 0; i < NUM_MAX_LOCAL_EXPERTS; ++ i) + for (int i = 0; i < NUM_MAX_LOCAL_EXPERTS; ++i) moe_recv_expert_counter[i] = -1; // MoE RDMA-level counter @@ -97,43 +110,12 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_ } Buffer::~Buffer() noexcept(false) { - // Synchronize - CUDA_CHECK(cudaDeviceSynchronize()); - - if (num_nvl_bytes > 0) { - // Barrier - intranode::barrier(task_fifo_ptrs_gpu, head, nvl_rank, num_nvl_ranks, comm_stream); - move_fifo_slots(); - CUDA_CHECK(cudaDeviceSynchronize()); - - // Close remote IPC - if (is_available()) { - for (int i = 0; i < num_nvl_ranks; ++ i) if (i != nvl_rank) - CUDA_CHECK(cudaIpcCloseMemHandle(buffer_ptrs[i])); - } - - // Free local buffer and error flag - CUDA_CHECK(cudaFree(buffer_ptrs[nvl_rank])); + if (not explicitly_destroy) { + destroy(); + } else if (not destroyed) { + printf("WARNING: destroy() was not called before DeepEP buffer destruction, which can leak resources.\n"); + fflush(stdout); } - - // Free NVSHMEM - if (num_rdma_bytes > 0) { - CUDA_CHECK(cudaDeviceSynchronize()); - internode::barrier(); - internode::free(rdma_buffer_ptr); - internode::finalize(); - } - - // Free cuBLAS handle, workspace and MoE counter - CUDA_CHECK(cudaFree(workspace)); - CUDA_CHECK(cudaFreeHost(const_cast(moe_recv_counter))); - - // Free chunked mode staffs - CUDA_CHECK(cudaFreeHost(const_cast(moe_recv_expert_counter))); -} - -void Buffer::move_fifo_slots(int num_slots) { - head = (head + num_ranks * num_slots) % NUM_MAX_FIFO_SLOTS; } bool Buffer::is_available() const { @@ -165,21 +147,81 @@ pybind11::bytearray Buffer::get_local_ipc_handle() const { } pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const { +#ifndef DISABLE_NVSHMEM EP_HOST_ASSERT(rdma_rank == 0 and "Only RDMA rank 0 can get NVSHMEM unique ID"); auto unique_id = internode::get_unique_id(); return {reinterpret_cast(unique_id.data()), unique_id.size()}; +#else + EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); +#endif } torch::Tensor Buffer::get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const { torch::ScalarType casted_dtype = torch::python::detail::py_object_to_dtype(dtype); auto element_bytes = static_cast(elementSize(casted_dtype)); - auto base_ptr = reinterpret_cast(use_rdma_buffer ? rdma_buffer_ptr : buffer_ptrs[nvl_rank]) + offset; + auto base_ptr = static_cast(use_rdma_buffer ? rdma_buffer_ptr : buffer_ptrs[nvl_rank]) + offset; auto num_bytes = use_rdma_buffer ? num_rdma_bytes : num_nvl_bytes; return torch::from_blob(base_ptr, num_bytes / element_bytes, torch::TensorOptions().dtype(casted_dtype).device(at::kCUDA)); } -void Buffer::sync(const std::vector &device_ids, - const std::vector> &all_gathered_handles, +torch::Stream Buffer::get_comm_stream() const { + return comm_stream; +} + +void Buffer::move_fifo_slots(int num_slots) { + head = (head + num_ranks * num_slots) % NUM_MAX_FIFO_SLOTS; +} + +void Buffer::destroy() { + EP_HOST_ASSERT(not destroyed); + + // Synchronize + CUDA_CHECK(cudaDeviceSynchronize()); + + if (num_nvl_bytes > 0) { + // Barrier + intranode::barrier(barrier_signal_ptrs_gpu, nvl_rank, num_nvl_ranks, comm_stream, head); + move_fifo_slots(); + CUDA_CHECK(cudaDeviceSynchronize()); + + // Close remote IPC + if (is_available()) { + for (int i = 0; i < num_nvl_ranks; ++i) + if (i != nvl_rank) + CUDA_CHECK(cudaIpcCloseMemHandle(buffer_ptrs[i])); + } + + // Free local buffer and error flag + CUDA_CHECK(cudaFree(buffer_ptrs[nvl_rank])); + } + + // Free NVSHMEM +#ifndef DISABLE_NVSHMEM + if (is_available() and num_rdma_bytes > 0) { + CUDA_CHECK(cudaDeviceSynchronize()); + internode::barrier(); + internode::free(rdma_buffer_ptr); + if (enable_shrink) { + internode::free(mask_buffer_ptr); + internode::free(sync_buffer_ptr); + } + internode::finalize(); + } +#endif + + // Free workspace and MoE counter + CUDA_CHECK(cudaFree(workspace)); + CUDA_CHECK(cudaFreeHost(const_cast(moe_recv_counter))); + + // Free chunked mode staffs + CUDA_CHECK(cudaFreeHost(const_cast(moe_recv_expert_counter))); + + destroyed = true; + available = false; +} + +void Buffer::sync(const std::vector& device_ids, + const std::vector>& all_gathered_handles, const std::optional& root_unique_id_opt) { EP_HOST_ASSERT(not is_available()); @@ -187,26 +229,27 @@ void Buffer::sync(const std::vector &device_ids, if (num_nvl_bytes > 0) { EP_HOST_ASSERT(num_ranks == device_ids.size()); EP_HOST_ASSERT(device_ids.size() == all_gathered_handles.size()); - for (int i = 0, offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks; ++ i) { + for (int i = 0, offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks; ++i) { EP_HOST_ASSERT(all_gathered_handles[offset + i].has_value()); auto handle_str = std::string(all_gathered_handles[offset + i].value()); EP_HOST_ASSERT(handle_str.size() == CUDA_IPC_HANDLE_SIZE); if (offset + i != rank) { std::memcpy(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE); CUDA_CHECK(cudaIpcOpenMemHandle(&buffer_ptrs[i], ipc_handles[i], cudaIpcMemLazyEnablePeerAccess)); - task_fifo_ptrs[i] = reinterpret_cast(reinterpret_cast(buffer_ptrs[i]) + num_nvl_bytes); + barrier_signal_ptrs[i] = reinterpret_cast(static_cast(buffer_ptrs[i]) + num_nvl_bytes); } else { EP_HOST_ASSERT(std::memcmp(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE) == 0); } } - // Copy all buffer and task pointers to GPU + // Copy all buffer and barrier signal pointers to GPU CUDA_CHECK(cudaMemcpy(buffer_ptrs_gpu, buffer_ptrs, sizeof(void*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice)); - CUDA_CHECK(cudaMemcpy(task_fifo_ptrs_gpu, task_fifo_ptrs, sizeof(int*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(barrier_signal_ptrs_gpu, barrier_signal_ptrs, sizeof(int*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice)); CUDA_CHECK(cudaDeviceSynchronize()); } // Sync NVSHMEM handles and allocate memory +#ifndef DISABLE_NVSHMEM if (num_rdma_bytes > 0) { // Initialize NVSHMEM EP_HOST_ASSERT(root_unique_id_opt.has_value()); @@ -217,20 +260,36 @@ void Buffer::sync(const std::vector &device_ids, auto num_nvshmem_ranks = low_latency_mode ? num_ranks : num_rdma_ranks; EP_HOST_ASSERT(nvshmem_rank == internode::init(root_unique_id, nvshmem_rank, num_nvshmem_ranks, low_latency_mode)); internode::barrier(); + // Allocate rdma_buffer_ptr = internode::alloc(num_rdma_bytes, NUM_BUFFER_ALIGNMENT_BYTES); + // Clean buffer (mainly for low-latency mode) CUDA_CHECK(cudaMemset(rdma_buffer_ptr, 0, num_rdma_bytes)); + + // Allocate and clean shrink buffer + if (enable_shrink) { + int num_mask_buffer_bytes = num_ranks * sizeof(int); + int num_sync_buffer_bytes = num_ranks * sizeof(int); + mask_buffer_ptr = reinterpret_cast(internode::alloc(num_mask_buffer_bytes, NUM_BUFFER_ALIGNMENT_BYTES)); + sync_buffer_ptr = reinterpret_cast(internode::alloc(num_sync_buffer_bytes, NUM_BUFFER_ALIGNMENT_BYTES)); + CUDA_CHECK(cudaMemset(mask_buffer_ptr, 0, num_mask_buffer_bytes)); + CUDA_CHECK(cudaMemset(sync_buffer_ptr, 0, num_sync_buffer_bytes)); + } + // Barrier internode::barrier(); CUDA_CHECK(cudaDeviceSynchronize()); } +#endif + + // Ready to use available = true; } std::tuple, torch::Tensor, torch::Tensor, std::optional> -Buffer::get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, - std::optional& previous_event, bool async, bool allocate_on_comm_stream) { +Buffer::get_dispatch_layout( + const torch::Tensor& topk_idx, int num_experts, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { EP_HOST_ASSERT(topk_idx.dim() == 2); EP_HOST_ASSERT(topk_idx.is_contiguous()); EP_HOST_ASSERT(num_experts > 0); @@ -258,24 +317,27 @@ Buffer::get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, if (is_internode_available()) num_tokens_per_rdma_rank = torch::empty({num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); - internode::get_dispatch_layout(topk_idx.data_ptr(), - num_tokens_per_rank.data_ptr(), - num_tokens_per_rdma_rank.has_value() ? num_tokens_per_rdma_rank.value().data_ptr() : nullptr, - num_tokens_per_expert.data_ptr(), - is_token_in_rank.data_ptr(), - num_tokens, num_topk, num_ranks, num_experts, - comm_stream); + layout::get_dispatch_layout(topk_idx.data_ptr(), + num_tokens_per_rank.data_ptr(), + num_tokens_per_rdma_rank.has_value() ? num_tokens_per_rdma_rank.value().data_ptr() : nullptr, + num_tokens_per_expert.data_ptr(), + is_token_in_rank.data_ptr(), + num_tokens, + num_topk, + num_ranks, + num_experts, + comm_stream); // Wait streams std::optional event; if (async) { event = EventHandle(comm_stream); - for (auto& t: {topk_idx, num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank}) { + for (auto& t : {topk_idx, num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank}) { t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } - for (auto& to: {num_tokens_per_rdma_rank}) { + for (auto& to : {num_tokens_per_rdma_rank}) { to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); @@ -291,12 +353,33 @@ Buffer::get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, return {num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event}; } -std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional> -Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional& x_scales, - const std::optional& topk_idx, const std::optional& topk_weights, - const std::optional& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, - int cached_num_recv_tokens, const std::optional& cached_rank_prefix_matrix, const std::optional& cached_channel_prefix_matrix, - int expert_alignment, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { +std::tuple, + std::optional, + std::optional, + std::vector, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + std::optional> +Buffer::intranode_dispatch(const torch::Tensor& x, + const std::optional& x_scales, + const std::optional& topk_idx, + const std::optional& topk_weights, + const std::optional& num_tokens_per_rank, + const torch::Tensor& is_token_in_rank, + const std::optional& num_tokens_per_expert, + int cached_num_recv_tokens, + const std::optional& cached_rank_prefix_matrix, + const std::optional& cached_channel_prefix_matrix, + int expert_alignment, + int num_worst_tokens, + const Config& config, + std::optional& previous_event, + bool async, + bool allocate_on_comm_stream) { bool cached_mode = cached_rank_prefix_matrix.has_value(); // One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving. @@ -343,7 +426,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optionalsize(0) and num_tokens == topk_weights->size(0)); EP_HOST_ASSERT(num_topk == topk_weights->size(1)); EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); - topk_idx_ptr = topk_idx->data_ptr(); + topk_idx_ptr = topk_idx->data_ptr(); topk_weights_ptr = topk_weights->data_ptr(); } // FP8 scales checks float* x_scales_ptr = nullptr; - int num_scales = 0; + int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0; if (x_scales.has_value()) { EP_HOST_ASSERT(x.element_size() == 1); - EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32); - EP_HOST_ASSERT(x_scales->dim() > 0 and x_scales->dim() < 3 and x_scales->is_contiguous()); + EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32 or x_scales->scalar_type() == torch::kInt); + EP_HOST_ASSERT(x_scales->dim() == 2); EP_HOST_ASSERT(x_scales->size(0) == num_tokens); num_scales = x_scales->dim() == 1 ? 1 : static_cast(x_scales->size(1)); - x_scales_ptr = x_scales->data_ptr(); + x_scales_ptr = static_cast(x_scales->data_ptr()); + scale_token_stride = static_cast(x_scales->stride(0)); + scale_hidden_stride = static_cast(x_scales->stride(1)); } // Allocate all tensors on comm stream if set @@ -393,16 +478,15 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional(), num_memset_int, - buffer_ptrs_gpu, task_fifo_ptrs_gpu, head, rank, num_ranks, - comm_stream); + intranode::cached_notify_dispatch( + rank_prefix_matrix.data_ptr(), num_memset_int, buffer_ptrs_gpu, barrier_signal_ptrs_gpu, rank, num_ranks, comm_stream, head); move_fifo_slots(2); } else { rank_prefix_matrix = torch::empty({num_ranks, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); @@ -414,92 +498,153 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optionaldata_ptr(), moe_recv_counter_mapped, num_ranks, - num_tokens_per_expert->data_ptr(), moe_recv_expert_counter_mapped, num_experts, - num_tokens, is_token_in_rank.data_ptr(), channel_prefix_matrix.data_ptr(), + intranode::notify_dispatch(num_tokens_per_rank->data_ptr(), + moe_recv_counter_mapped, + num_ranks, + num_tokens_per_expert->data_ptr(), + moe_recv_expert_counter_mapped, + num_experts, + num_tokens, + is_token_in_rank.data_ptr(), + channel_prefix_matrix.data_ptr(), rank_prefix_matrix.data_ptr(), - num_memset_int, expert_alignment, - buffer_ptrs_gpu, task_fifo_ptrs_gpu, head, rank, - comm_stream, num_channels); + num_memset_int, + expert_alignment, + buffer_ptrs_gpu, + barrier_signal_ptrs_gpu, + rank, + comm_stream, + num_channels, + head); move_fifo_slots(3); - // Synchronize total received tokens and tokens per expert - auto start_time = std::chrono::high_resolution_clock::now(); - while (true) { - // Read total count - num_recv_tokens = static_cast(*moe_recv_counter); - - // Read per-expert count - bool ready = (num_recv_tokens >= 0); - for (int i = 0; i < num_local_experts and ready; ++i) - ready &= moe_recv_expert_counter[i] >= 0; - - if (ready) - break; - - // Timeout check - if (std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time).count() > NUM_CPU_TIMEOUT_SECS) - throw std::runtime_error("DeepEP error: CPU recv timeout"); + if (num_worst_tokens > 0) { + // No CPU sync, just allocate the worst case + num_recv_tokens = num_worst_tokens; + + // Must be forward with top-k stuffs + EP_HOST_ASSERT(topk_idx.has_value()); + EP_HOST_ASSERT(topk_weights.has_value()); + } else { + // Synchronize total received tokens and tokens per expert + auto start_time = std::chrono::high_resolution_clock::now(); + while (true) { + // Read total count + num_recv_tokens = static_cast(*moe_recv_counter); + + // Read per-expert count + bool ready = (num_recv_tokens >= 0); + for (int i = 0; i < num_local_experts and ready; ++i) + { + ready &= moe_recv_expert_counter[i] >= 0; + } + + if (ready) + break; + std::this_thread::sleep_for(std::chrono::milliseconds(5000)); + // Timeout check + if (std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time).count() > + NUM_CPU_TIMEOUT_SECS) + throw std::runtime_error("DeepEP error: CPU recv timeout"); + } + num_recv_tokens_per_expert_list = std::vector(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts); } - num_recv_tokens_per_expert_list = std::vector(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts); } // Allocate new tensors auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options()); auto recv_src_idx = torch::empty({num_recv_tokens}, dtype(torch::kInt32).device(torch::kCUDA)); - auto recv_topk_idx = std::optional(), recv_topk_weights = std::optional(), recv_x_scales = std::optional(); + auto recv_topk_idx = std::optional(), recv_topk_weights = std::optional(), + recv_x_scales = std::optional(); auto recv_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); auto send_head = torch::empty({num_tokens, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); // Assign pointers - int64_t* recv_topk_idx_ptr = nullptr; + topk_idx_t* recv_topk_idx_ptr = nullptr; float* recv_topk_weights_ptr = nullptr; float* recv_x_scales_ptr = nullptr; if (topk_idx.has_value()) { recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options()); recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options()); - recv_topk_idx_ptr = recv_topk_idx->data_ptr(); + recv_topk_idx_ptr = recv_topk_idx->data_ptr(); recv_topk_weights_ptr = recv_topk_weights->data_ptr(); } if (x_scales.has_value()) { - recv_x_scales = x_scales->dim() == 1 ? - torch::empty({num_recv_tokens}, x_scales->options()) : - torch::empty({num_recv_tokens, num_scales}, x_scales->options()); - recv_x_scales_ptr = recv_x_scales->data_ptr(); + recv_x_scales = x_scales->dim() == 1 ? torch::empty({num_recv_tokens}, x_scales->options()) + : torch::empty({num_recv_tokens, num_scales}, x_scales->options()); + recv_x_scales_ptr = static_cast(recv_x_scales->data_ptr()); } // Dispatch - EP_HOST_ASSERT(num_ranks * num_ranks * sizeof(int) + // Size prefix matrix - num_channels * num_ranks * sizeof(int) + // Channel start offset - num_channels * num_ranks * sizeof(int) + // Channel end offset - num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * recv_x.element_size() + // Data buffer - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + // Source index buffer - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(int64_t) + // Top-k index buffer - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) + // Top-k weight buffer - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(float) * num_scales // FP8 scale buffer - <= num_nvl_bytes); - intranode::dispatch(recv_x.data_ptr(), recv_x_scales_ptr, recv_src_idx.data_ptr(), recv_topk_idx_ptr, recv_topk_weights_ptr, recv_channel_prefix_matrix.data_ptr(), + EP_HOST_ASSERT( + num_ranks * num_ranks * sizeof(int) + // Size prefix matrix + num_channels * num_ranks * sizeof(int) + // Channel start offset + num_channels * num_ranks * sizeof(int) + // Channel end offset + num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * recv_x.element_size() + // Data buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + // Source index buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(topk_idx_t) + // Top-k index buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) + // Top-k weight buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(float) * num_scales // FP8 scale buffer + <= num_nvl_bytes); + intranode::dispatch(recv_x.data_ptr(), + recv_x_scales_ptr, + recv_src_idx.data_ptr(), + recv_topk_idx_ptr, + recv_topk_weights_ptr, + recv_channel_prefix_matrix.data_ptr(), send_head.data_ptr(), - x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr, - is_token_in_rank.data_ptr(), channel_prefix_matrix.data_ptr(), - num_tokens, static_cast(hidden * recv_x.element_size() / sizeof(int4)), num_topk, num_experts, num_scales, - buffer_ptrs_gpu, rank, num_ranks, comm_stream, config.num_sms, - config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens); + x.data_ptr(), + x_scales_ptr, + topk_idx_ptr, + topk_weights_ptr, + is_token_in_rank.data_ptr(), + channel_prefix_matrix.data_ptr(), + num_tokens, + num_worst_tokens, + static_cast(hidden * recv_x.element_size() / sizeof(int4)), + num_topk, + num_experts, + num_scales, + scale_token_stride, + scale_hidden_stride, + buffer_ptrs_gpu, + rank, + num_ranks, + comm_stream, + config.num_sms, + config.num_max_nvl_chunked_send_tokens, + config.num_max_nvl_chunked_recv_tokens); // Wait streams std::optional event; if (async) { event = EventHandle(comm_stream); - for (auto& t: {x, is_token_in_rank, rank_prefix_matrix, channel_prefix_matrix, recv_x, recv_src_idx, recv_channel_prefix_matrix, send_head}) { + for (auto& t : {x, + is_token_in_rank, + rank_prefix_matrix, + channel_prefix_matrix, + recv_x, + recv_src_idx, + recv_channel_prefix_matrix, + send_head}) { t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } - for (auto& to: {x_scales, topk_idx, topk_weights, num_tokens_per_rank, num_tokens_per_expert, cached_channel_prefix_matrix, cached_rank_prefix_matrix, recv_topk_idx, recv_topk_weights, recv_x_scales}) { + for (auto& to : {x_scales, + topk_idx, + topk_weights, + num_tokens_per_rank, + num_tokens_per_expert, + cached_channel_prefix_matrix, + cached_rank_prefix_matrix, + recv_topk_idx, + recv_topk_weights, + recv_x_scales}) { to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); @@ -513,18 +658,37 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional, std::optional> -Buffer::intranode_combine(const torch::Tensor& x, const std::optional& topk_weights, - const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix, - const torch::Tensor& send_head, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { +std::tuple, std::optional> Buffer::intranode_combine( + const torch::Tensor& x, + const std::optional& topk_weights, + const torch::Tensor& src_idx, + const torch::Tensor& rank_prefix_matrix, + const torch::Tensor& channel_prefix_matrix, + const torch::Tensor& send_head, + const Config& config, + std::optional& previous_event, + bool async, + bool allocate_on_comm_stream) { EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); EP_HOST_ASSERT(src_idx.dim() == 1 and src_idx.is_contiguous() and src_idx.scalar_type() == torch::kInt32); EP_HOST_ASSERT(send_head.dim() == 2 and send_head.is_contiguous() and send_head.scalar_type() == torch::kInt32); - EP_HOST_ASSERT(rank_prefix_matrix.dim() == 2 and rank_prefix_matrix.is_contiguous() and rank_prefix_matrix.scalar_type() == torch::kInt32); - EP_HOST_ASSERT(channel_prefix_matrix.dim() == 2 and channel_prefix_matrix.is_contiguous() and channel_prefix_matrix.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(rank_prefix_matrix.dim() == 2 and rank_prefix_matrix.is_contiguous() and + rank_prefix_matrix.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(channel_prefix_matrix.dim() == 2 and channel_prefix_matrix.is_contiguous() and + channel_prefix_matrix.scalar_type() == torch::kInt32); // One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving. EP_HOST_ASSERT(config.num_sms % 2 == 0); @@ -569,40 +733,76 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional(), - num_channels, num_recv_tokens, num_channels * num_ranks * 2, - task_fifo_ptrs_gpu, head, rank, num_ranks, - comm_stream); - + intranode::cached_notify_combine(buffer_ptrs_gpu, + send_head.data_ptr(), + num_channels, + num_recv_tokens, + num_channels * num_ranks * 2, + barrier_signal_ptrs_gpu, + rank, + num_ranks, + comm_stream, + head); // NOTES: this function uses two FIFO slots (barrier before and after) - move_fifo_slots(2); + move_fifo_slots(2); + + // Assign bias pointers + /*auto bias_opts = std::vector>({bias_0, bias_1}); + void* bias_ptrs[2] = {nullptr, nullptr}; + for (int i = 0; i < 2; ++i) + if (bias_opts[i].has_value()) { + auto bias = bias_opts[i].value(); + EP_HOST_ASSERT(bias.dim() == 2 and bias.is_contiguous()); + EP_HOST_ASSERT(bias.scalar_type() == x.scalar_type()); + EP_HOST_ASSERT(bias.size(0) == num_recv_tokens and bias.size(1) == hidden); + bias_ptrs[i] = bias.data_ptr(); + } + */ // Combine data auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options()); - EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * x.element_size() + // Data buffer - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + // Source index buffer - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) // Top-k weight buffer + EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * x.element_size() + // Data buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + // Source index buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) // Top-k weight buffer <= num_nvl_bytes); + intranode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()), - recv_x.data_ptr(), recv_topk_weights_ptr, - x.data_ptr(), topk_weights_ptr, - src_idx.data_ptr(), rank_prefix_matrix.data_ptr(), channel_prefix_matrix.data_ptr(), - send_head.data_ptr(), num_tokens, num_recv_tokens, hidden, num_topk, - buffer_ptrs_gpu, rank, num_ranks, - comm_stream, config.num_sms, - config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens); + recv_x.data_ptr(), + recv_topk_weights_ptr, + x.data_ptr(), + topk_weights_ptr, + src_idx.data_ptr(), + rank_prefix_matrix.data_ptr(), + channel_prefix_matrix.data_ptr(), + send_head.data_ptr(), + num_tokens, + num_recv_tokens, + hidden, + num_topk, + buffer_ptrs_gpu, + rank, + num_ranks, + comm_stream, + config.num_sms, + config.num_max_nvl_chunked_send_tokens, + config.num_max_nvl_chunked_recv_tokens); // Wait streams std::optional event; if (async) { event = EventHandle(comm_stream); - for (auto& t: {x, src_idx, send_head, rank_prefix_matrix, channel_prefix_matrix, recv_x}) { + for (auto& t : {x, src_idx, send_head, rank_prefix_matrix, channel_prefix_matrix, recv_x}) { t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } - for (auto& to: {topk_weights, recv_topk_weights}) { + /*for (auto& to : {topk_weights, recv_topk_weights, bias_0, bias_1}) { + to.has_value() ? to->record_stream(comm_stream) : void(); + if (allocate_on_comm_stream) + to.has_value() ? to->record_stream(compute_stream) : void(); + }*/ + for (auto& to : {topk_weights, recv_topk_weights}) { to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); @@ -618,16 +818,46 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, std::optional, torch::Tensor, std::optional, torch::Tensor, std::optional, std::optional, std::optional, std::optional> -Buffer::internode_dispatch(const torch::Tensor& x, const std::optional& x_scales, - const std::optional& topk_idx, const std::optional& topk_weights, - const std::optional& num_tokens_per_rank, const std::optional& num_tokens_per_rdma_rank, - const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, - int cached_num_recv_tokens, int cached_num_rdma_recv_tokens, - const std::optional& cached_rdma_channel_prefix_matrix, const std::optional& cached_recv_rdma_rank_prefix_sum, - const std::optional& cached_gbl_channel_prefix_matrix, const std::optional& cached_recv_gbl_rank_prefix_sum, - int expert_alignment, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { +std::tuple, + std::optional, + std::optional, + std::vector, + torch::Tensor, + torch::Tensor, + std::optional, + torch::Tensor, + std::optional, + torch::Tensor, + std::optional, + std::optional, + std::optional, + std::optional> +Buffer::internode_dispatch(const torch::Tensor& x, + const std::optional& x_scales, + const std::optional& topk_idx, + const std::optional& topk_weights, + const std::optional& num_tokens_per_rank, + const std::optional& num_tokens_per_rdma_rank, + const torch::Tensor& is_token_in_rank, + const std::optional& num_tokens_per_expert, + int cached_num_recv_tokens, + int cached_num_rdma_recv_tokens, + const std::optional& cached_rdma_channel_prefix_matrix, + const std::optional& cached_recv_rdma_rank_prefix_sum, + const std::optional& cached_gbl_channel_prefix_matrix, + const std::optional& cached_recv_gbl_rank_prefix_sum, + int expert_alignment, + const Config& config, + std::optional& previous_event, + bool async, + bool allocate_on_comm_stream) { +#ifndef DISABLE_NVSHMEM + // In dispatch, CPU will busy-wait until GPU receive tensor size metadata from other ranks, which can be quite long. + // If users of DeepEP need to execute other Python code on other threads, such as KV transfer, their code will get stuck due to GIL + // unless we release GIL here. + pybind11::gil_scoped_release release; + const int num_channels = config.num_sms / 2; EP_HOST_ASSERT(config.num_sms % 2 == 0); EP_HOST_ASSERT(0 < get_num_rdma_ranks() and get_num_rdma_ranks() <= NUM_MAX_RDMA_PEERS); @@ -661,11 +891,13 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optionaldim() == 2 and cached_rdma_channel_prefix_matrix->is_contiguous()); - EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->size(0) == num_rdma_ranks and cached_rdma_channel_prefix_matrix->size(1) == num_channels); + EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->size(0) == num_rdma_ranks and + cached_rdma_channel_prefix_matrix->size(1) == num_channels); EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->dim() == 1 and cached_recv_rdma_rank_prefix_sum->is_contiguous()); EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->size(0) == num_rdma_ranks); EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->dim() == 2 and cached_gbl_channel_prefix_matrix->is_contiguous()); - EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->size(0) == num_ranks and cached_gbl_channel_prefix_matrix->size(1) == num_channels); + EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->size(0) == num_ranks and + cached_gbl_channel_prefix_matrix->size(1) == num_channels); EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->dim() == 1 and cached_recv_gbl_rank_prefix_sum->is_contiguous()); EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->size(0) == num_ranks); } else { @@ -678,12 +910,13 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optionalsize(0) / num_ranks <= NUM_MAX_LOCAL_EXPERTS); } - auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)), hidden_int4 = static_cast(x.size(1) * x.element_size() / sizeof(int4)); + auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)), + hidden_int4 = static_cast(x.size(1) * x.element_size() / sizeof(int4)); auto num_experts = cached_mode ? 0 : static_cast(num_tokens_per_expert->size(0)), num_local_experts = num_experts / num_ranks; // Top-k checks int num_topk = 0; - int64_t* topk_idx_ptr = nullptr; + topk_idx_t* topk_idx_ptr = nullptr; float* topk_weights_ptr = nullptr; EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value()); if (topk_idx.has_value()) { @@ -694,20 +927,22 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optionalsize(0) and num_tokens == topk_weights->size(0)); EP_HOST_ASSERT(num_topk == topk_weights->size(1)); EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); - topk_idx_ptr = topk_idx->data_ptr(); + topk_idx_ptr = topk_idx->data_ptr(); topk_weights_ptr = topk_weights->data_ptr(); } // FP8 scales checks float* x_scales_ptr = nullptr; - int num_scales = 0; + int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0; if (x_scales.has_value()) { EP_HOST_ASSERT(x.element_size() == 1); - EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32); - EP_HOST_ASSERT(x_scales->dim() > 0 and x_scales->dim() < 3 and x_scales->is_contiguous()); + EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32 or x_scales->scalar_type() == torch::kInt); + EP_HOST_ASSERT(x_scales->dim() == 2); EP_HOST_ASSERT(x_scales->size(0) == num_tokens); num_scales = x_scales->dim() == 1 ? 1 : static_cast(x_scales->size(1)); - x_scales_ptr = x_scales->data_ptr(); + x_scales_ptr = static_cast(x_scales->data_ptr()); + scale_token_stride = static_cast(x_scales->stride(0)); + scale_hidden_stride = static_cast(x_scales->stride(1)); } // Allocate all tensors on comm stream if set @@ -743,14 +978,28 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optionaldata_ptr(), moe_recv_counter_mapped, num_ranks, - num_tokens_per_rdma_rank->data_ptr(), moe_recv_rdma_counter_mapped, - num_tokens_per_expert->data_ptr(), moe_recv_expert_counter_mapped, num_experts, - is_token_in_rank.data_ptr(), num_tokens, num_channels, - hidden_int4, num_scales, num_topk, expert_alignment, - rdma_channel_prefix_matrix.data_ptr(), recv_rdma_rank_prefix_sum.data_ptr(), - gbl_channel_prefix_matrix.data_ptr(), recv_gbl_rank_prefix_sum.data_ptr(), - rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, - buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, - task_fifo_ptrs_gpu, head, rank, comm_stream, + internode::notify_dispatch(num_tokens_per_rank->data_ptr(), + moe_recv_counter_mapped, + num_ranks, + num_tokens_per_rdma_rank->data_ptr(), + moe_recv_rdma_counter_mapped, + num_tokens_per_expert->data_ptr(), + moe_recv_expert_counter_mapped, + num_experts, + is_token_in_rank.data_ptr(), + num_tokens, + num_channels, + hidden_int4, + num_scales, + num_topk, + expert_alignment, + rdma_channel_prefix_matrix.data_ptr(), + recv_rdma_rank_prefix_sum.data_ptr(), + gbl_channel_prefix_matrix.data_ptr(), + recv_gbl_rank_prefix_sum.data_ptr(), + rdma_buffer_ptr, + config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, + config.num_max_nvl_chunked_recv_tokens, + barrier_signal_ptrs_gpu, + rank, + comm_stream, config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), - num_nvl_bytes, low_latency_mode); + num_nvl_bytes, + low_latency_mode); move_fifo_slots(3); // Synchronize total received tokens and tokens per expert @@ -785,17 +1051,15 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional= 0) and (num_rdma_recv_tokens >= 0); - for (int i = 0; i < num_local_experts and ready; ++ i) + for (int i = 0; i < num_local_experts and ready; ++i) ready &= moe_recv_expert_counter[i] >= 0; if (ready) break; // Timeout check - if (std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time).count() > NUM_CPU_TIMEOUT_SECS) { - printf("Global rank: %d, num_recv_tokens: %d, num_rdma_recv_tokens: %d\n", rank, num_recv_tokens, num_rdma_recv_tokens); - for (int i = 0; i < num_local_experts; ++ i) - printf("moe_recv_expert_counter[%d]: %d\n", i, moe_recv_expert_counter[i]); + if (std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time).count() > + NUM_CPU_TIMEOUT_SECS) { throw std::runtime_error("DeepEP error: timeout (dispatch CPU)"); } } @@ -804,7 +1068,8 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional(), recv_topk_weights = std::optional(), recv_x_scales = std::optional(); + auto recv_topk_idx = std::optional(), recv_topk_weights = std::optional(), + recv_x_scales = std::optional(); auto recv_src_meta = std::optional(); auto recv_rdma_channel_prefix_matrix = std::optional(); auto recv_gbl_channel_prefix_matrix = std::optional(); @@ -819,56 +1084,94 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optionaloptions()); recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options()); - recv_topk_idx_ptr = recv_topk_idx->data_ptr(); + recv_topk_idx_ptr = recv_topk_idx->data_ptr(); recv_topk_weights_ptr = recv_topk_weights->data_ptr(); } if (x_scales.has_value()) { - recv_x_scales = x_scales->dim() == 1 ? - torch::empty({num_recv_tokens}, x_scales->options()) : - torch::empty({num_recv_tokens, num_scales}, x_scales->options()); - recv_x_scales_ptr = recv_x_scales->data_ptr(); + recv_x_scales = x_scales->dim() == 1 ? torch::empty({num_recv_tokens}, x_scales->options()) + : torch::empty({num_recv_tokens, num_scales}, x_scales->options()); + recv_x_scales_ptr = static_cast(recv_x_scales->data_ptr()); } // Launch data dispatch // NOTES: the buffer size checks are moved into the `.cu` file - internode::dispatch(recv_x.data_ptr(), recv_x_scales_ptr, recv_topk_idx_ptr, recv_topk_weights_ptr, + internode::dispatch(recv_x.data_ptr(), + recv_x_scales_ptr, + recv_topk_idx_ptr, + recv_topk_weights_ptr, cached_mode ? nullptr : recv_src_meta->data_ptr(), - x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr, - cached_mode ? nullptr : send_rdma_head->data_ptr(), cached_mode ? nullptr : send_nvl_head->data_ptr(), + x.data_ptr(), + x_scales_ptr, + topk_idx_ptr, + topk_weights_ptr, + cached_mode ? nullptr : send_rdma_head->data_ptr(), + cached_mode ? nullptr : send_nvl_head->data_ptr(), cached_mode ? nullptr : recv_rdma_channel_prefix_matrix->data_ptr(), cached_mode ? nullptr : recv_gbl_channel_prefix_matrix->data_ptr(), - rdma_channel_prefix_matrix.data_ptr(), recv_rdma_rank_prefix_sum.data_ptr(), - gbl_channel_prefix_matrix.data_ptr(), recv_gbl_rank_prefix_sum.data_ptr(), - num_tokens, hidden_int4, num_scales, num_topk, num_experts, + rdma_channel_prefix_matrix.data_ptr(), + recv_rdma_rank_prefix_sum.data_ptr(), + gbl_channel_prefix_matrix.data_ptr(), + recv_gbl_rank_prefix_sum.data_ptr(), is_token_in_rank.data_ptr(), - rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens, - buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens, - rank, num_ranks, cached_mode, - comm_stream, num_channels, low_latency_mode); + num_tokens, + hidden_int4, + num_scales, + num_topk, + num_experts, + scale_token_stride, + scale_hidden_stride, + rdma_buffer_ptr, + config.num_max_rdma_chunked_send_tokens, + config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, + config.num_max_nvl_chunked_send_tokens, + config.num_max_nvl_chunked_recv_tokens, + rank, + num_ranks, + cached_mode, + comm_stream, + num_channels, + low_latency_mode); // Wait streams std::optional event; if (async) { event = EventHandle(comm_stream); - for (auto& t: {x, is_token_in_rank, recv_x, - rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum}) { + for (auto& t : {x, + is_token_in_rank, + recv_x, + rdma_channel_prefix_matrix, + recv_rdma_rank_prefix_sum, + gbl_channel_prefix_matrix, + recv_gbl_rank_prefix_sum}) { t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } - for (auto& to: {x_scales, topk_idx, topk_weights, - num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, - cached_rdma_channel_prefix_matrix, cached_recv_rdma_rank_prefix_sum, - cached_gbl_channel_prefix_matrix, cached_recv_gbl_rank_prefix_sum, - recv_topk_idx, recv_topk_weights, recv_x_scales, - recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix, send_rdma_head, send_nvl_head, - recv_src_meta}) { + for (auto& to : {x_scales, + topk_idx, + topk_weights, + num_tokens_per_rank, + num_tokens_per_rdma_rank, + num_tokens_per_expert, + cached_rdma_channel_prefix_matrix, + cached_recv_rdma_rank_prefix_sum, + cached_gbl_channel_prefix_matrix, + cached_recv_gbl_rank_prefix_sum, + recv_topk_idx, + recv_topk_weights, + recv_x_scales, + recv_rdma_channel_prefix_matrix, + recv_gbl_channel_prefix_matrix, + send_rdma_head, + send_nvl_head, + recv_src_meta}) { to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); @@ -882,33 +1185,64 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional, std::optional> -Buffer::internode_combine(const torch::Tensor& x, const std::optional& topk_weights, - const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank, - const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix, - const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head, - const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { +std::tuple, std::optional> Buffer::internode_combine( + const torch::Tensor& x, + const std::optional& topk_weights, + const std::optional& bias_0, + const std::optional& bias_1, + const torch::Tensor& src_meta, + const torch::Tensor& is_combined_token_in_rank, + const torch::Tensor& rdma_channel_prefix_matrix, + const torch::Tensor& rdma_rank_prefix_sum, + const torch::Tensor& gbl_channel_prefix_matrix, + const torch::Tensor& combined_rdma_head, + const torch::Tensor& combined_nvl_head, + const Config& config, + std::optional& previous_event, + bool async, + bool allocate_on_comm_stream) { +#ifndef DISABLE_NVSHMEM const int num_channels = config.num_sms / 2; EP_HOST_ASSERT(config.num_sms % 2 == 0); // Shape and contiguous checks EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); EP_HOST_ASSERT(src_meta.dim() == 2 and src_meta.is_contiguous() and src_meta.scalar_type() == torch::kByte); - EP_HOST_ASSERT(is_combined_token_in_rank.dim() == 2 and is_combined_token_in_rank.is_contiguous() and is_combined_token_in_rank.scalar_type() == torch::kBool); - EP_HOST_ASSERT(rdma_channel_prefix_matrix.dim() == 2 and rdma_channel_prefix_matrix.is_contiguous() and rdma_channel_prefix_matrix.scalar_type() == torch::kInt32); - EP_HOST_ASSERT(rdma_rank_prefix_sum.dim() == 1 and rdma_rank_prefix_sum.is_contiguous() and rdma_rank_prefix_sum.scalar_type() == torch::kInt32); - EP_HOST_ASSERT(gbl_channel_prefix_matrix.dim() == 2 and gbl_channel_prefix_matrix.is_contiguous() and gbl_channel_prefix_matrix.scalar_type() == torch::kInt32); - EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and combined_rdma_head.is_contiguous() and combined_rdma_head.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(is_combined_token_in_rank.dim() == 2 and is_combined_token_in_rank.is_contiguous() and + is_combined_token_in_rank.scalar_type() == torch::kBool); + EP_HOST_ASSERT(rdma_channel_prefix_matrix.dim() == 2 and rdma_channel_prefix_matrix.is_contiguous() and + rdma_channel_prefix_matrix.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(rdma_rank_prefix_sum.dim() == 1 and rdma_rank_prefix_sum.is_contiguous() and + rdma_rank_prefix_sum.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(gbl_channel_prefix_matrix.dim() == 2 and gbl_channel_prefix_matrix.is_contiguous() and + gbl_channel_prefix_matrix.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and combined_rdma_head.is_contiguous() and + combined_rdma_head.scalar_type() == torch::kInt32); EP_HOST_ASSERT(combined_nvl_head.dim() == 2 and combined_nvl_head.is_contiguous() and combined_nvl_head.scalar_type() == torch::kInt32); - auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)), hidden_int4 = static_cast(x.size(1) * x.element_size() / sizeof(int4)); + auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)), + hidden_int4 = static_cast(x.size(1) * x.element_size() / sizeof(int4)); auto num_combined_tokens = static_cast(is_combined_token_in_rank.size(0)); EP_HOST_ASSERT((hidden * x.element_size()) % sizeof(int4) == 0); EP_HOST_ASSERT(src_meta.size(1) == internode::get_source_meta_bytes()); @@ -916,7 +1250,8 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional(), - rdma_channel_prefix_matrix.data_ptr(), rdma_rank_prefix_sum.data_ptr(), combined_nvl_head.data_ptr(), - rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, - buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, - task_fifo_ptrs_gpu, head, rank, comm_stream, + internode::cached_notify(hidden_int4, + 0, + 0, + num_topk, + num_ranks, + num_channels, + num_combined_tokens, + combined_rdma_head.data_ptr(), + rdma_channel_prefix_matrix.data_ptr(), + rdma_rank_prefix_sum.data_ptr(), + combined_nvl_head.data_ptr(), + rdma_buffer_ptr, + config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, + config.num_max_nvl_chunked_recv_tokens, + barrier_signal_ptrs_gpu, + rank, + comm_stream, config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), - num_nvl_bytes, false, low_latency_mode); + num_nvl_bytes, + false, + low_latency_mode); move_fifo_slots(2); + // Assign bias pointers + auto bias_opts = std::vector>({bias_0, bias_1}); + void* bias_ptrs[2] = {nullptr, nullptr}; + for (int i = 0; i < 2; ++i) + if (bias_opts[i].has_value()) { + auto bias = bias_opts[i].value(); + EP_HOST_ASSERT(bias.dim() == 2 and bias.is_contiguous()); + EP_HOST_ASSERT(bias.scalar_type() == x.scalar_type()); + EP_HOST_ASSERT(bias.size(0) == num_combined_tokens and bias.size(1) == hidden); + bias_ptrs[i] = bias.data_ptr(); + } + // Launch data combine auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options()); internode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()), - combined_x.data_ptr(), combined_topk_weights_ptr, + combined_x.data_ptr(), + combined_topk_weights_ptr, is_combined_token_in_rank.data_ptr(), - x.data_ptr(), topk_weights_ptr, - combined_rdma_head.data_ptr(), combined_nvl_head.data_ptr(), - src_meta.data_ptr(), rdma_channel_prefix_matrix.data_ptr(), rdma_rank_prefix_sum.data_ptr(), gbl_channel_prefix_matrix.data_ptr(), - num_tokens, num_combined_tokens, hidden, num_topk, - rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens, - buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens, - rank, num_ranks, comm_stream, num_channels, low_latency_mode); + x.data_ptr(), + topk_weights_ptr, + bias_ptrs[0], + bias_ptrs[1], + combined_rdma_head.data_ptr(), + combined_nvl_head.data_ptr(), + src_meta.data_ptr(), + rdma_channel_prefix_matrix.data_ptr(), + rdma_rank_prefix_sum.data_ptr(), + gbl_channel_prefix_matrix.data_ptr(), + num_tokens, + num_combined_tokens, + hidden, + num_topk, + rdma_buffer_ptr, + config.num_max_rdma_chunked_send_tokens, + config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, + config.num_max_nvl_chunked_send_tokens, + config.num_max_nvl_chunked_recv_tokens, + rank, + num_ranks, + comm_stream, + num_channels, + low_latency_mode); // Wait streams std::optional event; if (async) { event = EventHandle(comm_stream); - for (auto& t: {x, src_meta, - is_combined_token_in_rank, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, - combined_x, combined_rdma_head, combined_nvl_head}) { + for (auto& t : {x, + src_meta, + is_combined_token_in_rank, + rdma_channel_prefix_matrix, + rdma_rank_prefix_sum, + gbl_channel_prefix_matrix, + combined_x, + combined_rdma_head, + combined_nvl_head}) { t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } - for (auto& to: {topk_weights, combined_topk_weights}) { + for (auto& to : {topk_weights, combined_topk_weights, bias_0, bias_1}) { to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); @@ -1004,12 +1389,14 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional(clean_meta_0.first), + clean_meta_0.second, + reinterpret_cast(clean_meta_1.first), + clean_meta_1.second, + rank, + num_ranks, + mask_buffer_ptr, + sync_buffer_ptr, at::cuda::getCurrentCUDAStream()); -#endif //DISABLE_INTERNODE +#else + EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); +#endif } -std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional>> -Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, - int num_max_dispatch_tokens_per_rank, int num_experts, - bool use_fp8, bool async, bool return_recv_hook) { -#if DISABLE_INTERNODE - throw std::runtime_error("Low-latency mode is disabled"); -#else - +std::tuple, + torch::Tensor, + torch::Tensor, + torch::Tensor, + std::optional, + std::optional>> +Buffer::low_latency_dispatch(const torch::Tensor& x, + const torch::Tensor& topk_idx, + const std::optional& cumulative_local_expert_recv_stats, + const std::optional& dispatch_wait_recv_cost_stats, + int num_max_dispatch_tokens_per_rank, + int num_experts, + bool use_fp8, + bool round_scale, + bool use_ue8m0, + bool async, + bool return_recv_hook) { +#ifndef DISABLE_NVSHMEM EP_HOST_ASSERT(low_latency_mode); // Tensor checks @@ -1045,71 +1451,108 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i EP_HOST_ASSERT(x.size(1) % sizeof(int4) == 0 and x.size(1) % 128 == 0); EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous()); EP_HOST_ASSERT(x.size(0) == topk_idx.size(0) and x.size(0) <= num_max_dispatch_tokens_per_rank); - EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64); + EP_HOST_ASSERT(topk_idx.scalar_type() == c10::CppTypeToScalarType::value); EP_HOST_ASSERT(num_experts % num_ranks == 0); + // Diagnosis tensors + if (cumulative_local_expert_recv_stats.has_value()) { + EP_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt); + EP_HOST_ASSERT(cumulative_local_expert_recv_stats->dim() == 1 and cumulative_local_expert_recv_stats->is_contiguous()); + EP_HOST_ASSERT(cumulative_local_expert_recv_stats->size(0) == num_experts / num_ranks); + } + if (dispatch_wait_recv_cost_stats.has_value()) { + EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->scalar_type() == torch::kInt64); + EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->dim() == 1 and dispatch_wait_recv_cost_stats->is_contiguous()); + EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->size(0) == num_ranks); + } + auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)); - auto num_scales = hidden / 128, num_topk = static_cast(topk_idx.size(1)); - int num_local_experts = num_experts / num_ranks; + auto num_topk = static_cast(topk_idx.size(1)); + auto num_local_experts = num_experts / num_ranks; // Buffer control LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); auto buffer = layout.buffers[low_latency_buffer_idx]; auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; - - // Buffer control - LowLatencyLayout nvl_layout(nvl_buffer_ptrs[nvl_rank], num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); - EP_HOST_ASSERT(nvl_layout.total_bytes <= num_rdma_bytes); - auto nvl_buffer = nvl_layout.buffers[low_latency_buffer_idx ^= 1]; - auto nvl_next_buffer = nvl_layout.buffers[low_latency_buffer_idx ^= 1]; auto global_atomic_counter = torch::zeros({1}, torch::dtype(torch::kInt32).device(torch::kCUDA)); - + + // Wait previous tasks to be finished // NOTES: the hook mode will always use the default stream auto compute_stream = at::cuda::getCurrentCUDAStream(); auto launch_stream = return_recv_hook ? compute_stream : comm_stream; - EP_HOST_ASSERT(not (async and return_recv_hook)); + EP_HOST_ASSERT(not(async and return_recv_hook)); if (not return_recv_hook) stream_wait(launch_stream, compute_stream); // Allocate packed tensors - auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, - x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn: torch::kBFloat16)); #ifdef USE_ROCM - if (gfx == 942){ - packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, - x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fnuz: torch::kBFloat16)); - } -#endif - auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, + x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fnuz : torch::kBFloat16)); +#else + auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, + x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn : torch::kBFloat16)); +#endif + auto packed_recv_src_info = + torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA)); auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA)); auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA)); // Allocate column-majored scales auto packed_recv_x_scales = std::optional(); - float* packed_recv_x_scales_ptr = nullptr; + void* packed_recv_x_scales_ptr = nullptr; + EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4"); + if (use_fp8) { - EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4"); - packed_recv_x_scales = torch::empty({num_local_experts, num_scales, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); + // TODO: support unaligned cases + EP_HOST_ASSERT(hidden % 512 == 0); + if (not use_ue8m0) { + packed_recv_x_scales = torch::empty({num_local_experts, hidden / 128, num_ranks * num_max_dispatch_tokens_per_rank}, + torch::dtype(torch::kFloat32).device(torch::kCUDA)); + } else { + EP_HOST_ASSERT(round_scale); + packed_recv_x_scales = torch::empty({num_local_experts, hidden / 512, num_ranks * num_max_dispatch_tokens_per_rank}, + torch::dtype(torch::kInt).device(torch::kCUDA)); + } packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2); - packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr(); + packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr(); } // Kernel launch auto next_clean_meta = next_buffer.clean_meta(); auto launcher = [=](int phases) { - internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr, - packed_recv_src_info.data_ptr(), packed_recv_layout_range.data_ptr(), - packed_recv_count.data_ptr(), - global_atomic_counter.data_ptr(), - buffer.dispatch_rdma_recv_data_buffer, buffer.dispatch_rdma_recv_count_buffer, - buffer.dispatch_rdma_send_buffer, - x.data_ptr(), topk_idx.data_ptr(), - next_clean_meta.first, next_clean_meta.second, - num_tokens, hidden, num_max_dispatch_tokens_per_rank, - num_topk, num_experts, rank, num_ranks, use_fp8, - workspace, launch_stream, phases); + internode_ll::dispatch( + packed_recv_x.data_ptr(), + reinterpret_cast(packed_recv_x_scales_ptr), + packed_recv_src_info.data_ptr(), + packed_recv_layout_range.data_ptr(), + packed_recv_count.data_ptr(), + mask_buffer_ptr, + cumulative_local_expert_recv_stats.has_value() ? cumulative_local_expert_recv_stats->data_ptr() : nullptr, + dispatch_wait_recv_cost_stats.has_value() ? dispatch_wait_recv_cost_stats->data_ptr() : nullptr, + buffer.dispatch_rdma_recv_data_buffer, + reinterpret_cast(buffer.dispatch_rdma_recv_count_buffer), + buffer.dispatch_rdma_send_buffer, + x.data_ptr(), + topk_idx.data_ptr(), + reinterpret_cast(next_clean_meta.first), + next_clean_meta.second, + num_tokens, + hidden, + num_max_dispatch_tokens_per_rank, + num_topk, + num_experts, + rank, + num_ranks, + use_fp8, + round_scale, + use_ue8m0, + workspace, + num_device_sms, + launch_stream, + phases, + global_atomic_counter.data_ptr()); }; launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); @@ -1130,19 +1573,27 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i // Return values return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook}; -#endif //DISABLE_INTERNODE -} - -std::tuple, std::optional>> -Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, - const torch::Tensor& src_info, const torch::Tensor& layout_range, - int num_max_dispatch_tokens_per_rank, int num_experts, - bool zero_copy, bool async, bool return_recv_hook, - const std::optional& out) { -#if DISABLE_INTERNODE - throw std::runtime_error("Low-latency mode is disabled"); #else + EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); + return {}; +#endif +} +std::tuple, std::optional>> Buffer::low_latency_combine( + const torch::Tensor& x, + const torch::Tensor& topk_idx, + const torch::Tensor& topk_weights, + const torch::Tensor& src_info, + const torch::Tensor& layout_range, + const std::optional& combine_wait_recv_cost_stats, + int num_max_dispatch_tokens_per_rank, + int num_experts, + bool use_logfmt, + bool zero_copy, + bool async, + bool return_recv_hook, + const std::optional& out) { +#ifndef DISABLE_NVSHMEM EP_HOST_ASSERT(low_latency_mode); // Tensor checks @@ -1152,7 +1603,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id EP_HOST_ASSERT(x.size(2) % sizeof(int4) == 0 and x.size(2) % 128 == 0); EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous()); EP_HOST_ASSERT(topk_idx.size(0) == topk_weights.size(0) and topk_idx.size(1) == topk_weights.size(1)); - EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64); + EP_HOST_ASSERT(topk_idx.scalar_type() == c10::CppTypeToScalarType::value); EP_HOST_ASSERT(topk_weights.dim() == 2 and topk_weights.is_contiguous()); EP_HOST_ASSERT(topk_weights.size(0) <= num_max_dispatch_tokens_per_rank); EP_HOST_ASSERT(topk_weights.scalar_type() == torch::kFloat32); @@ -1161,27 +1612,30 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id EP_HOST_ASSERT(layout_range.dim() == 2 and layout_range.is_contiguous()); EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64); EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks); + + if (combine_wait_recv_cost_stats.has_value()) { + EP_HOST_ASSERT(combine_wait_recv_cost_stats->scalar_type() == torch::kInt64); + EP_HOST_ASSERT(combine_wait_recv_cost_stats->dim() == 1 and combine_wait_recv_cost_stats->is_contiguous()); + EP_HOST_ASSERT(combine_wait_recv_cost_stats->size(0) == num_ranks); + } + auto hidden = static_cast(x.size(2)); - auto num_local_experts = num_experts / num_ranks, num_topk = static_cast(topk_weights.size(1)); + auto num_topk = static_cast(topk_weights.size(1)); auto num_combined_tokens = static_cast(topk_weights.size(0)); auto global_atomic_counter = torch::zeros({1}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + + // Buffer control LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); auto buffer = layout.buffers[low_latency_buffer_idx]; auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; - - // Buffer control - LowLatencyLayout nvl_layout(nvl_buffer_ptrs[nvl_rank], num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); - EP_HOST_ASSERT(nvl_layout.total_bytes <= num_rdma_bytes); - auto nvl_buffer = nvl_layout.buffers[low_latency_buffer_idx ^= 1]; - auto nvl_next_buffer = nvl_layout.buffers[low_latency_buffer_idx ^= 1]; // Wait previous tasks to be finished // NOTES: the hook mode will always use the default stream auto compute_stream = at::cuda::getCurrentCUDAStream(); auto launch_stream = return_recv_hook ? compute_stream : comm_stream; - EP_HOST_ASSERT(not (async and return_recv_hook)); + EP_HOST_ASSERT(not(async and return_recv_hook)); if (not return_recv_hook) stream_wait(launch_stream, compute_stream); @@ -1200,16 +1654,32 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id auto next_clean_meta = next_buffer.clean_meta(); auto launcher = [=](int phases) { internode_ll::combine(combined_x.data_ptr(), - buffer.combine_rdma_recv_data_buffer, buffer.combine_rdma_recv_flag_buffer, + buffer.combine_rdma_recv_data_buffer, + reinterpret_cast(buffer.combine_rdma_recv_flag_buffer), buffer.combine_rdma_send_buffer, - x.data_ptr(), topk_idx.data_ptr(), topk_weights.data_ptr(), - src_info.data_ptr(), layout_range.data_ptr(), - global_atomic_counter.data_ptr(), - next_clean_meta.first, next_clean_meta.second, - num_combined_tokens, hidden, num_max_dispatch_tokens_per_rank, - num_topk, num_experts, rank, num_ranks, - workspace, launch_stream, - phases, zero_copy); + x.data_ptr(), + topk_idx.data_ptr(), + topk_weights.data_ptr(), + src_info.data_ptr(), + layout_range.data_ptr(), + mask_buffer_ptr, + combine_wait_recv_cost_stats.has_value() ? combine_wait_recv_cost_stats->data_ptr() : nullptr, + reinterpret_cast(next_clean_meta.first), + next_clean_meta.second, + num_combined_tokens, + hidden, + num_max_dispatch_tokens_per_rank, + num_topk, + num_experts, + rank, + num_ranks, + use_logfmt, + workspace, + num_device_sms, + launch_stream, + phases, + zero_copy, + global_atomic_counter.data_ptr()); }; launcher(return_recv_hook ? LOW_LATENCY_SEND_PHASE : (LOW_LATENCY_SEND_PHASE | LOW_LATENCY_RECV_PHASE)); @@ -1230,12 +1700,16 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id // Return values return {combined_x, event, recv_hook}; -#endif // DISABLE_INTERNODE +#else + EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); + return {}; +#endif } -torch::Tensor -Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) { +torch::Tensor Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const { +#ifndef DISABLE_NVSHMEM LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); + auto buffer = layout.buffers[low_latency_buffer_idx]; auto dtype = torch::kBFloat16; auto num_msg_elems = static_cast(buffer.num_bytes_per_combine_msg / elementSize(torch::kBFloat16)); @@ -1245,36 +1719,40 @@ Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank {num_experts / num_ranks, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, {num_ranks * num_max_dispatch_tokens_per_rank * num_msg_elems, num_msg_elems, 1}, torch::TensorOptions().dtype(dtype).device(torch::kCUDA)); +#else + EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); + return {}; +#endif } -std::string Buffer::get_local_ipc_handle_string() const { - return std::string(reinterpret_cast(ipc_handles[nvl_rank].reserved), CUDA_IPC_HANDLE_SIZE); +bool is_sm90_compiled() { +#ifndef DISABLE_SM90_FEATURES + return true; +#else + return false; +#endif } -std::string Buffer::get_local_nvshmem_unique_id_string() const { - EP_HOST_ASSERT(rdma_rank == 0 and "Only RDMA rank 0 can get NVSHMEM unique ID"); - auto unique_id = internode::get_unique_id(); - return std::string(reinterpret_cast(unique_id.data()), unique_id.size()); +void Buffer::low_latency_update_mask_buffer(int rank_to_mask, bool mask) { + EP_HOST_ASSERT(mask_buffer_ptr != nullptr and "Shrink mode must be enabled"); + EP_HOST_ASSERT(rank_to_mask >= 0 and rank_to_mask < num_ranks); + internode_ll::update_mask_buffer(mask_buffer_ptr, rank_to_mask, mask, at::cuda::getCurrentCUDAStream()); } -void Buffer::sync_string(const std::vector &device_ids, - const std::vector &all_gathered_handles, - const std::string& root_unique_id_opt) { - std::vector> py_all_gathered_handles; - for (auto& handle : all_gathered_handles) { - std::optional py_handle_opt = std::nullopt; - if (!handle.empty()) { - py_handle_opt.emplace(handle.c_str(), handle.size()); - } - py_all_gathered_handles.push_back(py_handle_opt); - } - std::optional py_root_unique_id_opt = std::nullopt; - if (!root_unique_id_opt.empty()) { - py_root_unique_id_opt.emplace(root_unique_id_opt.c_str(), root_unique_id_opt.size()); - } - sync(device_ids, py_all_gathered_handles, py_root_unique_id_opt); +void Buffer::low_latency_query_mask_buffer(const torch::Tensor& mask_status) { + EP_HOST_ASSERT(mask_buffer_ptr != nullptr and "Shrink mode must be enabled"); + EP_HOST_ASSERT(mask_status.numel() == num_ranks && mask_status.scalar_type() == torch::kInt32); + + internode_ll::query_mask_buffer( + mask_buffer_ptr, num_ranks, reinterpret_cast(mask_status.data_ptr()), at::cuda::getCurrentCUDAStream()); +} + +void Buffer::low_latency_clean_mask_buffer() { + EP_HOST_ASSERT(mask_buffer_ptr != nullptr and "Shrink mode must be enabled"); + internode_ll::clean_mask_buffer(mask_buffer_ptr, num_ranks, at::cuda::getCurrentCUDAStream()); } -} // namespace deep_ep + +} // namespace deep_ep PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "DeepEP: an efficient expert-parallel communication library"; @@ -1282,8 +1760,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { pybind11::class_(m, "Config") .def(pybind11::init(), py::arg("num_sms") = 20, - py::arg("num_max_nvl_chunked_send_tokens") = 6, py::arg("num_max_nvl_chunked_recv_tokens") = 256, - py::arg("num_max_rdma_chunked_send_tokens") = 6, py::arg("num_max_rdma_chunked_recv_tokens") = 256) + py::arg("num_max_nvl_chunked_send_tokens") = 6, + py::arg("num_max_nvl_chunked_recv_tokens") = 256, + py::arg("num_max_rdma_chunked_send_tokens") = 6, + py::arg("num_max_rdma_chunked_recv_tokens") = 256) .def("get_nvl_buffer_size_hint", &deep_ep::Config::get_nvl_buffer_size_hint) .def("get_rdma_buffer_size_hint", &deep_ep::Config::get_rdma_buffer_size_hint); m.def("get_low_latency_rdma_size_hint", &deep_ep::get_low_latency_rdma_size_hint); @@ -1293,7 +1773,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("current_stream_wait", &deep_ep::EventHandle::current_stream_wait); pybind11::class_(m, "Buffer") - .def(pybind11::init()) + .def(pybind11::init()) .def("is_available", &deep_ep::Buffer::is_available) .def("get_num_rdma_ranks", &deep_ep::Buffer::get_num_rdma_ranks) .def("get_rdma_rank", &deep_ep::Buffer::get_rdma_rank) @@ -1302,7 +1782,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("get_local_ipc_handle", &deep_ep::Buffer::get_local_ipc_handle) .def("get_local_nvshmem_unique_id", &deep_ep::Buffer::get_local_nvshmem_unique_id) .def("get_local_buffer_tensor", &deep_ep::Buffer::get_local_buffer_tensor) + .def("get_comm_stream", &deep_ep::Buffer::get_comm_stream) .def("sync", &deep_ep::Buffer::sync) + .def("destroy", &deep_ep::Buffer::destroy) .def("get_dispatch_layout", &deep_ep::Buffer::get_dispatch_layout) .def("intranode_dispatch", &deep_ep::Buffer::intranode_dispatch) .def("intranode_combine", &deep_ep::Buffer::intranode_combine) @@ -1311,5 +1793,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("clean_low_latency_buffer", &deep_ep::Buffer::clean_low_latency_buffer) .def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch) .def("low_latency_combine", &deep_ep::Buffer::low_latency_combine) + .def("low_latency_update_mask_buffer", &deep_ep::Buffer::low_latency_update_mask_buffer) + .def("low_latency_query_mask_buffer", &deep_ep::Buffer::low_latency_query_mask_buffer) + .def("low_latency_clean_mask_buffer", &deep_ep::Buffer::low_latency_clean_mask_buffer) .def("get_next_low_latency_combine_buffer", &deep_ep::Buffer::get_next_low_latency_combine_buffer); + + m.def("is_sm90_compiled", deep_ep::is_sm90_compiled); + m.attr("topk_idx_t") = + py::reinterpret_borrow((PyObject*)torch::getTHPDtype(c10::CppTypeToScalarType::value)); } diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index 969d164..3b7652f 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -8,6 +8,7 @@ #include #include #include + #include #include @@ -35,21 +36,21 @@ struct Buffer { void* buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; void** buffer_ptrs_gpu = nullptr; - void* nvl_buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; - void** nvl_buffer_ptrs_gpu = nullptr; // NVSHMEM Buffer int64_t num_rdma_bytes; void* rdma_buffer_ptr = nullptr; + // Shrink mode buffer + bool enable_shrink = false; + int* mask_buffer_ptr = nullptr; + int* sync_buffer_ptr = nullptr; + // Device info and communication int device_id; -#ifdef USE_ROCM - int gfx; -#endif + int num_device_sms; int rank, rdma_rank, nvl_rank; int num_ranks, num_rdma_ranks, num_nvl_ranks; cudaIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS]; - cudaIpcMemHandle_t pxn_ipc_handles[NUM_MAX_NVL_PEERS]; // Stream for communication at::cuda::CUDAStream comm_stream; @@ -59,8 +60,15 @@ struct Buffer { // Task fifo int head = 0; - int* task_fifo_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; - int** task_fifo_ptrs_gpu = nullptr; + + // Whether explicit `destroy()` is required. + bool explicitly_destroy; + // After `destroy()` be called, this flag will be true + bool destroyed = false; + + // Barrier signals + int* barrier_signal_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; + int** barrier_signal_ptrs_gpu = nullptr; // Workspace void* workspace = nullptr; @@ -79,9 +87,15 @@ struct Buffer { private: void move_fifo_slots(int num_slots = 1); - + public: - Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode); + Buffer(int rank, + int num_ranks, + int64_t num_nvl_bytes, + int64_t num_rdma_bytes, + bool low_latency_mode, + bool explicitly_destroy, + bool enable_shrink); ~Buffer() noexcept(false); @@ -100,67 +114,161 @@ struct Buffer { pybind11::bytearray get_local_ipc_handle() const; pybind11::bytearray get_local_nvshmem_unique_id() const; - - pybind11::bytearray get_local_pxn_ipc_handle() const; torch::Tensor get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const; - void sync(const std::vector& device_ids, const std::vector>& all_gathered_handles, const std::optional& root_unique_id_opt); - - std::tuple, torch::Tensor, torch::Tensor, std::optional> - get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, std::optional& previous_event, - bool async, bool allocate_on_comm_stream); - - std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional> - intranode_dispatch(const torch::Tensor& x, const std::optional& x_scales, - const std::optional& topk_idx, const std::optional& topk_weights, - const std::optional& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, - int cached_num_recv_tokens, const std::optional& cached_rank_prefix_matrix, const std::optional& cached_channel_prefix_matrix, - int expert_alignment, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); - - std::tuple, std::optional> - intranode_combine(const torch::Tensor& x, const std::optional& topk_weights, - const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix, - const torch::Tensor& send_head, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); - - std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, std::optional, torch::Tensor, std::optional, torch::Tensor, std::optional, std::optional, std::optional, std::optional> - internode_dispatch(const torch::Tensor& x, const std::optional& x_scales, - const std::optional& topk_idx, const std::optional& topk_weights, - const std::optional& num_tokens_per_rank, const std::optional& num_tokens_per_rdma_rank, - const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, - int cached_num_recv_tokens, int cached_num_rdma_recv_tokens, - const std::optional& cached_rdma_channel_prefix_matrix, const std::optional& cached_recv_rdma_rank_prefix_sum, - const std::optional& cached_gbl_channel_prefix_matrix, const std::optional& cached_recv_gbl_rank_prefix_sum, - int expert_alignment, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); - - std::tuple, std::optional> - internode_combine(const torch::Tensor& x, const std::optional& topk_weights, - const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank, - const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix, - const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head, - const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); + torch::Stream get_comm_stream() const; + + void sync(const std::vector& device_ids, + const std::vector>& all_gathered_handles, + const std::optional& root_unique_id_opt); + + void destroy(); + + std::tuple, torch::Tensor, torch::Tensor, std::optional> get_dispatch_layout( + const torch::Tensor& topk_idx, + int num_experts, + std::optional& previous_event, + bool async, + bool allocate_on_comm_stream); + + std::tuple, + std::optional, + std::optional, + std::vector, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + std::optional> + intranode_dispatch(const torch::Tensor& x, + const std::optional& x_scales, + const std::optional& topk_idx, + const std::optional& topk_weights, + const std::optional& num_tokens_per_rank, + const torch::Tensor& is_token_in_rank, + const std::optional& num_tokens_per_expert, + int cached_num_recv_tokens, + const std::optional& cached_rank_prefix_matrix, + const std::optional& cached_channel_prefix_matrix, + int expert_alignment, + int num_worst_tokens, + const Config& config, + std::optional& previous_event, + bool async, + bool allocate_on_comm_stream); + + std::tuple, std::optional> intranode_combine( + const torch::Tensor& x, + const std::optional& topk_weights, + //const std::optional& bias_0, + //const std::optional& bias_1, + const torch::Tensor& src_idx, + const torch::Tensor& rank_prefix_matrix, + const torch::Tensor& channel_prefix_matrix, + const torch::Tensor& send_head, + const Config& config, + std::optional& previous_event, + bool async, + bool allocate_on_comm_stream); + + std::tuple, + std::optional, + std::optional, + std::vector, + torch::Tensor, + torch::Tensor, + std::optional, + torch::Tensor, + std::optional, + torch::Tensor, + std::optional, + std::optional, + std::optional, + std::optional> + internode_dispatch(const torch::Tensor& x, + const std::optional& x_scales, + const std::optional& topk_idx, + const std::optional& topk_weights, + const std::optional& num_tokens_per_rank, + const std::optional& num_tokens_per_rdma_rank, + const torch::Tensor& is_token_in_rank, + const std::optional& num_tokens_per_expert, + int cached_num_recv_tokens, + int cached_num_rdma_recv_tokens, + const std::optional& cached_rdma_channel_prefix_matrix, + const std::optional& cached_recv_rdma_rank_prefix_sum, + const std::optional& cached_gbl_channel_prefix_matrix, + const std::optional& cached_recv_gbl_rank_prefix_sum, + int expert_alignment, + const Config& config, + std::optional& previous_event, + bool async, + bool allocate_on_comm_stream); + + std::tuple, std::optional> internode_combine( + const torch::Tensor& x, + const std::optional& topk_weights, + const std::optional& bias_0, + const std::optional& bias_1, + const torch::Tensor& src_meta, + const torch::Tensor& is_combined_token_in_rank, + const torch::Tensor& rdma_channel_prefix_matrix, + const torch::Tensor& rdma_rank_prefix_sum, + const torch::Tensor& gbl_channel_prefix_matrix, + const torch::Tensor& combined_rdma_head, + const torch::Tensor& combined_nvl_head, + const Config& config, + std::optional& previous_event, + bool async, + bool allocate_on_comm_stream); void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts); - std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional>> - low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, - int num_max_dispatch_tokens_per_rank, int num_experts, - bool use_fp8, bool async, bool return_recv_hook); - - std::tuple, std::optional>> - low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, - const torch::Tensor& src_info, const torch::Tensor& layout_range, - int num_max_dispatch_tokens_per_rank, int num_experts, - bool zero_copy, bool async, bool return_recv_hook, - const std::optional& out = std::nullopt); - - torch::Tensor - get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts); - - // addtional interface for c++ - std::string get_local_ipc_handle_string() const; - std::string get_local_nvshmem_unique_id_string() const; - void sync_string(const std::vector& device_ids, const std::vector& all_gathered_handles, const std::string& root_unique_id_opt); + std::tuple, + torch::Tensor, + torch::Tensor, + torch::Tensor, + std::optional, + std::optional>> + low_latency_dispatch(const torch::Tensor& x, + const torch::Tensor& topk_idx, + const std::optional& cumulative_local_expert_recv_stats, + const std::optional& dispatch_wait_recv_cost_stats, + int num_max_dispatch_tokens_per_rank, + int num_experts, + bool use_fp8, + bool round_scale, + bool use_ue8m0, + bool async, + bool return_recv_hook); + + std::tuple, std::optional>> low_latency_combine( + const torch::Tensor& x, + const torch::Tensor& topk_idx, + const torch::Tensor& topk_weights, + const torch::Tensor& src_info, + const torch::Tensor& layout_range, + const std::optional& combine_wait_recv_cost_stats, + int num_max_dispatch_tokens_per_rank, + int num_experts, + bool use_logfmt, + bool zero_copy, + bool async, + bool return_recv_hook, + const std::optional& out = std::nullopt); + + torch::Tensor get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const; + + void low_latency_update_mask_buffer(int rank_to_mask, bool mask); + + void low_latency_query_mask_buffer(const torch::Tensor& mask_status); + + void low_latency_clean_mask_buffer(); }; -} // namespace deep_ep +} // namespace deep_ep diff --git a/csrc/kernels/api.cuh b/csrc/kernels/api.cuh index f18a859..c540346 100644 --- a/csrc/kernels/api.cuh +++ b/csrc/kernels/api.cuh @@ -7,7 +7,8 @@ namespace deep_ep { // Intranode runtime namespace intranode { -void barrier(int **task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream); +//void barrier(int **task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream); +void barrier(int **task_fifo_ptrs, int rank, int num_ranks, cudaStream_t stream, int head = 0); } // namespace intranode @@ -30,6 +31,24 @@ void finalize(); } // namespace internode + +// Layout kernels +namespace layout { + +void get_dispatch_layout(const topk_idx_t* topk_idx, + int* num_tokens_per_rank, + int* num_tokens_per_rdma_rank, + int* num_tokens_per_expert, + bool* is_token_in_rank, + int num_tokens, + int num_topk, + int num_ranks, + int num_experts, + cudaStream_t stream); + +} // namespace layout + + // Intranode kernels namespace intranode { @@ -37,32 +56,67 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix, int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment, - void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank, - cudaStream_t stream, int num_sms); + void** buffer_ptrs, int **task_fifo_ptrs, int rank, + cudaStream_t stream, int num_sms, int head = 0); void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, - void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank, int num_ranks, - cudaStream_t stream); - -void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset, - int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, - const bool* is_token_in_rank, const int* channel_prefix_matrix, - int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, - void** buffer_ptrs, int rank, int num_ranks, - cudaStream_t stream, int num_sms, - int num_max_send_tokens, int num_recv_buffer_tokens); + void** buffer_ptrs, int **task_fifo_ptrs, int rank, int num_ranks, + cudaStream_t stream, int head = 0); + +void dispatch(void* recv_x, + float* recv_x_scales, + int* recv_src_idx, + topk_idx_t* recv_topk_idx, + float* recv_topk_weights, + int* recv_channel_offset, + int* send_head, + const void* x, + const float* x_scales, + const topk_idx_t* topk_idx, + const float* topk_weights, + const bool* is_token_in_rank, + const int* channel_prefix_matrix, + int num_tokens, + int num_worst_tokens, + int hidden_int4, + int num_topk, + int num_experts, + int num_scales, + int scale_token_stride, + int scale_hidden_stride, + void** buffer_ptrs, + int rank, + int num_ranks, + cudaStream_t stream, + int num_sms, + int num_max_send_tokens, + int num_recv_buffer_tokens); void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int, - int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream); + int** task_fifo_ptrs, int rank, int num_ranks, cudaStream_t stream, int head = 0); void combine(cudaDataType_t type, - void* recv_x, float* recv_topk_weights, - const void* x, const float* topk_weights, - const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix, - int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk, - void** buffer_ptrs, int rank, int num_ranks, - cudaStream_t stream, int num_sms, - int num_max_send_tokens, int num_recv_buffer_tokens); + void* recv_x, + float* recv_topk_weights, + const void* x, + const float* topk_weights, + //const void* bias_0, + //const void* bias_1, + const int* src_idx, + const int* rank_prefix_matrix, + const int* channel_prefix_matrix, + int* send_head, + int num_tokens, + int num_recv_tokens, + int hidden, + int num_topk, + void** buffer_ptrs, + int rank, + int num_ranks, + cudaStream_t stream, + int num_sms, + int num_max_send_tokens, + int num_recv_buffer_tokens); } // namespace intranode @@ -78,82 +132,210 @@ void get_dispatch_layout(const int64_t* topk_idx, int num_tokens, int num_topk, int num_ranks, int num_experts, cudaStream_t stream); -void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, - const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped, - const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, - const bool* is_token_in_rank, int num_tokens, int num_channels, - int hidden_int4, int num_scales, int num_topk, int expert_alignment, - int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, - int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, - void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, - int** task_fifo_ptrs, int head, int rank, - cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, - bool low_latency_mode); - -void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta, - const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, - int* send_rdma_head, int* send_nvl_head, - int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix, - const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum, - const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum, - int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts, +void notify_dispatch(const int* num_tokens_per_rank, + int* moe_recv_counter_mapped, + int num_ranks, + const int* num_tokens_per_rdma_rank, + int* moe_recv_rdma_counter_mapped, + const int* num_tokens_per_expert, + int* moe_recv_expert_counter_mapped, + int num_experts, + const bool* is_token_in_rank, + int num_tokens, + int num_channels, + int hidden_int4, + int num_scales, + int num_topk, + int expert_alignment, + int* rdma_channel_prefix_matrix, + int* recv_rdma_rank_prefix_sum, + int* gbl_channel_prefix_matrix, + int* recv_gbl_rank_prefix_sum, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_recv_tokens, + int** barrier_signal_ptrs, + int rank, + cudaStream_t stream, + int64_t num_rdma_bytes, + int64_t num_nvl_bytes, + bool low_latency_mode, + int head = 0); + +void dispatch(void* recv_x, + float* recv_x_scales, + topk_idx_t* recv_topk_idx, + float* recv_topk_weights, + void* recv_src_meta, + const void* x, + const float* x_scales, + const topk_idx_t* topk_idx, + const float* topk_weights, + int* send_rdma_head, + int* send_nvl_head, + int* recv_rdma_channel_prefix_matrix, + int* recv_gbl_channel_prefix_matrix, + const int* rdma_channel_prefix_matrix, + const int* recv_rdma_rank_prefix_sum, + const int* gbl_channel_prefix_matrix, + const int* recv_gbl_rank_prefix_sum, const bool* is_token_in_rank, - void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, - int rank, int num_ranks, bool is_cached_dispatch, - cudaStream_t stream, int num_channels, bool low_latency_mode); - -void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, - int num_ranks, int num_channels, int num_combined_tokens, int* combined_rdma_head, - const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head, - void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, - int** task_fifo_ptrs, int head, int rank, cudaStream_t stream, - int64_t num_rdma_bytes, int64_t num_nvl_bytes, - bool is_cached_dispatch, bool low_latency_mode); + int num_tokens, + int hidden_int4, + int num_scales, + int num_topk, + int num_experts, + int scale_token_stride, + int scale_hidden_stride, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_send_tokens, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_send_tokens, + int num_max_nvl_chunked_recv_tokens, + int rank, + int num_ranks, + bool is_cached_dispatch, + cudaStream_t stream, + int num_channels, + bool low_latency_mode); + +void cached_notify(int hidden_int4, + int num_scales, + int num_topk_idx, + int num_topk_weights, + int num_ranks, + int num_channels, + int num_combined_tokens, + int* combined_rdma_head, + const int* rdma_channel_prefix_matrix, + const int* rdma_rank_prefix_sum, + int* combined_nvl_head, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_recv_tokens, + int** barrier_signal_ptrs, + int rank, + cudaStream_t stream, + int64_t num_rdma_bytes, + int64_t num_nvl_bytes, + bool is_cached_dispatch, + bool low_latency_mode, + int head = 0); + void combine(cudaDataType_t type, - void* combined_x, float* combined_topk_weights, + void* combined_x, + float* combined_topk_weights, const bool* is_combined_token_in_rank, - const void* x, const float* topk_weights, - const int* combined_rdma_head, const int* combined_nvl_head, - const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, - int num_tokens, int num_combined_tokens, int hidden, int num_topk, - void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, - int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode); + const void* x, + const float* topk_weights, + const void* bias_0, + const void* bias_1, + const int* combined_rdma_head, + const int* combined_nvl_head, + const void* src_meta, + const int* rdma_channel_prefix_matrix, + const int* rdma_rank_prefix_sum, + const int* gbl_channel_prefix_matrix, + int num_tokens, + int num_combined_tokens, + int hidden, + int num_topk, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_send_tokens, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_send_tokens, + int num_max_nvl_chunked_recv_tokens, + int rank, + int num_ranks, + cudaStream_t stream, + int num_channels, + bool low_latency_mode); + } // namespace internode #if !DISABLE_INTERNODE // Internode low-latency kernels namespace internode_ll { -void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, - int64_t* clean_1, int num_clean_int_1, +void clean_low_latency_buffer(int64_t* clean_0, + int num_clean_int_0, + int64_t* clean_1, + int num_clean_int_1, + int rank, + int num_ranks, + int* mask_buffer_ptr, + int* sync_buffer_ptr, cudaStream_t stream); -void dispatch(void* packed_recv_x, float* packed_recv_x_scales, - int* packed_recv_src_info, int64_t* packed_recv_layout_range, +void dispatch(void* packed_recv_x, + float* packed_recv_x_scales, + int* packed_recv_src_info, + int64_t* packed_recv_layout_range, int* packed_recv_count, - int* global_atomic_counter, - void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x, - const void* x, const int64_t* topk_idx, - int64_t* next_clean, int num_next_clean_int, - int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, - int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8, - void* workspace, cudaStream_t stream, int phases); + int* mask_buffer_ptr, + int* cumulative_local_expert_recv_stats, + int64_t* dispatch_wait_recv_cost_stats, + void* rdma_recv_x, + int64_t* rdma_recv_count, + void* rdma_x, + const void* x, + const topk_idx_t* topk_idx, + int64_t* next_clean, + int num_next_clean_int, + int num_tokens, + int hidden, + int num_max_dispatch_tokens_per_rank, + int num_topk, + int num_experts, + int rank, + int num_ranks, + bool use_fp8, + bool round_scale, + bool use_ue8m0, + void* workspace, + int num_device_sms, + cudaStream_t stream, + int phases, + int* global_atomic_counter = NULL); void combine(void* combined_x, - void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x, - const void* x, const int64_t* topk_idx, const float* topk_weights, - const int* src_info, const int64_t* layout_range, - int* global_atomic_counter, - int64_t* next_clean, int num_next_clean_int, - int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, - int num_topk, int num_experts, int rank, int num_ranks, - void* workspace, cudaStream_t stream, - int phases, bool zero_copy); + void* rdma_recv_x, + int64_t* rdma_recv_flag, + void* rdma_send_x, + const void* x, + const topk_idx_t* topk_idx, + const float* topk_weights, + const int* src_info, + const int64_t* layout_range, + int* mask_buffer_ptr, + int64_t* combine_wait_recv_cost_stats, + int64_t* next_clean, + int num_next_clean_int, + int num_combined_tokens, + int hidden, + int num_max_dispatch_tokens_per_rank, + int num_topk, + int num_experts, + int rank, + int num_ranks, + bool use_logfmt, + void* workspace, + int num_device_sms, + cudaStream_t stream, + int phases, + bool zero_copy, + int* global_atomic_counter = NULL); + +void query_mask_buffer(int* mask_buffer_ptr, int num_ranks, int* output_mask_tensor, cudaStream_t stream); + +void update_mask_buffer(int* mask_buffer_ptr, int rank_to_mask, bool mask, cudaStream_t stream); + +void clean_mask_buffer(int* mask_buffer_ptr, int num_ranks, cudaStream_t stream); } // namespace internode_ll #endif diff --git a/csrc/kernels/configs.cuh b/csrc/kernels/configs.cuh index 4f53a58..9aee4ae 100644 --- a/csrc/kernels/configs.cuh +++ b/csrc/kernels/configs.cuh @@ -29,6 +29,21 @@ __host__ __device__ __forceinline__ void host_device_printf(const char* format, #define printf host_device_printf #endif +namespace deep_ep { + +#ifndef TOPK_IDX_BITS +#define TOPK_IDX_BITS 64 +#endif + +#define INT_BITS_T2(bits) int##bits##_t +#define INT_BITS_T(bits) INT_BITS_T2(bits) +typedef INT_BITS_T(TOPK_IDX_BITS) topk_idx_t; // int32_t or int64_t +#undef INT_BITS_T +#undef INT_BITS_T2 + +} // namespace deep_ep + + #ifdef USE_ROCM static constexpr int32_t kWarpSize = 64; // For ROCm equals to half the wave size or Nvidia warp size diff --git a/csrc/kernels/exception.cuh b/csrc/kernels/exception.cuh index 81b4be9..77ae7a6 100644 --- a/csrc/kernels/exception.cuh +++ b/csrc/kernels/exception.cuh @@ -46,7 +46,7 @@ do { \ do { \ if (not (cond)) { \ printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ - trap(); \ + abort();\ } \ } while (0) #else diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index 7197d53..f38d203 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -406,18 +406,36 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in } } -void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, - const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped, - const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, - const bool* is_token_in_rank, int num_tokens, int num_channels, - int hidden_int4, int num_scales, int num_topk, int expert_alignment, - int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, - int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, - void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, - int** task_fifo_ptrs, int head, int rank, - cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, - bool low_latency_mode) { +void notify_dispatch(const int* num_tokens_per_rank, + int* moe_recv_counter_mapped, + int num_ranks, + const int* num_tokens_per_rdma_rank, + int* moe_recv_rdma_counter_mapped, + const int* num_tokens_per_expert, + int* moe_recv_expert_counter_mapped, + int num_experts, + const bool* is_token_in_rank, + int num_tokens, + int num_channels, + int hidden_int4, + int num_scales, + int num_topk, + int expert_alignment, + int* rdma_channel_prefix_matrix, + int* recv_rdma_rank_prefix_sum, + int* gbl_channel_prefix_matrix, + int* recv_gbl_rank_prefix_sum, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_recv_tokens, + int** barrier_signal_ptrs, + int rank, + cudaStream_t stream, + int64_t num_rdma_bytes, + int64_t num_nvl_bytes, + bool low_latency_mode, + int head = 0) { #define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks) { \ auto notify_dispatch_func = low_latency_mode ? \ notify_dispatch : notify_dispatch; \ @@ -431,7 +449,7 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \ gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \ rdma_buffer_ptr, \ - buffer_ptrs, task_fifo_ptrs, head, rank, \ + buffer_ptrs, barrier_signal_ptrs, head, rank, \ cpu_rdma_team); } break constexpr int kNumThreads = 256; @@ -1135,18 +1153,43 @@ asm volatile( #endif } -void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta, - const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, - int* send_rdma_head, int* send_nvl_head, - int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix, - const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum, - const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum, - int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts, +void dispatch(void* recv_x, + float* recv_x_scales, + topk_idx_t* recv_topk_idx, + float* recv_topk_weights, + void* recv_src_meta, + const void* x, + const float* x_scales, + const topk_idx_t* topk_idx, + const float* topk_weights, + int* send_rdma_head, + int* send_nvl_head, + int* recv_rdma_channel_prefix_matrix, + int* recv_gbl_channel_prefix_matrix, + const int* rdma_channel_prefix_matrix, + const int* recv_rdma_rank_prefix_sum, + const int* gbl_channel_prefix_matrix, + const int* recv_gbl_rank_prefix_sum, const bool* is_token_in_rank, - void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, - int rank, int num_ranks, bool is_cached_dispatch, - cudaStream_t stream, int num_channels, bool low_latency_mode) { + int num_tokens, + int hidden_int4, + int num_scales, + int num_topk, + int num_experts, + int scale_token_stride, + int scale_hidden_stride, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_send_tokens, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_send_tokens, + int num_max_nvl_chunked_recv_tokens, + int rank, + int num_ranks, + bool is_cached_dispatch, + cudaStream_t stream, + int num_channels, + bool low_latency_mode) { constexpr int kNumDispatchRDMASenderWarps = 7; #define DISPATCH_LAUNCH_CASE(num_rdma_ranks) { \ @@ -1284,14 +1327,29 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in } } -void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, - int num_ranks, int num_channels, int num_combined_tokens, int* combined_rdma_head, - const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head, - void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, - int** task_fifo_ptrs, int head, int rank, cudaStream_t stream, - int64_t num_rdma_bytes, int64_t num_nvl_bytes, - bool is_cached_dispatch, bool low_latency_mode) { +void cached_notify(int hidden_int4, + int num_scales, + int num_topk_idx, + int num_topk_weights, + int num_ranks, + int num_channels, + int num_combined_tokens, + int* combined_rdma_head, + const int* rdma_channel_prefix_matrix, + const int* rdma_rank_prefix_sum, + int* combined_nvl_head, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_recv_tokens, + int** barrier_signal_ptrs, + int rank, + cudaStream_t stream, + int64_t num_rdma_bytes, + int64_t num_nvl_bytes, + bool is_cached_dispatch, + bool low_latency_mode, + int head = 0) { const int num_threads = std::max(128, kWarpSize * num_channels); const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; @@ -1313,7 +1371,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to combined_rdma_head, num_combined_tokens, num_channels, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, combined_nvl_head, rdma_buffer_ptr, - buffer_ptrs, task_fifo_ptrs, head, rank, num_ranks, + buffer_ptrs, barrier_signal_ptrs, head, rank, num_ranks, is_cached_dispatch, cpu_rdma_team); } @@ -1962,16 +2020,36 @@ combine(int4* combined_x, float* combined_topk_weights, #endif } + void combine(cudaDataType_t type, - void* combined_x, float* combined_topk_weights, + void* combined_x, + float* combined_topk_weights, const bool* is_combined_token_in_rank, - const void* x, const float* topk_weights, - const int* combined_rdma_head, const int* combined_nvl_head, - const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, - int num_tokens, int num_combined_tokens, int hidden, int num_topk, - void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, - int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode) + const void* x, + const float* topk_weights, + const void* bias_0, + const void* bias_1, + const int* combined_rdma_head, + const int* combined_nvl_head, + const void* src_meta, + const int* rdma_channel_prefix_matrix, + const int* rdma_rank_prefix_sum, + const int* gbl_channel_prefix_matrix, + int num_tokens, + int num_combined_tokens, + int hidden, + int num_topk, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_send_tokens, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_send_tokens, + int num_max_nvl_chunked_recv_tokens, + int rank, + int num_ranks, + cudaStream_t stream, + int num_channels, + bool low_latency_mode) { const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index 9cdc77d..cd731be 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -63,8 +63,14 @@ __global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, #endif } -void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, - int64_t* clean_1, int num_clean_int_1, +void clean_low_latency_buffer(int64_t* clean_0, + int num_clean_int_0, + int64_t* clean_1, + int num_clean_int_1, + int rank, + int num_ranks, + int* mask_buffer_ptr, + int* sync_buffer_ptr, cudaStream_t stream) { constexpr int kNumThreads = 256; @@ -98,7 +104,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, #if !defined(ROCM_DISABLE_CTX) __shared__ internode::shmem_ctx_t ctx; - internode::shmem_wg_ctx_create(&ctx); + EP_DEVICE_ASSERT(internode::shmem_wg_ctx_create(&ctx) == 0); #endif // FP8 staffs @@ -198,12 +204,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, for (int j = 0; j < kNumElemsPerRead; j += 2) { float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale}; #ifdef USE_ROCM -#if defined(__gfx942__) fp8x2_values[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E4M3_FNUZ); -#endif -#if defined(__gfx950__) - fp8x2_values[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E4M3); -#endif #else fp8x2_values[j / 2] = __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3); #endif @@ -239,6 +240,11 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, internode::shmem_ctx_schar_put_nbi_warp(ctx, #endif reinterpret_cast(dst_ptr), reinterpret_cast(src_ptr), num_bytes_per_msg, dst_rank); +#if defined(ROCM_DISABLE_CTX) + internode::shmem_fence(); +#else + internode::shmem_ctx_quiet(ctx); +#endif #else nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx); #endif @@ -296,13 +302,6 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, } } - if (thread_id == 0 and num_ranks > 8){ -#if defined(ROCM_DISABLE_CTX) - internode::shmem_fence(); -#else - internode::shmem_ctx_quiet(ctx); -#endif - } //revert sync_large_warp_counters to 0 for next sync __syncthreads(); @@ -325,7 +324,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx); #endif } else { - st_na_release(reinterpret_cast(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), -num_tokens_sent - 1); + st_na_release(reinterpret_cast(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), -num_tokens_sent - 1); } // Clean workspace for next use @@ -340,8 +339,12 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, // Receiving phase LOW_LATENCY_DISPATCH_RECV: - if ((phases & LOW_LATENCY_RECV_PHASE) == 0) + if ((phases & LOW_LATENCY_RECV_PHASE) == 0) { +#if defined(USE_ROCM) && !defined(ROCM_DISABLE_CTX) + internode::shmem_wg_ctx_destroy(&ctx); +#endif return; + } // For send-and-recv kernels, we need a grid sync for making `packed_recv_count` visible if (phases & LOW_LATENCY_SEND_PHASE){ @@ -422,16 +425,36 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, #endif } -void dispatch(void* packed_recv_x, float* packed_recv_x_scales, - int* packed_recv_src_info, int64_t* packed_recv_layout_range, +void dispatch(void* packed_recv_x, + float* packed_recv_x_scales, + int* packed_recv_src_info, + int64_t* packed_recv_layout_range, int* packed_recv_count, - int* global_atomic_counter, - void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x, - const void* x, const int64_t* topk_idx, - int64_t* next_clean, int num_next_clean_int, - int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, - int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8, - void* workspace, cudaStream_t stream, int phases) { + int* mask_buffer_ptr, + int* cumulative_local_expert_recv_stats, + int64_t* dispatch_wait_recv_cost_stats, + void* rdma_recv_x, + int64_t* rdma_recv_count, + void* rdma_x, + const void* x, + const topk_idx_t* topk_idx, + int64_t* next_clean, + int num_next_clean_int, + int num_tokens, + int hidden, + int num_max_dispatch_tokens_per_rank, + int num_topk, + int num_experts, + int rank, + int num_ranks, + bool use_fp8, + bool round_scale, + bool use_ue8m0, + void* workspace, + int num_device_sms, + cudaStream_t stream, + int phases, + int* global_atomic_counter = NULL) { #ifdef USE_ROCM constexpr int kNumWarpsPerGroup = 5; @@ -488,7 +511,7 @@ combine(void* combined_x, #if !defined(ROCM_DISABLE_CTX) __shared__ internode::shmem_ctx_t ctx; - internode::shmem_wg_ctx_create(&ctx); + EP_DEVICE_ASSERT(internode::shmem_wg_ctx_create(&ctx) == 0); #endif const auto sm_id = static_cast(blockIdx.x); const auto num_sms = static_cast(gridDim.x); @@ -580,14 +603,13 @@ combine(void* combined_x, internode::shmem_ctx_schar_put_nbi_warp(ctx, #endif reinterpret_cast(dst_ptr), reinterpret_cast(buf_ptr), hidden * sizeof(gpu_bfloat16_t), dst_rank); - if (num_ranks > 8){ + #if defined(ROCM_DISABLE_CTX) internode::shmem_fence(); #else internode::shmem_ctx_quiet(ctx); #endif } - } } // Put finishing flag @@ -616,7 +638,7 @@ combine(void* combined_x, nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx); #endif } else { - st_na_release(reinterpret_cast(rdma_recv_flag + global_expert_idx), 1); + st_na_release(reinterpret_cast(rdma_recv_flag + global_expert_idx), 1); } atomic_add_release_global(atomic_clean_flag, -1); } @@ -625,8 +647,12 @@ combine(void* combined_x, // Receiving phase LOW_LATENCY_COMBINE_RECV: - if ((phases & LOW_LATENCY_RECV_PHASE) == 0) + if ((phases & LOW_LATENCY_RECV_PHASE) == 0) { +#if defined(USE_ROCM) && !defined(ROCM_DISABLE_CTX) + internode::shmem_wg_ctx_destroy(&ctx); +#endif return; + } // Wait all ranks to arrive and notify PCIe usage if (responsible_expert_idx < num_experts) { @@ -681,15 +707,32 @@ combine(void* combined_x, } void combine(void* combined_x, - void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x, - const void* x, const int64_t* topk_idx, const float* topk_weights, - const int* src_info, const int64_t* layout_range, - int* global_atomic_counter, - int64_t* next_clean, int num_next_clean_int, - int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, - int num_topk, int num_experts, int rank, int num_ranks, - void* workspace, cudaStream_t stream, - int phases, bool zero_copy) { + void* rdma_recv_x, + int64_t* rdma_recv_flag, + void* rdma_send_x, + const void* x, + const topk_idx_t* topk_idx, + const float* topk_weights, + const int* src_info, + const int64_t* layout_range, + int* mask_buffer_ptr, + int64_t* combine_wait_recv_cost_stats, + int64_t* next_clean, + int num_next_clean_int, + int num_combined_tokens, + int hidden, + int num_max_dispatch_tokens_per_rank, + int num_topk, + int num_experts, + int rank, + int num_ranks, + bool use_logfmt, + void* workspace, + int num_device_sms, + cudaStream_t stream, + int phases, + bool zero_copy, + int* global_atomic_counter = NULL) { #ifdef USE_ROCM constexpr int kNumWarpsPerGroup = 4; constexpr int kNumWarpGroups = 4; @@ -726,6 +769,57 @@ LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \ #undef COMBINE_LAUNCH_CASE } + +template +__launch_bounds__(kNumThreads, 1) __global__ void query_mask_buffer(int* mask_buffer_ptr, int num_ranks, int* mask_tensor) { + const auto num_sms = static_cast(gridDim.x); + const auto sm_id = static_cast(blockIdx.x); + const auto num_threads = num_sms * kNumThreads; + const auto thread_id = sm_id * kNumThreads + static_cast(threadIdx.x); + for (int rank_id = thread_id; rank_id < num_ranks; rank_id += num_threads) { + mask_tensor[rank_id] = mask_buffer_ptr[rank_id]; + } +} + +void query_mask_buffer(int* mask_buffer_ptr, int num_ranks, int* mask_tensor, cudaStream_t stream) { + constexpr int num_sms = 1; + constexpr int kNumThreads = 1024; + SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); + LAUNCH_KERNEL(&cfg, query_mask_buffer, mask_buffer_ptr, num_ranks, mask_tensor); +} + +template +__launch_bounds__(kNumThreads, 1) __global__ void update_mask_buffer(int* mask_buffer_ptr, int rank_to_mask, bool mask) { + const auto sm_id = static_cast(blockIdx.x); + const auto thread_id = static_cast(threadIdx.x); + if (sm_id == 0 && thread_id == 0) { + atomicExch(mask_buffer_ptr + rank_to_mask, mask ? 1 : 0); + } +} + +void update_mask_buffer(int* mask_buffer_ptr, int rank, bool mask, cudaStream_t stream) { + constexpr int num_sms = 1; + constexpr int kNumThreads = 32; + SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); + LAUNCH_KERNEL(&cfg, update_mask_buffer, mask_buffer_ptr, rank, mask); +} + +template +__launch_bounds__(kNumThreads, 1) __global__ void clean_mask_buffer(int* mask_buffer_ptr, int num_ranks) { + auto thread_id = static_cast(threadIdx.x); + #pragma unroll + for (int i = thread_id; i < num_ranks; i += kNumThreads) + mask_buffer_ptr[i] = 0; +} + +void clean_mask_buffer(int* mask_buffer_ptr, int num_ranks, cudaStream_t stream) { + constexpr int num_sms = 1; + constexpr int kNumThreads = 32; + SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); + LAUNCH_KERNEL(&cfg, clean_mask_buffer, mask_buffer_ptr, num_ranks); +} + + } // namespace internode_ll -} // namespace deep_ep \ No newline at end of file +} // namespace deep_ep diff --git a/csrc/kernels/intranode.cu b/csrc/kernels/intranode.cu index 6738ca8..3ef7e35 100644 --- a/csrc/kernels/intranode.cu +++ b/csrc/kernels/intranode.cu @@ -18,13 +18,14 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, auto sm_id = static_cast(blockIdx.x); auto thread_id = static_cast(threadIdx.x), num_threads = static_cast(blockDim.x); auto lane_id = thread_id % kWarpSize, warp_id = thread_id / kWarpSize, num_warps = num_threads / kWarpSize; - + if (sm_id == 0) { // Barrier first barrier_device(task_fifo_ptrs, head, rank); move_fifo_slots(head); __syncthreads(); + int *per_rank_buffer, *per_expert_buffer; if (thread_id < kNumRanks) { per_rank_buffer = reinterpret_cast(buffer_ptrs[thread_id]); @@ -36,9 +37,10 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, // - `per_expert_buffer[rank][i, j]` means the number of tokens from rank i to local expert j int num_experts_per_rank = num_experts / kNumRanks; if (thread_id < kNumRanks) { - #pragma unroll - for (int i = 0; i < kNumRanks; ++ i) - per_rank_buffer[rank * kNumRanks + i] = num_tokens_per_rank[i]; + //#pragma unroll + //for (int i = 0; i < kNumRanks; ++ i) + // per_rank_buffer[rank * kNumRanks + i] = num_tokens_per_rank[i]; + per_rank_buffer[rank * kNumRanks + thread_id] = num_tokens_per_rank[thread_id]; #pragma unroll for (int i = 0; i < num_experts_per_rank; ++ i) per_expert_buffer[rank * num_experts_per_rank + i] = num_tokens_per_expert[thread_id * num_experts_per_rank + i]; @@ -112,25 +114,39 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, } } -void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, - const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, - int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix, - int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment, - void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank, - cudaStream_t stream, int num_channels) { +void notify_dispatch(const int* num_tokens_per_rank, + int* moe_recv_counter_mapped, + int num_ranks, + const int* num_tokens_per_expert, + int* moe_recv_expert_counter_mapped, + int num_experts, + int num_tokens, + const bool* is_token_in_rank, + int* channel_prefix_matrix, + int* rank_prefix_matrix_copy, + int num_memset_int, + int expert_alignment, + void** buffer_ptrs, + int** barrier_signal_ptrs, + int rank, + cudaStream_t stream, + int num_channels, + int head=0) { + #define NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \ LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, notify_dispatch, \ num_tokens_per_rank, moe_recv_counter_mapped, \ num_tokens_per_expert, moe_recv_expert_counter_mapped, num_experts, \ num_tokens, num_channels, is_token_in_rank, channel_prefix_matrix, \ rank_prefix_matrix_copy, num_memset_int, expert_alignment, \ - buffer_ptrs, task_fifo_ptrs, head, rank); \ + buffer_ptrs, barrier_signal_ptrs, head, rank); \ break constexpr int kNumThreads = 128; EP_HOST_ASSERT(num_experts % num_ranks == 0); EP_HOST_ASSERT(num_experts / num_ranks <= kNumThreads and num_ranks <= kNumThreads); + SETUP_LAUNCH_CONFIG(1 + num_ranks, kNumThreads, stream); SWITCH_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE); #undef NOTIFY_DISPATCH_LAUNCH_CASE @@ -177,7 +193,7 @@ cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, void** buffer_ptrs, int** task_fifo_ptrs, - int head, int rank, int num_ranks, cudaStream_t stream) { + int rank, int num_ranks, cudaStream_t stream, int head=0) { #define CACHED_NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \ LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, cached_notify_dispatch, \ rank_prefix_matrix, num_memset_int, buffer_ptrs, task_fifo_ptrs, head, rank); \ @@ -493,12 +509,36 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to } } -void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset, - int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, - const bool* is_token_in_rank, const int* channel_prefix_matrix, - int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, - void** buffer_ptrs, int rank, int num_ranks, - cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) { + +void dispatch(void* recv_x, + float* recv_x_scales, + int* recv_src_idx, + topk_idx_t* recv_topk_idx, + float* recv_topk_weights, + int* recv_channel_offset, + int* send_head, + const void* x, + const float* x_scales, + const topk_idx_t* topk_idx, + const float* topk_weights, + const bool* is_token_in_rank, + const int* channel_prefix_matrix, + int num_tokens, + int num_worst_tokens,// + int hidden_int4, + int num_topk, + int num_experts, + int num_scales, + int scale_token_stride,// + int scale_hidden_stride,// + void** buffer_ptrs, + int rank, + int num_ranks, + cudaStream_t stream, + int num_sms, + int num_max_send_tokens, + int num_recv_buffer_tokens) +{ constexpr int kNumThreads = (kWarpSize == 64 ? 1024 : 512); #define DISPATCH_LAUNCH_CASE(ranks) \ @@ -574,8 +614,8 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int, - int** task_fifo_ptrs, int head, int rank, int num_ranks, - cudaStream_t stream) { + int** task_fifo_ptrs, int rank, int num_ranks, + cudaStream_t stream, int head = 0 ) { #define CACHED_NOTIFY_COMBINE(ranks) \ LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, cached_notify_combine, \ buffer_ptrs, send_head, num_channels, num_recv_tokens, num_memset_int, task_fifo_ptrs, head, rank); \ @@ -605,13 +645,13 @@ combine(dtype_t* recv_x, float* recv_topk_weights, const auto num_channels = num_sms / 2; const bool is_sender = sm_id % 2 == 0; const int responsible_channel = sm_id / 2; + EP_DEVICE_ASSERT(num_topk <= kWarpSize); constexpr int kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t); int hidden_int4 = hidden * sizeof(dtype_t) / sizeof(int4); auto x_int4 = reinterpret_cast(x); auto recv_int4 = reinterpret_cast(recv_x); - if (is_sender) { // Workers for sending // Several warps are responsible for a single rank @@ -867,13 +907,25 @@ combine(dtype_t* recv_x, float* recv_topk_weights, } void combine(cudaDataType_t type, - void* recv_x, float* recv_topk_weights, - const void* x, const float* topk_weights, - const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix, - int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk, - void** buffer_ptrs, int rank, int num_ranks, - cudaStream_t stream, int num_sms, - int num_max_send_tokens, int num_recv_buffer_tokens) { + void* recv_x, + float* recv_topk_weights, + const void* x, + const float* topk_weights, + const int* src_idx, + const int* rank_prefix_matrix, + const int* channel_prefix_matrix, + int* send_head, + int num_tokens, + int num_recv_tokens, + int hidden, + int num_topk, + void** buffer_ptrs, + int rank, + int num_ranks, + cudaStream_t stream, + int num_sms, + int num_max_send_tokens, + int num_recv_buffer_tokens) { constexpr int kNumThreads = kWarpSize == 64 ? 1024 : 768; #define COMBINE_LAUNCH_CASE(dtype, ranks) \ diff --git a/csrc/kernels/launch.cuh b/csrc/kernels/launch.cuh index 7ab2a40..a4a064d 100644 --- a/csrc/kernels/launch.cuh +++ b/csrc/kernels/launch.cuh @@ -82,6 +82,18 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, #endif // #if defined(USE_ROCM) #endif // #ifndef LAUNCH_KERNEL + +#ifndef SET_SHARED_MEMORY_FOR_TMA +#ifndef DISABLE_SM90_FEATURES +#define SET_SHARED_MEMORY_FOR_TMA(kernel) \ + EP_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); \ + cfg.dynamicSmemBytes = smem_size; +#else +#define SET_SHARED_MEMORY_FOR_TMA(kernel) void() +#endif +#endif + + #define SWITCH_RANKS(case_macro) \ switch (num_ranks) { \ case 2: case_macro(2); \ diff --git a/csrc/kernels/runtime.cu b/csrc/kernels/runtime.cu index e336f34..7fa6031 100644 --- a/csrc/kernels/runtime.cu +++ b/csrc/kernels/runtime.cu @@ -16,7 +16,7 @@ __global__ void barrier(int** task_fifo_ptrs, int head, int rank) { barrier_device(task_fifo_ptrs, head, rank); } -void barrier(int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream) { +/*void barrier(int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream) { #define BARRIER_LAUNCH_CASE(ranks) \ LAUNCH_KERNEL(&cfg, barrier, task_fifo_ptrs, head, rank); \ break @@ -24,6 +24,16 @@ void barrier(int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream SETUP_LAUNCH_CONFIG(1, kWarpSize, stream); SWITCH_RANKS(BARRIER_LAUNCH_CASE); #undef BARRIER_LAUNCH_CASE +}*/ + +void barrier(int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream, int head = 0) { +#define BARRIER_LAUNCH_CASE(ranks) \ + LAUNCH_KERNEL(&cfg, barrier, barrier_signal_ptrs, head, rank); \ + break + + SETUP_LAUNCH_CONFIG(1, kWarpSize, stream); + SWITCH_RANKS(BARRIER_LAUNCH_CASE); +#undef BARRIER_LAUNCH_CASE } } // namespace intranode @@ -53,8 +63,13 @@ int init(const std::vector &root_unique_id_val, int rank, int num_ranks if (low_latency_mode and num_ranks > NUM_MAX_NVL_PEERS) { EP_HOST_ASSERT(cpu_rdma_team == SHMEM_TEAM_INVALID); EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0); - EP_HOST_ASSERT(shmem_team_split_strided(SHMEM_TEAM_WORLD, rank % NUM_MAX_NVL_PEERS, NUM_MAX_NVL_PEERS, - num_ranks / NUM_MAX_NVL_PEERS, &cpu_rdma_team_config, 0, &cpu_rdma_team) == 0); + EP_HOST_ASSERT(shmem_team_split_strided(SHMEM_TEAM_WORLD, + rank % NUM_MAX_NVL_PEERS, + NUM_MAX_NVL_PEERS, + num_ranks / NUM_MAX_NVL_PEERS, + &cpu_rdma_team_config, + 0, + &cpu_rdma_team) == 0); //TODO::issue on ROCM: enable it for ROCM #ifndef USE_ROCM EP_HOST_ASSERT(cpu_rdma_team != SHMEM_TEAM_INVALID); diff --git a/csrc/kernels/shmem_wrapper.cuh b/csrc/kernels/shmem_wrapper.cuh index 6e84b6f..1e90717 100644 --- a/csrc/kernels/shmem_wrapper.cuh +++ b/csrc/kernels/shmem_wrapper.cuh @@ -65,7 +65,7 @@ static inline const auto &shmem_ibgda_amo_nonfetch_add = #if !defined(ROCM_DISABLE_CTX) using shmem_ctx_t = rocshmem::rocshmem_ctx_t; static inline const auto &shmem_wg_ctx_create = [] __device__(rocshmem::rocshmem_ctx_t *ctx) { - rocshmem::rocshmem_wg_ctx_create(0, ctx); + return rocshmem::rocshmem_wg_ctx_create(0, ctx); }; static inline const auto &shmem_wg_ctx_destroy = rocshmem::rocshmem_wg_ctx_destroy; diff --git a/csrc/kernels/utils.cuh b/csrc/kernels/utils.cuh index 69182d0..8a43e55 100644 --- a/csrc/kernels/utils.cuh +++ b/csrc/kernels/utils.cuh @@ -2,6 +2,10 @@ #include "exception.cuh" +#ifdef USE_ROCM +#define syncthreads() __syncthreads() +#endif + #define UNROLLED_WARP_COPY(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \ { \ constexpr int kLoopStride = kWarpSize * (UNROLL_FACTOR); \ @@ -562,15 +566,6 @@ __device__ __forceinline__ void st_na_release(const uint32_t *ptr, uint32_t val) #endif } -__device__ __forceinline__ void st_na_release(const int64_t *ptr, int64_t val) { -#ifdef USE_ROCM - int64_t* non_const_ptr = const_cast(ptr); - __hip_atomic_store(non_const_ptr, val, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_AGENT); -#else - asm volatile("st.release.gpu.global.L1::no_allocate.b64 [%0], %1;" : : "l"(ptr), "l"(val)); -#endif -} - __device__ __forceinline__ void st_na_release(const uint64_t *ptr, uint64_t val) { #ifdef USE_ROCM uint64_t* non_const_ptr = const_cast(ptr); @@ -761,4 +756,64 @@ barrier_device(int **task_fifo_ptrs, int head, int rank, int tag = 0) { timeout_check(task_fifo_ptrs, head, rank, 0, tag); } + +template +__forceinline__ __device__ void barrier_block(int** barrier_signal_ptrs, int rank) { + auto thread_id = static_cast(threadIdx.x); + + // For non-sync-only cases, the memory operations by other threads in the block must be visible to the `sys` scope + if constexpr (not kSyncOnly) { + memory_fence(); + __syncthreads(); + } + + // Add self-ranks, sub other ranks + if (thread_id < kNumRanks) { + atomicAdd_system(barrier_signal_ptrs[rank] + thread_id, FINISHED_SUM_TAG); + atomicSub_system(barrier_signal_ptrs[thread_id] + rank, FINISHED_SUM_TAG); + } + EP_DEVICE_ASSERT(kNumRanks <= blockDim.x); + + // Check timeout + auto start_time = clock64(); + while (true) { + auto value = thread_id < kNumRanks ? ld_volatile_global(barrier_signal_ptrs[rank] + thread_id) : 0; + if (__all_sync(kFullWarpMask, value <= 0)) + break; + + if (clock64() - start_time > NUM_TIMEOUT_CYCLES and thread_id < kNumRanks) { + printf("DeepEP timeout check failed: rank = %d, thread = %d, value = %d)\n", rank, thread_id, value); + trap(); + } + } + __syncthreads(); +} + + +__device__ __forceinline__ uint32_t elect_one_sync() { +#ifndef DISABLE_SM90_FEATURES + uint32_t pred = 0; + asm volatile( + "{\n" + ".reg .b32 %%rx;\n" + ".reg .pred %%px;\n" + " elect.sync %%rx|%%px, %1;\n" + "@%%px mov.s32 %0, 1;\n" + "}\n" + : "+r"(pred) + : "r"(0xffffffff)); + return pred; +#else + return get_lane_id() == 0; +#endif +} + } // namespace deep_ep + +template +__host__ __device__ constexpr dtype_t align_down(dtype_t a, dtype_t b) { + return a / b * b; +} + + + diff --git a/deep_ep/__init__.py b/deep_ep/__init__.py index 7fb801f..2b20d83 100644 --- a/deep_ep/__init__.py +++ b/deep_ep/__init__.py @@ -4,4 +4,4 @@ from .buffer import Buffer # noinspection PyUnresolvedReferences -from deep_ep_cpp import Config +from deep_ep_cpp import Config, topk_idx_t diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index f0a8ee1..fcf8d77 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -7,7 +7,7 @@ import deep_ep_cpp # noinspection PyUnresolvedReferences from deep_ep_cpp import Config, EventHandle -from .utils import EventOverlap +from .utils import EventOverlap, check_nvlink_connections class Buffer: @@ -29,9 +29,18 @@ class Buffer: num_sms: int = 20 - def __init__(self, group: dist.ProcessGroup, - num_nvl_bytes: int = 0, num_rdma_bytes: int = 0, - low_latency_mode: bool = False, num_qps_per_rank: int = 1) -> None: + def __init__(self, + group: Optional[dist.ProcessGroup], + num_nvl_bytes: int = 0, + num_rdma_bytes: int = 0, + low_latency_mode: bool = False, + num_qps_per_rank: int = 24, + allow_nvlink_for_low_latency_mode: bool = True, + allow_mnnvl: bool = False, + use_fabric: bool = False, + explicitly_destroy: bool = False, + enable_shrink: bool = False, + comm: Optional["mpi4py.MPI.Comm"] = None) -> None: # noqa: F821 """ Initialize the communication buffer. @@ -42,56 +51,105 @@ def __init__(self, group: dist.ProcessGroup, low_latency_mode: whether to enable low-latency mode. num_qps_per_rank: the number of QPs for RDMA, the low-latency mode requires that this number equals to the number of local experts. - """ + allow_nvlink_for_low_latency_mode: whether allow NVLink traffic for low-latency mode, you should notice + this is somehow incompatible with the hook-based overlapping. + Warning: PCIe connections may lead to errors due to memory ordering issues, + please make sure all connections are via NVLink. + allow_mnnvl: whether to allow MNNVL + use_fabric: whether to use fabric API for memory buffers. + enable_shrink: whether to enable shrink mode. The enable mode allocates a mask buffer to support masking ranks dynamically. + explicitly_destroy: If this flag is set to True, you need to explicitly call `destroy()` to release resources; + otherwise, the resources will be released by the destructor. + Note: Releasing resources in the destructor may cause Python's exception handling process to hang. + comm: the `mpi4py.MPI.Comm` communicator to use in case the group parameter is absent. + """ + check_nvlink_connections(group) # Initialize the CPP runtime - self.rank = group.rank() - self.group_size = group.size() - self.group = group + if group is not None: + self.rank = group.rank() + self.group = group + self.group_size = group.size() + + def all_gather_object(obj): + object_list = [None] * self.group_size + dist.all_gather_object(object_list, obj, group) + return object_list + elif comm is not None: + self.rank = comm.Get_rank() + self.group = comm + self.group_size = comm.Get_size() + + def all_gather_object(obj): + return comm.allgather(obj) + else: + raise ValueError("Either 'group' or 'comm' must be provided.") self.num_nvl_bytes = num_nvl_bytes self.num_rdma_bytes = num_rdma_bytes self.low_latency_mode = low_latency_mode - self.runtime = deep_ep_cpp.Buffer(self.rank, self.group_size, num_nvl_bytes, num_rdma_bytes, low_latency_mode) + self.explicitly_destroy = explicitly_destroy + self.enable_shrink = enable_shrink + self.runtime = deep_ep_cpp.Buffer(self.rank, self.group_size, num_nvl_bytes, num_rdma_bytes, low_latency_mode, explicitly_destroy, + enable_shrink)#, use_fabric) # Synchronize device IDs - device_ids = [None, ] * self.group_size local_device_id = self.runtime.get_local_device_id() - dist.all_gather_object(device_ids, local_device_id, group) + device_ids = all_gather_object(local_device_id) # Synchronize IPC handles - ipc_handles = [None, ] * self.group_size local_ipc_handle = self.runtime.get_local_ipc_handle() - dist.all_gather_object(ipc_handles, local_ipc_handle, group) + ipc_handles = all_gather_object(local_ipc_handle) # Synchronize NVSHMEM unique IDs root_unique_id = None if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode: - # Enable IBGDA for the low latency mode, which refers to "no package forwarding between NVLink and RDMA" - if low_latency_mode: - assert num_qps_per_rank > 0 - os.environ['NVSHMEM_DISABLE_P2P'] = '1' - os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1' - os.environ['NVSHMEM_IBGDA_NIC_HANDLER'] = 'gpu' - os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = f'{num_qps_per_rank}' - # Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check - os.environ['NVSHMEM_QP_DEPTH'] = '1024' - # NOTES: NVSHMEM initialization requires at least 256 MiB - os.environ['NVSHMEM_CUMEM_GRANULARITY'] = f'{2 ** 29}' - dev_id = torch.cuda.current_device() - os.environ["NVSHMEM_HCA_LIST"] = f"fic2_soe_bond{dev_id // 2}:1" + # Enable IBGDA + assert num_qps_per_rank > 0 + os.environ['NVSHMEM_DISABLE_P2P'] = '0' if allow_nvlink_for_low_latency_mode else '1' + os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1' + os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = f'{num_qps_per_rank}' + + # Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check + self.nvshmem_qp_depth = int(os.environ.get('NVSHMEM_QP_DEPTH', '1024')) + os.environ['NVSHMEM_QP_DEPTH'] = str(self.nvshmem_qp_depth) + + # Reduce gpu memory usage + # 6 default teams + 1 extra team + os.environ['NVSHMEM_MAX_TEAMS'] = '7' + # Disable NVLink SHArP + os.environ['NVSHMEM_DISABLE_NVLS'] = '1' + # NOTES: NVSHMEM initialization requires at least 256 MiB + os.environ['NVSHMEM_CUMEM_GRANULARITY'] = f'{2 ** 29}' + + if not allow_mnnvl: + # Disable multi-node NVLink detection + os.environ['NVSHMEM_DISABLE_MNNVL'] = '1' + # Synchronize using the root ID - nvshmem_unique_ids = [None, ] * self.group_size if (low_latency_mode and self.rank == 0) or (not low_latency_mode and self.runtime.get_rdma_rank() == 0): root_unique_id = self.runtime.get_local_nvshmem_unique_id() - - dist.all_gather_object(nvshmem_unique_ids, root_unique_id, group) + nvshmem_unique_ids = all_gather_object(root_unique_id) root_unique_id = nvshmem_unique_ids[0 if low_latency_mode else self.runtime.get_root_rdma_rank(True)] # Make CPP runtime available self.runtime.sync(device_ids, ipc_handles, root_unique_id) - assert self.runtime.is_available() + def destroy(self): + """ + Destroy the cpp runtime and release resources. + + """ + + assert self.explicitly_destroy, '`explicitly_destroy` flag must be set' + + self.runtime.destroy() + self.runtime = None + + @staticmethod + def is_sm90_compiled(): + return deep_ep_cpp.is_sm90_compiled() + @staticmethod def set_num_sms(new_num_sms: int) -> None: """ @@ -130,8 +188,21 @@ def get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank: int, hidden """ return deep_ep_cpp.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts) - def get_local_buffer_tensor(self, dtype: torch.dtype, size: Optional[torch.Size] = None, - offset: int = 0, use_rdma_buffer: bool = False) -> torch.Tensor: + def get_comm_stream(self) -> torch.Stream: + """ + Get the communication stream. + + Returns: + stream: the communication stream. + """ + ts: torch.Stream = self.runtime.get_comm_stream() + return torch.cuda.Stream(stream_id=ts.stream_id, device_index=ts.device_index, device_type=ts.device_type) + + def get_local_buffer_tensor(self, + dtype: torch.dtype, + size: Optional[torch.Size] = None, + offset: int = 0, + use_rdma_buffer: bool = False) -> torch.Tensor: """ Get the raw buffer (slice supported) as a PyTorch tensor. @@ -148,6 +219,16 @@ def get_local_buffer_tensor(self, dtype: torch.dtype, size: Optional[torch.Size] assert tensor.numel() >= size.numel() return tensor[:size.numel()].view(size) + @staticmethod + def _unpack_bias(bias: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]): + bias_0, bias_1 = None, None + if isinstance(bias, torch.Tensor): + bias_0 = bias + elif isinstance(bias, tuple): + assert len(bias) == 2 + bias_0, bias_1 = bias + return bias_0, bias_1 + @staticmethod def get_dispatch_config(num_ranks: int) -> Config: """ @@ -160,6 +241,7 @@ def get_dispatch_config(num_ranks: int) -> Config: config: the recommended config. """ + # TODO: automatically tune config_map = { 2: Config(Buffer.num_sms, 16, 256, 6, 128), 4: Config(Buffer.num_sms, 16, 256, 6, 128), @@ -187,6 +269,7 @@ def get_combine_config(num_ranks: int) -> Config: config: the recommended config. """ + # TODO: automatically tune config_map = { 2: Config(Buffer.num_sms, 6, 256, 6, 128), 4: Config(Buffer.num_sms, 6, 256, 6, 128), @@ -211,8 +294,8 @@ def get_dispatch_layout(self, topk_idx: torch.Tensor, num_experts: int, Calculate the layout required for later communication. Arguments: - topk_idx: `[num_tokens, num_topk]`, dtype must be `torch.int64`, the expert indices selected by each token, - `-1` means no selections. + topk_idx: `[num_tokens, num_topk]`, dtype must be `deep_ep.topk_idx_t` (typically `torch.int64`), the expert + indices selected by each token, `-1` means no selections. num_experts: the number of experts. previous_event: the event to wait before actually executing the kernel. async_finish: the current stream will not wait for the communication kernels to be finished if set. @@ -236,7 +319,8 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], handle: Optional[Tuple] = None, num_tokens_per_rank: Optional[torch.Tensor] = None, num_tokens_per_rdma_rank: Optional[torch.Tensor] = None, is_token_in_rank: Optional[torch.Tensor] = None, num_tokens_per_expert: Optional[torch.Tensor] = None, - topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None, expert_alignment: int = 1, + topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None, + expert_alignment: int = 1, num_worst_tokens: int = 0, config: Optional[Config] = None, previous_event: Optional[EventOverlap] = None, async_finish: bool = False, allocate_on_comm_stream: bool = False) -> \ @@ -259,10 +343,12 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], rank (with the same GPU index), return `None` for intranode settings. is_token_in_rank: `[num_tokens, num_ranks]` with `torch.bool`, whether a token be sent to a rank. num_tokens_per_expert: `[num_experts]` with `torch.int`, the number of tokens to be sent to each expert. - topk_idx: `[num_tokens, num_topk]` with `torch.int64`, the expert indices selected by each token, - `-1` means no selections. + topk_idx: `[num_tokens, num_topk]` with `deep_ep.topk_idx_t` (typically `torch.int64`), the expert indices + selected by each token, `-1` means no selections. topk_weights: `[num_tokens, num_topk]` with `torch.float`, the expert weights of each token to dispatch. expert_alignment: align the number of tokens received by each local expert to this variable. + num_worst_tokens: the worst number of tokens to receive, if specified, there will be no CPU sync, and it + will be CUDA-graph compatible. Please also notice that this flag is for intranode only. config: the performance tuning config. previous_event: the event to wait before actually executing the kernel. async_finish: the current stream will not wait for the communication kernels to be finished if set. @@ -274,7 +360,8 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], recv_topk_idx: received expert indices. recv_topk_weights: received expert weights. num_recv_tokens_per_expert_list: Python list shaped `[num_local_experts]`, the received token count by - each local expert, aligned to the input `expert_alignment`. + each local expert, aligned to the input `expert_alignment`. If `num_worst_tokens` is specified, the list + will be empty. handle: the returned communication handle. event: the event after executing the kernel (valid only if `async_finish` is set). """ @@ -283,8 +370,10 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], # Internode if self.runtime.get_num_rdma_ranks() > 1: - return self.internode_dispatch(x, handle, num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert, - topk_idx, topk_weights, expert_alignment, config, previous_event, async_finish, allocate_on_comm_stream) + assert num_worst_tokens == 0, 'Internode dispatch does not support `num_worst_tokens > 0`' + return self.internode_dispatch(x, handle, num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, + num_tokens_per_expert, topk_idx, topk_weights, expert_alignment, config, previous_event, + async_finish, allocate_on_comm_stream) # Launch the kernel with cached or non-cached mode x, x_scales = x if isinstance(x, tuple) else (x, None) @@ -293,22 +382,26 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head = handle num_recv_tokens = recv_src_idx.size(0) recv_x, recv_x_scales, _, _, _, _, _, _, _, _, event = self.runtime.intranode_dispatch( - x, x_scales, None, None, - None, is_token_in_rank, None, num_recv_tokens, rank_prefix_matrix, channel_prefix_matrix, - expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream) + x, x_scales, None, None, None, is_token_in_rank, None, num_recv_tokens, rank_prefix_matrix, channel_prefix_matrix, + expert_alignment, num_worst_tokens, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream) return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event) else: assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, send_head, event = \ self.runtime.intranode_dispatch(x, x_scales, topk_idx, topk_weights, - num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, 0, None, None, - expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream) + num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, 0, None, None, + expert_alignment, num_worst_tokens, config, + getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream) handle = (rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head) - return (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(event) + return ( + recv_x, recv_x_scales + ) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap( + event) # noinspection PyTypeChecker def combine(self, x: torch.Tensor, handle: Tuple, topk_weights: Optional[torch.Tensor] = None, + bias: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None, config: Optional[Config] = None, previous_event: Optional[EventOverlap] = None, async_finish: bool = False, allocate_on_comm_stream: bool = False) -> \ @@ -324,6 +417,7 @@ def combine(self, x: torch.Tensor, handle: Tuple, x: `[num_tokens, hidden]` with `torch.bfloat16`, the tokens to send for reducing to its original ranks. handle: a must-set communication handle, you can obtain this from the dispatch function. topk_weights: `[num_tokens, num_topk]` with `torch.float`, the tokens' top-k weights for reducing to its original ranks. + bias: 0, 1 or 2 `[num_tokens, hidden]` with `torch.bfloat16` final bias to the output. config: the performance tuning config. previous_event: the event to wait before actually executing the kernel. async_finish: the current stream will not wait for the communication kernels to be finished if set. @@ -339,16 +433,17 @@ def combine(self, x: torch.Tensor, handle: Tuple, # Internode if self.runtime.get_num_rdma_ranks() > 1: - return self.internode_combine(x, handle, topk_weights, config, previous_event, async_finish, allocate_on_comm_stream) + return self.internode_combine(x, handle, topk_weights, bias, config, previous_event, async_finish, allocate_on_comm_stream) # NOTES: the second `_` is for the sending side, so we should use the third one rank_prefix_matrix, _, channel_prefix_matrix, src_idx, is_recv_token_in_rank, send_head = handle + #bias_0, bias_1 = Buffer._unpack_bias(bias) # Launch the kernel - recv_x, recv_topk_weights, event = self.runtime.intranode_combine( - x, topk_weights, - src_idx, rank_prefix_matrix, channel_prefix_matrix, send_head, config, - getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream) + recv_x, recv_topk_weights, event = self.runtime.intranode_combine(x, topk_weights, src_idx, rank_prefix_matrix, + channel_prefix_matrix, send_head, config, + getattr(previous_event, 'event', + None), async_finish, allocate_on_comm_stream) return recv_x, recv_topk_weights, EventOverlap(event) # noinspection PyTypeChecker @@ -379,9 +474,7 @@ def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Te num_recv_tokens = recv_src_meta.size(0) num_rdma_recv_tokens = send_nvl_head.size(0) recv_x, recv_x_scales, _, _, _, _, _, _, _, _, _, _, _, _, event = self.runtime.internode_dispatch( - x, x_scales, topk_idx, topk_weights, - None, None, is_token_in_rank, None, - num_recv_tokens, num_rdma_recv_tokens, + x, x_scales, topk_idx, topk_weights, None, None, is_token_in_rank, None, num_recv_tokens, num_rdma_recv_tokens, rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream) return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event) @@ -396,15 +489,18 @@ def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Te num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert, 0, 0, None, None, None, None, expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream) - handle = (is_token_in_rank, - rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, - recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, - recv_src_meta, send_rdma_head, send_nvl_head) - return (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(event) + handle = (is_token_in_rank, rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, recv_rdma_channel_prefix_matrix, + recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, recv_src_meta, send_rdma_head, + send_nvl_head) + return ( + recv_x, recv_x_scales + ) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap( + event) # noinspection PyTypeChecker def internode_combine(self, x: torch.Tensor, handle: Union[tuple, list], topk_weights: Optional[torch.Tensor] = None, + bias: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None, config: Optional[Config] = None, previous_event: Optional[EventOverlap] = None, async_finish: bool = False, allocate_on_comm_stream: bool = False) -> \ @@ -415,19 +511,20 @@ def internode_combine(self, x: torch.Tensor, handle: Union[tuple, list], """ assert config is not None - # Unpack handle + # Unpack handle and bias is_combined_token_in_rank, \ _, _, \ rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, gbl_rank_prefix_sum, \ src_meta, send_rdma_head, send_nvl_head = handle + bias_0, bias_1 = Buffer._unpack_bias(bias) # Launch the kernel - combined_x, combined_topk_weights, event = self.runtime.internode_combine( - x, topk_weights, - src_meta, is_combined_token_in_rank, - rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, - send_rdma_head, send_nvl_head, config, getattr(previous_event, 'event', None), - async_finish, allocate_on_comm_stream) + combined_x, combined_topk_weights, event = self.runtime.internode_combine(x, topk_weights, bias_0, bias_1, src_meta, + is_combined_token_in_rank, rdma_channel_prefix_matrix, + rdma_rank_prefix_sum, gbl_channel_prefix_matrix, + send_rdma_head, send_nvl_head, config, + getattr(previous_event, 'event', + None), async_finish, allocate_on_comm_stream) return combined_x, combined_topk_weights, EventOverlap(event) def clean_low_latency_buffer(self, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int) -> None: @@ -447,99 +544,147 @@ def clean_low_latency_buffer(self, num_max_dispatch_tokens_per_rank: int, hidden # noinspection PyTypeChecker def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor, num_max_dispatch_tokens_per_rank: int, num_experts: int, - use_fp8: bool = True, async_finish: bool = False, return_recv_hook: bool = False) -> \ + cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None, + dispatch_wait_recv_cost_stats: Optional[torch.Tensor] = None, + use_fp8: bool = True, round_scale: bool = False, use_ue8m0: bool = False, + async_finish: bool = False, return_recv_hook: bool = False) -> \ Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]: """ A low-latency implementation for dispatching with IBGDA. This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA (specifically, IBGDA must be enabled). - Even for ranks in the same node, NVLink are fully disabled for simplicity. - Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2 - low-latency kernels' result tensor at a single moment. + Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2 + low-latency kernels' result tensors at a single moment. Arguments: x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are supported. The number of tokens to be dispatched must be less than `num_max_dispatch_tokens_per_rank`. - topk_idx: `torch.Tensor` with `torch.int64`, shaped as `[num_tokens, num_topk]`, only several top-k shapes - are supported. `-1` indices (not selecting any expert) are supported. + topk_idx: `torch.Tensor` with `deep_ep.topk_idx_t` (typically `torch.int64`), shaped as `[num_tokens, num_topk]`, + only several top-k shapes are supported. `-1` indices (not selecting any expert) are supported. num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value. num_experts: the number of all experts. + cumulative_local_expert_recv_stats: a cumulative expert count tensor for statistics, which should have shape + `[num_local_experts]` and be typed as `torch.int`. This is useful for online service EP load balance + monitoring. + dispatch_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics, + which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`. + This is useful for detecting and precisely localizing slow anomalies. use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors. + round_scale: whether round the scaling factors into power of 2. + use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`). async_finish: the current stream will not wait for the communication kernels to be finished if set. return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues, but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. - If you not set this flag, the kernel will ensure the data's arrival. + If you do not set this flag, the kernel will ensure the data's arrival. Returns: recv_x: a tensor or tuple with received tokens for each expert. With `use_fp8=True`: the first element is a `torch.Tensor` shaped as `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`. The second tensor is the corresponding scales for the first element with shape - `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`. + `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`, + if `use_ue8m0=False`. With `use_ue8m0=True`, the second one is packed and shaped as + `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 512]` with type `torch.int`. Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility. With `use_fp8=False`, the result would be a tensor shaped as `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`. Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are, as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced). recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each - expert receive. As mentioned before, all not tokens are valid in `recv_x`. + expert receives. As mentioned before, not all tokens are valid in `recv_x`. handle: the communication handle to be used in the `low_latency_combine` function. event: the event after executing the kernel (valid only if `async_finish` is set). hook: the receiving hook function (valid only if `return_recv_hook` is set). """ + assert self.nvshmem_qp_depth >= (num_max_dispatch_tokens_per_rank + 1) * 2 packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \ self.runtime.low_latency_dispatch(x, topk_idx, + cumulative_local_expert_recv_stats, + dispatch_wait_recv_cost_stats, num_max_dispatch_tokens_per_rank, num_experts, - use_fp8, async_finish, return_recv_hook) + use_fp8, round_scale, use_ue8m0, + async_finish, return_recv_hook) handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts) - tensors_to_record = (x, topk_idx, - packed_recv_x, packed_recv_x_scales, packed_recv_count, - packed_recv_src_info, packed_recv_layout_range) + tensors_to_record = (x, topk_idx, packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, + packed_recv_layout_range, cumulative_local_expert_recv_stats) return (packed_recv_x, packed_recv_x_scales) if use_fp8 else packed_recv_x, packed_recv_count, handle, \ EventOverlap(event, tensors_to_record if async_finish else None), hook # noinspection PyTypeChecker def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, - handle: tuple, zero_copy: bool = False, async_finish: bool = False, - return_recv_hook: bool = False, out: Optional[torch.Tensor] = None) -> \ + handle: tuple, use_logfmt: bool = False, zero_copy: bool = False, async_finish: bool = False, + return_recv_hook: bool = False, out: Optional[torch.Tensor] = None, + combine_wait_recv_cost_stats: Optional[torch.Tensor] = None) -> \ Tuple[torch.Tensor, EventOverlap, Callable]: """ A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA. This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA (specifically, IBGDA must be enabled). - Even for ranks in the same node, NVLink are fully disabled for simplicity. - Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2 - low-latency kernels' result tensor at a single moment. + Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2 + low-latency kernels' result tensors at a single moment. Arguments: x: `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`, the local calculated tokens to be sent to this original rank and reduced. - topk_idx: `[num_combined_tokens, num_topk]` with `torch.int64`, the expert indices selected by the dispatched - tokens. `-1` indices (not selecting any expert) are supported. Note that, `num_combined_tokens` equals - to the number of dispatched tokens. + topk_idx: `[num_combined_tokens, num_topk]` with `deep_ep.topk_idx_t` (typically `torch.int64`), the expert + indices selected by the dispatched tokens. `-1` indices (not selecting any expert) are supported. Note that, + `num_combined_tokens` equals to the number of dispatched tokens. topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched tokens. The received tokens will be reduced with the weights in this tensor. handle: the communication handle given by the `dispatch` function. + use_logfmt: whether to use an internal "LogFMT with dynamic per-64-channel cast" format (10 bits). zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative with `get_next_low_latency_combine_buffer`. async_finish: the current stream will not wait for the communication kernels to be finished if set. return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues, but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. - If you not set this flag, the kernel will ensure the data's arrival. + If you do not set this flag, the kernel will ensure the data's arrival. out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly. + combine_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics, + which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`. + This is useful for detecting and pre-cisely localizing slow anomalies. Returns: - combined_x: the reduced token tensor, with shape `[num_combined_tokens, num_topk]` and type `torch.bfloat16`. + combined_x: the reduced token tensor, with shape `[num_combined_tokens, hidden]` and type `torch.bfloat16`. event: the event after executing the kernel (valid only if `async_finish` is set). hook: the receiving hook function (valid only if `return_recv_hook` is set). """ src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle + assert self.nvshmem_qp_depth >= (num_max_dispatch_tokens_per_rank + 1) * 2 combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range, - num_max_dispatch_tokens_per_rank, num_experts, - zero_copy, async_finish, return_recv_hook, out) + combine_wait_recv_cost_stats, num_max_dispatch_tokens_per_rank, + num_experts, use_logfmt, zero_copy, async_finish, return_recv_hook, out) tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x) return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook + def low_latency_update_mask_buffer(self, rank_to_mask: int, mask: bool = False): + """ + Mask (unmask) a rank during communication (dispatch, combine, and clean) + + Arguments: + rank: the rank to mask (unmask). + mask: if True, will mask the rank (do not recvfrom/sendto the rank), otherwise will unmask the rank. + + """ + self.runtime.low_latency_update_mask_buffer(rank_to_mask, mask) + + def low_latency_query_mask_buffer(self, mask_status: torch.Tensor): + """ + Query the mask status of all ranks + + Arguments: + mask_status: `[num_ranks]` with `torch.int`, the mask status of each rank. `1` means mask and `0` means unmasked. + + """ + self.runtime.low_latency_query_mask_buffer(mask_status) + + def low_latency_clean_mask_buffer(self): + """ + Clean the mask buffer + + """ + self.runtime.low_latency_clean_mask_buffer() + def get_next_low_latency_combine_buffer(self, handle: object): """ Get the raw registered RDMA buffer tensor for next low-latency combine, so that the next combine kernel can skip the copying. diff --git a/deep_ep/utils.py b/deep_ep/utils.py index 009aa2a..e61a2c5 100644 --- a/deep_ep/utils.py +++ b/deep_ep/utils.py @@ -1,8 +1,10 @@ +import os import torch +import torch.distributed as dist from typing import Any, Optional, Tuple # noinspection PyUnresolvedReferences -from deep_ep_cpp import Config, EventHandle +from deep_ep_cpp import EventHandle class EventOverlap: @@ -14,8 +16,7 @@ class EventOverlap: extra_tensors: an easier way to simulate PyTorch tensor `record_stream`, may be useful with CUDA graph. """ - def __init__(self, event: Optional[EventHandle] = None, - extra_tensors: Optional[Tuple[torch.Tensor]] = None) -> None: + def __init__(self, event: Optional[EventHandle] = None, extra_tensors: Optional[Tuple[torch.Tensor]] = None) -> None: """ Initialize the class. @@ -58,3 +59,43 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: """ if self.event is not None: self.event.current_stream_wait() + + +def check_nvlink_connections(group: dist.ProcessGroup): + """ + Check NVLink connection between every pair of GPUs. + + Arguments: + group: the communication group. + """ + # Check NVLink connection + # NOTES: some A100 PCIE GPUs only have pairwise NVLink connection, so that we can only use EP2 + # TODO: check all cases, all local-node GPUs in the group should be connected via NVLink + if 'PCIE' in torch.cuda.get_device_name(): + assert group.size() <= 2, 'PCIe GPUs only have pairwise NVLink connections' + + # noinspection PyUnresolvedReferences + import pynvml + pynvml.nvmlInit() + + # noinspection PyTypeChecker + devices = os.environ.get('CUDA_VISIBLE_DEVICES', '0,1,2,3,4,5,6,7').strip(',').split(',') + physical_device_idx = int(devices[torch.cuda.current_device()]) + physical_device_indices = [ + 0, + ] * group.size() + dist.all_gather_object(physical_device_indices, physical_device_idx, group) + + # Check whether they are all connected via NVLink + # Reference: https://github.com/vllm-project/vllm/blob/b8e809a057765c574726a6077fd124db5077ce1f/vllm/platforms/cuda.py#L438 + handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_indices] + for i, handle in enumerate(handles): + for j, peer_handle in enumerate(handles): + if i >= j: + continue + status = pynvml.nvmlDeviceGetP2PStatus(handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK) + assert status == pynvml.NVML_P2P_STATUS_OK,\ + f'GPU {physical_device_indices[i]} and GPU {physical_device_indices[j]} are not connected via NVLink' + + # Close NVML + pynvml.nvmlShutdown() diff --git a/setup.py b/setup.py index ac2bd30..b6a8111 100644 --- a/setup.py +++ b/setup.py @@ -7,6 +7,13 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension +# Wheel specific: the wheels only include the soname of the host library `libnvshmem_host.so.X` +def get_nvshmem_host_lib_name(base_dir): + path = Path(base_dir).joinpath('lib') + for file in path.rglob('libnvshmem_host.so.*'): + return file.name + raise ModuleNotFoundError('libnvshmem_host.so not found') + if __name__ == "__main__": # Add argument parser for handling --variant flag parser = argparse.ArgumentParser(description="DeepEP setup configuration") @@ -21,16 +28,39 @@ parser.add_argument("--verbose", action="store_true", help="Verbose build") parser.add_argument("--enable_timer", action="store_true", help="Enable timer to debug time out in internode") parser.add_argument("--rocm-disable-ctx", action="store_true", help="Disable workgroup context optimization in internode") - parser.add_argument("--disable-mpi", action="store_true", help="Disable MPI detection and configuration") # Get the arguments to be parsed and separate setuptools arguments args, unknown_args = parser.parse_known_args() variant = args.variant debug = args.debug rocm_disable_ctx = args.rocm_disable_ctx - disable_mpi = args.disable_mpi enable_timer = args.enable_timer + + if variant != "rocm": + disable_nvshmem = False + nvshmem_dir = os.getenv('NVSHMEM_DIR', None) + nvshmem_host_lib = 'libnvshmem_host.so' + if nvshmem_dir is None: + try: + nvshmem_dir = importlib.util.find_spec("nvidia.nvshmem").submodule_search_locations[0] + nvshmem_host_lib = get_nvshmem_host_lib_name(nvshmem_dir) + import nvidia.nvshmem as nvshmem # noqa: F401 + except (ModuleNotFoundError, AttributeError, IndexError): + print( + 'Warning: `NVSHMEM_DIR` is not specified, and the NVSHMEM module is not installed. All internode and low-latency features are disabled\n' + ) + disable_nvshmem = True + else: + disable_nvshmem = False + + if not disable_nvshmem: + assert os.path.exists(nvshmem_dir), f'The specified NVSHMEM directory does not exist: {nvshmem_dir}' + + #else: + # disable_nvshmem = False + + # Reset sys.argv for setuptools to avoid conflicts sys.argv = [sys.argv[0]] + unknown_args @@ -56,11 +86,9 @@ ), f"Failed to find {shmem_variant_name}" print(f"{shmem_variant_name} directory: {shmem_dir}") - ompi_dir = None - if variant == "rocm" and not disable_mpi: + if variant == "rocm": # Attempt to auto-detect OpenMPI installation directory if OMPI_DIR not set. # The first existing candidate containing bin/mpicc will be used. - print("MPI detection enabled for ROCm variant") ompi_dir_env = os.getenv("OMPI_DIR", "").strip() candidate_dirs = [ ompi_dir_env if ompi_dir_env else None, @@ -72,41 +100,30 @@ "/usr/local/ompi", "/usr/local/openmpi", ] + ompi_dir = None for d in candidate_dirs: if not d: continue mpicc_path = os.path.join(d, "bin", "mpicc") if os.path.exists(d) and os.path.exists(mpicc_path): ompi_dir = d - break - assert ompi_dir is not None, ( - f"Failed to find OpenMPI installation. " - f"Searched: {', '.join([d for d in candidate_dirs if d])}. " - f"Set OMPI_DIR environment variable or use --disable-mpi flag." - ) + break + if ompi_dir is None: + # Fallback to root (will trigger the assert below) + ompi_dir = "/" print(f"Detected OpenMPI directory: {ompi_dir}") - elif variant == "rocm" and disable_mpi: - print("MPI detection disabled for ROCm variant") - elif variant == "cuda" and not disable_mpi: - print("MPI detection enabled for CUDA variant") - else: - print("MPI detection disabled for CUDA variant") + assert os.path.exists(ompi_dir), f"Failed to find OMPI: {ompi_dir}" # TODO: currently, we only support Hopper architecture, we may add Ampere support later if variant == "rocm": - arch = os.getenv("PYTORCH_ROCM_ARCH") - allowed_arch = {"gfx942", "gfx950"} - if arch not in allowed_arch: - raise EnvironmentError( - f"Invalid PYTORCH_ROCM_ARCH='{arch}'. " - f"Use one of: {', '.join(sorted(allowed_arch))}.") + os.environ["PYTORCH_ROCM_ARCH"] = os.getenv("PYTORCH_ROCM_ARCH", "gfx942") elif variant == "cuda": os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0" optimization_flag = "-O0" if debug else "-O3" debug_symbol_flags = ["-g", "-ggdb"] if debug else [] define_macros = ( - ["-DUSE_ROCM=1", "-fgpu-rdc",] if variant == "rocm" else [] + ["-DUSE_ROCM=1", "-DDISABLE_SM90_FEATURES=1", "-fgpu-rdc",] if variant == "rocm" else [] ) if enable_timer: define_macros.append("-DENABLE_TIMER") @@ -138,19 +155,20 @@ nvcc_flags = [f"{optimization_flag}"] + debug_symbol_flags + define_macros include_dirs = ["csrc/", f"{shmem_dir}/include"] - if variant == "rocm" and ompi_dir is not None: + if variant == "rocm": include_dirs.append(f"{ompi_dir}/include") sources = [ "csrc/deep_ep.cpp", "csrc/kernels/runtime.cu", + 'csrc/kernels/layout.cu', "csrc/kernels/intranode.cu", "csrc/kernels/internode.cu", "csrc/kernels/internode_ll.cu", ] library_dirs = [f"{shmem_dir}/lib"] - if variant == "rocm" and ompi_dir is not None: + if variant == "rocm": library_dirs.append(f"{ompi_dir}/lib") # Disable aggressive PTX instructions @@ -158,6 +176,13 @@ cxx_flags.append("-DDISABLE_AGGRESSIVE_PTX_INSTRS") nvcc_flags.append("-DDISABLE_AGGRESSIVE_PTX_INSTRS") + + # Bits of `topk_idx.dtype`, choices are 32 and 64 + if "TOPK_IDX_BITS" in os.environ: + topk_idx_bits = int(os.environ['TOPK_IDX_BITS']) + cxx_flags.append(f'-DTOPK_IDX_BITS={topk_idx_bits}') + nvcc_flags.append(f'-DTOPK_IDX_BITS={topk_idx_bits}') + shmem_lib_name = "nvshmem" if variant == "cuda" else "rocshmem" # Disable DLTO (default by PyTorch) nvcc_dlink = ["-dlink", f"-L{shmem_dir}/lib", f"-l{shmem_lib_name}"] @@ -172,15 +197,10 @@ "-lamdhip64", "-lhsa-runtime64", "-libverbs", + f"-l:libmpi.so", + f"-Wl,-rpath,{ompi_dir}/lib", ] ) - if not disable_mpi: - extra_link_args.extend( - [ - f"-l:libmpi.so", - f"-Wl,-rpath,{ompi_dir}/lib", - ] - ) extra_compile_args = { "cxx": cxx_flags, @@ -189,6 +209,17 @@ if variant == "cuda": extra_compile_args["nvcc_dlink"] = nvcc_dlink + + # Summary + print('Build summary:') + print(f' > Sources: {sources}') + print(f' > Includes: {include_dirs}') + print(f' > Libraries: {library_dirs}') + print(f' > Compilation flags: {extra_compile_args}') + print(f' > Link flags: {extra_link_args}') + print(f' > NVSHMEM path: {shmem_dir}') + print() + # noinspection PyBroadException try: cmd = ["git", "rev-parse", "--short", "HEAD"] @@ -198,7 +229,7 @@ setuptools.setup( name="deep_ep", - version="1.0.0" + revision, + version="1.2.1" + revision, packages=setuptools.find_packages(include=["deep_ep"]), ext_modules=[ CUDAExtension( diff --git a/tests/test_internode.py b/tests/test_internode.py index 5f45d0f..c03d344 100644 --- a/tests/test_internode.py +++ b/tests/test_internode.py @@ -1,3 +1,4 @@ +import argparse import os import time import torch @@ -5,16 +6,27 @@ # noinspection PyUnresolvedReferences import deep_ep -from utils import init_dist, bench, calc_diff, create_grouped_scores, inplace_unique, per_token_cast_to_fp8, per_token_cast_back +from utils import init_dist, bench, bench_kineto, calc_diff, create_grouped_scores, inplace_unique, per_token_cast_to_fp8, per_token_cast_back, hash_tensor # Test compatibility with low latency functions import test_low_latency -import argparse -def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup): +# noinspection PyShadowingNames +def test_main(args: argparse.Namespace, + num_sms: int, + local_rank: int, + num_local_ranks: int, + num_ranks: int, + num_nodes: int, + rank: int, + buffer: deep_ep.Buffer, + group: dist.ProcessGroup, + skip_benchmark: bool = False): # Settings - num_tokens, hidden, num_topk_groups, num_topk, num_experts = 4096, 7168, min(num_nodes, 4), 8, (256 // num_ranks) * num_ranks + num_tokens, hidden = args.num_tokens, args.hidden + num_topk_groups, num_topk, num_experts = args.num_topk_groups, args.num_topk, args.num_experts + assert num_experts % num_ranks == 0 and num_local_ranks == 8 if local_rank == 0: print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}', flush=True) @@ -23,19 +35,24 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') x_e4m3 = per_token_cast_to_fp8(x) + x_pure_rand_e4m3 = per_token_cast_to_fp8(x_pure_rand) + x_e4m3 = (x_e4m3[0], x_e4m3[1].T.contiguous().T) scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1) group_idx = torch.topk(group_scores, k=num_topk_groups, dim=-1, sorted=False).indices masked_scores = create_grouped_scores(scores, group_idx, num_nodes) topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[1] + topk_idx = topk_idx.to(deep_ep.topk_idx_t) topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda') rank_idx = topk_idx // (num_experts // num_ranks) + rank_idx = rank_idx.to(torch.int64) rank_idx.masked_fill_(topk_idx == -1, -1) inplace_unique(rank_idx, num_ranks) rdma_rank_idx = rank_idx // num_local_ranks rdma_rank_idx.masked_fill_(rank_idx == -1, -1) inplace_unique(rdma_rank_idx, num_nodes) + hash_value = 0 # RDMA dispatch counts rdma_idx = topk_idx // (num_experts // num_nodes) @@ -77,12 +94,12 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0] if local_rank == 0: print(f'[layout] Kernel performance: {t * 1000:.3f} ms', flush=True) - print() + print('', flush=True) group.barrier() time.sleep(1) # Config - rdma_buffer_size, nvl_buffer_size = 128, (720 if num_ranks in (144, 160) else 512) + rdma_buffer_size, nvl_buffer_size = 128, (720 if num_ranks in (24, 48, 96, 144, 160) else 512) config = deep_ep.Config(num_sms, 8, nvl_buffer_size, 16, rdma_buffer_size) # Test dispatch @@ -99,37 +116,59 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): for async_mode in (False, True): for current_x in (x_pure_rand, x, x_e4m3): for with_topk in (False, True): + is_rand = current_x is x_pure_rand or current_x is x_pure_rand_e4m3 if local_rank == 0: - print(f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', flush=True, end='') - dispatch_args = {'x': current_x, 'num_tokens_per_rank': num_tokens_per_rank, 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank, 'is_token_in_rank': is_token_in_rank, - 'num_tokens_per_expert': num_tokens_per_expert, 'config': config, 'async_finish': async_mode} + print( + f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', + flush=True, + end='') + dispatch_args = { + 'x': current_x, + 'num_tokens_per_rank': num_tokens_per_rank, + 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank, + 'is_token_in_rank': is_token_in_rank, + 'num_tokens_per_expert': num_tokens_per_expert, + 'config': config, + 'async_finish': async_mode + } if with_topk: - dispatch_args.update({'topk_idx': topk_idx, 'topk_weights': topk_weights_pure_rand if current_x is x_pure_rand else topk_weights}) + dispatch_args.update({'topk_idx': topk_idx, 'topk_weights': topk_weights_pure_rand if is_rand else topk_weights}) if previous_mode: dispatch_args.update({'previous_event': buffer.capture()}) - recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch(**dispatch_args) + recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch( + **dispatch_args) event.current_stream_wait() if async_mode else () + + if current_x is x_pure_rand or current_x is x: + hash_value += hash_tensor(recv_x) + else: + hash_value += hash_tensor(recv_x[0]) + hash_value += hash_tensor(recv_x[1]) + recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x # Checks recv_gbl_rank_prefix_sum = handle[-4] - assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(0), f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}' + assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(0), \ + f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}' assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list - if current_x is not x_pure_rand: + if not is_rand: check_data(recv_x, recv_gbl_rank_prefix_sum) if with_topk: # Check `topk_idx` - assert (recv_topk_idx.eq(-1) | ((recv_topk_idx >= 0) & (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel() + assert (recv_topk_idx.eq(-1) | + ((recv_topk_idx >= 0) & + (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel() for i, count in enumerate(recv_num_tokens_per_expert_list): assert recv_topk_idx.eq(i).sum().item() == count # Check `topk_weights` - if current_x is not x_pure_rand: - recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)] + if not is_rand: + recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax( + dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)] check_data(recv_topk_weights, recv_gbl_rank_prefix_sum) # Test cached dispatch (must without top-k staffs) - # NOTES: handle must be refreshed if not with_topk: dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode} if previous_mode: @@ -137,10 +176,13 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args) event.current_stream_wait() if async_mode else () recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x - if current_x is not x_pure_rand: + if not is_rand: check_data(recv_x, recv_gbl_rank_prefix_sum) # Test combine + bias_0 = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + bias_1 = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + combine_args = {'x': recv_x, 'bias': (bias_0, bias_1), 'handle': handle, 'config': config, 'async_finish': async_mode} combine_args = {'x': recv_x, 'handle': handle, 'config': config, 'async_finish': async_mode} if with_topk: combine_args.update({'topk_weights': recv_topk_weights}) @@ -148,14 +190,17 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): combine_args.update({'previous_event': buffer.capture()}) combined_x, combined_topk_weights, event = buffer.combine(**combine_args) event.current_stream_wait() if async_mode else () - check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1) - ref_x = x_pure_rand if current_x is x_pure_rand else x - assert calc_diff(check_x, ref_x) < 5e-6 + check_x = (combined_x.float() - bias_0.float() - bias_1.float()) / is_token_in_rank.sum(dim=1).unsqueeze(1) + ref_x = x_pure_rand if is_rand else x + assert calc_diff(check_x, ref_x) < 5e-4 if current_x is x_pure_rand_e4m3 else 5e-6 if with_topk: - check_topk_weights = combined_topk_weights if (current_x is x_pure_rand) else (combined_topk_weights / is_token_in_rank.sum(dim=1).unsqueeze(1)) - ref_topk_weights = topk_weights_pure_rand if current_x is x_pure_rand else topk_weights + check_topk_weights = combined_topk_weights if is_rand else (combined_topk_weights / + is_token_in_rank.sum(dim=1).unsqueeze(1)) + ref_topk_weights = topk_weights_pure_rand if is_rand else topk_weights assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9 + hash_value += hash_tensor(recv_x) + # For later tuning dispatch_bf16_rdma_send_bytes = num_rdma_token_sent * hidden * 2 dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2 @@ -165,7 +210,10 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): if local_rank == 0: print(' passed', flush=True) if local_rank == 0: - print() + print('', flush=True) + + if skip_benchmark: + return hash_value # Tune dispatch performance best_dispatch_results = None @@ -174,18 +222,29 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): best_time, best_results = 1e10, None rdma_send_bytes = (dispatch_bf16_rdma_send_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_rdma_send_bytes nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes - for nvl_chunk_size in range(4, 33, 4): + for nvl_chunk_size in range(4, 45, 4): for rdma_chunk_size in range(4, 33, 4): config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size) tune_args = {'x': current_x, 'handle': handle, 'config': config} - t = bench(lambda: buffer.dispatch(**tune_args))[0] + t, notify_t = bench_kineto( + lambda: buffer.dispatch(**tune_args), # noqa: B023 + ('dispatch', 'notify'), + suppress_kineto_output=True) if t < best_time: - best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size) + best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size, notify_t) if local_rank == 0: - print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {rdma_send_bytes / 1e9 / t:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ') + print( + f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: ' + f'{notify_t * 1e6:.0f} + {t * 1e6:.0f} us, ' + f'{rdma_send_bytes / 1e9 / t:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ', + flush=True) if local_rank == 0: - print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {rdma_send_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)') - print() + print( + f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: ' + f'{best_results[3] * 1e6:.0f} + {best_time * 1e6:.0f} us, ' + f'{rdma_send_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', + flush=True) + print('', flush=True) if isinstance(current_x, tuple): # Gather FP8 the best config from rank 0 @@ -193,66 +252,125 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): all_best_fp8_results_list = [torch.zeros_like(best_dispatch_results) for _ in range(torch.distributed.get_world_size())] dist.all_gather(all_best_fp8_results_list, best_dispatch_results, group=group) best_dispatch_results = all_best_fp8_results_list[0].tolist() - dispatch_config = deep_ep.Config(best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size, best_dispatch_results[2], rdma_buffer_size) + dispatch_config = deep_ep.Config(best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size, best_dispatch_results[2], + rdma_buffer_size) - dispatch_args = {'x': x, 'num_tokens_per_rank': num_tokens_per_rank, 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank, - 'is_token_in_rank': is_token_in_rank, 'num_tokens_per_expert': num_tokens_per_expert, - 'config': dispatch_config if dispatch_config is not None else config} + dispatch_args = { + 'x': x, + 'num_tokens_per_rank': num_tokens_per_rank, + 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank, + 'is_token_in_rank': is_token_in_rank, + 'num_tokens_per_expert': num_tokens_per_expert, + 'config': dispatch_config if dispatch_config is not None else config + } recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args) # Tune combine performance best_time, best_results = 1e10, None - for nvl_chunk_size in range(1, 5, 1): - # TODO: Sort out the assertation for 16 nodes - upper_bound = 29 if num_ranks == 128 else 33 - for rdma_chunk_size in range(8, upper_bound, 4): + for nvl_chunk_size in range(1, 8, 1): + for rdma_chunk_size in range(12 if num_nodes == 2 else 8, 33, 4): config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size) tune_args = {'x': recv_x, 'handle': handle, 'config': config} - t = bench(lambda: buffer.combine(**tune_args))[0] + t, notify_t = bench_kineto( + lambda: buffer.combine(**tune_args), # noqa: B023 + ('combine', 'notify'), + suppress_kineto_output=True) if local_rank == 0: - print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ') + print( + f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: ' + f'{notify_t * 1e6:.0f} + {t * 1e6:.0f} us, ' + f'{combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (RDMA), ' + f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ', + flush=True) if t < best_time: - best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size) + best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size, notify_t) if local_rank == 0: - print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)') - print() + print( + f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}, ' + f'{best_results[3] * 1e6:.2f} + {best_time * 1e6:.2f} us, ' + f'{combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)', + flush=True) + print('', flush=True) + return hash_value -# noinspection PyUnboundLocalVariable -def test_loop(local_rank: int, num_local_ranks: int, backend: str): - num_nodes = int(os.getenv('WORLD_SIZE', 2)) - rank, num_ranks, group = init_dist(local_rank, num_local_ranks, backend=backend) - test_ll_compatibility = False - if test_ll_compatibility: +# noinspection PyUnboundLocalVariable,PyShadowingNames +def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): + num_nodes = int(os.getenv('WORLD_SIZE', 1)) + rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + if args.test_ll_compatibility: ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 - buffer = deep_ep.Buffer(group, int(1e9), int(1e9), low_latency_mode=test_ll_compatibility, - num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1)) + num_sms = 24 + num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if args.test_ll_compatibility else 0) + + buffer = deep_ep.Buffer(group, + int(2e9), + int(1e9), + low_latency_mode=args.test_ll_compatibility, + num_qps_per_rank=num_qps_per_rank, + explicitly_destroy=True) assert num_local_ranks == 8 and num_ranks > 8 - torch.manual_seed(rank) - for i in (32, ): - test_main(i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group) + for seed in range(int(1e9)): if local_rank == 0: - print() + print(f'Testing with seed {seed} ...', flush=True) + torch.manual_seed(rank + seed) + ref_hash = 0 + for i in (num_sms, ): + ref_hash += test_main(args, i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group, + args.pressure_test_mode == 1) + if local_rank == 0: + print('', flush=True) + if args.pressure_test_mode == 0: + break + + if local_rank == 0: + print(f'{ref_hash=}') + print('', flush=True) + + for _ in range(20): + torch.manual_seed(rank + seed) + current_hash = 0 + for i in (num_sms, ): + current_hash += test_main(args, i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group, + args.pressure_test_mode == 1) + if local_rank == 0: + print('', flush=True) + assert current_hash == ref_hash # Test compatibility with low latency functions - if test_ll_compatibility: + if args.test_ll_compatibility: buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts) test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1) + # Destroy the buffer runtime and communication group + buffer.destroy() + dist.barrier() + dist.destroy_process_group() + if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Test internode communication') - parser.add_argument('--backend', type=str, choices=['mpi', 'nccl'], default='nccl', - help='Backend for distributed communication (mpi or nccl)') + parser = argparse.ArgumentParser(description='Test internode EP kernels') + parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)') + parser.add_argument('--num-tokens', type=int, default=4096, help='Number of tokens (default: 4096)') + parser.add_argument('--hidden', type=int, default=7168, help='Hidden dimension size (default: 7168)') + parser.add_argument('--num-topk-groups', type=int, default=None, help='Number of top-k groups (default: `min(num_nodes, 4)`)') + parser.add_argument('--num-topk', type=int, default=8, help='Number of top-k experts (default: 8)') + parser.add_argument( + '--pressure-test-mode', + type=int, + default=0, + help='Pressure test mode. 0: don\'t do pressure test, 1: do pressure test without benchmarks, 2: do pressure test with benchmarks') + parser.add_argument('--num-experts', type=int, default=256, help='Number of experts (default: 256') + parser.add_argument('--test-ll-compatibility', action='store_true', help='whether to test compatibility with low-latency kernels') args = parser.parse_args() - num_processes = 8 - if args.backend == 'mpi': - dist.init_process_group(backend='mpi') - rank = dist.get_rank() - local_rank = rank % num_processes - test_loop(local_rank=local_rank, num_local_ranks=num_processes, backend='mpi') - else: - torch.multiprocessing.spawn(test_loop, args=(num_processes, 'nccl'), nprocs=num_processes) + + # Set default `num_topk_groups` if not provided + if args.num_topk_groups is None: + num_nodes = int(os.getenv('WORLD_SIZE', 1)) + args.num_topk_groups = min(num_nodes, 4) + + num_processes = args.num_processes + torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes) diff --git a/tests/test_intranode.py b/tests/test_intranode.py index 68a95d8..53dda47 100644 --- a/tests/test_intranode.py +++ b/tests/test_intranode.py @@ -1,4 +1,4 @@ -import os +import argparse import time import torch import torch.distributed as dist @@ -11,9 +11,13 @@ import test_low_latency -def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup): +# noinspection PyShadowingNames +def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: deep_ep.Buffer, + group: dist.ProcessGroup): # Settings - num_tokens, hidden, num_topk, num_experts = 4096, 7168, 8, (256 // num_ranks) * num_ranks + num_tokens, hidden = args.num_tokens, args.hidden + num_topk, num_experts = args.num_topk, args.num_experts + assert num_experts % num_ranks == 0 if local_rank == 0: print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk={num_topk}', flush=True) @@ -21,12 +25,15 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: # Random data x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') - x_e4m3 = per_token_cast_to_fp8(x) + x_e4m3 = per_token_cast_to_fp8(x) if deep_ep.Buffer.is_sm90_compiled() else None + x_e4m3 = (x_e4m3[0], x_e4m3[1].T.contiguous().T) if x_e4m3 is not None else None scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)[1] + topk_idx = topk_idx.to(deep_ep.topk_idx_t) topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda') rank_idx = topk_idx // (num_experts // num_ranks) + rank_idx = rank_idx.to(torch.int64) rank_idx.masked_fill_(topk_idx == -1, -1) inplace_unique(rank_idx, num_ranks) @@ -60,7 +67,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0] if local_rank == 0: print(f'[layout] Kernel performance: {t * 1000:.3f} ms', flush=True) - print() + print('', flush=True) group.barrier() time.sleep(1) @@ -80,39 +87,74 @@ def check_data(check_x, rank_prefix_matrix): for previous_mode in (False, True): for async_mode in (False, True): - for current_x in (x_pure_rand, x, x_e4m3): + for current_x in filter(lambda elem: elem is not None, (x_pure_rand, x, x_e4m3)): for with_topk in (False, True): if local_rank == 0: - print(f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', flush=True, end='') - dispatch_args = {'x': current_x, 'num_tokens_per_rank': num_tokens_per_rank, 'is_token_in_rank': is_token_in_rank, - 'num_tokens_per_expert': num_tokens_per_expert, 'config': config, 'async_finish': async_mode} + print( + f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', + flush=True, + end='') + dispatch_args = { + 'x': current_x, + 'num_tokens_per_rank': num_tokens_per_rank, + 'is_token_in_rank': is_token_in_rank, + 'num_tokens_per_expert': num_tokens_per_expert, + 'config': config, + 'async_finish': async_mode + } if with_topk: - dispatch_args.update({'topk_idx': topk_idx, 'topk_weights': topk_weights_pure_rand if current_x is x_pure_rand else topk_weights}) + dispatch_args.update({ + 'topk_idx': topk_idx, + 'topk_weights': topk_weights_pure_rand if current_x is x_pure_rand else topk_weights + }) if previous_mode: dispatch_args.update({'previous_event': buffer.capture()}) - recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch(**dispatch_args) + recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch( + **dispatch_args) event.current_stream_wait() if async_mode else () recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x # Checks rank_prefix_matrix = handle[0] - assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(0), f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}' + assert gbl_num_tokens_per_rank[rank].item() == recv_x.size( + 0), f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}' assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list if current_x is not x_pure_rand: check_data(recv_x, rank_prefix_matrix) + recv_topk_weights_clone = None if with_topk: # Check `topk_idx` - assert (recv_topk_idx.eq(-1) | ((recv_topk_idx >= 0) & (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel() + assert (recv_topk_idx.eq(-1) | + ((recv_topk_idx >= 0) & + (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel() for i, count in enumerate(recv_num_tokens_per_expert_list): assert recv_topk_idx.eq(i).sum().item() == count # Check `topk_weights` + recv_topk_weights_clone = recv_topk_weights.clone() if current_x is not x_pure_rand: - recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)] + recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax( + dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)] check_data(recv_topk_weights, rank_prefix_matrix) + # Test `num_worst_tokens != 0` + if with_topk: + num_worst_tokens = num_tokens * num_ranks + dispatch_args.update({'num_worst_tokens': num_worst_tokens}) + recv_worst_x, recv_worst_topk_idx, recv_worst_topk_weights, empty_list, _, event = buffer.dispatch(**dispatch_args) + event.current_stream_wait() if async_mode else () + recv_worst_x = per_token_cast_back(*recv_worst_x) if isinstance(recv_worst_x, tuple) else recv_worst_x + assert len(empty_list) == 0 + assert num_worst_tokens == recv_worst_x.size(0) + assert num_worst_tokens == recv_worst_topk_idx.size(0) + assert num_worst_tokens == recv_worst_topk_weights.size(0) + assert torch.equal(recv_x, recv_worst_x[:recv_x.size(0)]) + assert torch.equal(recv_topk_idx, recv_worst_topk_idx[:recv_x.size(0)]) + assert torch.equal(recv_topk_weights_clone, recv_worst_topk_weights[:recv_x.size(0)]) + #TODO check why overflow area is not all -1. + #assert torch.all(recv_worst_topk_idx[recv_x.size(0):] == -1).item() + # Test cached dispatch (must without top-k staffs) - # NOTES: handle must be refreshed if not with_topk: dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode} if previous_mode: @@ -128,14 +170,16 @@ def check_data(check_x, rank_prefix_matrix): if with_topk: combine_args.update({'topk_weights': recv_topk_weights}) if previous_mode: - dispatch_args.update({'previous_event': buffer.capture()}) + combine_args.update({'previous_event': buffer.capture()}) combined_x, combined_topk_weights, event = buffer.combine(**combine_args) event.current_stream_wait() if async_mode else () check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1) ref_x = x_pure_rand if current_x is x_pure_rand else x assert calc_diff(check_x, ref_x) < 5e-6 if with_topk: - check_topk_weights = combined_topk_weights if (current_x is x_pure_rand) else (combined_topk_weights / is_token_in_rank.sum(dim=1).unsqueeze(1)) + check_topk_weights = combined_topk_weights if (current_x + is x_pure_rand) else (combined_topk_weights / + is_token_in_rank.sum(dim=1).unsqueeze(1)) ref_topk_weights = topk_weights_pure_rand if current_x is x_pure_rand else topk_weights assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9 @@ -146,79 +190,123 @@ def check_data(check_x, rank_prefix_matrix): if local_rank == 0: print(' passed', flush=True) if local_rank == 0: - print() + print('', flush=True) # Tune dispatch performance best_dispatch_results = None fp8_factor = (1 + 4 / 128) / 2 - for current_x in (x_e4m3, x): + for current_x in filter(lambda elem: elem is not None, (x_e4m3, x)): best_time, best_results = 1e10, None nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes - for nvl_chunk_size in range(4, 150, 4): - config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size) + for nvl_chunk_size in tuple(range(4, 33, 2)) + (0, ): + if nvl_chunk_size > 0: + config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size) + else: + # Test default config as well + deep_ep.Buffer.set_num_sms(num_sms) + config = deep_ep.Buffer.get_dispatch_config(num_ranks) tune_args = {'x': current_x, 'handle': handle, 'config': config} - t = bench(lambda: buffer.dispatch(**tune_args))[0] - if t < best_time: + t = bench(lambda: buffer.dispatch(**tune_args))[0] # noqa: B023 + if t < best_time and nvl_chunk_size > 0: best_time, best_results = t, (num_sms, nvl_chunk_size) if local_rank == 0: - print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL), time {t * 1000 * 1000:.2f} us', flush=True) + print( + f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: ' + f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL), {t * 1e6:.2f} us', + flush=True) if local_rank == 0: - print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL), time {best_time * 1000 * 1000:.2f} us', flush=True) - print() + print( + f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL), t: {best_time * 1e6:.2f} us', + flush=True) + print('', flush=True) - if isinstance(current_x, tuple): - # Gather FP8 the best config from rank 0 + # Gather the best config from rank 0 and the first test setting + if best_dispatch_results is None: best_dispatch_results = torch.tensor([best_results[0], best_results[1]], dtype=torch.int32, device='cuda') all_best_fp8_results_list = [torch.zeros_like(best_dispatch_results) for _ in range(torch.distributed.get_world_size())] dist.all_gather(all_best_fp8_results_list, best_dispatch_results, group=group) best_dispatch_results = all_best_fp8_results_list[0].tolist() dispatch_config = deep_ep.Config(best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size) - dispatch_args = {'x': x, 'num_tokens_per_rank': num_tokens_per_rank, - 'is_token_in_rank': is_token_in_rank, 'num_tokens_per_expert': num_tokens_per_expert, - 'config': dispatch_config if dispatch_config is not None else config} + dispatch_args = { + 'x': x, + 'num_tokens_per_rank': num_tokens_per_rank, + 'is_token_in_rank': is_token_in_rank, + 'num_tokens_per_expert': num_tokens_per_expert, + 'config': dispatch_config if dispatch_config is not None else config + } recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args) # Tune combine performance best_time, best_results = 1e10, None - for nvl_chunk_size in range(1, 35, 1): - config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size) + for nvl_chunk_size in tuple(range(1, 17, 1)) + (0, ): + if nvl_chunk_size > 0: + config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size) + else: + # Test default config as well + deep_ep.Buffer.set_num_sms(num_sms) + config = deep_ep.Buffer.get_combine_config(num_ranks) tune_args = {'x': recv_x, 'handle': handle, 'config': config} - t = bench(lambda: buffer.combine(**tune_args))[0] + t = bench(lambda: buffer.combine(**tune_args))[0] # noqa: B023 if local_rank == 0: - print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL), time {t * 1000 * 1000:.2f} us', flush=True) - if t < best_time: + print( + f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: ' + f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL), {t * 1e6:.2f} us', + flush=True) + if t < best_time and nvl_chunk_size > 0: best_time, best_results = t, (num_sms, nvl_chunk_size) if local_rank == 0: - print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL), time {best_time * 1000 * 1000:.2f} us', flush=True) - print() + print( + f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL), t: {best_time * 1e6:.2f} us', + flush=True) + print('', flush=True) -# noinspection PyUnboundLocalVariable -def test_loop(local_rank: int, num_local_ranks: int): +# noinspection PyUnboundLocalVariable,PyShadowingNames +def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): rank, num_ranks, group = init_dist(local_rank, num_local_ranks) test_ll_compatibility, num_rdma_bytes = False, 0 if test_ll_compatibility: ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts) - buffer = deep_ep.Buffer(group, int(1e9), num_rdma_bytes, low_latency_mode=test_ll_compatibility, - num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1)) + buffer = deep_ep.Buffer(group, + int(2e9), + num_rdma_bytes, + low_latency_mode=test_ll_compatibility, + num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1), + explicitly_destroy=True, + allow_mnnvl=args.allow_mnnvl, + use_fabric=args.use_fabric) torch.manual_seed(rank) - for i in (64, ): - test_main(i, local_rank, num_ranks, rank, buffer, group) + for i in (24, ): + test_main(args, i, local_rank, num_ranks, rank, buffer, group) if local_rank == 0: - print() + print('', flush=True) # Test compatibility with low latency functions if test_ll_compatibility: buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts) test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1) - dist.destroy_process_group(group) + # Destroy the buffer runtime and communication group + buffer.destroy() + dist.barrier() + dist.destroy_process_group() + if __name__ == '__main__': - num_processes = 8 - torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes) + parser = argparse.ArgumentParser(description='Test intranode EP kernels') + parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)') + parser.add_argument('--num-tokens', type=int, default=4096, help='Number of tokens (default: 4096)') + parser.add_argument('--hidden', type=int, default=7168, help='Hidden dimension size (default: 7168)') + parser.add_argument('--num-topk', type=int, default=8, help='Number of top-k experts (default: 8)') + parser.add_argument('--num-experts', type=int, default=256, help='Number of experts (default: 256)') + parser.add_argument('--allow-mnnvl', action="store_true", help='Enable MNNVL support') + parser.add_argument('--use-fabric', action="store_true", help='Enable fabric mode') + args = parser.parse_args() + + num_processes = args.num_processes + torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes) diff --git a/tests/test_low_latency.py b/tests/test_low_latency.py index 0c65024..5dfddc7 100644 --- a/tests/test_low_latency.py +++ b/tests/test_low_latency.py @@ -1,104 +1,195 @@ +import argparse import random import torch import torch.distributed as dist from functools import partial +from typing import Literal, Set import deep_ep from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back -def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, - rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: deep_ep.Buffer, seed: int = 0): +def simulate_failure_and_skip(rank: int, api: Literal["dispatch", "combine", "clean"], expected_masked_ranks: Set[int]): + # Simulates rank failure when the rank first calls the corresponding communication API + failed_api_ranks = { + # API -> rank to fail (rank fails when it first calls the corresponding communication API) + 'dispatch': 1, + 'combine': 3, + 'clean': 5 + } + if rank in expected_masked_ranks: + # Rank already failed + return True + if api in failed_api_ranks.keys(): + expected_masked_ranks.add(failed_api_ranks[api]) + if failed_api_ranks[api] == rank: + print(f"Rank {rank} failed when first calling {api} communication API, exit...", flush=True) + return True + return False + + +def query_mask_buffer_and_check(api: Literal["dispatch", "combine", "clean"], buffer: deep_ep.Buffer, mask_status: torch.Tensor, + expected_masked_ranks: Set[int]): + buffer.low_latency_query_mask_buffer(mask_status) + assert set(mask_status.nonzero().squeeze(-1).tolist()) == expected_masked_ranks + + +def test_main(num_tokens: int, + hidden: int, + num_experts: int, + num_topk: int, + rank: int, + num_ranks: int, + group: dist.ProcessGroup, + buffer: deep_ep.Buffer, + use_logfmt: bool = False, + shrink_test: bool = False, + seed: int = 0): torch.manual_seed(seed + rank) random.seed(seed + rank) assert num_experts % num_ranks == 0 num_local_experts = num_experts // num_ranks - # NOTES: the integers greater than 256 exceeds the BF16 precision limit + # NOTES: the integers greater than 256 exceed the BF16 precision limit rank_offset = 128 assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)' x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset) x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1) + x_list = [x] + for _ in range(4 if use_logfmt else 0): + # NOTES: make more LogFMT casts and also with some BF16 + x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.5 * random.random()) + # NOTES: the last one is for performance testing + # Most of the values in the perf case is lower than the threshold, casting most channels + x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1) + scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1] + topk_idx = topk_idx.to(deep_ep.topk_idx_t) topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs() + # Randomly mask some positions - for i in range(10): + for _ in range(10): topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1 + all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda') + dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group) + + # For failure simulation and shrink testing + mask_status = torch.zeros((num_ranks, ), dtype=torch.int, device='cuda') + expected_masked_ranks = set() + # Check dispatch correctness do_check = True hash_value, num_times = 0, 0 + for current_x in x_list: + for return_recv_hook in (False, True): + for dispatch_use_fp8 in (False, True): + for round_scale in (False, True) if dispatch_use_fp8 else (False, ): + for use_ue8m0 in (False,) if round_scale else (False, ): + if shrink_test and simulate_failure_and_skip(rank, "dispatch", expected_masked_ranks): + break + num_times += 1 + for _ in range((num_times % 2) + 1): + cumulative_local_expert_recv_stats = torch.zeros((num_local_experts, ), dtype=torch.int, device='cuda') + packed_recv_x, packed_recv_count, handle, event, hook = \ + buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts, + use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0, + cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats, + async_finish=not return_recv_hook, return_recv_hook=return_recv_hook) + hook() if return_recv_hook else event.current_stream_wait() + if shrink_test: + query_mask_buffer_and_check("dispatch", buffer, mask_status, expected_masked_ranks) + packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x + simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \ + if dispatch_use_fp8 else packed_recv_x.clone() + for i in range(num_local_experts if do_check else 0): + expert_id = rank * num_local_experts + i + recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i] + recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i] - for return_recv_hook in (False, True): - for dispatch_use_fp8 in (False, True): - num_times += 1 - for i in range((num_times) + 1): - packed_recv_x, packed_recv_count, handle, event, hook = \ - buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, use_fp8=dispatch_use_fp8, - async_finish=not return_recv_hook, return_recv_hook=return_recv_hook) - hook() if return_recv_hook else event.current_stream_wait() - packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x - simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \ - if dispatch_use_fp8 else packed_recv_x.clone() - #print(f"rank{rank}: packed_recv_x[0]\n{packed_recv_x[0].cpu()}\n") - #print(f"rank{rank}: packed_recv_x[1]\n{packed_recv_x[1].cpu()}\n") - #print(f"simulated_gemm_x{simulated_gemm_x.cpu()}") - all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda') - dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group) - for i in range(num_local_experts if do_check else 0): - expert_id = rank * num_local_experts + i - recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i] - recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i] - - # Check expert indices - int_mask = (2 ** 32) - 1 - num_valid_tokens = recv_count.item() - assert num_valid_tokens == (recv_layout_range & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()' - assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}' - - # Check received data - recv_x = recv_x[:num_valid_tokens] - recv_x_amin = recv_x[:, :-128].amin(dim=-1) - recv_src_info = recv_src_info[:num_valid_tokens] - assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1)) - assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0 - for j in range(num_ranks): - begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item() - assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item() - assert (recv_x[begin_idx:begin_idx + count][:-128] - j).sum().item() == 0 - if dispatch_use_fp8: - hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens]) - hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens]) - else: - hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens]) - - # Check combine correctness - for zero_copy in (False,True): - if zero_copy: - buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x - out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') - combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, - async_finish=not return_recv_hook, - return_recv_hook=return_recv_hook, out=out) - hook() if return_recv_hook else event.current_stream_wait() - if do_check: - diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x) - assert torch.isnan(combined_x).sum().item() == 0 - assert diff < 1e-5, f'Error: diff={diff}' - hash_value ^= hash_tensor(combined_x) - - def create_test_cast_with_outliers(num_outliers): - tmp = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') - tmp /= tmp.abs().amax(dim=1).view(-1, 1) - assert tmp.abs().amax().item() <= 1 - - # Create some amax outliers - for i in range(num_outliers): - tmp[random.randint(0, num_tokens - 1)] *= 1e3 - return tmp + # Check expert indices + int_mask = (2**32) - 1 + num_valid_tokens = recv_count.item() + # cumulative_local_expert_recv_stats not currently enabled. + #assert cumulative_local_expert_recv_stats[i].item( + #) == num_valid_tokens, f'{cumulative_local_expert_recv_stats[i].item()} != {num_valid_tokens}' + assert num_valid_tokens == ( + recv_layout_range + & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()' + assert num_valid_tokens == (all_topk_idx == expert_id).sum(dim=[1, 2])[mask_status == 0].sum().item( + ), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum(dim=[1, 2])[mask_status==0].sum().item()}' + + if num_valid_tokens == 0: + continue + # Check received data + if current_x is x: + recv_x = recv_x[:num_valid_tokens] + recv_x_amin = recv_x[:, :-128].amin(dim=-1) + recv_src_info = recv_src_info[:num_valid_tokens] + assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1)) + if round_scale: + assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007 + else: + assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0 + for j in range(num_ranks): + if shrink_test and mask_status[j]: + continue + begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item() + if not round_scale: + assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item() + assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0 + if dispatch_use_fp8: + hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens]) + hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens]) + else: + hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens]) + + # Check combine correctness + if shrink_test and simulate_failure_and_skip(rank, "combine", expected_masked_ranks): + break + for zero_copy in (False, ) if use_logfmt else (False, True): + if zero_copy: + buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x + out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, + topk_idx, + topk_weights, + handle, + use_logfmt=use_logfmt, + async_finish=not return_recv_hook, + zero_copy=zero_copy, + return_recv_hook=return_recv_hook, + out=out) + hook() if return_recv_hook else event.current_stream_wait() + if shrink_test: + query_mask_buffer_and_check("combine", buffer, mask_status, expected_masked_ranks) + if do_check: + if shrink_test: + owner_by_expert = (torch.arange(num_experts, device='cuda') // num_local_experts) + fail_owner_mask = (mask_status == 1).index_select(0, owner_by_expert) + valid_topk_idx = topk_idx >= 0 + failed_topk_idx = torch.zeros_like(topk_idx, device='cuda', dtype=torch.bool) + failed_topk_idx[valid_topk_idx] = fail_owner_mask.index_select(0, topk_idx[valid_topk_idx]) + topk_idx[failed_topk_idx] = -1 + diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x) + assert torch.isnan(combined_x).sum().item() == 0 + if not round_scale: + assert diff < (9e-4 if dispatch_use_fp8 else 1e-5), f'Error: {diff=}, {dispatch_use_fp8=}, {zero_copy=}' + hash_value ^= hash_tensor(combined_x) + + # Clean buffer API + if shrink_test: + if simulate_failure_and_skip(rank, "clean", expected_masked_ranks): + break + + buffer.clean_low_latency_buffer(num_tokens, hidden, num_experts) + query_mask_buffer_and_check("clean", buffer, mask_status, expected_masked_ranks) + + if shrink_test: + return # noinspection PyShadowingNames def large_gemm_with_hook(hook): @@ -108,70 +199,133 @@ def large_gemm_with_hook(hook): hook() # noinspection PyShadowingNames - def test_func(zero_copy: bool, return_recv_hook: bool): + def test_func(return_recv_hook: bool): recv_x, recv_count, handle, event, hook = \ - buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, use_fp8=True, - async_finish=False, return_recv_hook=return_recv_hook) + buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts, + cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats, + use_fp8=True, async_finish=False, return_recv_hook=return_recv_hook) large_gemm_with_hook(hook) if return_recv_hook else None - if zero_copy: - buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x - combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, - zero_copy=zero_copy, return_recv_hook=return_recv_hook) + combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, + topk_idx, + topk_weights, + handle, + use_logfmt=use_logfmt, + return_recv_hook=return_recv_hook) large_gemm_with_hook(hook) if return_recv_hook else None # Calculate bandwidth num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2 + num_logfmt10_bytes = hidden * 10 / 8 + hidden / 128 * 4 num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0 for i in range(num_tokens): num_selections = (topk_idx[i] != -1).sum().item() num_dispatch_comm_bytes += num_fp8_bytes * num_selections - num_combine_comm_bytes += num_bf16_bytes * num_selections + num_combine_comm_bytes += (num_logfmt10_bytes if use_logfmt else num_bf16_bytes) * num_selections # Dispatch + combine testing - avg_t, min_t, max_t = bench(partial(test_func, zero_copy=zero_copy, return_recv_hook=False)) - print(f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, ' - f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', flush=True) + avg_t, min_t, max_t = bench(partial(test_func, return_recv_hook=False)) + print( + f'[rank {rank}] Dispatch + combine bandwidth: {(num_dispatch_comm_bytes + num_combine_comm_bytes) / 1e9 / avg_t:.2f} GB/s, ' + f'avg_t={avg_t * 1e6:.2f} us, min_t={min_t * 1e6:.2f} us, max_t={max_t * 1e6:.2f} us', + flush=True) # Separate profiling for return_recv_hook in (False, True): group.barrier() - - dispatch_t, combine_t = bench_kineto(partial(test_func, zero_copy=False, return_recv_hook=return_recv_hook), - kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True, - suppress_kineto_output=True) + dispatch_t, combine_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook), + kernel_names=('dispatch', 'combine'), + barrier_comm_profiling=True, + suppress_kineto_output=True, + num_kernels_per_period=2 if return_recv_hook else 1) if not return_recv_hook: - print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | ' - f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us') + print( + f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | ' + f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us', + flush=True) else: - print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t * 2 * 1e6:.2f} us | ' - f'Combine send/recv time: {combine_t * 2 * 1e6:.2f} us') - + print( + f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | ' + f'Combine send/recv time: {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us', + flush=True) return hash_value -# noinspection PyUnboundLocalVariable -def test_loop(local_rank: int, num_local_ranks: int): +# noinspection PyUnboundLocalVariable,PyShadowingNames +def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): rank, num_ranks, group = init_dist(local_rank, num_local_ranks) - # The default setting of deepEP upstream is below: - num_tokens, hidden, num_topk, num_experts = 128, 7168, 8, 288 + num_tokens, hidden = args.num_tokens, args.hidden + num_topk, num_experts = args.num_topk, args.num_experts num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts) if local_rank == 0: print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True) - buffer = deep_ep.Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True, - num_qps_per_rank=num_experts // num_ranks) - test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=1) + buffer = deep_ep.Buffer(group, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=True, + num_qps_per_rank=num_experts // num_ranks, + allow_nvlink_for_low_latency_mode=not args.disable_nvlink, + explicitly_destroy=True, + allow_mnnvl=args.allow_mnnvl, + enable_shrink=args.shrink_test) + test_main(num_tokens, + hidden, + num_experts, + num_topk, + rank, + num_ranks, + group, + buffer, + use_logfmt=args.use_logfmt, + shrink_test=args.shrink_test, + seed=1) - do_pressure_test = False + do_pressure_test = args.pressure_test for seed in range(int(1e9) if do_pressure_test else 0): if local_rank == 0: print(f'Testing with seed {seed} ...', flush=True) - ref_hash = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed) - for i in range(20): - assert test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed) == ref_hash, f'Error: seed={seed}' + ref_hash = test_main(num_tokens, + hidden, + num_experts, + num_topk, + rank, + num_ranks, + group, + buffer, + use_logfmt=args.use_logfmt, + seed=seed) + for _ in range(20): + assert test_main(num_tokens, + hidden, + num_experts, + num_topk, + rank, + num_ranks, + group, + buffer, + use_logfmt=args.use_logfmt, + seed=seed) == ref_hash, f'Error: seed={seed}' + + # Destroy the buffer runtime and communication group + buffer.destroy() + dist.barrier() + dist.destroy_process_group() if __name__ == '__main__': # TODO: you may modify NUMA binding for less CPU overhead - num_processes = 8 - torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes) + # TODO: buggy with `num_tokens=512` + parser = argparse.ArgumentParser(description='Test low-latency EP kernels') + parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)') + parser.add_argument('--num-tokens', type=int, default=128, help='Number of tokens (default: 128)') + parser.add_argument('--hidden', type=int, default=7168, help='Hidden dimension size (default: 7168)') + parser.add_argument('--num-topk', type=int, default=8, help='Number of top-k experts (default: 8)') + parser.add_argument('--num-experts', type=int, default=288, help='Number of experts (default: 288)') + parser.add_argument('--allow-mnnvl', action="store_true", help='Allow MNNVL for communication') + parser.add_argument('--disable-nvlink', action='store_true', help='Whether to disable NVLink for testing') + parser.add_argument('--use-logfmt', action='store_true', help='Whether to test LogFMT combine') + parser.add_argument("--pressure-test", action='store_true', help='Whether to do pressure test') + parser.add_argument("--shrink-test", action='store_true', help='Whether to simulate failure and test shrink mode') + args = parser.parse_args() + + num_processes = args.num_processes + torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes) diff --git a/tests/utils.py b/tests/utils.py index 7665889..1390b2b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,26 +1,34 @@ +import inspect +import json +import tempfile +from pathlib import Path + +import numpy as np import os import sys -import numpy as np import torch import torch.distributed as dist -from typing import Optional +from typing import Optional, Union -def init_dist(local_rank: int, num_local_ranks: int, backend: str = 'nccl'): +def init_dist(local_rank: int, num_local_ranks: int): # NOTES: you may rewrite this function with your own cluster settings - if backend == 'nccl': - ip = os.getenv('MASTER_ADDR', '127.0.0.1') - port = int(os.getenv('MASTER_PORT', '8361')) - node_rank = int(os.getenv('RANK', 0)) + ip = os.getenv('MASTER_ADDR', '127.0.0.1') + port = int(os.getenv('MASTER_PORT', '8361')) num_nodes = int(os.getenv('WORLD_SIZE', 1)) - assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8 - if backend == 'nccl': - dist.init_process_group( - backend='nccl', - init_method=f'tcp://{ip}:{port}', - world_size=num_nodes * num_local_ranks, - rank=node_rank * num_local_ranks + local_rank - ) + node_rank = int(os.getenv('RANK', 0)) + + sig = inspect.signature(dist.init_process_group) + params = { + 'backend': 'nccl', + 'init_method': f'tcp://{ip}:{port}', + 'world_size': num_nodes * num_local_ranks, + 'rank': node_rank * num_local_ranks + local_rank, + } + if 'device_id' in sig.parameters: + # noinspection PyTypeChecker + params['device_id'] = torch.device(f'cuda:{local_rank}') + dist.init_process_group(**params) torch.set_default_dtype(torch.bfloat16) torch.set_default_device('cuda') torch.cuda.set_device(local_rank) @@ -35,18 +43,35 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): return (1 - sim).item() +def align_up(x, y): + return (x + y - 1) // y * y + + def per_token_cast_to_fp8(x: torch.Tensor): - assert x.dim() == 2 and x.size(1) % 128 == 0 + assert x.dim() == 2 m, n = x.shape - x_view = x.view(m, -1, 128) - x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) + aligned_n = align_up(n, 128) + x_padded = torch.nn.functional.pad(x, (0, aligned_n - n), mode='constant', value=0) + x_padded_view = x_padded.view(m, -1, 128) + x_amax = x_padded_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + return (x_padded_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view( + m, aligned_n)[:, :n].contiguous(), (x_amax / 448.0).view(m, -1) def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor): - x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128) + if x_fp8.numel() == 0: + return x_fp8.to(torch.bfloat16) + + assert x_fp8.dim() == 2 + m, n = x_fp8.shape + aligned_n = align_up(n, 128) + x_fp8_padded = torch.nn.functional.pad(x_fp8, (0, aligned_n - n), mode='constant', value=0) + if x_scales.dtype == torch.int: + x_scales = x_scales.view(dtype=torch.uint8).to(torch.int) << 23 + x_scales = x_scales.view(dtype=torch.float) + x_fp32_padded = x_fp8_padded.to(torch.float32).view(x_fp8.size(0), -1, 128) x_scales = x_scales.view(x_fp8.size(0), -1, 1) - return (x_fp32 * x_scales).view(x_fp8.shape).to(torch.bfloat16) + return (x_fp32_padded * x_scales).view(x_fp8_padded.shape).to(torch.bfloat16)[:, :n].contiguous() def inplace_unique(x: torch.Tensor, num_slots: int): @@ -72,7 +97,7 @@ def create_grouped_scores(scores: torch.Tensor, group_idx: torch.Tensor, num_gro return (scores * mask).view(num_tokens, num_experts) -def bench(fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None): +def bench(fn, num_warmups: int = 50, num_tests: int = 50, post_fn=None): # Flush L2 cache with 256 MB data torch.cuda.synchronize() cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') @@ -101,6 +126,7 @@ def bench(fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None): class empty_suppress: + def __enter__(self): return self @@ -109,6 +135,7 @@ def __exit__(self, *_): class suppress_stdout_stderr: + def __enter__(self): self.outnull_file = open(os.devnull, 'w') self.errnull_file = open(os.devnull, 'w') @@ -143,14 +170,19 @@ def __exit__(self, *_): self.errnull_file.close() -def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False, - trace_path: Optional[str] = None, barrier_comm_profiling: bool = False): +def bench_kineto(fn, + kernel_names: Union[str, tuple], + num_tests: int = 30, + suppress_kineto_output: bool = False, + trace_path: Optional[str] = None, + barrier_comm_profiling: bool = False, + num_kernels_per_period: int = 1): # Profile suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress with suppress(): - schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) + schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1) with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) as prof: - for i in range(2): + for _ in range(2): # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead if barrier_comm_profiling: lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') @@ -159,11 +191,12 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda')) for _ in range(num_tests): fn() + torch.cuda.synchronize() prof.step() # Parse the profiling table - assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) - is_tupled = isinstance(kernel_names, tuple) + assert isinstance(kernel_names, (str, tuple)) + is_tuple = isinstance(kernel_names, tuple) prof_lines = prof.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names assert all([isinstance(name, str) for name in kernel_names]) @@ -174,20 +207,36 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: if trace_path is not None: prof.export_chrome_trace(trace_path) - # Return average kernel times + # Return average kernel durations units = {'ms': 1e3, 'us': 1e6} - kernel_times = [] + kernel_durations = [] for name in kernel_names: for line in prof_lines: if name in line: time_str = line.split()[-2] for unit, scale in units.items(): if unit in time_str: - kernel_times.append(float(time_str.replace(unit, '')) / scale) + kernel_durations.append(float(time_str.replace(unit, '')) / scale) break break - return tuple(kernel_times) if is_tupled else kernel_times[0] + + # Expand the kernels by periods + if num_kernels_per_period > 1: + with tempfile.NamedTemporaryFile(suffix='.json') as tmp: + prof.export_chrome_trace(tmp.name) + profile_data = json.loads(Path(tmp.name).read_text()) + + for i, kernel_name in enumerate(kernel_names): + events = [event for event in profile_data['traceEvents'] if f'::{kernel_name}' in event['name']] + events = sorted(events, key=lambda event: event['ts']) + durations = [event['dur'] / 1e6 for event in events] + assert len(durations) % num_kernels_per_period == 0 + num_kernel_patterns = len(durations) // num_kernels_per_period + kernel_durations[i] = [sum(durations[j::num_kernels_per_period]) / num_kernel_patterns for j in range(num_kernels_per_period)] + + # Return execution durations + return kernel_durations if is_tuple else kernel_durations[0] def hash_tensor(t: torch.Tensor): - return t.view(torch.int64).sum().item() + return t.view(torch.int).sum().item() diff --git a/third-party/README.md b/third-party/README.md index 505efc8..39ad467 100644 --- a/third-party/README.md +++ b/third-party/README.md @@ -23,12 +23,6 @@ MPI_ROOT=$BUILD_DIR/ompi ../rocSHMEM/scripts/build_configs/gda_mlx5 --fresh \ -DUSE_IPC=ON \ -DGDA_BNXT=ON -# To build rocSHMEM with MPI disabled, please add this flag -DUSE_EXTERNAL_MPI=OFF -MPI_ROOT=$BUILD_DIR/ompi ../rocSHMEM/scripts/build_configs/gda_mlx5 --fresh \ - -DUSE_IPC=ON \ - -DGDA_BNXT=ON - -DUSE_EXTERNAL_MPI=OFF - # You may pass additional arguments to Cmake, # e.g., -DBUILD_LOCAL_GPU_TARGET_ONLY=ON ``` From 3e7610619fdaea09f84dfac63ed0e803571af004 Mon Sep 17 00:00:00 2001 From: Richard Chamberlain Date: Fri, 7 Nov 2025 04:22:06 -0800 Subject: [PATCH 02/22] Adding support for disabling MPI. --- README.md | 10 ++++++++- setup.py | 49 ++++++++++++++++++++++++++++++----------- tests/test_internode.py | 12 ++++++++-- tests/utils.py | 38 ++++++++++++++++++-------------- 4 files changed, 77 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 75eda73..0b617e4 100644 --- a/README.md +++ b/README.md @@ -31,15 +31,23 @@ DeepEP (AMD version) depends on [rocSHMEM](https://github.com/ROCm/rocSHMEM). Pl git clone https://github.com/ROCm/DeepEP cd DeepEP -# export OMPI dir in the next command (e.g., it's $BUILD_DIR/ompi in third-party/README.md) +# To use DeepEP with MPI, please proceed with these commands +# Export OMPI dir in the next command (e.g., it's $BUILD_DIR/ompi in third-party/README.md) export OMPI_DIR= python3 setup.py --variant rocm build develop + +# To use DeepEP without MPI, please make sure rocSHMEM was built with this flag -DUSE_EXTERNAL_MPI=OFF +# Then install DeepEP using this command +python3 setup.py --variant rocm --disable-mpi build develop + # Run test cases # NOTES: you may modify the `init_dist` function in `tests/utils.py` # according to your own cluster settings, and launch into multiple nodes python3 tests/test_intranode.py python3 tests/test_internode.py +# Set the required ROCSHMEM heap size (for example, for DeepSeek models) +export ROCSHMEM_HEAP_SIZE=2147483648 python3 tests/test_low_latency.py ``` diff --git a/setup.py b/setup.py index b6a8111..6645bd6 100644 --- a/setup.py +++ b/setup.py @@ -28,12 +28,14 @@ def get_nvshmem_host_lib_name(base_dir): parser.add_argument("--verbose", action="store_true", help="Verbose build") parser.add_argument("--enable_timer", action="store_true", help="Enable timer to debug time out in internode") parser.add_argument("--rocm-disable-ctx", action="store_true", help="Disable workgroup context optimization in internode") + parser.add_argument("--disable-mpi", action="store_true", help="Disable MPI detection and configuration") # Get the arguments to be parsed and separate setuptools arguments args, unknown_args = parser.parse_known_args() variant = args.variant debug = args.debug rocm_disable_ctx = args.rocm_disable_ctx + disable_mpi = args.disable_mpi enable_timer = args.enable_timer @@ -57,8 +59,8 @@ def get_nvshmem_host_lib_name(base_dir): if not disable_nvshmem: assert os.path.exists(nvshmem_dir), f'The specified NVSHMEM directory does not exist: {nvshmem_dir}' - #else: - # disable_nvshmem = False + else: + disable_nvshmem = False # Reset sys.argv for setuptools to avoid conflicts @@ -86,7 +88,8 @@ def get_nvshmem_host_lib_name(base_dir): ), f"Failed to find {shmem_variant_name}" print(f"{shmem_variant_name} directory: {shmem_dir}") - if variant == "rocm": + ompi_dir = None + if variant == "rocm" and not disable_mpi: # Attempt to auto-detect OpenMPI installation directory if OMPI_DIR not set. # The first existing candidate containing bin/mpicc will be used. ompi_dir_env = os.getenv("OMPI_DIR", "").strip() @@ -107,16 +110,30 @@ def get_nvshmem_host_lib_name(base_dir): mpicc_path = os.path.join(d, "bin", "mpicc") if os.path.exists(d) and os.path.exists(mpicc_path): ompi_dir = d - break - if ompi_dir is None: - # Fallback to root (will trigger the assert below) - ompi_dir = "/" + break + + assert ompi_dir is not None, ( + f"Failed to find OpenMPI installation. " + f"Searched: {', '.join([d for d in candidate_dirs if d])}. " + f"Set OMPI_DIR environment variable or use --disable-mpi flag." + ) print(f"Detected OpenMPI directory: {ompi_dir}") - assert os.path.exists(ompi_dir), f"Failed to find OMPI: {ompi_dir}" + elif variant == "rocm" and disable_mpi: + print("MPI detection disabled for ROCm variant") + elif variant == "cuda" and not disable_mpi: + print("MPI detection enabled for CUDA variant") + else: + print("MPI detection disabled for CUDA variant") + # TODO: currently, we only support Hopper architecture, we may add Ampere support later if variant == "rocm": - os.environ["PYTORCH_ROCM_ARCH"] = os.getenv("PYTORCH_ROCM_ARCH", "gfx942") + arch = os.getenv("PYTORCH_ROCM_ARCH") + allowed_arch = {"gfx942", "gfx950"} + if arch not in allowed_arch: + raise EnvironmentError( + f"Invalid PYTORCH_ROCM_ARCH='{arch}'. " + f"Use one of: {', '.join(sorted(allowed_arch))}.") elif variant == "cuda": os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0" @@ -155,7 +172,7 @@ def get_nvshmem_host_lib_name(base_dir): nvcc_flags = [f"{optimization_flag}"] + debug_symbol_flags + define_macros include_dirs = ["csrc/", f"{shmem_dir}/include"] - if variant == "rocm": + if variant == "rocm" and ompi_dir is not None: include_dirs.append(f"{ompi_dir}/include") sources = [ @@ -168,7 +185,7 @@ def get_nvshmem_host_lib_name(base_dir): ] library_dirs = [f"{shmem_dir}/lib"] - if variant == "rocm": + if variant == "rocm" and ompi_dir is not None: library_dirs.append(f"{ompi_dir}/lib") # Disable aggressive PTX instructions @@ -197,10 +214,15 @@ def get_nvshmem_host_lib_name(base_dir): "-lamdhip64", "-lhsa-runtime64", "-libverbs", - f"-l:libmpi.so", - f"-Wl,-rpath,{ompi_dir}/lib", ] ) + if not disable_mpi: + extra_link_args.extend( + [ + f"-l:libmpi.so", + f"-Wl,-rpath,{ompi_dir}/lib", + ] + ) extra_compile_args = { "cxx": cxx_flags, @@ -218,6 +240,7 @@ def get_nvshmem_host_lib_name(base_dir): print(f' > Compilation flags: {extra_compile_args}') print(f' > Link flags: {extra_link_args}') print(f' > NVSHMEM path: {shmem_dir}') + print(f' > Disable MPI: {disable_mpi}') print() # noinspection PyBroadException diff --git a/tests/test_internode.py b/tests/test_internode.py index c03d344..3315b3d 100644 --- a/tests/test_internode.py +++ b/tests/test_internode.py @@ -298,7 +298,7 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): # noinspection PyUnboundLocalVariable,PyShadowingNames def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): num_nodes = int(os.getenv('WORLD_SIZE', 1)) - rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + rank, num_ranks, group = init_dist(local_rank, num_local_ranks, backend=args.backend) if args.test_ll_compatibility: ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 @@ -353,6 +353,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): if __name__ == '__main__': parser = argparse.ArgumentParser(description='Test internode EP kernels') + parser.add_argument('--backend', type=str, choices=['mpi', 'nccl'], default='nccl',help='Backend for distributed communication (mpi or nccl)') parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)') parser.add_argument('--num-tokens', type=int, default=4096, help='Number of tokens (default: 4096)') parser.add_argument('--hidden', type=int, default=7168, help='Hidden dimension size (default: 7168)') @@ -373,4 +374,11 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): args.num_topk_groups = min(num_nodes, 4) num_processes = args.num_processes - torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes) + if args.backend == 'mpi': + dist.init_process_group(backend='mpi') + rank = dist.get_rank() + local_rank = rank % num_processes + test_loop(local_rank=local_rank, num_local_ranks=num_processes, args=args) + else: + torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes) + diff --git a/tests/utils.py b/tests/utils.py index 1390b2b..6ac111e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -11,24 +11,30 @@ from typing import Optional, Union -def init_dist(local_rank: int, num_local_ranks: int): +def init_dist(local_rank: int, num_local_ranks: int, backend: str = 'nccl'): # NOTES: you may rewrite this function with your own cluster settings - ip = os.getenv('MASTER_ADDR', '127.0.0.1') - port = int(os.getenv('MASTER_PORT', '8361')) + if backend == 'nccl': + ip = os.getenv('MASTER_ADDR', '127.0.0.1') + port = int(os.getenv('MASTER_PORT', '8361')) + node_rank = int(os.getenv('RANK', 0)) + num_nodes = int(os.getenv('WORLD_SIZE', 1)) - node_rank = int(os.getenv('RANK', 0)) - - sig = inspect.signature(dist.init_process_group) - params = { - 'backend': 'nccl', - 'init_method': f'tcp://{ip}:{port}', - 'world_size': num_nodes * num_local_ranks, - 'rank': node_rank * num_local_ranks + local_rank, - } - if 'device_id' in sig.parameters: - # noinspection PyTypeChecker - params['device_id'] = torch.device(f'cuda:{local_rank}') - dist.init_process_group(**params) + + assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8 + + if backend == 'nccl': + sig = inspect.signature(dist.init_process_group) + params = { + 'backend': backend, + 'init_method': f'tcp://{ip}:{port}', + 'world_size': num_nodes * num_local_ranks, + 'rank': node_rank * num_local_ranks + local_rank, + } + if 'device_id' in sig.parameters: + # noinspection PyTypeChecker + params['device_id'] = torch.device(f'cuda:{local_rank}') + dist.init_process_group(**params) + torch.set_default_dtype(torch.bfloat16) torch.set_default_device('cuda') torch.cuda.set_device(local_rank) From 00e9f4e1f1af35bb8bfc3fc29d8b0aea887aab6e Mon Sep 17 00:00:00 2001 From: Richard Chamberlain Date: Fri, 7 Nov 2025 04:24:38 -0800 Subject: [PATCH 03/22] Restored readme code. --- third-party/README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/third-party/README.md b/third-party/README.md index 39ad467..505efc8 100644 --- a/third-party/README.md +++ b/third-party/README.md @@ -23,6 +23,12 @@ MPI_ROOT=$BUILD_DIR/ompi ../rocSHMEM/scripts/build_configs/gda_mlx5 --fresh \ -DUSE_IPC=ON \ -DGDA_BNXT=ON +# To build rocSHMEM with MPI disabled, please add this flag -DUSE_EXTERNAL_MPI=OFF +MPI_ROOT=$BUILD_DIR/ompi ../rocSHMEM/scripts/build_configs/gda_mlx5 --fresh \ + -DUSE_IPC=ON \ + -DGDA_BNXT=ON + -DUSE_EXTERNAL_MPI=OFF + # You may pass additional arguments to Cmake, # e.g., -DBUILD_LOCAL_GPU_TARGET_ONLY=ON ``` From 2645c3e513840210376314cadb21f486b9173ae3 Mon Sep 17 00:00:00 2001 From: Richard Chamberlain Date: Fri, 7 Nov 2025 08:58:07 -0800 Subject: [PATCH 04/22] Adding missing kernel file layout.cu --- csrc/kernels/layout.cu | 153 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 153 insertions(+) create mode 100644 csrc/kernels/layout.cu diff --git a/csrc/kernels/layout.cu b/csrc/kernels/layout.cu new file mode 100644 index 0000000..c3a16ae --- /dev/null +++ b/csrc/kernels/layout.cu @@ -0,0 +1,153 @@ +#include "configs.cuh" +#include "exception.cuh" +#include "launch.cuh" + +namespace deep_ep { + +namespace layout { + +template +__global__ void get_dispatch_layout(const topk_idx_t* topk_idx, + int* num_tokens_per_rank, + int* num_tokens_per_rdma_rank, + int* num_tokens_per_expert, + bool* is_token_in_rank, + int num_tokens, + int num_topk, + int num_ranks, + int num_experts) { + auto sm_id = static_cast(blockIdx.x); + auto thread_id = static_cast(threadIdx.x); + + // Count expert statistics + __shared__ int num_tokens_per_expert_per_thread[kNumThreads][kNumExpertsPerSM]; + int expert_begin_idx = sm_id * kNumExpertsPerSM, expert_end_idx = min(expert_begin_idx + kNumExpertsPerSM, num_experts); + if (expert_begin_idx < expert_end_idx) { + // Per-thread count + #pragma unroll + for (int i = 0; i < kNumExpertsPerSM; ++i) + num_tokens_per_expert_per_thread[thread_id][i] = 0; + #pragma unroll + for (int i = thread_id; i < num_tokens; i += kNumThreads) { + auto shifted_topk_idx = topk_idx + i * num_topk; + #pragma unroll + for (int j = 0, expert_idx; j < num_topk; ++j) { + expert_idx = static_cast(shifted_topk_idx[j]); + if (expert_begin_idx <= expert_idx and expert_idx < expert_end_idx) + ++num_tokens_per_expert_per_thread[thread_id][expert_idx - expert_begin_idx]; + } + } + __syncthreads(); + + // Sum up + EP_STATIC_ASSERT(kNumExpertsPerSM <= kNumThreads, "Too many experts per SM"); + if (expert_begin_idx + thread_id < expert_end_idx) { + int sum = 0; + #pragma unroll + for (int i = 0; i < kNumThreads; ++i) + sum += num_tokens_per_expert_per_thread[i][thread_id]; + num_tokens_per_expert[expert_begin_idx + thread_id] = sum; + } + return; + } + + if (num_tokens_per_rdma_rank != nullptr) + EP_DEVICE_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0 and num_ranks > NUM_MAX_NVL_PEERS); + + // Count rank statistics + constexpr int kNumRDMARanksPerSM = kNumRanksPerSM / NUM_MAX_NVL_PEERS; + __shared__ int num_tokens_per_rank_per_thread[kNumThreads][kNumRanksPerSM]; + __shared__ int num_tokens_per_rdma_rank_per_thread[kNumThreads][kNumRDMARanksPerSM]; + auto sm_begin = (num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM; + int rank_begin_idx = (sm_id - sm_begin) * kNumRanksPerSM, rank_end_idx = min(rank_begin_idx + kNumRanksPerSM, num_ranks); + int rdma_rank_begin_idx = rank_begin_idx / NUM_MAX_NVL_PEERS, rdma_rank_end_idx = rank_end_idx / NUM_MAX_NVL_PEERS; + if (rank_begin_idx < rank_end_idx) { + const auto num_expert_per_rank = num_experts / num_ranks; + auto expert_begin = rank_begin_idx * num_expert_per_rank; + auto expert_end = rank_end_idx * num_expert_per_rank; + + // Per-thread count + #pragma unroll + for (int i = 0; i < kNumRanksPerSM; ++i) + num_tokens_per_rank_per_thread[thread_id][i] = 0; + #pragma unroll + for (int i = 0; i < kNumRDMARanksPerSM; ++i) + num_tokens_per_rdma_rank_per_thread[thread_id][i] = 0; + #pragma unroll + for (int i = thread_id; i < num_tokens; i += kNumThreads) { + auto shifted_topk_idx = topk_idx + i * num_topk; + int is_in_rank[kNumRanksPerSM] = {0}, is_in_rdma_rank[kNumRDMARanksPerSM] = {0}; + #pragma unroll + for (int j = 0, expert_idx, rank_idx; j < num_topk; ++j) { + expert_idx = static_cast(shifted_topk_idx[j]); + if (expert_begin <= expert_idx and expert_idx < expert_end) { + // Count single rank + rank_idx = expert_idx / num_expert_per_rank - rank_begin_idx; + is_in_rank[rank_idx]++, is_in_rdma_rank[rank_idx / NUM_MAX_NVL_PEERS]++; + } + } + + auto shifted_is_token_in_rank = is_token_in_rank + i * num_ranks; + #pragma unroll + for (int j = 0; j + rank_begin_idx < rank_end_idx; ++j) { + shifted_is_token_in_rank[j + rank_begin_idx] = (is_in_rank[j] > 0); + num_tokens_per_rank_per_thread[thread_id][j] += (is_in_rank[j] > 0); + } + + #pragma unroll + for (int j = 0; j + rdma_rank_begin_idx < rdma_rank_end_idx; ++j) + num_tokens_per_rdma_rank_per_thread[thread_id][j] += (is_in_rdma_rank[j] > 0); + } + __syncthreads(); + + // Sum up + EP_STATIC_ASSERT(kNumRanksPerSM <= kNumThreads, "Too many ranks per SM"); + if (rank_begin_idx + thread_id < rank_end_idx) { + int sum = 0; + #pragma unroll + for (int i = 0; i < kNumThreads; ++i) + sum += num_tokens_per_rank_per_thread[i][thread_id]; + num_tokens_per_rank[rank_begin_idx + thread_id] = sum; + } + + if (num_tokens_per_rdma_rank != nullptr and rdma_rank_begin_idx + thread_id < rdma_rank_end_idx) { + int sum = 0; + #pragma unroll + for (int i = 0; i < kNumThreads; ++i) + sum += num_tokens_per_rdma_rank_per_thread[i][thread_id]; + num_tokens_per_rdma_rank[rdma_rank_begin_idx + thread_id] = sum; + } + } +} + +void get_dispatch_layout(const topk_idx_t* topk_idx, + int* num_tokens_per_rank, + int* num_tokens_per_rdma_rank, + int* num_tokens_per_expert, + bool* is_token_in_rank, + int num_tokens, + int num_topk, + int num_ranks, + int num_experts, + cudaStream_t stream) { + constexpr int kNumThreads = 256, kNumExpertsPerSM = 4, kNumRanksPerSM = 8; + int num_sms = ((num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM) + (num_ranks + kNumRanksPerSM - 1) / kNumRanksPerSM; + EP_STATIC_ASSERT(kNumRanksPerSM % NUM_MAX_NVL_PEERS == 0, "Invalid number of ranks per SM"); + + SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); + LAUNCH_KERNEL(&cfg, + (get_dispatch_layout), + topk_idx, + num_tokens_per_rank, + num_tokens_per_rdma_rank, + num_tokens_per_expert, + is_token_in_rank, + num_tokens, + num_topk, + num_ranks, + num_experts); +} + +} // namespace layout + +} // namespace deep_ep From cefd395660b444b4e4abd5ed631e77a29d59624c Mon Sep 17 00:00:00 2001 From: RichardChamberlain1 Date: Fri, 7 Nov 2025 18:34:29 +0000 Subject: [PATCH 05/22] Update internode_ll.cu --- csrc/kernels/internode_ll.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index cd731be..6234f67 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -204,7 +204,12 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, for (int j = 0; j < kNumElemsPerRead; j += 2) { float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale}; #ifdef USE_ROCM +#if defined(__gfx942__) fp8x2_values[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E4M3_FNUZ); +#endif +#if defined(__gfx950__) + fp8x2_values[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E4M3); +#endif #else fp8x2_values[j / 2] = __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3); #endif @@ -240,11 +245,6 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, internode::shmem_ctx_schar_put_nbi_warp(ctx, #endif reinterpret_cast(dst_ptr), reinterpret_cast(src_ptr), num_bytes_per_msg, dst_rank); -#if defined(ROCM_DISABLE_CTX) - internode::shmem_fence(); -#else - internode::shmem_ctx_quiet(ctx); -#endif #else nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx); #endif From 030be0b3568a6ec193d3a56d20f2157bbb82c29a Mon Sep 17 00:00:00 2001 From: RichardChamberlain1 Date: Mon, 10 Nov 2025 12:19:14 +0000 Subject: [PATCH 06/22] Fix gfx950 FP8 datatypes --- csrc/deep_ep.cpp | 6 +++++- csrc/deep_ep.hpp | 3 +++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index ed36325..db32686 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -57,6 +57,10 @@ Buffer::Buffer(int rank, // Get device info cudaDeviceProp device_prop = {}; CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_id)); +#ifdef USE_ROCM + sscanf(device_prop.gcnArchName, "gfx%d", &gfx); + EP_HOST_ASSERT(gfx >= 942); +#endif num_device_sms = device_prop.multiProcessorCount; // Number of per-channel bytes cannot be large @@ -1489,7 +1493,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, // Allocate packed tensors #ifdef USE_ROCM auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, - x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fnuz : torch::kBFloat16)); + x.options().dtype(use_fp8 ? (gfx == 942 ? torch::kFloat8_e4m3fnuz : torch::kFloat8_e4m3fn) : torch::kBFloat16)); #else auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn : torch::kBFloat16)); diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index 3b7652f..5ec4635 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -47,6 +47,9 @@ struct Buffer { // Device info and communication int device_id; +#ifdef USE_ROCM + int gfx; +#endif int num_device_sms; int rank, rdma_rank, nvl_rank; int num_ranks, num_rdma_ranks, num_nvl_ranks; From 751e87f1158159fa8f6784f1b606f7a1048373f0 Mon Sep 17 00:00:00 2001 From: RichardChamberlain1 Date: Mon, 10 Nov 2025 12:36:53 +0000 Subject: [PATCH 07/22] Address review comments --- csrc/deep_ep.cpp | 1 + csrc/deep_ep.hpp | 2 -- csrc/kernels/intranode.cu | 3 --- csrc/kernels/runtime.cu | 9 --------- setup.py | 4 ++++ 5 files changed, 5 insertions(+), 14 deletions(-) diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index db32686..41f465a 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -752,6 +752,7 @@ std::tuple, std::optional>({bias_0, bias_1}); void* bias_ptrs[2] = {nullptr, nullptr}; for (int i = 0; i < 2; ++i) diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index 5ec4635..f758d68 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -166,8 +166,6 @@ struct Buffer { std::tuple, std::optional> intranode_combine( const torch::Tensor& x, const std::optional& topk_weights, - //const std::optional& bias_0, - //const std::optional& bias_1, const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix, diff --git a/csrc/kernels/intranode.cu b/csrc/kernels/intranode.cu index 3ef7e35..7b36cc2 100644 --- a/csrc/kernels/intranode.cu +++ b/csrc/kernels/intranode.cu @@ -37,9 +37,6 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, // - `per_expert_buffer[rank][i, j]` means the number of tokens from rank i to local expert j int num_experts_per_rank = num_experts / kNumRanks; if (thread_id < kNumRanks) { - //#pragma unroll - //for (int i = 0; i < kNumRanks; ++ i) - // per_rank_buffer[rank * kNumRanks + i] = num_tokens_per_rank[i]; per_rank_buffer[rank * kNumRanks + thread_id] = num_tokens_per_rank[thread_id]; #pragma unroll for (int i = 0; i < num_experts_per_rank; ++ i) diff --git a/csrc/kernels/runtime.cu b/csrc/kernels/runtime.cu index 7fa6031..606e433 100644 --- a/csrc/kernels/runtime.cu +++ b/csrc/kernels/runtime.cu @@ -16,15 +16,6 @@ __global__ void barrier(int** task_fifo_ptrs, int head, int rank) { barrier_device(task_fifo_ptrs, head, rank); } -/*void barrier(int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream) { -#define BARRIER_LAUNCH_CASE(ranks) \ - LAUNCH_KERNEL(&cfg, barrier, task_fifo_ptrs, head, rank); \ - break - - SETUP_LAUNCH_CONFIG(1, kWarpSize, stream); - SWITCH_RANKS(BARRIER_LAUNCH_CASE); -#undef BARRIER_LAUNCH_CASE -}*/ void barrier(int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream, int head = 0) { #define BARRIER_LAUNCH_CASE(ranks) \ diff --git a/setup.py b/setup.py index 6645bd6..3e34858 100644 --- a/setup.py +++ b/setup.py @@ -197,6 +197,10 @@ def get_nvshmem_host_lib_name(base_dir): # Bits of `topk_idx.dtype`, choices are 32 and 64 if "TOPK_IDX_BITS" in os.environ: topk_idx_bits = int(os.environ['TOPK_IDX_BITS']) + assert topk_idx_bits in (32, 64), ( + f"Invalid TOPK_IDX_BITS={topk_idx_bits}. " + "Must be either 32 or 64." + ) cxx_flags.append(f'-DTOPK_IDX_BITS={topk_idx_bits}') nvcc_flags.append(f'-DTOPK_IDX_BITS={topk_idx_bits}') From 57fa1186c3d7d69151053af37097cefc540bf206 Mon Sep 17 00:00:00 2001 From: RichardChamberlain1 Date: Mon, 10 Nov 2025 12:45:57 +0000 Subject: [PATCH 08/22] Update utils.cuh Removed unused definition. --- csrc/kernels/utils.cuh | 4 ---- 1 file changed, 4 deletions(-) diff --git a/csrc/kernels/utils.cuh b/csrc/kernels/utils.cuh index 8a43e55..037eb69 100644 --- a/csrc/kernels/utils.cuh +++ b/csrc/kernels/utils.cuh @@ -2,10 +2,6 @@ #include "exception.cuh" -#ifdef USE_ROCM -#define syncthreads() __syncthreads() -#endif - #define UNROLLED_WARP_COPY(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \ { \ constexpr int kLoopStride = kWarpSize * (UNROLL_FACTOR); \ From 4c9c51e7447c504287bc62ef2a228429ac7d3f70 Mon Sep 17 00:00:00 2001 From: Richard Chamberlain Date: Thu, 13 Nov 2025 02:45:23 -0800 Subject: [PATCH 09/22] Removed broken buffer cleanup code --- csrc/deep_ep.cpp | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index 24f11ca..e151153 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -131,28 +131,8 @@ Buffer::~Buffer() noexcept(false) { printf("WARNING: destroy() was not called before DeepEP buffer destruction, which can leak resources.\n"); fflush(stdout); } - - // Free NVSHMEM - if (num_rdma_bytes > 0) { - CUDA_CHECK(cudaDeviceSynchronize()); - internode::barrier(); - internode::free(rdma_buffer_ptr); - internode::finalize(); - } - - // Free cuBLAS handle, workspace and MoE counter - CUDA_CHECK(cudaFree(workspace)); - CUDA_CHECK(cudaFree(dispatch_global_atomic_counter)); - CUDA_CHECK(cudaFree(combine_global_atomic_counter)); - CUDA_CHECK(cudaFreeHost(const_cast(moe_recv_counter))); - - // Free chunked mode staffs - CUDA_CHECK(cudaFreeHost(const_cast(moe_recv_expert_counter))); } -void Buffer::move_fifo_slots(int num_slots) { - head = (head + num_ranks * num_slots) % NUM_MAX_FIFO_SLOTS; -} bool Buffer::is_available() const { return available; From dc41ca08393c1bb72d3cf7f0314e6b0a0fb9e957 Mon Sep 17 00:00:00 2001 From: RichardChamberlain1 Date: Wed, 19 Nov 2025 18:28:59 +0000 Subject: [PATCH 10/22] Update shmem_wrapper.cuh --- csrc/kernels/shmem_wrapper.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/kernels/shmem_wrapper.cuh b/csrc/kernels/shmem_wrapper.cuh index 1e90717..275798e 100644 --- a/csrc/kernels/shmem_wrapper.cuh +++ b/csrc/kernels/shmem_wrapper.cuh @@ -65,7 +65,7 @@ static inline const auto &shmem_ibgda_amo_nonfetch_add = #if !defined(ROCM_DISABLE_CTX) using shmem_ctx_t = rocshmem::rocshmem_ctx_t; static inline const auto &shmem_wg_ctx_create = [] __device__(rocshmem::rocshmem_ctx_t *ctx) { - return rocshmem::rocshmem_wg_ctx_create(0, ctx); + return rocshmem::rocshmem_wg_ctx_create(ctx); }; static inline const auto &shmem_wg_ctx_destroy = rocshmem::rocshmem_wg_ctx_destroy; From 5865024795e201a3e25039e4710f67d037911da9 Mon Sep 17 00:00:00 2001 From: RichardChamberlain1 Date: Thu, 20 Nov 2025 16:36:26 +0000 Subject: [PATCH 11/22] Update shmem_wrapper.cuh --- csrc/kernels/shmem_wrapper.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/kernels/shmem_wrapper.cuh b/csrc/kernels/shmem_wrapper.cuh index 275798e..3a6d653 100644 --- a/csrc/kernels/shmem_wrapper.cuh +++ b/csrc/kernels/shmem_wrapper.cuh @@ -65,7 +65,7 @@ static inline const auto &shmem_ibgda_amo_nonfetch_add = #if !defined(ROCM_DISABLE_CTX) using shmem_ctx_t = rocshmem::rocshmem_ctx_t; static inline const auto &shmem_wg_ctx_create = [] __device__(rocshmem::rocshmem_ctx_t *ctx) { - return rocshmem::rocshmem_wg_ctx_create(ctx); + return rocshmem::rocshmem_wg_ctx_create(0,ctx); }; static inline const auto &shmem_wg_ctx_destroy = rocshmem::rocshmem_wg_ctx_destroy; From 3ec2aab580bf23c704fd47ef2f334a4b64330e41 Mon Sep 17 00:00:00 2001 From: Richard Chamberlain Date: Thu, 6 Nov 2025 09:09:38 -0800 Subject: [PATCH 12/22] Merging code from IFU branch. --- README.md | 6 +- csrc/config.hpp | 4 +- csrc/deep_ep.cpp | 1162 ++++++++++++++++++++++---------- csrc/deep_ep.hpp | 240 +++++-- csrc/kernels/api.cuh | 346 +++++++--- csrc/kernels/configs.cuh | 15 + csrc/kernels/exception.cuh | 2 +- csrc/kernels/internode.cu | 160 +++-- csrc/kernels/internode_ll.cu | 153 ++++- csrc/kernels/intranode.cu | 108 ++- csrc/kernels/launch.cuh | 12 + csrc/kernels/runtime.cu | 21 +- csrc/kernels/shmem_wrapper.cuh | 2 +- csrc/kernels/utils.cuh | 73 +- deep_ep/__init__.py | 2 +- deep_ep/buffer.py | 327 ++++++--- deep_ep/utils.py | 47 +- setup.py | 99 ++- tests/test_internode.py | 252 +++++-- tests/test_intranode.py | 184 +++-- tests/test_low_latency.py | 354 +++++++--- tests/utils.py | 106 ++- third-party/README.md | 6 - 23 files changed, 2688 insertions(+), 993 deletions(-) diff --git a/README.md b/README.md index 82bc7ad..4dd1e64 100644 --- a/README.md +++ b/README.md @@ -31,9 +31,7 @@ DeepEP (AMD version) depends on [rocSHMEM](https://github.com/ROCm/rocSHMEM). Pl git clone https://github.com/ROCm/DeepEP cd DeepEP - -# To use DeepEP with MPI, please proceed with these commands -# Export OMPI dir in the next command (e.g., it's $BUILD_DIR/ompi in third-party/README.md) +# export OMPI dir in the next command (e.g., it's $BUILD_DIR/ompi in third-party/README.md) export OMPI_DIR= python3 setup.py --variant rocm build develop --user @@ -46,8 +44,6 @@ python3 setup.py --variant rocm --disable-mpi build develop --user # according to your own cluster settings, and launch into multiple nodes python3 tests/test_intranode.py python3 tests/test_internode.py -# Set the required ROCSHMEM heap size (for example, for DeepSeek models) -export ROCSHMEM_HEAP_SIZE=2147483648 python3 tests/test_low_latency.py ``` diff --git a/csrc/config.hpp b/csrc/config.hpp index 83c60fe..9acf674 100644 --- a/csrc/config.hpp +++ b/csrc/config.hpp @@ -7,13 +7,13 @@ namespace deep_ep { template -dtype_t cell_div(dtype_t a, dtype_t b) { +dtype_t ceil_div(dtype_t a, dtype_t b) { return (a + b - 1) / b; } template dtype_t align(dtype_t a, dtype_t b) { - return cell_div(a, b) * b; + return ceil_div(a, b) * b; } struct Config { diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index 22eb144..57ffd33 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -11,27 +11,36 @@ #include "kernels/api.cuh" #include "kernels/configs.cuh" -int get_env_with_default_value(const std::string& env_path, const std::string& default_value) { - const char* value = std::getenv(env_path.c_str()); - std::string value_str = (value != nullptr) ? std::string(value) : default_value; - return std::stoi(value_str); -} - namespace deep_ep { -Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode): - rank(rank), num_ranks(num_ranks), - num_nvl_bytes(num_nvl_bytes), num_rdma_bytes(num_rdma_bytes), - low_latency_mode(low_latency_mode), - comm_stream(at::cuda::getStreamFromPool(true)) { - // Task fifo memory - int64_t fifo_bytes = sizeof(int) * NUM_MAX_FIFO_SLOTS; - int64_t buffer_ptr_bytes = sizeof(void*) * NUM_MAX_NVL_PEERS; - int64_t task_ptr_bytes = sizeof(int*) * NUM_MAX_NVL_PEERS; +Buffer::Buffer(int rank, + int num_ranks, + int64_t num_nvl_bytes, + int64_t num_rdma_bytes, + bool low_latency_mode, + bool explicitly_destroy, + bool enable_shrink) + : rank(rank), + num_ranks(num_ranks), + num_nvl_bytes(num_nvl_bytes), + num_rdma_bytes(num_rdma_bytes), + enable_shrink(enable_shrink), + low_latency_mode(low_latency_mode), + explicitly_destroy(explicitly_destroy), + comm_stream(at::cuda::getStreamFromPool(true)) { + // Metadata memory + int64_t barrier_signal_bytes = NUM_MAX_FIFO_SLOTS * sizeof(int); + int64_t buffer_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(void*); + int64_t barrier_signal_ptr_bytes = NUM_MAX_NVL_PEERS * sizeof(int*); // Common checks - EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and (num_nvl_bytes <= std::numeric_limits::max() or num_rdma_bytes == 0)); - EP_HOST_ASSERT(num_rdma_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and (low_latency_mode or num_rdma_bytes <= std::numeric_limits::max())); + EP_STATIC_ASSERT(NUM_BUFFER_ALIGNMENT_BYTES % sizeof(int4) == 0, "Invalid alignment"); + EP_HOST_ASSERT(num_nvl_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and + (num_nvl_bytes <= std::numeric_limits::max() or num_rdma_bytes == 0)); + EP_HOST_ASSERT(num_rdma_bytes % NUM_BUFFER_ALIGNMENT_BYTES == 0 and + (low_latency_mode or num_rdma_bytes <= std::numeric_limits::max())); + EP_HOST_ASSERT(num_nvl_bytes / sizeof(int4) < std::numeric_limits::max()); + EP_HOST_ASSERT(num_rdma_bytes / sizeof(int4) < std::numeric_limits::max()); EP_HOST_ASSERT(0 <= rank and rank < num_ranks and (num_ranks <= NUM_MAX_NVL_PEERS * NUM_MAX_RDMA_PEERS or low_latency_mode)); EP_HOST_ASSERT(num_ranks < NUM_MAX_NVL_PEERS or num_ranks % NUM_MAX_NVL_PEERS == 0); if (num_rdma_bytes > 0) @@ -41,32 +50,36 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_ CUDA_CHECK(cudaGetDevice(&device_id)); rdma_rank = rank / NUM_MAX_NVL_PEERS, nvl_rank = rank % NUM_MAX_NVL_PEERS; num_rdma_ranks = std::max(1, num_ranks / NUM_MAX_NVL_PEERS), num_nvl_ranks = std::min(num_ranks, NUM_MAX_NVL_PEERS); +#ifdef DISABLE_NVSHMEM + EP_HOST_ASSERT(num_rdma_ranks == 1 and not low_latency_mode and "NVSHMEM is disabled during compilation"); +#endif // Get device info cudaDeviceProp device_prop = {}; CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_id)); -#ifdef USE_ROCM - sscanf(device_prop.gcnArchName, "gfx%d", &gfx); - EP_HOST_ASSERT(gfx >= 942); -#endif + num_device_sms = device_prop.multiProcessorCount; + + // Number of per-channel bytes cannot be large + EP_HOST_ASSERT(ceil_div(num_nvl_bytes, num_device_sms / 2) < std::numeric_limits::max()); + EP_HOST_ASSERT(ceil_div(num_rdma_bytes, num_device_sms / 2) < std::numeric_limits::max()); if (num_nvl_bytes > 0) { - // Local IPC: alloc local memory and set local IPC handle + // Local IPC: alloc local memory and set local IPC handles #ifdef USE_ROCM - CUDA_CHECK(hipExtMallocWithFlags(&buffer_ptrs[nvl_rank], num_nvl_bytes + fifo_bytes + buffer_ptr_bytes + task_ptr_bytes, hipDeviceMallocUncached)); + CUDA_CHECK(hipExtMallocWithFlags(&buffer_ptrs[nvl_rank], num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes, hipDeviceMallocUncached)); #else - CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + fifo_bytes + buffer_ptr_bytes + task_ptr_bytes)); + CUDA_CHECK(cudaMalloc(&buffer_ptrs[nvl_rank], num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes + barrier_signal_ptr_bytes)); #endif CUDA_CHECK(cudaIpcGetMemHandle(&ipc_handles[nvl_rank], buffer_ptrs[nvl_rank])); - buffer_ptrs_gpu = reinterpret_cast(reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + fifo_bytes); + buffer_ptrs_gpu = reinterpret_cast(reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes); - // Set task fifo - EP_HOST_ASSERT(NUM_MAX_FIFO_SLOTS % num_nvl_ranks == 0); - task_fifo_ptrs[nvl_rank] = reinterpret_cast(reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes); - task_fifo_ptrs_gpu = reinterpret_cast(reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + fifo_bytes + buffer_ptr_bytes); + // Set barrier signals + barrier_signal_ptrs[nvl_rank] = reinterpret_cast(reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes); + barrier_signal_ptrs_gpu = + reinterpret_cast(reinterpret_cast(buffer_ptrs[nvl_rank]) + num_nvl_bytes + barrier_signal_bytes + buffer_ptr_bytes); // No need to synchronize, will do a full device sync during `sync` - CUDA_CHECK(cudaMemsetAsync(task_fifo_ptrs[nvl_rank], 0, fifo_bytes, comm_stream)); + CUDA_CHECK(cudaMemsetAsync(barrier_signal_ptrs[nvl_rank], 0, barrier_signal_bytes, comm_stream)); } // Create 32 MiB workspace @@ -96,7 +109,7 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_ // MoE expert-level counter CUDA_CHECK(cudaMallocHost(&moe_recv_expert_counter, sizeof(int) * NUM_MAX_LOCAL_EXPERTS, cudaHostAllocMapped)); CUDA_CHECK(cudaHostGetDevicePointer(reinterpret_cast(&moe_recv_expert_counter_mapped), const_cast(moe_recv_expert_counter), 0)); - for (int i = 0; i < NUM_MAX_LOCAL_EXPERTS; ++ i) + for (int i = 0; i < NUM_MAX_LOCAL_EXPERTS; ++i) moe_recv_expert_counter[i] = -1; // MoE RDMA-level counter @@ -108,23 +121,11 @@ Buffer::Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_ } Buffer::~Buffer() noexcept(false) { - // Synchronize - CUDA_CHECK(cudaDeviceSynchronize()); - - if (num_nvl_bytes > 0) { - // Barrier - intranode::barrier(task_fifo_ptrs_gpu, head, nvl_rank, num_nvl_ranks, comm_stream); - move_fifo_slots(); - CUDA_CHECK(cudaDeviceSynchronize()); - - // Close remote IPC - if (is_available()) { - for (int i = 0; i < num_nvl_ranks; ++ i) if (i != nvl_rank) - CUDA_CHECK(cudaIpcCloseMemHandle(buffer_ptrs[i])); - } - - // Free local buffer and error flag - CUDA_CHECK(cudaFree(buffer_ptrs[nvl_rank])); + if (not explicitly_destroy) { + destroy(); + } else if (not destroyed) { + printf("WARNING: destroy() was not called before DeepEP buffer destruction, which can leak resources.\n"); + fflush(stdout); } // Free NVSHMEM @@ -178,21 +179,81 @@ pybind11::bytearray Buffer::get_local_ipc_handle() const { } pybind11::bytearray Buffer::get_local_nvshmem_unique_id() const { +#ifndef DISABLE_NVSHMEM EP_HOST_ASSERT(rdma_rank == 0 and "Only RDMA rank 0 can get NVSHMEM unique ID"); auto unique_id = internode::get_unique_id(); return {reinterpret_cast(unique_id.data()), unique_id.size()}; +#else + EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); +#endif } torch::Tensor Buffer::get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const { torch::ScalarType casted_dtype = torch::python::detail::py_object_to_dtype(dtype); auto element_bytes = static_cast(elementSize(casted_dtype)); - auto base_ptr = reinterpret_cast(use_rdma_buffer ? rdma_buffer_ptr : buffer_ptrs[nvl_rank]) + offset; + auto base_ptr = static_cast(use_rdma_buffer ? rdma_buffer_ptr : buffer_ptrs[nvl_rank]) + offset; auto num_bytes = use_rdma_buffer ? num_rdma_bytes : num_nvl_bytes; return torch::from_blob(base_ptr, num_bytes / element_bytes, torch::TensorOptions().dtype(casted_dtype).device(at::kCUDA)); } -void Buffer::sync(const std::vector &device_ids, - const std::vector> &all_gathered_handles, +torch::Stream Buffer::get_comm_stream() const { + return comm_stream; +} + +void Buffer::move_fifo_slots(int num_slots) { + head = (head + num_ranks * num_slots) % NUM_MAX_FIFO_SLOTS; +} + +void Buffer::destroy() { + EP_HOST_ASSERT(not destroyed); + + // Synchronize + CUDA_CHECK(cudaDeviceSynchronize()); + + if (num_nvl_bytes > 0) { + // Barrier + intranode::barrier(barrier_signal_ptrs_gpu, nvl_rank, num_nvl_ranks, comm_stream, head); + move_fifo_slots(); + CUDA_CHECK(cudaDeviceSynchronize()); + + // Close remote IPC + if (is_available()) { + for (int i = 0; i < num_nvl_ranks; ++i) + if (i != nvl_rank) + CUDA_CHECK(cudaIpcCloseMemHandle(buffer_ptrs[i])); + } + + // Free local buffer and error flag + CUDA_CHECK(cudaFree(buffer_ptrs[nvl_rank])); + } + + // Free NVSHMEM +#ifndef DISABLE_NVSHMEM + if (is_available() and num_rdma_bytes > 0) { + CUDA_CHECK(cudaDeviceSynchronize()); + internode::barrier(); + internode::free(rdma_buffer_ptr); + if (enable_shrink) { + internode::free(mask_buffer_ptr); + internode::free(sync_buffer_ptr); + } + internode::finalize(); + } +#endif + + // Free workspace and MoE counter + CUDA_CHECK(cudaFree(workspace)); + CUDA_CHECK(cudaFreeHost(const_cast(moe_recv_counter))); + + // Free chunked mode staffs + CUDA_CHECK(cudaFreeHost(const_cast(moe_recv_expert_counter))); + + destroyed = true; + available = false; +} + +void Buffer::sync(const std::vector& device_ids, + const std::vector>& all_gathered_handles, const std::optional& root_unique_id_opt) { EP_HOST_ASSERT(not is_available()); @@ -200,26 +261,27 @@ void Buffer::sync(const std::vector &device_ids, if (num_nvl_bytes > 0) { EP_HOST_ASSERT(num_ranks == device_ids.size()); EP_HOST_ASSERT(device_ids.size() == all_gathered_handles.size()); - for (int i = 0, offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks; ++ i) { + for (int i = 0, offset = rdma_rank * num_nvl_ranks; i < num_nvl_ranks; ++i) { EP_HOST_ASSERT(all_gathered_handles[offset + i].has_value()); auto handle_str = std::string(all_gathered_handles[offset + i].value()); EP_HOST_ASSERT(handle_str.size() == CUDA_IPC_HANDLE_SIZE); if (offset + i != rank) { std::memcpy(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE); CUDA_CHECK(cudaIpcOpenMemHandle(&buffer_ptrs[i], ipc_handles[i], cudaIpcMemLazyEnablePeerAccess)); - task_fifo_ptrs[i] = reinterpret_cast(reinterpret_cast(buffer_ptrs[i]) + num_nvl_bytes); + barrier_signal_ptrs[i] = reinterpret_cast(static_cast(buffer_ptrs[i]) + num_nvl_bytes); } else { EP_HOST_ASSERT(std::memcmp(ipc_handles[i].reserved, handle_str.c_str(), CUDA_IPC_HANDLE_SIZE) == 0); } } - // Copy all buffer and task pointers to GPU + // Copy all buffer and barrier signal pointers to GPU CUDA_CHECK(cudaMemcpy(buffer_ptrs_gpu, buffer_ptrs, sizeof(void*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice)); - CUDA_CHECK(cudaMemcpy(task_fifo_ptrs_gpu, task_fifo_ptrs, sizeof(int*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(barrier_signal_ptrs_gpu, barrier_signal_ptrs, sizeof(int*) * NUM_MAX_NVL_PEERS, cudaMemcpyHostToDevice)); CUDA_CHECK(cudaDeviceSynchronize()); } // Sync NVSHMEM handles and allocate memory +#ifndef DISABLE_NVSHMEM if (num_rdma_bytes > 0) { // Initialize NVSHMEM EP_HOST_ASSERT(root_unique_id_opt.has_value()); @@ -230,20 +292,36 @@ void Buffer::sync(const std::vector &device_ids, auto num_nvshmem_ranks = low_latency_mode ? num_ranks : num_rdma_ranks; EP_HOST_ASSERT(nvshmem_rank == internode::init(root_unique_id, nvshmem_rank, num_nvshmem_ranks, low_latency_mode)); internode::barrier(); + // Allocate rdma_buffer_ptr = internode::alloc(num_rdma_bytes, NUM_BUFFER_ALIGNMENT_BYTES); + // Clean buffer (mainly for low-latency mode) CUDA_CHECK(cudaMemset(rdma_buffer_ptr, 0, num_rdma_bytes)); + + // Allocate and clean shrink buffer + if (enable_shrink) { + int num_mask_buffer_bytes = num_ranks * sizeof(int); + int num_sync_buffer_bytes = num_ranks * sizeof(int); + mask_buffer_ptr = reinterpret_cast(internode::alloc(num_mask_buffer_bytes, NUM_BUFFER_ALIGNMENT_BYTES)); + sync_buffer_ptr = reinterpret_cast(internode::alloc(num_sync_buffer_bytes, NUM_BUFFER_ALIGNMENT_BYTES)); + CUDA_CHECK(cudaMemset(mask_buffer_ptr, 0, num_mask_buffer_bytes)); + CUDA_CHECK(cudaMemset(sync_buffer_ptr, 0, num_sync_buffer_bytes)); + } + // Barrier internode::barrier(); CUDA_CHECK(cudaDeviceSynchronize()); } +#endif + + // Ready to use available = true; } std::tuple, torch::Tensor, torch::Tensor, std::optional> -Buffer::get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, - std::optional& previous_event, bool async, bool allocate_on_comm_stream) { +Buffer::get_dispatch_layout( + const torch::Tensor& topk_idx, int num_experts, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { EP_HOST_ASSERT(topk_idx.dim() == 2); EP_HOST_ASSERT(topk_idx.is_contiguous()); EP_HOST_ASSERT(num_experts > 0); @@ -271,24 +349,27 @@ Buffer::get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, if (is_internode_available()) num_tokens_per_rdma_rank = torch::empty({num_rdma_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); - internode::get_dispatch_layout(topk_idx.data_ptr(), - num_tokens_per_rank.data_ptr(), - num_tokens_per_rdma_rank.has_value() ? num_tokens_per_rdma_rank.value().data_ptr() : nullptr, - num_tokens_per_expert.data_ptr(), - is_token_in_rank.data_ptr(), - num_tokens, num_topk, num_ranks, num_experts, - comm_stream); + layout::get_dispatch_layout(topk_idx.data_ptr(), + num_tokens_per_rank.data_ptr(), + num_tokens_per_rdma_rank.has_value() ? num_tokens_per_rdma_rank.value().data_ptr() : nullptr, + num_tokens_per_expert.data_ptr(), + is_token_in_rank.data_ptr(), + num_tokens, + num_topk, + num_ranks, + num_experts, + comm_stream); // Wait streams std::optional event; if (async) { event = EventHandle(comm_stream); - for (auto& t: {topk_idx, num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank}) { + for (auto& t : {topk_idx, num_tokens_per_rank, num_tokens_per_expert, is_token_in_rank}) { t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } - for (auto& to: {num_tokens_per_rdma_rank}) { + for (auto& to : {num_tokens_per_rdma_rank}) { to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); @@ -304,12 +385,33 @@ Buffer::get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, return {num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, event}; } -std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional> -Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional& x_scales, - const std::optional& topk_idx, const std::optional& topk_weights, - const std::optional& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, - int cached_num_recv_tokens, const std::optional& cached_rank_prefix_matrix, const std::optional& cached_channel_prefix_matrix, - int expert_alignment, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { +std::tuple, + std::optional, + std::optional, + std::vector, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + std::optional> +Buffer::intranode_dispatch(const torch::Tensor& x, + const std::optional& x_scales, + const std::optional& topk_idx, + const std::optional& topk_weights, + const std::optional& num_tokens_per_rank, + const torch::Tensor& is_token_in_rank, + const std::optional& num_tokens_per_expert, + int cached_num_recv_tokens, + const std::optional& cached_rank_prefix_matrix, + const std::optional& cached_channel_prefix_matrix, + int expert_alignment, + int num_worst_tokens, + const Config& config, + std::optional& previous_event, + bool async, + bool allocate_on_comm_stream) { bool cached_mode = cached_rank_prefix_matrix.has_value(); // One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving. @@ -356,7 +458,7 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optionalsize(0) and num_tokens == topk_weights->size(0)); EP_HOST_ASSERT(num_topk == topk_weights->size(1)); EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); - topk_idx_ptr = topk_idx->data_ptr(); + topk_idx_ptr = topk_idx->data_ptr(); topk_weights_ptr = topk_weights->data_ptr(); } // FP8 scales checks float* x_scales_ptr = nullptr; - int num_scales = 0; + int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0; if (x_scales.has_value()) { EP_HOST_ASSERT(x.element_size() == 1); - EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32); - EP_HOST_ASSERT(x_scales->dim() > 0 and x_scales->dim() < 3 and x_scales->is_contiguous()); + EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32 or x_scales->scalar_type() == torch::kInt); + EP_HOST_ASSERT(x_scales->dim() == 2); EP_HOST_ASSERT(x_scales->size(0) == num_tokens); num_scales = x_scales->dim() == 1 ? 1 : static_cast(x_scales->size(1)); - x_scales_ptr = x_scales->data_ptr(); + x_scales_ptr = static_cast(x_scales->data_ptr()); + scale_token_stride = static_cast(x_scales->stride(0)); + scale_hidden_stride = static_cast(x_scales->stride(1)); } // Allocate all tensors on comm stream if set @@ -406,16 +510,15 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional(), num_memset_int, - buffer_ptrs_gpu, task_fifo_ptrs_gpu, head, rank, num_ranks, - comm_stream); + intranode::cached_notify_dispatch( + rank_prefix_matrix.data_ptr(), num_memset_int, buffer_ptrs_gpu, barrier_signal_ptrs_gpu, rank, num_ranks, comm_stream, head); move_fifo_slots(2); } else { rank_prefix_matrix = torch::empty({num_ranks, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); @@ -427,92 +530,153 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optionaldata_ptr(), moe_recv_counter_mapped, num_ranks, - num_tokens_per_expert->data_ptr(), moe_recv_expert_counter_mapped, num_experts, - num_tokens, is_token_in_rank.data_ptr(), channel_prefix_matrix.data_ptr(), + intranode::notify_dispatch(num_tokens_per_rank->data_ptr(), + moe_recv_counter_mapped, + num_ranks, + num_tokens_per_expert->data_ptr(), + moe_recv_expert_counter_mapped, + num_experts, + num_tokens, + is_token_in_rank.data_ptr(), + channel_prefix_matrix.data_ptr(), rank_prefix_matrix.data_ptr(), - num_memset_int, expert_alignment, - buffer_ptrs_gpu, task_fifo_ptrs_gpu, head, rank, - comm_stream, num_channels); + num_memset_int, + expert_alignment, + buffer_ptrs_gpu, + barrier_signal_ptrs_gpu, + rank, + comm_stream, + num_channels, + head); move_fifo_slots(3); - // Synchronize total received tokens and tokens per expert - auto start_time = std::chrono::high_resolution_clock::now(); - while (true) { - // Read total count - num_recv_tokens = static_cast(*moe_recv_counter); - - // Read per-expert count - bool ready = (num_recv_tokens >= 0); - for (int i = 0; i < num_local_experts and ready; ++i) - ready &= moe_recv_expert_counter[i] >= 0; - - if (ready) - break; - - // Timeout check - if (std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time).count() > NUM_CPU_TIMEOUT_SECS) - throw std::runtime_error("DeepEP error: CPU recv timeout"); + if (num_worst_tokens > 0) { + // No CPU sync, just allocate the worst case + num_recv_tokens = num_worst_tokens; + + // Must be forward with top-k stuffs + EP_HOST_ASSERT(topk_idx.has_value()); + EP_HOST_ASSERT(topk_weights.has_value()); + } else { + // Synchronize total received tokens and tokens per expert + auto start_time = std::chrono::high_resolution_clock::now(); + while (true) { + // Read total count + num_recv_tokens = static_cast(*moe_recv_counter); + + // Read per-expert count + bool ready = (num_recv_tokens >= 0); + for (int i = 0; i < num_local_experts and ready; ++i) + { + ready &= moe_recv_expert_counter[i] >= 0; + } + + if (ready) + break; + std::this_thread::sleep_for(std::chrono::milliseconds(5000)); + // Timeout check + if (std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time).count() > + NUM_CPU_TIMEOUT_SECS) + throw std::runtime_error("DeepEP error: CPU recv timeout"); + } + num_recv_tokens_per_expert_list = std::vector(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts); } - num_recv_tokens_per_expert_list = std::vector(moe_recv_expert_counter, moe_recv_expert_counter + num_local_experts); } // Allocate new tensors auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options()); auto recv_src_idx = torch::empty({num_recv_tokens}, dtype(torch::kInt32).device(torch::kCUDA)); - auto recv_topk_idx = std::optional(), recv_topk_weights = std::optional(), recv_x_scales = std::optional(); + auto recv_topk_idx = std::optional(), recv_topk_weights = std::optional(), + recv_x_scales = std::optional(); auto recv_channel_prefix_matrix = torch::empty({num_ranks, num_channels}, dtype(torch::kInt32).device(torch::kCUDA)); auto send_head = torch::empty({num_tokens, num_ranks}, dtype(torch::kInt32).device(torch::kCUDA)); // Assign pointers - int64_t* recv_topk_idx_ptr = nullptr; + topk_idx_t* recv_topk_idx_ptr = nullptr; float* recv_topk_weights_ptr = nullptr; float* recv_x_scales_ptr = nullptr; if (topk_idx.has_value()) { recv_topk_idx = torch::empty({num_recv_tokens, num_topk}, topk_idx->options()); recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options()); - recv_topk_idx_ptr = recv_topk_idx->data_ptr(); + recv_topk_idx_ptr = recv_topk_idx->data_ptr(); recv_topk_weights_ptr = recv_topk_weights->data_ptr(); } if (x_scales.has_value()) { - recv_x_scales = x_scales->dim() == 1 ? - torch::empty({num_recv_tokens}, x_scales->options()) : - torch::empty({num_recv_tokens, num_scales}, x_scales->options()); - recv_x_scales_ptr = recv_x_scales->data_ptr(); + recv_x_scales = x_scales->dim() == 1 ? torch::empty({num_recv_tokens}, x_scales->options()) + : torch::empty({num_recv_tokens, num_scales}, x_scales->options()); + recv_x_scales_ptr = static_cast(recv_x_scales->data_ptr()); } // Dispatch - EP_HOST_ASSERT(num_ranks * num_ranks * sizeof(int) + // Size prefix matrix - num_channels * num_ranks * sizeof(int) + // Channel start offset - num_channels * num_ranks * sizeof(int) + // Channel end offset - num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * recv_x.element_size() + // Data buffer - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + // Source index buffer - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(int64_t) + // Top-k index buffer - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) + // Top-k weight buffer - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(float) * num_scales // FP8 scale buffer - <= num_nvl_bytes); - intranode::dispatch(recv_x.data_ptr(), recv_x_scales_ptr, recv_src_idx.data_ptr(), recv_topk_idx_ptr, recv_topk_weights_ptr, recv_channel_prefix_matrix.data_ptr(), + EP_HOST_ASSERT( + num_ranks * num_ranks * sizeof(int) + // Size prefix matrix + num_channels * num_ranks * sizeof(int) + // Channel start offset + num_channels * num_ranks * sizeof(int) + // Channel end offset + num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * recv_x.element_size() + // Data buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + // Source index buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(topk_idx_t) + // Top-k index buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) + // Top-k weight buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(float) * num_scales // FP8 scale buffer + <= num_nvl_bytes); + intranode::dispatch(recv_x.data_ptr(), + recv_x_scales_ptr, + recv_src_idx.data_ptr(), + recv_topk_idx_ptr, + recv_topk_weights_ptr, + recv_channel_prefix_matrix.data_ptr(), send_head.data_ptr(), - x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr, - is_token_in_rank.data_ptr(), channel_prefix_matrix.data_ptr(), - num_tokens, static_cast(hidden * recv_x.element_size() / sizeof(int4)), num_topk, num_experts, num_scales, - buffer_ptrs_gpu, rank, num_ranks, comm_stream, config.num_sms, - config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens); + x.data_ptr(), + x_scales_ptr, + topk_idx_ptr, + topk_weights_ptr, + is_token_in_rank.data_ptr(), + channel_prefix_matrix.data_ptr(), + num_tokens, + num_worst_tokens, + static_cast(hidden * recv_x.element_size() / sizeof(int4)), + num_topk, + num_experts, + num_scales, + scale_token_stride, + scale_hidden_stride, + buffer_ptrs_gpu, + rank, + num_ranks, + comm_stream, + config.num_sms, + config.num_max_nvl_chunked_send_tokens, + config.num_max_nvl_chunked_recv_tokens); // Wait streams std::optional event; if (async) { event = EventHandle(comm_stream); - for (auto& t: {x, is_token_in_rank, rank_prefix_matrix, channel_prefix_matrix, recv_x, recv_src_idx, recv_channel_prefix_matrix, send_head}) { + for (auto& t : {x, + is_token_in_rank, + rank_prefix_matrix, + channel_prefix_matrix, + recv_x, + recv_src_idx, + recv_channel_prefix_matrix, + send_head}) { t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } - for (auto& to: {x_scales, topk_idx, topk_weights, num_tokens_per_rank, num_tokens_per_expert, cached_channel_prefix_matrix, cached_rank_prefix_matrix, recv_topk_idx, recv_topk_weights, recv_x_scales}) { + for (auto& to : {x_scales, + topk_idx, + topk_weights, + num_tokens_per_rank, + num_tokens_per_expert, + cached_channel_prefix_matrix, + cached_rank_prefix_matrix, + recv_topk_idx, + recv_topk_weights, + recv_x_scales}) { to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); @@ -526,18 +690,37 @@ Buffer::intranode_dispatch(const torch::Tensor& x, const std::optional, std::optional> -Buffer::intranode_combine(const torch::Tensor& x, const std::optional& topk_weights, - const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix, - const torch::Tensor& send_head, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { +std::tuple, std::optional> Buffer::intranode_combine( + const torch::Tensor& x, + const std::optional& topk_weights, + const torch::Tensor& src_idx, + const torch::Tensor& rank_prefix_matrix, + const torch::Tensor& channel_prefix_matrix, + const torch::Tensor& send_head, + const Config& config, + std::optional& previous_event, + bool async, + bool allocate_on_comm_stream) { EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); EP_HOST_ASSERT(src_idx.dim() == 1 and src_idx.is_contiguous() and src_idx.scalar_type() == torch::kInt32); EP_HOST_ASSERT(send_head.dim() == 2 and send_head.is_contiguous() and send_head.scalar_type() == torch::kInt32); - EP_HOST_ASSERT(rank_prefix_matrix.dim() == 2 and rank_prefix_matrix.is_contiguous() and rank_prefix_matrix.scalar_type() == torch::kInt32); - EP_HOST_ASSERT(channel_prefix_matrix.dim() == 2 and channel_prefix_matrix.is_contiguous() and channel_prefix_matrix.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(rank_prefix_matrix.dim() == 2 and rank_prefix_matrix.is_contiguous() and + rank_prefix_matrix.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(channel_prefix_matrix.dim() == 2 and channel_prefix_matrix.is_contiguous() and + channel_prefix_matrix.scalar_type() == torch::kInt32); // One channel use two blocks, even-numbered blocks for sending, odd-numbered blocks for receiving. EP_HOST_ASSERT(config.num_sms % 2 == 0); @@ -582,40 +765,76 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional(), - num_channels, num_recv_tokens, num_channels * num_ranks * 2, - task_fifo_ptrs_gpu, head, rank, num_ranks, - comm_stream); - + intranode::cached_notify_combine(buffer_ptrs_gpu, + send_head.data_ptr(), + num_channels, + num_recv_tokens, + num_channels * num_ranks * 2, + barrier_signal_ptrs_gpu, + rank, + num_ranks, + comm_stream, + head); // NOTES: this function uses two FIFO slots (barrier before and after) - move_fifo_slots(2); + move_fifo_slots(2); + + // Assign bias pointers + /*auto bias_opts = std::vector>({bias_0, bias_1}); + void* bias_ptrs[2] = {nullptr, nullptr}; + for (int i = 0; i < 2; ++i) + if (bias_opts[i].has_value()) { + auto bias = bias_opts[i].value(); + EP_HOST_ASSERT(bias.dim() == 2 and bias.is_contiguous()); + EP_HOST_ASSERT(bias.scalar_type() == x.scalar_type()); + EP_HOST_ASSERT(bias.size(0) == num_recv_tokens and bias.size(1) == hidden); + bias_ptrs[i] = bias.data_ptr(); + } + */ // Combine data auto recv_x = torch::empty({num_recv_tokens, hidden}, x.options()); - EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * x.element_size() + // Data buffer - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + // Source index buffer - num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) // Top-k weight buffer + EP_HOST_ASSERT(num_channels * num_ranks * sizeof(int) * 2 + // Queue head and tail + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * hidden * x.element_size() + // Data buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * sizeof(int) + // Source index buffer + num_channels * num_ranks * config.num_max_nvl_chunked_recv_tokens * num_topk * sizeof(float) // Top-k weight buffer <= num_nvl_bytes); + intranode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()), - recv_x.data_ptr(), recv_topk_weights_ptr, - x.data_ptr(), topk_weights_ptr, - src_idx.data_ptr(), rank_prefix_matrix.data_ptr(), channel_prefix_matrix.data_ptr(), - send_head.data_ptr(), num_tokens, num_recv_tokens, hidden, num_topk, - buffer_ptrs_gpu, rank, num_ranks, - comm_stream, config.num_sms, - config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens); + recv_x.data_ptr(), + recv_topk_weights_ptr, + x.data_ptr(), + topk_weights_ptr, + src_idx.data_ptr(), + rank_prefix_matrix.data_ptr(), + channel_prefix_matrix.data_ptr(), + send_head.data_ptr(), + num_tokens, + num_recv_tokens, + hidden, + num_topk, + buffer_ptrs_gpu, + rank, + num_ranks, + comm_stream, + config.num_sms, + config.num_max_nvl_chunked_send_tokens, + config.num_max_nvl_chunked_recv_tokens); // Wait streams std::optional event; if (async) { event = EventHandle(comm_stream); - for (auto& t: {x, src_idx, send_head, rank_prefix_matrix, channel_prefix_matrix, recv_x}) { + for (auto& t : {x, src_idx, send_head, rank_prefix_matrix, channel_prefix_matrix, recv_x}) { t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } - for (auto& to: {topk_weights, recv_topk_weights}) { + /*for (auto& to : {topk_weights, recv_topk_weights, bias_0, bias_1}) { + to.has_value() ? to->record_stream(comm_stream) : void(); + if (allocate_on_comm_stream) + to.has_value() ? to->record_stream(compute_stream) : void(); + }*/ + for (auto& to : {topk_weights, recv_topk_weights}) { to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); @@ -631,16 +850,46 @@ Buffer::intranode_combine(const torch::Tensor& x, const std::optional, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, std::optional, torch::Tensor, std::optional, torch::Tensor, std::optional, std::optional, std::optional, std::optional> -Buffer::internode_dispatch(const torch::Tensor& x, const std::optional& x_scales, - const std::optional& topk_idx, const std::optional& topk_weights, - const std::optional& num_tokens_per_rank, const std::optional& num_tokens_per_rdma_rank, - const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, - int cached_num_recv_tokens, int cached_num_rdma_recv_tokens, - const std::optional& cached_rdma_channel_prefix_matrix, const std::optional& cached_recv_rdma_rank_prefix_sum, - const std::optional& cached_gbl_channel_prefix_matrix, const std::optional& cached_recv_gbl_rank_prefix_sum, - int expert_alignment, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { +std::tuple, + std::optional, + std::optional, + std::vector, + torch::Tensor, + torch::Tensor, + std::optional, + torch::Tensor, + std::optional, + torch::Tensor, + std::optional, + std::optional, + std::optional, + std::optional> +Buffer::internode_dispatch(const torch::Tensor& x, + const std::optional& x_scales, + const std::optional& topk_idx, + const std::optional& topk_weights, + const std::optional& num_tokens_per_rank, + const std::optional& num_tokens_per_rdma_rank, + const torch::Tensor& is_token_in_rank, + const std::optional& num_tokens_per_expert, + int cached_num_recv_tokens, + int cached_num_rdma_recv_tokens, + const std::optional& cached_rdma_channel_prefix_matrix, + const std::optional& cached_recv_rdma_rank_prefix_sum, + const std::optional& cached_gbl_channel_prefix_matrix, + const std::optional& cached_recv_gbl_rank_prefix_sum, + int expert_alignment, + const Config& config, + std::optional& previous_event, + bool async, + bool allocate_on_comm_stream) { +#ifndef DISABLE_NVSHMEM + // In dispatch, CPU will busy-wait until GPU receive tensor size metadata from other ranks, which can be quite long. + // If users of DeepEP need to execute other Python code on other threads, such as KV transfer, their code will get stuck due to GIL + // unless we release GIL here. + pybind11::gil_scoped_release release; + const int num_channels = config.num_sms / 2; EP_HOST_ASSERT(config.num_sms % 2 == 0); EP_HOST_ASSERT(0 < get_num_rdma_ranks() and get_num_rdma_ranks() <= NUM_MAX_RDMA_PEERS); @@ -674,11 +923,13 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optionaldim() == 2 and cached_rdma_channel_prefix_matrix->is_contiguous()); - EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->size(0) == num_rdma_ranks and cached_rdma_channel_prefix_matrix->size(1) == num_channels); + EP_HOST_ASSERT(cached_rdma_channel_prefix_matrix->size(0) == num_rdma_ranks and + cached_rdma_channel_prefix_matrix->size(1) == num_channels); EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->dim() == 1 and cached_recv_rdma_rank_prefix_sum->is_contiguous()); EP_HOST_ASSERT(cached_recv_rdma_rank_prefix_sum->size(0) == num_rdma_ranks); EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->dim() == 2 and cached_gbl_channel_prefix_matrix->is_contiguous()); - EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->size(0) == num_ranks and cached_gbl_channel_prefix_matrix->size(1) == num_channels); + EP_HOST_ASSERT(cached_gbl_channel_prefix_matrix->size(0) == num_ranks and + cached_gbl_channel_prefix_matrix->size(1) == num_channels); EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->dim() == 1 and cached_recv_gbl_rank_prefix_sum->is_contiguous()); EP_HOST_ASSERT(cached_recv_gbl_rank_prefix_sum->size(0) == num_ranks); } else { @@ -691,12 +942,13 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optionalsize(0) / num_ranks <= NUM_MAX_LOCAL_EXPERTS); } - auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)), hidden_int4 = static_cast(x.size(1) * x.element_size() / sizeof(int4)); + auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)), + hidden_int4 = static_cast(x.size(1) * x.element_size() / sizeof(int4)); auto num_experts = cached_mode ? 0 : static_cast(num_tokens_per_expert->size(0)), num_local_experts = num_experts / num_ranks; // Top-k checks int num_topk = 0; - int64_t* topk_idx_ptr = nullptr; + topk_idx_t* topk_idx_ptr = nullptr; float* topk_weights_ptr = nullptr; EP_HOST_ASSERT(topk_idx.has_value() == topk_weights.has_value()); if (topk_idx.has_value()) { @@ -707,20 +959,22 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optionalsize(0) and num_tokens == topk_weights->size(0)); EP_HOST_ASSERT(num_topk == topk_weights->size(1)); EP_HOST_ASSERT(topk_weights->scalar_type() == torch::kFloat32); - topk_idx_ptr = topk_idx->data_ptr(); + topk_idx_ptr = topk_idx->data_ptr(); topk_weights_ptr = topk_weights->data_ptr(); } // FP8 scales checks float* x_scales_ptr = nullptr; - int num_scales = 0; + int num_scales = 0, scale_token_stride = 0, scale_hidden_stride = 0; if (x_scales.has_value()) { EP_HOST_ASSERT(x.element_size() == 1); - EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32); - EP_HOST_ASSERT(x_scales->dim() > 0 and x_scales->dim() < 3 and x_scales->is_contiguous()); + EP_HOST_ASSERT(x_scales->scalar_type() == torch::kFloat32 or x_scales->scalar_type() == torch::kInt); + EP_HOST_ASSERT(x_scales->dim() == 2); EP_HOST_ASSERT(x_scales->size(0) == num_tokens); num_scales = x_scales->dim() == 1 ? 1 : static_cast(x_scales->size(1)); - x_scales_ptr = x_scales->data_ptr(); + x_scales_ptr = static_cast(x_scales->data_ptr()); + scale_token_stride = static_cast(x_scales->stride(0)); + scale_hidden_stride = static_cast(x_scales->stride(1)); } // Allocate all tensors on comm stream if set @@ -756,14 +1010,28 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optionaldata_ptr(), moe_recv_counter_mapped, num_ranks, - num_tokens_per_rdma_rank->data_ptr(), moe_recv_rdma_counter_mapped, - num_tokens_per_expert->data_ptr(), moe_recv_expert_counter_mapped, num_experts, - is_token_in_rank.data_ptr(), num_tokens, num_channels, - hidden_int4, num_scales, num_topk, expert_alignment, - rdma_channel_prefix_matrix.data_ptr(), recv_rdma_rank_prefix_sum.data_ptr(), - gbl_channel_prefix_matrix.data_ptr(), recv_gbl_rank_prefix_sum.data_ptr(), - rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, - buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, - task_fifo_ptrs_gpu, head, rank, comm_stream, + internode::notify_dispatch(num_tokens_per_rank->data_ptr(), + moe_recv_counter_mapped, + num_ranks, + num_tokens_per_rdma_rank->data_ptr(), + moe_recv_rdma_counter_mapped, + num_tokens_per_expert->data_ptr(), + moe_recv_expert_counter_mapped, + num_experts, + is_token_in_rank.data_ptr(), + num_tokens, + num_channels, + hidden_int4, + num_scales, + num_topk, + expert_alignment, + rdma_channel_prefix_matrix.data_ptr(), + recv_rdma_rank_prefix_sum.data_ptr(), + gbl_channel_prefix_matrix.data_ptr(), + recv_gbl_rank_prefix_sum.data_ptr(), + rdma_buffer_ptr, + config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, + config.num_max_nvl_chunked_recv_tokens, + barrier_signal_ptrs_gpu, + rank, + comm_stream, config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), - num_nvl_bytes, low_latency_mode); + num_nvl_bytes, + low_latency_mode); move_fifo_slots(3); // Synchronize total received tokens and tokens per expert @@ -798,17 +1083,15 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional= 0) and (num_rdma_recv_tokens >= 0); - for (int i = 0; i < num_local_experts and ready; ++ i) + for (int i = 0; i < num_local_experts and ready; ++i) ready &= moe_recv_expert_counter[i] >= 0; if (ready) break; // Timeout check - if (std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time).count() > NUM_CPU_TIMEOUT_SECS) { - printf("Global rank: %d, num_recv_tokens: %d, num_rdma_recv_tokens: %d\n", rank, num_recv_tokens, num_rdma_recv_tokens); - for (int i = 0; i < num_local_experts; ++ i) - printf("moe_recv_expert_counter[%d]: %d\n", i, moe_recv_expert_counter[i]); + if (std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - start_time).count() > + NUM_CPU_TIMEOUT_SECS) { throw std::runtime_error("DeepEP error: timeout (dispatch CPU)"); } } @@ -817,7 +1100,8 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional(), recv_topk_weights = std::optional(), recv_x_scales = std::optional(); + auto recv_topk_idx = std::optional(), recv_topk_weights = std::optional(), + recv_x_scales = std::optional(); auto recv_src_meta = std::optional(); auto recv_rdma_channel_prefix_matrix = std::optional(); auto recv_gbl_channel_prefix_matrix = std::optional(); @@ -832,56 +1116,94 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optionaloptions()); recv_topk_weights = torch::empty({num_recv_tokens, num_topk}, topk_weights->options()); - recv_topk_idx_ptr = recv_topk_idx->data_ptr(); + recv_topk_idx_ptr = recv_topk_idx->data_ptr(); recv_topk_weights_ptr = recv_topk_weights->data_ptr(); } if (x_scales.has_value()) { - recv_x_scales = x_scales->dim() == 1 ? - torch::empty({num_recv_tokens}, x_scales->options()) : - torch::empty({num_recv_tokens, num_scales}, x_scales->options()); - recv_x_scales_ptr = recv_x_scales->data_ptr(); + recv_x_scales = x_scales->dim() == 1 ? torch::empty({num_recv_tokens}, x_scales->options()) + : torch::empty({num_recv_tokens, num_scales}, x_scales->options()); + recv_x_scales_ptr = static_cast(recv_x_scales->data_ptr()); } // Launch data dispatch // NOTES: the buffer size checks are moved into the `.cu` file - internode::dispatch(recv_x.data_ptr(), recv_x_scales_ptr, recv_topk_idx_ptr, recv_topk_weights_ptr, + internode::dispatch(recv_x.data_ptr(), + recv_x_scales_ptr, + recv_topk_idx_ptr, + recv_topk_weights_ptr, cached_mode ? nullptr : recv_src_meta->data_ptr(), - x.data_ptr(), x_scales_ptr, topk_idx_ptr, topk_weights_ptr, - cached_mode ? nullptr : send_rdma_head->data_ptr(), cached_mode ? nullptr : send_nvl_head->data_ptr(), + x.data_ptr(), + x_scales_ptr, + topk_idx_ptr, + topk_weights_ptr, + cached_mode ? nullptr : send_rdma_head->data_ptr(), + cached_mode ? nullptr : send_nvl_head->data_ptr(), cached_mode ? nullptr : recv_rdma_channel_prefix_matrix->data_ptr(), cached_mode ? nullptr : recv_gbl_channel_prefix_matrix->data_ptr(), - rdma_channel_prefix_matrix.data_ptr(), recv_rdma_rank_prefix_sum.data_ptr(), - gbl_channel_prefix_matrix.data_ptr(), recv_gbl_rank_prefix_sum.data_ptr(), - num_tokens, hidden_int4, num_scales, num_topk, num_experts, + rdma_channel_prefix_matrix.data_ptr(), + recv_rdma_rank_prefix_sum.data_ptr(), + gbl_channel_prefix_matrix.data_ptr(), + recv_gbl_rank_prefix_sum.data_ptr(), is_token_in_rank.data_ptr(), - rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens, - buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens, - rank, num_ranks, cached_mode, - comm_stream, num_channels, low_latency_mode); + num_tokens, + hidden_int4, + num_scales, + num_topk, + num_experts, + scale_token_stride, + scale_hidden_stride, + rdma_buffer_ptr, + config.num_max_rdma_chunked_send_tokens, + config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, + config.num_max_nvl_chunked_send_tokens, + config.num_max_nvl_chunked_recv_tokens, + rank, + num_ranks, + cached_mode, + comm_stream, + num_channels, + low_latency_mode); // Wait streams std::optional event; if (async) { event = EventHandle(comm_stream); - for (auto& t: {x, is_token_in_rank, recv_x, - rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum}) { + for (auto& t : {x, + is_token_in_rank, + recv_x, + rdma_channel_prefix_matrix, + recv_rdma_rank_prefix_sum, + gbl_channel_prefix_matrix, + recv_gbl_rank_prefix_sum}) { t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } - for (auto& to: {x_scales, topk_idx, topk_weights, - num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, - cached_rdma_channel_prefix_matrix, cached_recv_rdma_rank_prefix_sum, - cached_gbl_channel_prefix_matrix, cached_recv_gbl_rank_prefix_sum, - recv_topk_idx, recv_topk_weights, recv_x_scales, - recv_rdma_channel_prefix_matrix, recv_gbl_channel_prefix_matrix, send_rdma_head, send_nvl_head, - recv_src_meta}) { + for (auto& to : {x_scales, + topk_idx, + topk_weights, + num_tokens_per_rank, + num_tokens_per_rdma_rank, + num_tokens_per_expert, + cached_rdma_channel_prefix_matrix, + cached_recv_rdma_rank_prefix_sum, + cached_gbl_channel_prefix_matrix, + cached_recv_gbl_rank_prefix_sum, + recv_topk_idx, + recv_topk_weights, + recv_x_scales, + recv_rdma_channel_prefix_matrix, + recv_gbl_channel_prefix_matrix, + send_rdma_head, + send_nvl_head, + recv_src_meta}) { to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); @@ -895,33 +1217,64 @@ Buffer::internode_dispatch(const torch::Tensor& x, const std::optional, std::optional> -Buffer::internode_combine(const torch::Tensor& x, const std::optional& topk_weights, - const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank, - const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix, - const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head, - const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream) { +std::tuple, std::optional> Buffer::internode_combine( + const torch::Tensor& x, + const std::optional& topk_weights, + const std::optional& bias_0, + const std::optional& bias_1, + const torch::Tensor& src_meta, + const torch::Tensor& is_combined_token_in_rank, + const torch::Tensor& rdma_channel_prefix_matrix, + const torch::Tensor& rdma_rank_prefix_sum, + const torch::Tensor& gbl_channel_prefix_matrix, + const torch::Tensor& combined_rdma_head, + const torch::Tensor& combined_nvl_head, + const Config& config, + std::optional& previous_event, + bool async, + bool allocate_on_comm_stream) { +#ifndef DISABLE_NVSHMEM const int num_channels = config.num_sms / 2; EP_HOST_ASSERT(config.num_sms % 2 == 0); // Shape and contiguous checks EP_HOST_ASSERT(x.dim() == 2 and x.is_contiguous()); EP_HOST_ASSERT(src_meta.dim() == 2 and src_meta.is_contiguous() and src_meta.scalar_type() == torch::kByte); - EP_HOST_ASSERT(is_combined_token_in_rank.dim() == 2 and is_combined_token_in_rank.is_contiguous() and is_combined_token_in_rank.scalar_type() == torch::kBool); - EP_HOST_ASSERT(rdma_channel_prefix_matrix.dim() == 2 and rdma_channel_prefix_matrix.is_contiguous() and rdma_channel_prefix_matrix.scalar_type() == torch::kInt32); - EP_HOST_ASSERT(rdma_rank_prefix_sum.dim() == 1 and rdma_rank_prefix_sum.is_contiguous() and rdma_rank_prefix_sum.scalar_type() == torch::kInt32); - EP_HOST_ASSERT(gbl_channel_prefix_matrix.dim() == 2 and gbl_channel_prefix_matrix.is_contiguous() and gbl_channel_prefix_matrix.scalar_type() == torch::kInt32); - EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and combined_rdma_head.is_contiguous() and combined_rdma_head.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(is_combined_token_in_rank.dim() == 2 and is_combined_token_in_rank.is_contiguous() and + is_combined_token_in_rank.scalar_type() == torch::kBool); + EP_HOST_ASSERT(rdma_channel_prefix_matrix.dim() == 2 and rdma_channel_prefix_matrix.is_contiguous() and + rdma_channel_prefix_matrix.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(rdma_rank_prefix_sum.dim() == 1 and rdma_rank_prefix_sum.is_contiguous() and + rdma_rank_prefix_sum.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(gbl_channel_prefix_matrix.dim() == 2 and gbl_channel_prefix_matrix.is_contiguous() and + gbl_channel_prefix_matrix.scalar_type() == torch::kInt32); + EP_HOST_ASSERT(combined_rdma_head.dim() == 2 and combined_rdma_head.is_contiguous() and + combined_rdma_head.scalar_type() == torch::kInt32); EP_HOST_ASSERT(combined_nvl_head.dim() == 2 and combined_nvl_head.is_contiguous() and combined_nvl_head.scalar_type() == torch::kInt32); - auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)), hidden_int4 = static_cast(x.size(1) * x.element_size() / sizeof(int4)); + auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)), + hidden_int4 = static_cast(x.size(1) * x.element_size() / sizeof(int4)); auto num_combined_tokens = static_cast(is_combined_token_in_rank.size(0)); EP_HOST_ASSERT((hidden * x.element_size()) % sizeof(int4) == 0); EP_HOST_ASSERT(src_meta.size(1) == internode::get_source_meta_bytes()); @@ -929,7 +1282,8 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional(), - rdma_channel_prefix_matrix.data_ptr(), rdma_rank_prefix_sum.data_ptr(), combined_nvl_head.data_ptr(), - rdma_buffer_ptr, config.num_max_rdma_chunked_recv_tokens, - buffer_ptrs_gpu, config.num_max_nvl_chunked_recv_tokens, - task_fifo_ptrs_gpu, head, rank, comm_stream, + internode::cached_notify(hidden_int4, + 0, + 0, + num_topk, + num_ranks, + num_channels, + num_combined_tokens, + combined_rdma_head.data_ptr(), + rdma_channel_prefix_matrix.data_ptr(), + rdma_rank_prefix_sum.data_ptr(), + combined_nvl_head.data_ptr(), + rdma_buffer_ptr, + config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, + config.num_max_nvl_chunked_recv_tokens, + barrier_signal_ptrs_gpu, + rank, + comm_stream, config.get_rdma_buffer_size_hint(hidden_int4 * sizeof(int4), num_ranks), - num_nvl_bytes, false, low_latency_mode); + num_nvl_bytes, + false, + low_latency_mode); move_fifo_slots(2); + // Assign bias pointers + auto bias_opts = std::vector>({bias_0, bias_1}); + void* bias_ptrs[2] = {nullptr, nullptr}; + for (int i = 0; i < 2; ++i) + if (bias_opts[i].has_value()) { + auto bias = bias_opts[i].value(); + EP_HOST_ASSERT(bias.dim() == 2 and bias.is_contiguous()); + EP_HOST_ASSERT(bias.scalar_type() == x.scalar_type()); + EP_HOST_ASSERT(bias.size(0) == num_combined_tokens and bias.size(1) == hidden); + bias_ptrs[i] = bias.data_ptr(); + } + // Launch data combine auto combined_x = torch::empty({num_combined_tokens, hidden}, x.options()); internode::combine(at::cuda::ScalarTypeToCudaDataType(x.scalar_type()), - combined_x.data_ptr(), combined_topk_weights_ptr, + combined_x.data_ptr(), + combined_topk_weights_ptr, is_combined_token_in_rank.data_ptr(), - x.data_ptr(), topk_weights_ptr, - combined_rdma_head.data_ptr(), combined_nvl_head.data_ptr(), - src_meta.data_ptr(), rdma_channel_prefix_matrix.data_ptr(), rdma_rank_prefix_sum.data_ptr(), gbl_channel_prefix_matrix.data_ptr(), - num_tokens, num_combined_tokens, hidden, num_topk, - rdma_buffer_ptr, config.num_max_rdma_chunked_send_tokens, config.num_max_rdma_chunked_recv_tokens, - buffer_ptrs_gpu, config.num_max_nvl_chunked_send_tokens, config.num_max_nvl_chunked_recv_tokens, - rank, num_ranks, comm_stream, num_channels, low_latency_mode); + x.data_ptr(), + topk_weights_ptr, + bias_ptrs[0], + bias_ptrs[1], + combined_rdma_head.data_ptr(), + combined_nvl_head.data_ptr(), + src_meta.data_ptr(), + rdma_channel_prefix_matrix.data_ptr(), + rdma_rank_prefix_sum.data_ptr(), + gbl_channel_prefix_matrix.data_ptr(), + num_tokens, + num_combined_tokens, + hidden, + num_topk, + rdma_buffer_ptr, + config.num_max_rdma_chunked_send_tokens, + config.num_max_rdma_chunked_recv_tokens, + buffer_ptrs_gpu, + config.num_max_nvl_chunked_send_tokens, + config.num_max_nvl_chunked_recv_tokens, + rank, + num_ranks, + comm_stream, + num_channels, + low_latency_mode); // Wait streams std::optional event; if (async) { event = EventHandle(comm_stream); - for (auto& t: {x, src_meta, - is_combined_token_in_rank, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, - combined_x, combined_rdma_head, combined_nvl_head}) { + for (auto& t : {x, + src_meta, + is_combined_token_in_rank, + rdma_channel_prefix_matrix, + rdma_rank_prefix_sum, + gbl_channel_prefix_matrix, + combined_x, + combined_rdma_head, + combined_nvl_head}) { t.record_stream(comm_stream); if (allocate_on_comm_stream) t.record_stream(compute_stream); } - for (auto& to: {topk_weights, combined_topk_weights}) { + for (auto& to : {topk_weights, combined_topk_weights, bias_0, bias_1}) { to.has_value() ? to->record_stream(comm_stream) : void(); if (allocate_on_comm_stream) to.has_value() ? to->record_stream(compute_stream) : void(); @@ -1017,12 +1421,14 @@ Buffer::internode_combine(const torch::Tensor& x, const std::optional(clean_meta_0.first), + clean_meta_0.second, + reinterpret_cast(clean_meta_1.first), + clean_meta_1.second, + rank, + num_ranks, + mask_buffer_ptr, + sync_buffer_ptr, at::cuda::getCurrentCUDAStream()); -#endif //DISABLE_INTERNODE +#else + EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); +#endif } -std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional>> -Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, - int num_max_dispatch_tokens_per_rank, int num_experts, - bool use_fp8, bool async, bool return_recv_hook) { -#if DISABLE_INTERNODE - throw std::runtime_error("Low-latency mode is disabled"); -#else - +std::tuple, + torch::Tensor, + torch::Tensor, + torch::Tensor, + std::optional, + std::optional>> +Buffer::low_latency_dispatch(const torch::Tensor& x, + const torch::Tensor& topk_idx, + const std::optional& cumulative_local_expert_recv_stats, + const std::optional& dispatch_wait_recv_cost_stats, + int num_max_dispatch_tokens_per_rank, + int num_experts, + bool use_fp8, + bool round_scale, + bool use_ue8m0, + bool async, + bool return_recv_hook) { +#ifndef DISABLE_NVSHMEM EP_HOST_ASSERT(low_latency_mode); // Tensor checks @@ -1058,12 +1483,24 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i EP_HOST_ASSERT(x.size(1) % sizeof(int4) == 0 and x.size(1) % 128 == 0); EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous()); EP_HOST_ASSERT(x.size(0) == topk_idx.size(0) and x.size(0) <= num_max_dispatch_tokens_per_rank); - EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64); + EP_HOST_ASSERT(topk_idx.scalar_type() == c10::CppTypeToScalarType::value); EP_HOST_ASSERT(num_experts % num_ranks == 0); + // Diagnosis tensors + if (cumulative_local_expert_recv_stats.has_value()) { + EP_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt); + EP_HOST_ASSERT(cumulative_local_expert_recv_stats->dim() == 1 and cumulative_local_expert_recv_stats->is_contiguous()); + EP_HOST_ASSERT(cumulative_local_expert_recv_stats->size(0) == num_experts / num_ranks); + } + if (dispatch_wait_recv_cost_stats.has_value()) { + EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->scalar_type() == torch::kInt64); + EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->dim() == 1 and dispatch_wait_recv_cost_stats->is_contiguous()); + EP_HOST_ASSERT(dispatch_wait_recv_cost_stats->size(0) == num_ranks); + } + auto num_tokens = static_cast(x.size(0)), hidden = static_cast(x.size(1)); - auto num_scales = hidden / 128, num_topk = static_cast(topk_idx.size(1)); - int num_local_experts = num_experts / num_ranks; + auto num_topk = static_cast(topk_idx.size(1)); + auto num_local_experts = num_experts / num_ranks; // Buffer control LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); @@ -1087,26 +1524,36 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i stream_wait(launch_stream, compute_stream); // Allocate packed tensors - auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, - x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn: torch::kBFloat16)); #ifdef USE_ROCM - if (gfx == 942){ - packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, - x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fnuz: torch::kBFloat16)); - } -#endif - auto packed_recv_src_info = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA)); + auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, + x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fnuz : torch::kBFloat16)); +#else + auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, + x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn : torch::kBFloat16)); +#endif + auto packed_recv_src_info = + torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kInt32).device(torch::kCUDA)); auto packed_recv_layout_range = torch::empty({num_local_experts, num_ranks}, torch::dtype(torch::kInt64).device(torch::kCUDA)); auto packed_recv_count = torch::empty({num_local_experts}, torch::dtype(torch::kInt32).device(torch::kCUDA)); // Allocate column-majored scales auto packed_recv_x_scales = std::optional(); - float* packed_recv_x_scales_ptr = nullptr; + void* packed_recv_x_scales_ptr = nullptr; + EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4"); + if (use_fp8) { - EP_HOST_ASSERT((num_ranks * num_max_dispatch_tokens_per_rank) % 4 == 0 and "TMA requires the number of tokens to be multiple of 4"); - packed_recv_x_scales = torch::empty({num_local_experts, num_scales, num_ranks * num_max_dispatch_tokens_per_rank}, torch::dtype(torch::kFloat32).device(torch::kCUDA)); + // TODO: support unaligned cases + EP_HOST_ASSERT(hidden % 512 == 0); + if (not use_ue8m0) { + packed_recv_x_scales = torch::empty({num_local_experts, hidden / 128, num_ranks * num_max_dispatch_tokens_per_rank}, + torch::dtype(torch::kFloat32).device(torch::kCUDA)); + } else { + EP_HOST_ASSERT(round_scale); + packed_recv_x_scales = torch::empty({num_local_experts, hidden / 512, num_ranks * num_max_dispatch_tokens_per_rank}, + torch::dtype(torch::kInt).device(torch::kCUDA)); + } packed_recv_x_scales = torch::transpose(packed_recv_x_scales.value(), 1, 2); - packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr(); + packed_recv_x_scales_ptr = packed_recv_x_scales->data_ptr(); } // Kernel launch @@ -1143,19 +1590,27 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i // Return values return {packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, recv_hook}; -#endif //DISABLE_INTERNODE -} - -std::tuple, std::optional>> -Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, - const torch::Tensor& src_info, const torch::Tensor& layout_range, - int num_max_dispatch_tokens_per_rank, int num_experts, - bool zero_copy, bool async, bool return_recv_hook, - const std::optional& out) { -#if DISABLE_INTERNODE - throw std::runtime_error("Low-latency mode is disabled"); #else + EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); + return {}; +#endif +} +std::tuple, std::optional>> Buffer::low_latency_combine( + const torch::Tensor& x, + const torch::Tensor& topk_idx, + const torch::Tensor& topk_weights, + const torch::Tensor& src_info, + const torch::Tensor& layout_range, + const std::optional& combine_wait_recv_cost_stats, + int num_max_dispatch_tokens_per_rank, + int num_experts, + bool use_logfmt, + bool zero_copy, + bool async, + bool return_recv_hook, + const std::optional& out) { +#ifndef DISABLE_NVSHMEM EP_HOST_ASSERT(low_latency_mode); // Tensor checks @@ -1165,7 +1620,7 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id EP_HOST_ASSERT(x.size(2) % sizeof(int4) == 0 and x.size(2) % 128 == 0); EP_HOST_ASSERT(topk_idx.dim() == 2 and topk_idx.is_contiguous()); EP_HOST_ASSERT(topk_idx.size(0) == topk_weights.size(0) and topk_idx.size(1) == topk_weights.size(1)); - EP_HOST_ASSERT(topk_idx.scalar_type() == torch::kInt64); + EP_HOST_ASSERT(topk_idx.scalar_type() == c10::CppTypeToScalarType::value); EP_HOST_ASSERT(topk_weights.dim() == 2 and topk_weights.is_contiguous()); EP_HOST_ASSERT(topk_weights.size(0) <= num_max_dispatch_tokens_per_rank); EP_HOST_ASSERT(topk_weights.scalar_type() == torch::kFloat32); @@ -1174,8 +1629,15 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id EP_HOST_ASSERT(layout_range.dim() == 2 and layout_range.is_contiguous()); EP_HOST_ASSERT(layout_range.scalar_type() == torch::kInt64); EP_HOST_ASSERT(layout_range.size(0) == num_experts / num_ranks and layout_range.size(1) == num_ranks); + + if (combine_wait_recv_cost_stats.has_value()) { + EP_HOST_ASSERT(combine_wait_recv_cost_stats->scalar_type() == torch::kInt64); + EP_HOST_ASSERT(combine_wait_recv_cost_stats->dim() == 1 and combine_wait_recv_cost_stats->is_contiguous()); + EP_HOST_ASSERT(combine_wait_recv_cost_stats->size(0) == num_ranks); + } + auto hidden = static_cast(x.size(2)); - auto num_local_experts = num_experts / num_ranks, num_topk = static_cast(topk_weights.size(1)); + auto num_topk = static_cast(topk_weights.size(1)); auto num_combined_tokens = static_cast(topk_weights.size(0)); // Buffer control @@ -1183,12 +1645,6 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id EP_HOST_ASSERT(layout.total_bytes <= num_rdma_bytes); auto buffer = layout.buffers[low_latency_buffer_idx]; auto next_buffer = layout.buffers[low_latency_buffer_idx ^= 1]; - - // Buffer control - LowLatencyLayout nvl_layout(nvl_buffer_ptrs[nvl_rank], num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); - EP_HOST_ASSERT(nvl_layout.total_bytes <= num_rdma_bytes); - auto nvl_buffer = nvl_layout.buffers[low_latency_buffer_idx ^= 1]; - auto nvl_next_buffer = nvl_layout.buffers[low_latency_buffer_idx ^= 1]; // Wait previous tasks to be finished // NOTES: the hook mode will always use the default stream @@ -1214,7 +1670,8 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id auto next_clean_meta = next_buffer.clean_meta(); auto launcher = [=](int phases) { internode_ll::combine(combined_x.data_ptr(), - buffer.combine_rdma_recv_data_buffer, buffer.combine_rdma_recv_flag_buffer, + buffer.combine_rdma_recv_data_buffer, + reinterpret_cast(buffer.combine_rdma_recv_flag_buffer), buffer.combine_rdma_send_buffer, x.data_ptr(), topk_idx.data_ptr(), topk_weights.data_ptr(), src_info.data_ptr(), layout_range.data_ptr(), @@ -1244,12 +1701,16 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id // Return values return {combined_x, event, recv_hook}; -#endif // DISABLE_INTERNODE +#else + EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); + return {}; +#endif } -torch::Tensor -Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) { +torch::Tensor Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const { +#ifndef DISABLE_NVSHMEM LowLatencyLayout layout(rdma_buffer_ptr, num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts); + auto buffer = layout.buffers[low_latency_buffer_idx]; auto dtype = torch::kBFloat16; auto num_msg_elems = static_cast(buffer.num_bytes_per_combine_msg / elementSize(torch::kBFloat16)); @@ -1259,36 +1720,40 @@ Buffer::get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank {num_experts / num_ranks, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, {num_ranks * num_max_dispatch_tokens_per_rank * num_msg_elems, num_msg_elems, 1}, torch::TensorOptions().dtype(dtype).device(torch::kCUDA)); +#else + EP_HOST_ASSERT(false and "NVSHMEM is disabled during compilation"); + return {}; +#endif } -std::string Buffer::get_local_ipc_handle_string() const { - return std::string(reinterpret_cast(ipc_handles[nvl_rank].reserved), CUDA_IPC_HANDLE_SIZE); +bool is_sm90_compiled() { +#ifndef DISABLE_SM90_FEATURES + return true; +#else + return false; +#endif } -std::string Buffer::get_local_nvshmem_unique_id_string() const { - EP_HOST_ASSERT(rdma_rank == 0 and "Only RDMA rank 0 can get NVSHMEM unique ID"); - auto unique_id = internode::get_unique_id(); - return std::string(reinterpret_cast(unique_id.data()), unique_id.size()); +void Buffer::low_latency_update_mask_buffer(int rank_to_mask, bool mask) { + EP_HOST_ASSERT(mask_buffer_ptr != nullptr and "Shrink mode must be enabled"); + EP_HOST_ASSERT(rank_to_mask >= 0 and rank_to_mask < num_ranks); + internode_ll::update_mask_buffer(mask_buffer_ptr, rank_to_mask, mask, at::cuda::getCurrentCUDAStream()); } -void Buffer::sync_string(const std::vector &device_ids, - const std::vector &all_gathered_handles, - const std::string& root_unique_id_opt) { - std::vector> py_all_gathered_handles; - for (auto& handle : all_gathered_handles) { - std::optional py_handle_opt = std::nullopt; - if (!handle.empty()) { - py_handle_opt.emplace(handle.c_str(), handle.size()); - } - py_all_gathered_handles.push_back(py_handle_opt); - } - std::optional py_root_unique_id_opt = std::nullopt; - if (!root_unique_id_opt.empty()) { - py_root_unique_id_opt.emplace(root_unique_id_opt.c_str(), root_unique_id_opt.size()); - } - sync(device_ids, py_all_gathered_handles, py_root_unique_id_opt); +void Buffer::low_latency_query_mask_buffer(const torch::Tensor& mask_status) { + EP_HOST_ASSERT(mask_buffer_ptr != nullptr and "Shrink mode must be enabled"); + EP_HOST_ASSERT(mask_status.numel() == num_ranks && mask_status.scalar_type() == torch::kInt32); + + internode_ll::query_mask_buffer( + mask_buffer_ptr, num_ranks, reinterpret_cast(mask_status.data_ptr()), at::cuda::getCurrentCUDAStream()); } -} // namespace deep_ep + +void Buffer::low_latency_clean_mask_buffer() { + EP_HOST_ASSERT(mask_buffer_ptr != nullptr and "Shrink mode must be enabled"); + internode_ll::clean_mask_buffer(mask_buffer_ptr, num_ranks, at::cuda::getCurrentCUDAStream()); +} + +} // namespace deep_ep PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "DeepEP: an efficient expert-parallel communication library"; @@ -1296,8 +1761,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { pybind11::class_(m, "Config") .def(pybind11::init(), py::arg("num_sms") = 20, - py::arg("num_max_nvl_chunked_send_tokens") = 6, py::arg("num_max_nvl_chunked_recv_tokens") = 256, - py::arg("num_max_rdma_chunked_send_tokens") = 6, py::arg("num_max_rdma_chunked_recv_tokens") = 256) + py::arg("num_max_nvl_chunked_send_tokens") = 6, + py::arg("num_max_nvl_chunked_recv_tokens") = 256, + py::arg("num_max_rdma_chunked_send_tokens") = 6, + py::arg("num_max_rdma_chunked_recv_tokens") = 256) .def("get_nvl_buffer_size_hint", &deep_ep::Config::get_nvl_buffer_size_hint) .def("get_rdma_buffer_size_hint", &deep_ep::Config::get_rdma_buffer_size_hint); m.def("get_low_latency_rdma_size_hint", &deep_ep::get_low_latency_rdma_size_hint); @@ -1307,7 +1774,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("current_stream_wait", &deep_ep::EventHandle::current_stream_wait); pybind11::class_(m, "Buffer") - .def(pybind11::init()) + .def(pybind11::init()) .def("is_available", &deep_ep::Buffer::is_available) .def("get_num_rdma_ranks", &deep_ep::Buffer::get_num_rdma_ranks) .def("get_rdma_rank", &deep_ep::Buffer::get_rdma_rank) @@ -1316,7 +1783,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("get_local_ipc_handle", &deep_ep::Buffer::get_local_ipc_handle) .def("get_local_nvshmem_unique_id", &deep_ep::Buffer::get_local_nvshmem_unique_id) .def("get_local_buffer_tensor", &deep_ep::Buffer::get_local_buffer_tensor) + .def("get_comm_stream", &deep_ep::Buffer::get_comm_stream) .def("sync", &deep_ep::Buffer::sync) + .def("destroy", &deep_ep::Buffer::destroy) .def("get_dispatch_layout", &deep_ep::Buffer::get_dispatch_layout) .def("intranode_dispatch", &deep_ep::Buffer::intranode_dispatch) .def("intranode_combine", &deep_ep::Buffer::intranode_combine) @@ -1325,5 +1794,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("clean_low_latency_buffer", &deep_ep::Buffer::clean_low_latency_buffer) .def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch) .def("low_latency_combine", &deep_ep::Buffer::low_latency_combine) + .def("low_latency_update_mask_buffer", &deep_ep::Buffer::low_latency_update_mask_buffer) + .def("low_latency_query_mask_buffer", &deep_ep::Buffer::low_latency_query_mask_buffer) + .def("low_latency_clean_mask_buffer", &deep_ep::Buffer::low_latency_clean_mask_buffer) .def("get_next_low_latency_combine_buffer", &deep_ep::Buffer::get_next_low_latency_combine_buffer); + + m.def("is_sm90_compiled", deep_ep::is_sm90_compiled); + m.attr("topk_idx_t") = + py::reinterpret_borrow((PyObject*)torch::getTHPDtype(c10::CppTypeToScalarType::value)); } diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index dfbd842..90fc163 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -8,6 +8,7 @@ #include #include #include + #include #include @@ -35,21 +36,21 @@ struct Buffer { void* buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; void** buffer_ptrs_gpu = nullptr; - void* nvl_buffer_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; - void** nvl_buffer_ptrs_gpu = nullptr; // NVSHMEM Buffer int64_t num_rdma_bytes; void* rdma_buffer_ptr = nullptr; + // Shrink mode buffer + bool enable_shrink = false; + int* mask_buffer_ptr = nullptr; + int* sync_buffer_ptr = nullptr; + // Device info and communication int device_id; -#ifdef USE_ROCM - int gfx; -#endif + int num_device_sms; int rank, rdma_rank, nvl_rank; int num_ranks, num_rdma_ranks, num_nvl_ranks; cudaIpcMemHandle_t ipc_handles[NUM_MAX_NVL_PEERS]; - cudaIpcMemHandle_t pxn_ipc_handles[NUM_MAX_NVL_PEERS]; // Stream for communication at::cuda::CUDAStream comm_stream; @@ -59,8 +60,15 @@ struct Buffer { // Task fifo int head = 0; - int* task_fifo_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; - int** task_fifo_ptrs_gpu = nullptr; + + // Whether explicit `destroy()` is required. + bool explicitly_destroy; + // After `destroy()` be called, this flag will be true + bool destroyed = false; + + // Barrier signals + int* barrier_signal_ptrs[NUM_MAX_NVL_PEERS] = {nullptr}; + int** barrier_signal_ptrs_gpu = nullptr; // Workspace void* workspace = nullptr; @@ -83,9 +91,15 @@ struct Buffer { private: void move_fifo_slots(int num_slots = 1); - + public: - Buffer(int rank, int num_ranks, int64_t num_nvl_bytes, int64_t num_rdma_bytes, bool low_latency_mode); + Buffer(int rank, + int num_ranks, + int64_t num_nvl_bytes, + int64_t num_rdma_bytes, + bool low_latency_mode, + bool explicitly_destroy, + bool enable_shrink); ~Buffer() noexcept(false); @@ -104,67 +118,161 @@ struct Buffer { pybind11::bytearray get_local_ipc_handle() const; pybind11::bytearray get_local_nvshmem_unique_id() const; - - pybind11::bytearray get_local_pxn_ipc_handle() const; torch::Tensor get_local_buffer_tensor(const pybind11::object& dtype, int64_t offset, bool use_rdma_buffer) const; - void sync(const std::vector& device_ids, const std::vector>& all_gathered_handles, const std::optional& root_unique_id_opt); - - std::tuple, torch::Tensor, torch::Tensor, std::optional> - get_dispatch_layout(const torch::Tensor& topk_idx, int num_experts, std::optional& previous_event, - bool async, bool allocate_on_comm_stream); - - std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, std::optional> - intranode_dispatch(const torch::Tensor& x, const std::optional& x_scales, - const std::optional& topk_idx, const std::optional& topk_weights, - const std::optional& num_tokens_per_rank, const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, - int cached_num_recv_tokens, const std::optional& cached_rank_prefix_matrix, const std::optional& cached_channel_prefix_matrix, - int expert_alignment, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); - - std::tuple, std::optional> - intranode_combine(const torch::Tensor& x, const std::optional& topk_weights, - const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix, - const torch::Tensor& send_head, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); - - std::tuple, std::optional, std::optional, std::vector, torch::Tensor, torch::Tensor, std::optional, torch::Tensor, std::optional, torch::Tensor, std::optional, std::optional, std::optional, std::optional> - internode_dispatch(const torch::Tensor& x, const std::optional& x_scales, - const std::optional& topk_idx, const std::optional& topk_weights, - const std::optional& num_tokens_per_rank, const std::optional& num_tokens_per_rdma_rank, - const torch::Tensor& is_token_in_rank, const std::optional& num_tokens_per_expert, - int cached_num_recv_tokens, int cached_num_rdma_recv_tokens, - const std::optional& cached_rdma_channel_prefix_matrix, const std::optional& cached_recv_rdma_rank_prefix_sum, - const std::optional& cached_gbl_channel_prefix_matrix, const std::optional& cached_recv_gbl_rank_prefix_sum, - int expert_alignment, const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); - - std::tuple, std::optional> - internode_combine(const torch::Tensor& x, const std::optional& topk_weights, - const torch::Tensor& src_meta, const torch::Tensor& is_combined_token_in_rank, - const torch::Tensor& rdma_channel_prefix_matrix, const torch::Tensor& rdma_rank_prefix_sum, const torch::Tensor& gbl_channel_prefix_matrix, - const torch::Tensor& combined_rdma_head, const torch::Tensor& combined_nvl_head, - const Config& config, std::optional& previous_event, bool async, bool allocate_on_comm_stream); + torch::Stream get_comm_stream() const; + + void sync(const std::vector& device_ids, + const std::vector>& all_gathered_handles, + const std::optional& root_unique_id_opt); + + void destroy(); + + std::tuple, torch::Tensor, torch::Tensor, std::optional> get_dispatch_layout( + const torch::Tensor& topk_idx, + int num_experts, + std::optional& previous_event, + bool async, + bool allocate_on_comm_stream); + + std::tuple, + std::optional, + std::optional, + std::vector, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + torch::Tensor, + std::optional> + intranode_dispatch(const torch::Tensor& x, + const std::optional& x_scales, + const std::optional& topk_idx, + const std::optional& topk_weights, + const std::optional& num_tokens_per_rank, + const torch::Tensor& is_token_in_rank, + const std::optional& num_tokens_per_expert, + int cached_num_recv_tokens, + const std::optional& cached_rank_prefix_matrix, + const std::optional& cached_channel_prefix_matrix, + int expert_alignment, + int num_worst_tokens, + const Config& config, + std::optional& previous_event, + bool async, + bool allocate_on_comm_stream); + + std::tuple, std::optional> intranode_combine( + const torch::Tensor& x, + const std::optional& topk_weights, + //const std::optional& bias_0, + //const std::optional& bias_1, + const torch::Tensor& src_idx, + const torch::Tensor& rank_prefix_matrix, + const torch::Tensor& channel_prefix_matrix, + const torch::Tensor& send_head, + const Config& config, + std::optional& previous_event, + bool async, + bool allocate_on_comm_stream); + + std::tuple, + std::optional, + std::optional, + std::vector, + torch::Tensor, + torch::Tensor, + std::optional, + torch::Tensor, + std::optional, + torch::Tensor, + std::optional, + std::optional, + std::optional, + std::optional> + internode_dispatch(const torch::Tensor& x, + const std::optional& x_scales, + const std::optional& topk_idx, + const std::optional& topk_weights, + const std::optional& num_tokens_per_rank, + const std::optional& num_tokens_per_rdma_rank, + const torch::Tensor& is_token_in_rank, + const std::optional& num_tokens_per_expert, + int cached_num_recv_tokens, + int cached_num_rdma_recv_tokens, + const std::optional& cached_rdma_channel_prefix_matrix, + const std::optional& cached_recv_rdma_rank_prefix_sum, + const std::optional& cached_gbl_channel_prefix_matrix, + const std::optional& cached_recv_gbl_rank_prefix_sum, + int expert_alignment, + const Config& config, + std::optional& previous_event, + bool async, + bool allocate_on_comm_stream); + + std::tuple, std::optional> internode_combine( + const torch::Tensor& x, + const std::optional& topk_weights, + const std::optional& bias_0, + const std::optional& bias_1, + const torch::Tensor& src_meta, + const torch::Tensor& is_combined_token_in_rank, + const torch::Tensor& rdma_channel_prefix_matrix, + const torch::Tensor& rdma_rank_prefix_sum, + const torch::Tensor& gbl_channel_prefix_matrix, + const torch::Tensor& combined_rdma_head, + const torch::Tensor& combined_nvl_head, + const Config& config, + std::optional& previous_event, + bool async, + bool allocate_on_comm_stream); void clean_low_latency_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts); - std::tuple, torch::Tensor, torch::Tensor, torch::Tensor, std::optional, std::optional>> - low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_idx, - int num_max_dispatch_tokens_per_rank, int num_experts, - bool use_fp8, bool async, bool return_recv_hook); - - std::tuple, std::optional>> - low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_idx, const torch::Tensor& topk_weights, - const torch::Tensor& src_info, const torch::Tensor& layout_range, - int num_max_dispatch_tokens_per_rank, int num_experts, - bool zero_copy, bool async, bool return_recv_hook, - const std::optional& out = std::nullopt); - - torch::Tensor - get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts); - - // addtional interface for c++ - std::string get_local_ipc_handle_string() const; - std::string get_local_nvshmem_unique_id_string() const; - void sync_string(const std::vector& device_ids, const std::vector& all_gathered_handles, const std::string& root_unique_id_opt); + std::tuple, + torch::Tensor, + torch::Tensor, + torch::Tensor, + std::optional, + std::optional>> + low_latency_dispatch(const torch::Tensor& x, + const torch::Tensor& topk_idx, + const std::optional& cumulative_local_expert_recv_stats, + const std::optional& dispatch_wait_recv_cost_stats, + int num_max_dispatch_tokens_per_rank, + int num_experts, + bool use_fp8, + bool round_scale, + bool use_ue8m0, + bool async, + bool return_recv_hook); + + std::tuple, std::optional>> low_latency_combine( + const torch::Tensor& x, + const torch::Tensor& topk_idx, + const torch::Tensor& topk_weights, + const torch::Tensor& src_info, + const torch::Tensor& layout_range, + const std::optional& combine_wait_recv_cost_stats, + int num_max_dispatch_tokens_per_rank, + int num_experts, + bool use_logfmt, + bool zero_copy, + bool async, + bool return_recv_hook, + const std::optional& out = std::nullopt); + + torch::Tensor get_next_low_latency_combine_buffer(int num_max_dispatch_tokens_per_rank, int hidden, int num_experts) const; + + void low_latency_update_mask_buffer(int rank_to_mask, bool mask); + + void low_latency_query_mask_buffer(const torch::Tensor& mask_status); + + void low_latency_clean_mask_buffer(); }; -} // namespace deep_ep +} // namespace deep_ep diff --git a/csrc/kernels/api.cuh b/csrc/kernels/api.cuh index f18a859..c540346 100644 --- a/csrc/kernels/api.cuh +++ b/csrc/kernels/api.cuh @@ -7,7 +7,8 @@ namespace deep_ep { // Intranode runtime namespace intranode { -void barrier(int **task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream); +//void barrier(int **task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream); +void barrier(int **task_fifo_ptrs, int rank, int num_ranks, cudaStream_t stream, int head = 0); } // namespace intranode @@ -30,6 +31,24 @@ void finalize(); } // namespace internode + +// Layout kernels +namespace layout { + +void get_dispatch_layout(const topk_idx_t* topk_idx, + int* num_tokens_per_rank, + int* num_tokens_per_rdma_rank, + int* num_tokens_per_expert, + bool* is_token_in_rank, + int num_tokens, + int num_topk, + int num_ranks, + int num_experts, + cudaStream_t stream); + +} // namespace layout + + // Intranode kernels namespace intranode { @@ -37,32 +56,67 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix, int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment, - void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank, - cudaStream_t stream, int num_sms); + void** buffer_ptrs, int **task_fifo_ptrs, int rank, + cudaStream_t stream, int num_sms, int head = 0); void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, - void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank, int num_ranks, - cudaStream_t stream); - -void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset, - int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, - const bool* is_token_in_rank, const int* channel_prefix_matrix, - int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, - void** buffer_ptrs, int rank, int num_ranks, - cudaStream_t stream, int num_sms, - int num_max_send_tokens, int num_recv_buffer_tokens); + void** buffer_ptrs, int **task_fifo_ptrs, int rank, int num_ranks, + cudaStream_t stream, int head = 0); + +void dispatch(void* recv_x, + float* recv_x_scales, + int* recv_src_idx, + topk_idx_t* recv_topk_idx, + float* recv_topk_weights, + int* recv_channel_offset, + int* send_head, + const void* x, + const float* x_scales, + const topk_idx_t* topk_idx, + const float* topk_weights, + const bool* is_token_in_rank, + const int* channel_prefix_matrix, + int num_tokens, + int num_worst_tokens, + int hidden_int4, + int num_topk, + int num_experts, + int num_scales, + int scale_token_stride, + int scale_hidden_stride, + void** buffer_ptrs, + int rank, + int num_ranks, + cudaStream_t stream, + int num_sms, + int num_max_send_tokens, + int num_recv_buffer_tokens); void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int, - int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream); + int** task_fifo_ptrs, int rank, int num_ranks, cudaStream_t stream, int head = 0); void combine(cudaDataType_t type, - void* recv_x, float* recv_topk_weights, - const void* x, const float* topk_weights, - const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix, - int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk, - void** buffer_ptrs, int rank, int num_ranks, - cudaStream_t stream, int num_sms, - int num_max_send_tokens, int num_recv_buffer_tokens); + void* recv_x, + float* recv_topk_weights, + const void* x, + const float* topk_weights, + //const void* bias_0, + //const void* bias_1, + const int* src_idx, + const int* rank_prefix_matrix, + const int* channel_prefix_matrix, + int* send_head, + int num_tokens, + int num_recv_tokens, + int hidden, + int num_topk, + void** buffer_ptrs, + int rank, + int num_ranks, + cudaStream_t stream, + int num_sms, + int num_max_send_tokens, + int num_recv_buffer_tokens); } // namespace intranode @@ -78,82 +132,210 @@ void get_dispatch_layout(const int64_t* topk_idx, int num_tokens, int num_topk, int num_ranks, int num_experts, cudaStream_t stream); -void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, - const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped, - const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, - const bool* is_token_in_rank, int num_tokens, int num_channels, - int hidden_int4, int num_scales, int num_topk, int expert_alignment, - int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, - int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, - void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, - int** task_fifo_ptrs, int head, int rank, - cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, - bool low_latency_mode); - -void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta, - const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, - int* send_rdma_head, int* send_nvl_head, - int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix, - const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum, - const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum, - int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts, +void notify_dispatch(const int* num_tokens_per_rank, + int* moe_recv_counter_mapped, + int num_ranks, + const int* num_tokens_per_rdma_rank, + int* moe_recv_rdma_counter_mapped, + const int* num_tokens_per_expert, + int* moe_recv_expert_counter_mapped, + int num_experts, + const bool* is_token_in_rank, + int num_tokens, + int num_channels, + int hidden_int4, + int num_scales, + int num_topk, + int expert_alignment, + int* rdma_channel_prefix_matrix, + int* recv_rdma_rank_prefix_sum, + int* gbl_channel_prefix_matrix, + int* recv_gbl_rank_prefix_sum, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_recv_tokens, + int** barrier_signal_ptrs, + int rank, + cudaStream_t stream, + int64_t num_rdma_bytes, + int64_t num_nvl_bytes, + bool low_latency_mode, + int head = 0); + +void dispatch(void* recv_x, + float* recv_x_scales, + topk_idx_t* recv_topk_idx, + float* recv_topk_weights, + void* recv_src_meta, + const void* x, + const float* x_scales, + const topk_idx_t* topk_idx, + const float* topk_weights, + int* send_rdma_head, + int* send_nvl_head, + int* recv_rdma_channel_prefix_matrix, + int* recv_gbl_channel_prefix_matrix, + const int* rdma_channel_prefix_matrix, + const int* recv_rdma_rank_prefix_sum, + const int* gbl_channel_prefix_matrix, + const int* recv_gbl_rank_prefix_sum, const bool* is_token_in_rank, - void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, - int rank, int num_ranks, bool is_cached_dispatch, - cudaStream_t stream, int num_channels, bool low_latency_mode); - -void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, - int num_ranks, int num_channels, int num_combined_tokens, int* combined_rdma_head, - const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head, - void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, - int** task_fifo_ptrs, int head, int rank, cudaStream_t stream, - int64_t num_rdma_bytes, int64_t num_nvl_bytes, - bool is_cached_dispatch, bool low_latency_mode); + int num_tokens, + int hidden_int4, + int num_scales, + int num_topk, + int num_experts, + int scale_token_stride, + int scale_hidden_stride, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_send_tokens, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_send_tokens, + int num_max_nvl_chunked_recv_tokens, + int rank, + int num_ranks, + bool is_cached_dispatch, + cudaStream_t stream, + int num_channels, + bool low_latency_mode); + +void cached_notify(int hidden_int4, + int num_scales, + int num_topk_idx, + int num_topk_weights, + int num_ranks, + int num_channels, + int num_combined_tokens, + int* combined_rdma_head, + const int* rdma_channel_prefix_matrix, + const int* rdma_rank_prefix_sum, + int* combined_nvl_head, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_recv_tokens, + int** barrier_signal_ptrs, + int rank, + cudaStream_t stream, + int64_t num_rdma_bytes, + int64_t num_nvl_bytes, + bool is_cached_dispatch, + bool low_latency_mode, + int head = 0); + void combine(cudaDataType_t type, - void* combined_x, float* combined_topk_weights, + void* combined_x, + float* combined_topk_weights, const bool* is_combined_token_in_rank, - const void* x, const float* topk_weights, - const int* combined_rdma_head, const int* combined_nvl_head, - const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, - int num_tokens, int num_combined_tokens, int hidden, int num_topk, - void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, - int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode); + const void* x, + const float* topk_weights, + const void* bias_0, + const void* bias_1, + const int* combined_rdma_head, + const int* combined_nvl_head, + const void* src_meta, + const int* rdma_channel_prefix_matrix, + const int* rdma_rank_prefix_sum, + const int* gbl_channel_prefix_matrix, + int num_tokens, + int num_combined_tokens, + int hidden, + int num_topk, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_send_tokens, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_send_tokens, + int num_max_nvl_chunked_recv_tokens, + int rank, + int num_ranks, + cudaStream_t stream, + int num_channels, + bool low_latency_mode); + } // namespace internode #if !DISABLE_INTERNODE // Internode low-latency kernels namespace internode_ll { -void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, - int64_t* clean_1, int num_clean_int_1, +void clean_low_latency_buffer(int64_t* clean_0, + int num_clean_int_0, + int64_t* clean_1, + int num_clean_int_1, + int rank, + int num_ranks, + int* mask_buffer_ptr, + int* sync_buffer_ptr, cudaStream_t stream); -void dispatch(void* packed_recv_x, float* packed_recv_x_scales, - int* packed_recv_src_info, int64_t* packed_recv_layout_range, +void dispatch(void* packed_recv_x, + float* packed_recv_x_scales, + int* packed_recv_src_info, + int64_t* packed_recv_layout_range, int* packed_recv_count, - int* global_atomic_counter, - void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x, - const void* x, const int64_t* topk_idx, - int64_t* next_clean, int num_next_clean_int, - int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, - int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8, - void* workspace, cudaStream_t stream, int phases); + int* mask_buffer_ptr, + int* cumulative_local_expert_recv_stats, + int64_t* dispatch_wait_recv_cost_stats, + void* rdma_recv_x, + int64_t* rdma_recv_count, + void* rdma_x, + const void* x, + const topk_idx_t* topk_idx, + int64_t* next_clean, + int num_next_clean_int, + int num_tokens, + int hidden, + int num_max_dispatch_tokens_per_rank, + int num_topk, + int num_experts, + int rank, + int num_ranks, + bool use_fp8, + bool round_scale, + bool use_ue8m0, + void* workspace, + int num_device_sms, + cudaStream_t stream, + int phases, + int* global_atomic_counter = NULL); void combine(void* combined_x, - void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x, - const void* x, const int64_t* topk_idx, const float* topk_weights, - const int* src_info, const int64_t* layout_range, - int* global_atomic_counter, - int64_t* next_clean, int num_next_clean_int, - int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, - int num_topk, int num_experts, int rank, int num_ranks, - void* workspace, cudaStream_t stream, - int phases, bool zero_copy); + void* rdma_recv_x, + int64_t* rdma_recv_flag, + void* rdma_send_x, + const void* x, + const topk_idx_t* topk_idx, + const float* topk_weights, + const int* src_info, + const int64_t* layout_range, + int* mask_buffer_ptr, + int64_t* combine_wait_recv_cost_stats, + int64_t* next_clean, + int num_next_clean_int, + int num_combined_tokens, + int hidden, + int num_max_dispatch_tokens_per_rank, + int num_topk, + int num_experts, + int rank, + int num_ranks, + bool use_logfmt, + void* workspace, + int num_device_sms, + cudaStream_t stream, + int phases, + bool zero_copy, + int* global_atomic_counter = NULL); + +void query_mask_buffer(int* mask_buffer_ptr, int num_ranks, int* output_mask_tensor, cudaStream_t stream); + +void update_mask_buffer(int* mask_buffer_ptr, int rank_to_mask, bool mask, cudaStream_t stream); + +void clean_mask_buffer(int* mask_buffer_ptr, int num_ranks, cudaStream_t stream); } // namespace internode_ll #endif diff --git a/csrc/kernels/configs.cuh b/csrc/kernels/configs.cuh index 4f53a58..9aee4ae 100644 --- a/csrc/kernels/configs.cuh +++ b/csrc/kernels/configs.cuh @@ -29,6 +29,21 @@ __host__ __device__ __forceinline__ void host_device_printf(const char* format, #define printf host_device_printf #endif +namespace deep_ep { + +#ifndef TOPK_IDX_BITS +#define TOPK_IDX_BITS 64 +#endif + +#define INT_BITS_T2(bits) int##bits##_t +#define INT_BITS_T(bits) INT_BITS_T2(bits) +typedef INT_BITS_T(TOPK_IDX_BITS) topk_idx_t; // int32_t or int64_t +#undef INT_BITS_T +#undef INT_BITS_T2 + +} // namespace deep_ep + + #ifdef USE_ROCM static constexpr int32_t kWarpSize = 64; // For ROCm equals to half the wave size or Nvidia warp size diff --git a/csrc/kernels/exception.cuh b/csrc/kernels/exception.cuh index 81b4be9..77ae7a6 100644 --- a/csrc/kernels/exception.cuh +++ b/csrc/kernels/exception.cuh @@ -46,7 +46,7 @@ do { \ do { \ if (not (cond)) { \ printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ - trap(); \ + abort();\ } \ } while (0) #else diff --git a/csrc/kernels/internode.cu b/csrc/kernels/internode.cu index 7197d53..f38d203 100644 --- a/csrc/kernels/internode.cu +++ b/csrc/kernels/internode.cu @@ -406,18 +406,36 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, in } } -void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, - const int* num_tokens_per_rdma_rank, int* moe_recv_rdma_counter_mapped, - const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, - const bool* is_token_in_rank, int num_tokens, int num_channels, - int hidden_int4, int num_scales, int num_topk, int expert_alignment, - int* rdma_channel_prefix_matrix, int* recv_rdma_rank_prefix_sum, - int* gbl_channel_prefix_matrix, int* recv_gbl_rank_prefix_sum, - void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, - int** task_fifo_ptrs, int head, int rank, - cudaStream_t stream, int64_t num_rdma_bytes, int64_t num_nvl_bytes, - bool low_latency_mode) { +void notify_dispatch(const int* num_tokens_per_rank, + int* moe_recv_counter_mapped, + int num_ranks, + const int* num_tokens_per_rdma_rank, + int* moe_recv_rdma_counter_mapped, + const int* num_tokens_per_expert, + int* moe_recv_expert_counter_mapped, + int num_experts, + const bool* is_token_in_rank, + int num_tokens, + int num_channels, + int hidden_int4, + int num_scales, + int num_topk, + int expert_alignment, + int* rdma_channel_prefix_matrix, + int* recv_rdma_rank_prefix_sum, + int* gbl_channel_prefix_matrix, + int* recv_gbl_rank_prefix_sum, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_recv_tokens, + int** barrier_signal_ptrs, + int rank, + cudaStream_t stream, + int64_t num_rdma_bytes, + int64_t num_nvl_bytes, + bool low_latency_mode, + int head = 0) { #define NOTIFY_DISPATCH_LAUNCH_CASE(num_rdma_ranks) { \ auto notify_dispatch_func = low_latency_mode ? \ notify_dispatch : notify_dispatch; \ @@ -431,7 +449,7 @@ void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mappe rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, \ gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, \ rdma_buffer_ptr, \ - buffer_ptrs, task_fifo_ptrs, head, rank, \ + buffer_ptrs, barrier_signal_ptrs, head, rank, \ cpu_rdma_team); } break constexpr int kNumThreads = 256; @@ -1135,18 +1153,43 @@ asm volatile( #endif } -void dispatch(void* recv_x, float* recv_x_scales, int64_t* recv_topk_idx, float* recv_topk_weights, void* recv_src_meta, - const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, - int* send_rdma_head, int* send_nvl_head, - int* recv_rdma_channel_prefix_matrix, int* recv_gbl_channel_prefix_matrix, - const int* rdma_channel_prefix_matrix, const int* recv_rdma_rank_prefix_sum, - const int* gbl_channel_prefix_matrix, const int* recv_gbl_rank_prefix_sum, - int num_tokens, int hidden_int4, int num_scales, int num_topk, int num_experts, +void dispatch(void* recv_x, + float* recv_x_scales, + topk_idx_t* recv_topk_idx, + float* recv_topk_weights, + void* recv_src_meta, + const void* x, + const float* x_scales, + const topk_idx_t* topk_idx, + const float* topk_weights, + int* send_rdma_head, + int* send_nvl_head, + int* recv_rdma_channel_prefix_matrix, + int* recv_gbl_channel_prefix_matrix, + const int* rdma_channel_prefix_matrix, + const int* recv_rdma_rank_prefix_sum, + const int* gbl_channel_prefix_matrix, + const int* recv_gbl_rank_prefix_sum, const bool* is_token_in_rank, - void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, - int rank, int num_ranks, bool is_cached_dispatch, - cudaStream_t stream, int num_channels, bool low_latency_mode) { + int num_tokens, + int hidden_int4, + int num_scales, + int num_topk, + int num_experts, + int scale_token_stride, + int scale_hidden_stride, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_send_tokens, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_send_tokens, + int num_max_nvl_chunked_recv_tokens, + int rank, + int num_ranks, + bool is_cached_dispatch, + cudaStream_t stream, + int num_channels, + bool low_latency_mode) { constexpr int kNumDispatchRDMASenderWarps = 7; #define DISPATCH_LAUNCH_CASE(num_rdma_ranks) { \ @@ -1284,14 +1327,29 @@ __global__ void cached_notify(const int rdma_clean_offset, const int rdma_num_in } } -void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_topk_weights, - int num_ranks, int num_channels, int num_combined_tokens, int* combined_rdma_head, - const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, int* combined_nvl_head, - void* rdma_buffer_ptr, int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, int num_max_nvl_chunked_recv_tokens, - int** task_fifo_ptrs, int head, int rank, cudaStream_t stream, - int64_t num_rdma_bytes, int64_t num_nvl_bytes, - bool is_cached_dispatch, bool low_latency_mode) { +void cached_notify(int hidden_int4, + int num_scales, + int num_topk_idx, + int num_topk_weights, + int num_ranks, + int num_channels, + int num_combined_tokens, + int* combined_rdma_head, + const int* rdma_channel_prefix_matrix, + const int* rdma_rank_prefix_sum, + int* combined_nvl_head, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_recv_tokens, + int** barrier_signal_ptrs, + int rank, + cudaStream_t stream, + int64_t num_rdma_bytes, + int64_t num_nvl_bytes, + bool is_cached_dispatch, + bool low_latency_mode, + int head = 0) { const int num_threads = std::max(128, kWarpSize * num_channels); const auto num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; @@ -1313,7 +1371,7 @@ void cached_notify(int hidden_int4, int num_scales, int num_topk_idx, int num_to combined_rdma_head, num_combined_tokens, num_channels, rdma_channel_prefix_matrix, rdma_rank_prefix_sum, combined_nvl_head, rdma_buffer_ptr, - buffer_ptrs, task_fifo_ptrs, head, rank, num_ranks, + buffer_ptrs, barrier_signal_ptrs, head, rank, num_ranks, is_cached_dispatch, cpu_rdma_team); } @@ -1962,16 +2020,36 @@ combine(int4* combined_x, float* combined_topk_weights, #endif } + void combine(cudaDataType_t type, - void* combined_x, float* combined_topk_weights, + void* combined_x, + float* combined_topk_weights, const bool* is_combined_token_in_rank, - const void* x, const float* topk_weights, - const int* combined_rdma_head, const int* combined_nvl_head, - const void* src_meta, const int* rdma_channel_prefix_matrix, const int* rdma_rank_prefix_sum, const int* gbl_channel_prefix_matrix, - int num_tokens, int num_combined_tokens, int hidden, int num_topk, - void* rdma_buffer_ptr, int num_max_rdma_chunked_send_tokens, int num_max_rdma_chunked_recv_tokens, - void** buffer_ptrs, int num_max_nvl_chunked_send_tokens, int num_max_nvl_chunked_recv_tokens, - int rank, int num_ranks, cudaStream_t stream, int num_channels, bool low_latency_mode) + const void* x, + const float* topk_weights, + const void* bias_0, + const void* bias_1, + const int* combined_rdma_head, + const int* combined_nvl_head, + const void* src_meta, + const int* rdma_channel_prefix_matrix, + const int* rdma_rank_prefix_sum, + const int* gbl_channel_prefix_matrix, + int num_tokens, + int num_combined_tokens, + int hidden, + int num_topk, + void* rdma_buffer_ptr, + int num_max_rdma_chunked_send_tokens, + int num_max_rdma_chunked_recv_tokens, + void** buffer_ptrs, + int num_max_nvl_chunked_send_tokens, + int num_max_nvl_chunked_recv_tokens, + int rank, + int num_ranks, + cudaStream_t stream, + int num_channels, + bool low_latency_mode) { const int num_rdma_ranks = num_ranks / NUM_MAX_NVL_PEERS; diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index 6252cff..66ce3c6 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -63,8 +63,14 @@ __global__ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, #endif } -void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, - int64_t* clean_1, int num_clean_int_1, +void clean_low_latency_buffer(int64_t* clean_0, + int num_clean_int_0, + int64_t* clean_1, + int num_clean_int_1, + int rank, + int num_ranks, + int* mask_buffer_ptr, + int* sync_buffer_ptr, cudaStream_t stream) { constexpr int kNumThreads = 256; @@ -187,12 +193,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, for (int j = 0;j < kNumElemsPerRead;j += 2) { float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale}; #ifdef USE_ROCM -#if defined(__gfx942__) fp8x2_values[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E4M3_FNUZ); -#endif -#if defined(__gfx950__) - fp8x2_values[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E4M3); -#endif #else fp8x2_values[j / 2] = __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3); #endif @@ -230,8 +231,13 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, #else //DISABLE_CTX internode::shmem_ctx_schar_put_nbi_warp(ctx, reinterpret_cast(dst_ptr), reinterpret_cast(src_ptr), num_bytes_per_msg, dst_rank); #endif - } -#else //USE_ROCM + reinterpret_cast(dst_ptr), reinterpret_cast(src_ptr), num_bytes_per_msg, dst_rank); +#if defined(ROCM_DISABLE_CTX) + internode::shmem_fence(); +#else + internode::shmem_ctx_quiet(ctx); +#endif +#else nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx); #endif } else { @@ -321,7 +327,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx); #endif } else { - st_na_release(reinterpret_cast(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), -num_tokens_sent - 1); + st_na_release(reinterpret_cast(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), -num_tokens_sent - 1); } // Clean workspace for next use @@ -437,16 +443,36 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, #endif } -void dispatch(void* packed_recv_x, float* packed_recv_x_scales, - int* packed_recv_src_info, int64_t* packed_recv_layout_range, +void dispatch(void* packed_recv_x, + float* packed_recv_x_scales, + int* packed_recv_src_info, + int64_t* packed_recv_layout_range, int* packed_recv_count, - int* global_atomic_counter, - void* rdma_recv_x, int64_t* rdma_recv_count, void* rdma_x, - const void* x, const int64_t* topk_idx, - int64_t* next_clean, int num_next_clean_int, - int num_tokens, int hidden, int num_max_dispatch_tokens_per_rank, - int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8, - void* workspace, cudaStream_t stream, int phases) { + int* mask_buffer_ptr, + int* cumulative_local_expert_recv_stats, + int64_t* dispatch_wait_recv_cost_stats, + void* rdma_recv_x, + int64_t* rdma_recv_count, + void* rdma_x, + const void* x, + const topk_idx_t* topk_idx, + int64_t* next_clean, + int num_next_clean_int, + int num_tokens, + int hidden, + int num_max_dispatch_tokens_per_rank, + int num_topk, + int num_experts, + int rank, + int num_ranks, + bool use_fp8, + bool round_scale, + bool use_ue8m0, + void* workspace, + int num_device_sms, + cudaStream_t stream, + int phases, + int* global_atomic_counter = NULL) { #ifdef USE_ROCM constexpr int kNumWarpsPerGroup = 8; @@ -604,7 +630,6 @@ combine(void* combined_x, internode::shmem_ctx_quiet(ctx); #endif } - } } // Put finishing flag @@ -636,7 +661,7 @@ combine(void* combined_x, nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_flag + global_expert_idx, 1, dst_rank, local_expert_idx); #endif //USE_ROCM } else { - st_na_release(reinterpret_cast(rdma_recv_flag + global_expert_idx), 1); + st_na_release(reinterpret_cast(rdma_recv_flag + global_expert_idx), 1); } atomic_add_relaxed_global(atomic_clean_flag, -1); if constexpr (kMultinode){ @@ -718,15 +743,32 @@ combine(void* combined_x, } void combine(void* combined_x, - void* rdma_recv_x, int64_t* rdma_recv_flag, void* rdma_send_x, - const void* x, const int64_t* topk_idx, const float* topk_weights, - const int* src_info, const int64_t* layout_range, - int* global_atomic_counter, - int64_t* next_clean, int num_next_clean_int, - int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, - int num_topk, int num_experts, int rank, int num_ranks, - void* workspace, cudaStream_t stream, - int phases, bool zero_copy) { + void* rdma_recv_x, + int64_t* rdma_recv_flag, + void* rdma_send_x, + const void* x, + const topk_idx_t* topk_idx, + const float* topk_weights, + const int* src_info, + const int64_t* layout_range, + int* mask_buffer_ptr, + int64_t* combine_wait_recv_cost_stats, + int64_t* next_clean, + int num_next_clean_int, + int num_combined_tokens, + int hidden, + int num_max_dispatch_tokens_per_rank, + int num_topk, + int num_experts, + int rank, + int num_ranks, + bool use_logfmt, + void* workspace, + int num_device_sms, + cudaStream_t stream, + int phases, + bool zero_copy, + int* global_atomic_counter = NULL) { #ifdef USE_ROCM constexpr int kNumWarpsPerGroup = 8; constexpr int kNumWarpGroups = 2; @@ -764,6 +806,57 @@ LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \ #undef COMBINE_LAUNCH_CASE } + +template +__launch_bounds__(kNumThreads, 1) __global__ void query_mask_buffer(int* mask_buffer_ptr, int num_ranks, int* mask_tensor) { + const auto num_sms = static_cast(gridDim.x); + const auto sm_id = static_cast(blockIdx.x); + const auto num_threads = num_sms * kNumThreads; + const auto thread_id = sm_id * kNumThreads + static_cast(threadIdx.x); + for (int rank_id = thread_id; rank_id < num_ranks; rank_id += num_threads) { + mask_tensor[rank_id] = mask_buffer_ptr[rank_id]; + } +} + +void query_mask_buffer(int* mask_buffer_ptr, int num_ranks, int* mask_tensor, cudaStream_t stream) { + constexpr int num_sms = 1; + constexpr int kNumThreads = 1024; + SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); + LAUNCH_KERNEL(&cfg, query_mask_buffer, mask_buffer_ptr, num_ranks, mask_tensor); +} + +template +__launch_bounds__(kNumThreads, 1) __global__ void update_mask_buffer(int* mask_buffer_ptr, int rank_to_mask, bool mask) { + const auto sm_id = static_cast(blockIdx.x); + const auto thread_id = static_cast(threadIdx.x); + if (sm_id == 0 && thread_id == 0) { + atomicExch(mask_buffer_ptr + rank_to_mask, mask ? 1 : 0); + } +} + +void update_mask_buffer(int* mask_buffer_ptr, int rank, bool mask, cudaStream_t stream) { + constexpr int num_sms = 1; + constexpr int kNumThreads = 32; + SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); + LAUNCH_KERNEL(&cfg, update_mask_buffer, mask_buffer_ptr, rank, mask); +} + +template +__launch_bounds__(kNumThreads, 1) __global__ void clean_mask_buffer(int* mask_buffer_ptr, int num_ranks) { + auto thread_id = static_cast(threadIdx.x); + #pragma unroll + for (int i = thread_id; i < num_ranks; i += kNumThreads) + mask_buffer_ptr[i] = 0; +} + +void clean_mask_buffer(int* mask_buffer_ptr, int num_ranks, cudaStream_t stream) { + constexpr int num_sms = 1; + constexpr int kNumThreads = 32; + SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); + LAUNCH_KERNEL(&cfg, clean_mask_buffer, mask_buffer_ptr, num_ranks); +} + + } // namespace internode_ll } // namespace deep_ep diff --git a/csrc/kernels/intranode.cu b/csrc/kernels/intranode.cu index 6738ca8..3ef7e35 100644 --- a/csrc/kernels/intranode.cu +++ b/csrc/kernels/intranode.cu @@ -18,13 +18,14 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, auto sm_id = static_cast(blockIdx.x); auto thread_id = static_cast(threadIdx.x), num_threads = static_cast(blockDim.x); auto lane_id = thread_id % kWarpSize, warp_id = thread_id / kWarpSize, num_warps = num_threads / kWarpSize; - + if (sm_id == 0) { // Barrier first barrier_device(task_fifo_ptrs, head, rank); move_fifo_slots(head); __syncthreads(); + int *per_rank_buffer, *per_expert_buffer; if (thread_id < kNumRanks) { per_rank_buffer = reinterpret_cast(buffer_ptrs[thread_id]); @@ -36,9 +37,10 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, // - `per_expert_buffer[rank][i, j]` means the number of tokens from rank i to local expert j int num_experts_per_rank = num_experts / kNumRanks; if (thread_id < kNumRanks) { - #pragma unroll - for (int i = 0; i < kNumRanks; ++ i) - per_rank_buffer[rank * kNumRanks + i] = num_tokens_per_rank[i]; + //#pragma unroll + //for (int i = 0; i < kNumRanks; ++ i) + // per_rank_buffer[rank * kNumRanks + i] = num_tokens_per_rank[i]; + per_rank_buffer[rank * kNumRanks + thread_id] = num_tokens_per_rank[thread_id]; #pragma unroll for (int i = 0; i < num_experts_per_rank; ++ i) per_expert_buffer[rank * num_experts_per_rank + i] = num_tokens_per_expert[thread_id * num_experts_per_rank + i]; @@ -112,25 +114,39 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, } } -void notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, int num_ranks, - const int* num_tokens_per_expert, int* moe_recv_expert_counter_mapped, int num_experts, - int num_tokens, const bool* is_token_in_rank, int* channel_prefix_matrix, - int* rank_prefix_matrix_copy, int num_memset_int, int expert_alignment, - void** buffer_ptrs, int **task_fifo_ptrs, int head, int rank, - cudaStream_t stream, int num_channels) { +void notify_dispatch(const int* num_tokens_per_rank, + int* moe_recv_counter_mapped, + int num_ranks, + const int* num_tokens_per_expert, + int* moe_recv_expert_counter_mapped, + int num_experts, + int num_tokens, + const bool* is_token_in_rank, + int* channel_prefix_matrix, + int* rank_prefix_matrix_copy, + int num_memset_int, + int expert_alignment, + void** buffer_ptrs, + int** barrier_signal_ptrs, + int rank, + cudaStream_t stream, + int num_channels, + int head=0) { + #define NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \ LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, notify_dispatch, \ num_tokens_per_rank, moe_recv_counter_mapped, \ num_tokens_per_expert, moe_recv_expert_counter_mapped, num_experts, \ num_tokens, num_channels, is_token_in_rank, channel_prefix_matrix, \ rank_prefix_matrix_copy, num_memset_int, expert_alignment, \ - buffer_ptrs, task_fifo_ptrs, head, rank); \ + buffer_ptrs, barrier_signal_ptrs, head, rank); \ break constexpr int kNumThreads = 128; EP_HOST_ASSERT(num_experts % num_ranks == 0); EP_HOST_ASSERT(num_experts / num_ranks <= kNumThreads and num_ranks <= kNumThreads); + SETUP_LAUNCH_CONFIG(1 + num_ranks, kNumThreads, stream); SWITCH_RANKS(NOTIFY_DISPATCH_LAUNCH_CASE); #undef NOTIFY_DISPATCH_LAUNCH_CASE @@ -177,7 +193,7 @@ cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, void cached_notify_dispatch(const int* rank_prefix_matrix, int num_memset_int, void** buffer_ptrs, int** task_fifo_ptrs, - int head, int rank, int num_ranks, cudaStream_t stream) { + int rank, int num_ranks, cudaStream_t stream, int head=0) { #define CACHED_NOTIFY_DISPATCH_LAUNCH_CASE(ranks) \ LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, cached_notify_dispatch, \ rank_prefix_matrix, num_memset_int, buffer_ptrs, task_fifo_ptrs, head, rank); \ @@ -493,12 +509,36 @@ dispatch(int4* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_to } } -void dispatch(void* recv_x, float* recv_x_scales, int* recv_src_idx, int64_t* recv_topk_idx, float* recv_topk_weights, int* recv_channel_offset, - int* send_head, const void* x, const float* x_scales, const int64_t* topk_idx, const float* topk_weights, - const bool* is_token_in_rank, const int* channel_prefix_matrix, - int num_tokens, int hidden_int4, int num_topk, int num_experts, int num_scales, - void** buffer_ptrs, int rank, int num_ranks, - cudaStream_t stream, int num_sms, int num_max_send_tokens, int num_recv_buffer_tokens) { + +void dispatch(void* recv_x, + float* recv_x_scales, + int* recv_src_idx, + topk_idx_t* recv_topk_idx, + float* recv_topk_weights, + int* recv_channel_offset, + int* send_head, + const void* x, + const float* x_scales, + const topk_idx_t* topk_idx, + const float* topk_weights, + const bool* is_token_in_rank, + const int* channel_prefix_matrix, + int num_tokens, + int num_worst_tokens,// + int hidden_int4, + int num_topk, + int num_experts, + int num_scales, + int scale_token_stride,// + int scale_hidden_stride,// + void** buffer_ptrs, + int rank, + int num_ranks, + cudaStream_t stream, + int num_sms, + int num_max_send_tokens, + int num_recv_buffer_tokens) +{ constexpr int kNumThreads = (kWarpSize == 64 ? 1024 : 512); #define DISPATCH_LAUNCH_CASE(ranks) \ @@ -574,8 +614,8 @@ cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int void cached_notify_combine(void** buffer_ptrs, int* send_head, int num_channels, int num_recv_tokens, int num_memset_int, - int** task_fifo_ptrs, int head, int rank, int num_ranks, - cudaStream_t stream) { + int** task_fifo_ptrs, int rank, int num_ranks, + cudaStream_t stream, int head = 0 ) { #define CACHED_NOTIFY_COMBINE(ranks) \ LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, cached_notify_combine, \ buffer_ptrs, send_head, num_channels, num_recv_tokens, num_memset_int, task_fifo_ptrs, head, rank); \ @@ -605,13 +645,13 @@ combine(dtype_t* recv_x, float* recv_topk_weights, const auto num_channels = num_sms / 2; const bool is_sender = sm_id % 2 == 0; const int responsible_channel = sm_id / 2; + EP_DEVICE_ASSERT(num_topk <= kWarpSize); constexpr int kDtypePerInt4 = sizeof(int4) / sizeof(dtype_t); int hidden_int4 = hidden * sizeof(dtype_t) / sizeof(int4); auto x_int4 = reinterpret_cast(x); auto recv_int4 = reinterpret_cast(recv_x); - if (is_sender) { // Workers for sending // Several warps are responsible for a single rank @@ -867,13 +907,25 @@ combine(dtype_t* recv_x, float* recv_topk_weights, } void combine(cudaDataType_t type, - void* recv_x, float* recv_topk_weights, - const void* x, const float* topk_weights, - const int* src_idx, const int* rank_prefix_matrix, const int* channel_prefix_matrix, - int* send_head, int num_tokens, int num_recv_tokens, int hidden, int num_topk, - void** buffer_ptrs, int rank, int num_ranks, - cudaStream_t stream, int num_sms, - int num_max_send_tokens, int num_recv_buffer_tokens) { + void* recv_x, + float* recv_topk_weights, + const void* x, + const float* topk_weights, + const int* src_idx, + const int* rank_prefix_matrix, + const int* channel_prefix_matrix, + int* send_head, + int num_tokens, + int num_recv_tokens, + int hidden, + int num_topk, + void** buffer_ptrs, + int rank, + int num_ranks, + cudaStream_t stream, + int num_sms, + int num_max_send_tokens, + int num_recv_buffer_tokens) { constexpr int kNumThreads = kWarpSize == 64 ? 1024 : 768; #define COMBINE_LAUNCH_CASE(dtype, ranks) \ diff --git a/csrc/kernels/launch.cuh b/csrc/kernels/launch.cuh index 7ab2a40..a4a064d 100644 --- a/csrc/kernels/launch.cuh +++ b/csrc/kernels/launch.cuh @@ -82,6 +82,18 @@ inline void LAUNCH_KERNEL_NON_COOPERATIVE(T &&config, Kern &&kernel, #endif // #if defined(USE_ROCM) #endif // #ifndef LAUNCH_KERNEL + +#ifndef SET_SHARED_MEMORY_FOR_TMA +#ifndef DISABLE_SM90_FEATURES +#define SET_SHARED_MEMORY_FOR_TMA(kernel) \ + EP_HOST_ASSERT(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) == cudaSuccess); \ + cfg.dynamicSmemBytes = smem_size; +#else +#define SET_SHARED_MEMORY_FOR_TMA(kernel) void() +#endif +#endif + + #define SWITCH_RANKS(case_macro) \ switch (num_ranks) { \ case 2: case_macro(2); \ diff --git a/csrc/kernels/runtime.cu b/csrc/kernels/runtime.cu index e336f34..7fa6031 100644 --- a/csrc/kernels/runtime.cu +++ b/csrc/kernels/runtime.cu @@ -16,7 +16,7 @@ __global__ void barrier(int** task_fifo_ptrs, int head, int rank) { barrier_device(task_fifo_ptrs, head, rank); } -void barrier(int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream) { +/*void barrier(int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream) { #define BARRIER_LAUNCH_CASE(ranks) \ LAUNCH_KERNEL(&cfg, barrier, task_fifo_ptrs, head, rank); \ break @@ -24,6 +24,16 @@ void barrier(int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream SETUP_LAUNCH_CONFIG(1, kWarpSize, stream); SWITCH_RANKS(BARRIER_LAUNCH_CASE); #undef BARRIER_LAUNCH_CASE +}*/ + +void barrier(int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream, int head = 0) { +#define BARRIER_LAUNCH_CASE(ranks) \ + LAUNCH_KERNEL(&cfg, barrier, barrier_signal_ptrs, head, rank); \ + break + + SETUP_LAUNCH_CONFIG(1, kWarpSize, stream); + SWITCH_RANKS(BARRIER_LAUNCH_CASE); +#undef BARRIER_LAUNCH_CASE } } // namespace intranode @@ -53,8 +63,13 @@ int init(const std::vector &root_unique_id_val, int rank, int num_ranks if (low_latency_mode and num_ranks > NUM_MAX_NVL_PEERS) { EP_HOST_ASSERT(cpu_rdma_team == SHMEM_TEAM_INVALID); EP_HOST_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0); - EP_HOST_ASSERT(shmem_team_split_strided(SHMEM_TEAM_WORLD, rank % NUM_MAX_NVL_PEERS, NUM_MAX_NVL_PEERS, - num_ranks / NUM_MAX_NVL_PEERS, &cpu_rdma_team_config, 0, &cpu_rdma_team) == 0); + EP_HOST_ASSERT(shmem_team_split_strided(SHMEM_TEAM_WORLD, + rank % NUM_MAX_NVL_PEERS, + NUM_MAX_NVL_PEERS, + num_ranks / NUM_MAX_NVL_PEERS, + &cpu_rdma_team_config, + 0, + &cpu_rdma_team) == 0); //TODO::issue on ROCM: enable it for ROCM #ifndef USE_ROCM EP_HOST_ASSERT(cpu_rdma_team != SHMEM_TEAM_INVALID); diff --git a/csrc/kernels/shmem_wrapper.cuh b/csrc/kernels/shmem_wrapper.cuh index 6e84b6f..1e90717 100644 --- a/csrc/kernels/shmem_wrapper.cuh +++ b/csrc/kernels/shmem_wrapper.cuh @@ -65,7 +65,7 @@ static inline const auto &shmem_ibgda_amo_nonfetch_add = #if !defined(ROCM_DISABLE_CTX) using shmem_ctx_t = rocshmem::rocshmem_ctx_t; static inline const auto &shmem_wg_ctx_create = [] __device__(rocshmem::rocshmem_ctx_t *ctx) { - rocshmem::rocshmem_wg_ctx_create(0, ctx); + return rocshmem::rocshmem_wg_ctx_create(0, ctx); }; static inline const auto &shmem_wg_ctx_destroy = rocshmem::rocshmem_wg_ctx_destroy; diff --git a/csrc/kernels/utils.cuh b/csrc/kernels/utils.cuh index 02f6cd0..3410e0c 100644 --- a/csrc/kernels/utils.cuh +++ b/csrc/kernels/utils.cuh @@ -2,6 +2,10 @@ #include "exception.cuh" +#ifdef USE_ROCM +#define syncthreads() __syncthreads() +#endif + #define UNROLLED_WARP_COPY(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \ { \ constexpr int kLoopStride = kWarpSize * (UNROLL_FACTOR); \ @@ -596,15 +600,6 @@ __device__ __forceinline__ void st_na_release(const uint32_t *ptr, uint32_t val) #endif } -__device__ __forceinline__ void st_na_release(const int64_t *ptr, int64_t val) { -#ifdef USE_ROCM - int64_t* non_const_ptr = const_cast(ptr); - __hip_atomic_store(non_const_ptr, val, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_AGENT); -#else - asm volatile("st.release.gpu.global.L1::no_allocate.b64 [%0], %1;" : : "l"(ptr), "l"(val)); -#endif -} - __device__ __forceinline__ void st_na_release(const uint64_t *ptr, uint64_t val) { #ifdef USE_ROCM uint64_t* non_const_ptr = const_cast(ptr); @@ -795,4 +790,64 @@ barrier_device(int **task_fifo_ptrs, int head, int rank, int tag = 0) { timeout_check(task_fifo_ptrs, head, rank, 0, tag); } + +template +__forceinline__ __device__ void barrier_block(int** barrier_signal_ptrs, int rank) { + auto thread_id = static_cast(threadIdx.x); + + // For non-sync-only cases, the memory operations by other threads in the block must be visible to the `sys` scope + if constexpr (not kSyncOnly) { + memory_fence(); + __syncthreads(); + } + + // Add self-ranks, sub other ranks + if (thread_id < kNumRanks) { + atomicAdd_system(barrier_signal_ptrs[rank] + thread_id, FINISHED_SUM_TAG); + atomicSub_system(barrier_signal_ptrs[thread_id] + rank, FINISHED_SUM_TAG); + } + EP_DEVICE_ASSERT(kNumRanks <= blockDim.x); + + // Check timeout + auto start_time = clock64(); + while (true) { + auto value = thread_id < kNumRanks ? ld_volatile_global(barrier_signal_ptrs[rank] + thread_id) : 0; + if (__all_sync(kFullWarpMask, value <= 0)) + break; + + if (clock64() - start_time > NUM_TIMEOUT_CYCLES and thread_id < kNumRanks) { + printf("DeepEP timeout check failed: rank = %d, thread = %d, value = %d)\n", rank, thread_id, value); + trap(); + } + } + __syncthreads(); +} + + +__device__ __forceinline__ uint32_t elect_one_sync() { +#ifndef DISABLE_SM90_FEATURES + uint32_t pred = 0; + asm volatile( + "{\n" + ".reg .b32 %%rx;\n" + ".reg .pred %%px;\n" + " elect.sync %%rx|%%px, %1;\n" + "@%%px mov.s32 %0, 1;\n" + "}\n" + : "+r"(pred) + : "r"(0xffffffff)); + return pred; +#else + return get_lane_id() == 0; +#endif +} + } // namespace deep_ep + +template +__host__ __device__ constexpr dtype_t align_down(dtype_t a, dtype_t b) { + return a / b * b; +} + + + diff --git a/deep_ep/__init__.py b/deep_ep/__init__.py index 7fb801f..2b20d83 100644 --- a/deep_ep/__init__.py +++ b/deep_ep/__init__.py @@ -4,4 +4,4 @@ from .buffer import Buffer # noinspection PyUnresolvedReferences -from deep_ep_cpp import Config +from deep_ep_cpp import Config, topk_idx_t diff --git a/deep_ep/buffer.py b/deep_ep/buffer.py index f0a8ee1..fcf8d77 100644 --- a/deep_ep/buffer.py +++ b/deep_ep/buffer.py @@ -7,7 +7,7 @@ import deep_ep_cpp # noinspection PyUnresolvedReferences from deep_ep_cpp import Config, EventHandle -from .utils import EventOverlap +from .utils import EventOverlap, check_nvlink_connections class Buffer: @@ -29,9 +29,18 @@ class Buffer: num_sms: int = 20 - def __init__(self, group: dist.ProcessGroup, - num_nvl_bytes: int = 0, num_rdma_bytes: int = 0, - low_latency_mode: bool = False, num_qps_per_rank: int = 1) -> None: + def __init__(self, + group: Optional[dist.ProcessGroup], + num_nvl_bytes: int = 0, + num_rdma_bytes: int = 0, + low_latency_mode: bool = False, + num_qps_per_rank: int = 24, + allow_nvlink_for_low_latency_mode: bool = True, + allow_mnnvl: bool = False, + use_fabric: bool = False, + explicitly_destroy: bool = False, + enable_shrink: bool = False, + comm: Optional["mpi4py.MPI.Comm"] = None) -> None: # noqa: F821 """ Initialize the communication buffer. @@ -42,56 +51,105 @@ def __init__(self, group: dist.ProcessGroup, low_latency_mode: whether to enable low-latency mode. num_qps_per_rank: the number of QPs for RDMA, the low-latency mode requires that this number equals to the number of local experts. - """ + allow_nvlink_for_low_latency_mode: whether allow NVLink traffic for low-latency mode, you should notice + this is somehow incompatible with the hook-based overlapping. + Warning: PCIe connections may lead to errors due to memory ordering issues, + please make sure all connections are via NVLink. + allow_mnnvl: whether to allow MNNVL + use_fabric: whether to use fabric API for memory buffers. + enable_shrink: whether to enable shrink mode. The enable mode allocates a mask buffer to support masking ranks dynamically. + explicitly_destroy: If this flag is set to True, you need to explicitly call `destroy()` to release resources; + otherwise, the resources will be released by the destructor. + Note: Releasing resources in the destructor may cause Python's exception handling process to hang. + comm: the `mpi4py.MPI.Comm` communicator to use in case the group parameter is absent. + """ + check_nvlink_connections(group) # Initialize the CPP runtime - self.rank = group.rank() - self.group_size = group.size() - self.group = group + if group is not None: + self.rank = group.rank() + self.group = group + self.group_size = group.size() + + def all_gather_object(obj): + object_list = [None] * self.group_size + dist.all_gather_object(object_list, obj, group) + return object_list + elif comm is not None: + self.rank = comm.Get_rank() + self.group = comm + self.group_size = comm.Get_size() + + def all_gather_object(obj): + return comm.allgather(obj) + else: + raise ValueError("Either 'group' or 'comm' must be provided.") self.num_nvl_bytes = num_nvl_bytes self.num_rdma_bytes = num_rdma_bytes self.low_latency_mode = low_latency_mode - self.runtime = deep_ep_cpp.Buffer(self.rank, self.group_size, num_nvl_bytes, num_rdma_bytes, low_latency_mode) + self.explicitly_destroy = explicitly_destroy + self.enable_shrink = enable_shrink + self.runtime = deep_ep_cpp.Buffer(self.rank, self.group_size, num_nvl_bytes, num_rdma_bytes, low_latency_mode, explicitly_destroy, + enable_shrink)#, use_fabric) # Synchronize device IDs - device_ids = [None, ] * self.group_size local_device_id = self.runtime.get_local_device_id() - dist.all_gather_object(device_ids, local_device_id, group) + device_ids = all_gather_object(local_device_id) # Synchronize IPC handles - ipc_handles = [None, ] * self.group_size local_ipc_handle = self.runtime.get_local_ipc_handle() - dist.all_gather_object(ipc_handles, local_ipc_handle, group) + ipc_handles = all_gather_object(local_ipc_handle) # Synchronize NVSHMEM unique IDs root_unique_id = None if self.runtime.get_num_rdma_ranks() > 1 or low_latency_mode: - # Enable IBGDA for the low latency mode, which refers to "no package forwarding between NVLink and RDMA" - if low_latency_mode: - assert num_qps_per_rank > 0 - os.environ['NVSHMEM_DISABLE_P2P'] = '1' - os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1' - os.environ['NVSHMEM_IBGDA_NIC_HANDLER'] = 'gpu' - os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = f'{num_qps_per_rank}' - # Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check - os.environ['NVSHMEM_QP_DEPTH'] = '1024' - # NOTES: NVSHMEM initialization requires at least 256 MiB - os.environ['NVSHMEM_CUMEM_GRANULARITY'] = f'{2 ** 29}' - dev_id = torch.cuda.current_device() - os.environ["NVSHMEM_HCA_LIST"] = f"fic2_soe_bond{dev_id // 2}:1" + # Enable IBGDA + assert num_qps_per_rank > 0 + os.environ['NVSHMEM_DISABLE_P2P'] = '0' if allow_nvlink_for_low_latency_mode else '1' + os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1' + os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = f'{num_qps_per_rank}' + + # Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check + self.nvshmem_qp_depth = int(os.environ.get('NVSHMEM_QP_DEPTH', '1024')) + os.environ['NVSHMEM_QP_DEPTH'] = str(self.nvshmem_qp_depth) + + # Reduce gpu memory usage + # 6 default teams + 1 extra team + os.environ['NVSHMEM_MAX_TEAMS'] = '7' + # Disable NVLink SHArP + os.environ['NVSHMEM_DISABLE_NVLS'] = '1' + # NOTES: NVSHMEM initialization requires at least 256 MiB + os.environ['NVSHMEM_CUMEM_GRANULARITY'] = f'{2 ** 29}' + + if not allow_mnnvl: + # Disable multi-node NVLink detection + os.environ['NVSHMEM_DISABLE_MNNVL'] = '1' + # Synchronize using the root ID - nvshmem_unique_ids = [None, ] * self.group_size if (low_latency_mode and self.rank == 0) or (not low_latency_mode and self.runtime.get_rdma_rank() == 0): root_unique_id = self.runtime.get_local_nvshmem_unique_id() - - dist.all_gather_object(nvshmem_unique_ids, root_unique_id, group) + nvshmem_unique_ids = all_gather_object(root_unique_id) root_unique_id = nvshmem_unique_ids[0 if low_latency_mode else self.runtime.get_root_rdma_rank(True)] # Make CPP runtime available self.runtime.sync(device_ids, ipc_handles, root_unique_id) - assert self.runtime.is_available() + def destroy(self): + """ + Destroy the cpp runtime and release resources. + + """ + + assert self.explicitly_destroy, '`explicitly_destroy` flag must be set' + + self.runtime.destroy() + self.runtime = None + + @staticmethod + def is_sm90_compiled(): + return deep_ep_cpp.is_sm90_compiled() + @staticmethod def set_num_sms(new_num_sms: int) -> None: """ @@ -130,8 +188,21 @@ def get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank: int, hidden """ return deep_ep_cpp.get_low_latency_rdma_size_hint(num_max_dispatch_tokens_per_rank, hidden, num_ranks, num_experts) - def get_local_buffer_tensor(self, dtype: torch.dtype, size: Optional[torch.Size] = None, - offset: int = 0, use_rdma_buffer: bool = False) -> torch.Tensor: + def get_comm_stream(self) -> torch.Stream: + """ + Get the communication stream. + + Returns: + stream: the communication stream. + """ + ts: torch.Stream = self.runtime.get_comm_stream() + return torch.cuda.Stream(stream_id=ts.stream_id, device_index=ts.device_index, device_type=ts.device_type) + + def get_local_buffer_tensor(self, + dtype: torch.dtype, + size: Optional[torch.Size] = None, + offset: int = 0, + use_rdma_buffer: bool = False) -> torch.Tensor: """ Get the raw buffer (slice supported) as a PyTorch tensor. @@ -148,6 +219,16 @@ def get_local_buffer_tensor(self, dtype: torch.dtype, size: Optional[torch.Size] assert tensor.numel() >= size.numel() return tensor[:size.numel()].view(size) + @staticmethod + def _unpack_bias(bias: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]): + bias_0, bias_1 = None, None + if isinstance(bias, torch.Tensor): + bias_0 = bias + elif isinstance(bias, tuple): + assert len(bias) == 2 + bias_0, bias_1 = bias + return bias_0, bias_1 + @staticmethod def get_dispatch_config(num_ranks: int) -> Config: """ @@ -160,6 +241,7 @@ def get_dispatch_config(num_ranks: int) -> Config: config: the recommended config. """ + # TODO: automatically tune config_map = { 2: Config(Buffer.num_sms, 16, 256, 6, 128), 4: Config(Buffer.num_sms, 16, 256, 6, 128), @@ -187,6 +269,7 @@ def get_combine_config(num_ranks: int) -> Config: config: the recommended config. """ + # TODO: automatically tune config_map = { 2: Config(Buffer.num_sms, 6, 256, 6, 128), 4: Config(Buffer.num_sms, 6, 256, 6, 128), @@ -211,8 +294,8 @@ def get_dispatch_layout(self, topk_idx: torch.Tensor, num_experts: int, Calculate the layout required for later communication. Arguments: - topk_idx: `[num_tokens, num_topk]`, dtype must be `torch.int64`, the expert indices selected by each token, - `-1` means no selections. + topk_idx: `[num_tokens, num_topk]`, dtype must be `deep_ep.topk_idx_t` (typically `torch.int64`), the expert + indices selected by each token, `-1` means no selections. num_experts: the number of experts. previous_event: the event to wait before actually executing the kernel. async_finish: the current stream will not wait for the communication kernels to be finished if set. @@ -236,7 +319,8 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], handle: Optional[Tuple] = None, num_tokens_per_rank: Optional[torch.Tensor] = None, num_tokens_per_rdma_rank: Optional[torch.Tensor] = None, is_token_in_rank: Optional[torch.Tensor] = None, num_tokens_per_expert: Optional[torch.Tensor] = None, - topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None, expert_alignment: int = 1, + topk_idx: Optional[torch.Tensor] = None, topk_weights: Optional[torch.Tensor] = None, + expert_alignment: int = 1, num_worst_tokens: int = 0, config: Optional[Config] = None, previous_event: Optional[EventOverlap] = None, async_finish: bool = False, allocate_on_comm_stream: bool = False) -> \ @@ -259,10 +343,12 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], rank (with the same GPU index), return `None` for intranode settings. is_token_in_rank: `[num_tokens, num_ranks]` with `torch.bool`, whether a token be sent to a rank. num_tokens_per_expert: `[num_experts]` with `torch.int`, the number of tokens to be sent to each expert. - topk_idx: `[num_tokens, num_topk]` with `torch.int64`, the expert indices selected by each token, - `-1` means no selections. + topk_idx: `[num_tokens, num_topk]` with `deep_ep.topk_idx_t` (typically `torch.int64`), the expert indices + selected by each token, `-1` means no selections. topk_weights: `[num_tokens, num_topk]` with `torch.float`, the expert weights of each token to dispatch. expert_alignment: align the number of tokens received by each local expert to this variable. + num_worst_tokens: the worst number of tokens to receive, if specified, there will be no CPU sync, and it + will be CUDA-graph compatible. Please also notice that this flag is for intranode only. config: the performance tuning config. previous_event: the event to wait before actually executing the kernel. async_finish: the current stream will not wait for the communication kernels to be finished if set. @@ -274,7 +360,8 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], recv_topk_idx: received expert indices. recv_topk_weights: received expert weights. num_recv_tokens_per_expert_list: Python list shaped `[num_local_experts]`, the received token count by - each local expert, aligned to the input `expert_alignment`. + each local expert, aligned to the input `expert_alignment`. If `num_worst_tokens` is specified, the list + will be empty. handle: the returned communication handle. event: the event after executing the kernel (valid only if `async_finish` is set). """ @@ -283,8 +370,10 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], # Internode if self.runtime.get_num_rdma_ranks() > 1: - return self.internode_dispatch(x, handle, num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert, - topk_idx, topk_weights, expert_alignment, config, previous_event, async_finish, allocate_on_comm_stream) + assert num_worst_tokens == 0, 'Internode dispatch does not support `num_worst_tokens > 0`' + return self.internode_dispatch(x, handle, num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, + num_tokens_per_expert, topk_idx, topk_weights, expert_alignment, config, previous_event, + async_finish, allocate_on_comm_stream) # Launch the kernel with cached or non-cached mode x, x_scales = x if isinstance(x, tuple) else (x, None) @@ -293,22 +382,26 @@ def dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head = handle num_recv_tokens = recv_src_idx.size(0) recv_x, recv_x_scales, _, _, _, _, _, _, _, _, event = self.runtime.intranode_dispatch( - x, x_scales, None, None, - None, is_token_in_rank, None, num_recv_tokens, rank_prefix_matrix, channel_prefix_matrix, - expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream) + x, x_scales, None, None, None, is_token_in_rank, None, num_recv_tokens, rank_prefix_matrix, channel_prefix_matrix, + expert_alignment, num_worst_tokens, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream) return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event) else: assert num_tokens_per_rank is not None and is_token_in_rank is not None and num_tokens_per_expert is not None recv_x, recv_x_scales, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, send_head, event = \ self.runtime.intranode_dispatch(x, x_scales, topk_idx, topk_weights, - num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, 0, None, None, - expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream) + num_tokens_per_rank, is_token_in_rank, num_tokens_per_expert, 0, None, None, + expert_alignment, num_worst_tokens, config, + getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream) handle = (rank_prefix_matrix, channel_prefix_matrix, recv_channel_prefix_matrix, recv_src_idx, is_token_in_rank, send_head) - return (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(event) + return ( + recv_x, recv_x_scales + ) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap( + event) # noinspection PyTypeChecker def combine(self, x: torch.Tensor, handle: Tuple, topk_weights: Optional[torch.Tensor] = None, + bias: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None, config: Optional[Config] = None, previous_event: Optional[EventOverlap] = None, async_finish: bool = False, allocate_on_comm_stream: bool = False) -> \ @@ -324,6 +417,7 @@ def combine(self, x: torch.Tensor, handle: Tuple, x: `[num_tokens, hidden]` with `torch.bfloat16`, the tokens to send for reducing to its original ranks. handle: a must-set communication handle, you can obtain this from the dispatch function. topk_weights: `[num_tokens, num_topk]` with `torch.float`, the tokens' top-k weights for reducing to its original ranks. + bias: 0, 1 or 2 `[num_tokens, hidden]` with `torch.bfloat16` final bias to the output. config: the performance tuning config. previous_event: the event to wait before actually executing the kernel. async_finish: the current stream will not wait for the communication kernels to be finished if set. @@ -339,16 +433,17 @@ def combine(self, x: torch.Tensor, handle: Tuple, # Internode if self.runtime.get_num_rdma_ranks() > 1: - return self.internode_combine(x, handle, topk_weights, config, previous_event, async_finish, allocate_on_comm_stream) + return self.internode_combine(x, handle, topk_weights, bias, config, previous_event, async_finish, allocate_on_comm_stream) # NOTES: the second `_` is for the sending side, so we should use the third one rank_prefix_matrix, _, channel_prefix_matrix, src_idx, is_recv_token_in_rank, send_head = handle + #bias_0, bias_1 = Buffer._unpack_bias(bias) # Launch the kernel - recv_x, recv_topk_weights, event = self.runtime.intranode_combine( - x, topk_weights, - src_idx, rank_prefix_matrix, channel_prefix_matrix, send_head, config, - getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream) + recv_x, recv_topk_weights, event = self.runtime.intranode_combine(x, topk_weights, src_idx, rank_prefix_matrix, + channel_prefix_matrix, send_head, config, + getattr(previous_event, 'event', + None), async_finish, allocate_on_comm_stream) return recv_x, recv_topk_weights, EventOverlap(event) # noinspection PyTypeChecker @@ -379,9 +474,7 @@ def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Te num_recv_tokens = recv_src_meta.size(0) num_rdma_recv_tokens = send_nvl_head.size(0) recv_x, recv_x_scales, _, _, _, _, _, _, _, _, _, _, _, _, event = self.runtime.internode_dispatch( - x, x_scales, topk_idx, topk_weights, - None, None, is_token_in_rank, None, - num_recv_tokens, num_rdma_recv_tokens, + x, x_scales, topk_idx, topk_weights, None, None, is_token_in_rank, None, num_recv_tokens, num_rdma_recv_tokens, rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream) return (recv_x, recv_x_scales) if x_scales is not None else recv_x, None, None, None, None, EventOverlap(event) @@ -396,15 +489,18 @@ def internode_dispatch(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Te num_tokens_per_rank, num_tokens_per_rdma_rank, is_token_in_rank, num_tokens_per_expert, 0, 0, None, None, None, None, expert_alignment, config, getattr(previous_event, 'event', None), async_finish, allocate_on_comm_stream) - handle = (is_token_in_rank, - rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, - recv_rdma_channel_prefix_matrix, recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, - recv_src_meta, send_rdma_head, send_nvl_head) - return (recv_x, recv_x_scales) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap(event) + handle = (is_token_in_rank, rdma_channel_prefix_matrix, gbl_channel_prefix_matrix, recv_rdma_channel_prefix_matrix, + recv_rdma_rank_prefix_sum, recv_gbl_channel_prefix_matrix, recv_gbl_rank_prefix_sum, recv_src_meta, send_rdma_head, + send_nvl_head) + return ( + recv_x, recv_x_scales + ) if x_scales is not None else recv_x, recv_topk_idx, recv_topk_weights, num_recv_tokens_per_expert_list, handle, EventOverlap( + event) # noinspection PyTypeChecker def internode_combine(self, x: torch.Tensor, handle: Union[tuple, list], topk_weights: Optional[torch.Tensor] = None, + bias: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] = None, config: Optional[Config] = None, previous_event: Optional[EventOverlap] = None, async_finish: bool = False, allocate_on_comm_stream: bool = False) -> \ @@ -415,19 +511,20 @@ def internode_combine(self, x: torch.Tensor, handle: Union[tuple, list], """ assert config is not None - # Unpack handle + # Unpack handle and bias is_combined_token_in_rank, \ _, _, \ rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, gbl_rank_prefix_sum, \ src_meta, send_rdma_head, send_nvl_head = handle + bias_0, bias_1 = Buffer._unpack_bias(bias) # Launch the kernel - combined_x, combined_topk_weights, event = self.runtime.internode_combine( - x, topk_weights, - src_meta, is_combined_token_in_rank, - rdma_channel_prefix_matrix, rdma_rank_prefix_sum, gbl_channel_prefix_matrix, - send_rdma_head, send_nvl_head, config, getattr(previous_event, 'event', None), - async_finish, allocate_on_comm_stream) + combined_x, combined_topk_weights, event = self.runtime.internode_combine(x, topk_weights, bias_0, bias_1, src_meta, + is_combined_token_in_rank, rdma_channel_prefix_matrix, + rdma_rank_prefix_sum, gbl_channel_prefix_matrix, + send_rdma_head, send_nvl_head, config, + getattr(previous_event, 'event', + None), async_finish, allocate_on_comm_stream) return combined_x, combined_topk_weights, EventOverlap(event) def clean_low_latency_buffer(self, num_max_dispatch_tokens_per_rank: int, hidden: int, num_experts: int) -> None: @@ -447,99 +544,147 @@ def clean_low_latency_buffer(self, num_max_dispatch_tokens_per_rank: int, hidden # noinspection PyTypeChecker def low_latency_dispatch(self, x: torch.Tensor, topk_idx: torch.Tensor, num_max_dispatch_tokens_per_rank: int, num_experts: int, - use_fp8: bool = True, async_finish: bool = False, return_recv_hook: bool = False) -> \ + cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None, + dispatch_wait_recv_cost_stats: Optional[torch.Tensor] = None, + use_fp8: bool = True, round_scale: bool = False, use_ue8m0: bool = False, + async_finish: bool = False, return_recv_hook: bool = False) -> \ Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, Tuple, EventOverlap, Callable]: """ A low-latency implementation for dispatching with IBGDA. This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA (specifically, IBGDA must be enabled). - Even for ranks in the same node, NVLink are fully disabled for simplicity. - Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2 - low-latency kernels' result tensor at a single moment. + Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2 + low-latency kernels' result tensors at a single moment. Arguments: x: `torch.Tensor` with `torch.bfloat16`, shaped as `[num_tokens, hidden]`, only several hidden shapes are supported. The number of tokens to be dispatched must be less than `num_max_dispatch_tokens_per_rank`. - topk_idx: `torch.Tensor` with `torch.int64`, shaped as `[num_tokens, num_topk]`, only several top-k shapes - are supported. `-1` indices (not selecting any expert) are supported. + topk_idx: `torch.Tensor` with `deep_ep.topk_idx_t` (typically `torch.int64`), shaped as `[num_tokens, num_topk]`, + only several top-k shapes are supported. `-1` indices (not selecting any expert) are supported. num_max_dispatch_tokens_per_rank: the maximum number of tokens to dispatch, all the ranks must hold the same value. num_experts: the number of all experts. + cumulative_local_expert_recv_stats: a cumulative expert count tensor for statistics, which should have shape + `[num_local_experts]` and be typed as `torch.int`. This is useful for online service EP load balance + monitoring. + dispatch_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics, + which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`. + This is useful for detecting and precisely localizing slow anomalies. use_fp8: whether to enable FP8 casting, with this, the received data will be a tuple of FP8 tensor and scaling factors. + round_scale: whether round the scaling factors into power of 2. + use_ue8m0: whether use UE8M0 as scaling factor format (available only with `round_scale=True`). async_finish: the current stream will not wait for the communication kernels to be finished if set. return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues, but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. - If you not set this flag, the kernel will ensure the data's arrival. + If you do not set this flag, the kernel will ensure the data's arrival. Returns: recv_x: a tensor or tuple with received tokens for each expert. With `use_fp8=True`: the first element is a `torch.Tensor` shaped as `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.float8_e4m3fn`. The second tensor is the corresponding scales for the first element with shape - `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`. + `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 128]` with `torch.float`, + if `use_ue8m0=False`. With `use_ue8m0=True`, the second one is packed and shaped as + `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden // 512]` with type `torch.int`. Notice that, the last-two-dimension of the scaling tensors are in column-major for TMA compatibility. With `use_fp8=False`, the result would be a tensor shaped as `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`. Moreover, not all tokens are valid, only some of the `num_max_dispatch_tokens_per_rank * num_ranks` are, as we do not synchronize CPU received count with GPU (also not incompatible with CUDA graph if synced). recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each - expert receive. As mentioned before, all not tokens are valid in `recv_x`. + expert receives. As mentioned before, not all tokens are valid in `recv_x`. handle: the communication handle to be used in the `low_latency_combine` function. event: the event after executing the kernel (valid only if `async_finish` is set). hook: the receiving hook function (valid only if `return_recv_hook` is set). """ + assert self.nvshmem_qp_depth >= (num_max_dispatch_tokens_per_rank + 1) * 2 packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, packed_recv_layout_range, event, hook = \ self.runtime.low_latency_dispatch(x, topk_idx, + cumulative_local_expert_recv_stats, + dispatch_wait_recv_cost_stats, num_max_dispatch_tokens_per_rank, num_experts, - use_fp8, async_finish, return_recv_hook) + use_fp8, round_scale, use_ue8m0, + async_finish, return_recv_hook) handle = (packed_recv_src_info, packed_recv_layout_range, num_max_dispatch_tokens_per_rank, x.size(1), num_experts) - tensors_to_record = (x, topk_idx, - packed_recv_x, packed_recv_x_scales, packed_recv_count, - packed_recv_src_info, packed_recv_layout_range) + tensors_to_record = (x, topk_idx, packed_recv_x, packed_recv_x_scales, packed_recv_count, packed_recv_src_info, + packed_recv_layout_range, cumulative_local_expert_recv_stats) return (packed_recv_x, packed_recv_x_scales) if use_fp8 else packed_recv_x, packed_recv_count, handle, \ EventOverlap(event, tensors_to_record if async_finish else None), hook # noinspection PyTypeChecker def low_latency_combine(self, x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, - handle: tuple, zero_copy: bool = False, async_finish: bool = False, - return_recv_hook: bool = False, out: Optional[torch.Tensor] = None) -> \ + handle: tuple, use_logfmt: bool = False, zero_copy: bool = False, async_finish: bool = False, + return_recv_hook: bool = False, out: Optional[torch.Tensor] = None, + combine_wait_recv_cost_stats: Optional[torch.Tensor] = None) -> \ Tuple[torch.Tensor, EventOverlap, Callable]: """ A low-latency implementation for combining tokens (reduce **with weights**) with IBGDA. This kernel requires all the ranks (no matter intranode or internode) should be visible via RDMA (specifically, IBGDA must be enabled). - Even for ranks in the same node, NVLink are fully disabled for simplicity. - Warning: as there are only two buffers, and the returned tensors reuse the buffer, you can not hold more than 2 - low-latency kernels' result tensor at a single moment. + Warning: as there are only two buffers, and the returned tensors reuse the buffer, you cannot hold more than 2 + low-latency kernels' result tensors at a single moment. Arguments: x: `[num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]` with `torch.bfloat16`, the local calculated tokens to be sent to this original rank and reduced. - topk_idx: `[num_combined_tokens, num_topk]` with `torch.int64`, the expert indices selected by the dispatched - tokens. `-1` indices (not selecting any expert) are supported. Note that, `num_combined_tokens` equals - to the number of dispatched tokens. + topk_idx: `[num_combined_tokens, num_topk]` with `deep_ep.topk_idx_t` (typically `torch.int64`), the expert + indices selected by the dispatched tokens. `-1` indices (not selecting any expert) are supported. Note that, + `num_combined_tokens` equals to the number of dispatched tokens. topk_weights: `[num_combined_tokens, num_topk]` with `torch.float`, the expert weights selected by the dispatched tokens. The received tokens will be reduced with the weights in this tensor. handle: the communication handle given by the `dispatch` function. + use_logfmt: whether to use an internal "LogFMT with dynamic per-64-channel cast" format (10 bits). zero_copy: whether the tensor is already copied into the RDMA buffer, should be cooperative with `get_next_low_latency_combine_buffer`. async_finish: the current stream will not wait for the communication kernels to be finished if set. return_recv_hook: return a receiving hook if set. If set, the kernel will just do the RDMA request issues, but **without actually receiving the data**. You must call the received hook to make sure the data's arrival. - If you not set this flag, the kernel will ensure the data's arrival. + If you do not set this flag, the kernel will ensure the data's arrival. out: the in-place output tensor, if set, the kernel will write the result to this tensor and return it directly. + combine_wait_recv_cost_stats: a cumulative time spent waiting to receive each token tensor for statistics, + which should have shape `[num_ranks, num_ranks]` and be typed as `torch.int64`. + This is useful for detecting and pre-cisely localizing slow anomalies. Returns: - combined_x: the reduced token tensor, with shape `[num_combined_tokens, num_topk]` and type `torch.bfloat16`. + combined_x: the reduced token tensor, with shape `[num_combined_tokens, hidden]` and type `torch.bfloat16`. event: the event after executing the kernel (valid only if `async_finish` is set). hook: the receiving hook function (valid only if `return_recv_hook` is set). """ src_info, layout_range, num_max_dispatch_tokens_per_rank, hidden, num_experts = handle + assert self.nvshmem_qp_depth >= (num_max_dispatch_tokens_per_rank + 1) * 2 combined_x, event, hook = self.runtime.low_latency_combine(x, topk_idx, topk_weights, src_info, layout_range, - num_max_dispatch_tokens_per_rank, num_experts, - zero_copy, async_finish, return_recv_hook, out) + combine_wait_recv_cost_stats, num_max_dispatch_tokens_per_rank, + num_experts, use_logfmt, zero_copy, async_finish, return_recv_hook, out) tensors_to_record = (x, topk_idx, topk_weights, src_info, layout_range, combined_x) return combined_x, EventOverlap(event, tensors_to_record if async_finish else None), hook + def low_latency_update_mask_buffer(self, rank_to_mask: int, mask: bool = False): + """ + Mask (unmask) a rank during communication (dispatch, combine, and clean) + + Arguments: + rank: the rank to mask (unmask). + mask: if True, will mask the rank (do not recvfrom/sendto the rank), otherwise will unmask the rank. + + """ + self.runtime.low_latency_update_mask_buffer(rank_to_mask, mask) + + def low_latency_query_mask_buffer(self, mask_status: torch.Tensor): + """ + Query the mask status of all ranks + + Arguments: + mask_status: `[num_ranks]` with `torch.int`, the mask status of each rank. `1` means mask and `0` means unmasked. + + """ + self.runtime.low_latency_query_mask_buffer(mask_status) + + def low_latency_clean_mask_buffer(self): + """ + Clean the mask buffer + + """ + self.runtime.low_latency_clean_mask_buffer() + def get_next_low_latency_combine_buffer(self, handle: object): """ Get the raw registered RDMA buffer tensor for next low-latency combine, so that the next combine kernel can skip the copying. diff --git a/deep_ep/utils.py b/deep_ep/utils.py index 009aa2a..e61a2c5 100644 --- a/deep_ep/utils.py +++ b/deep_ep/utils.py @@ -1,8 +1,10 @@ +import os import torch +import torch.distributed as dist from typing import Any, Optional, Tuple # noinspection PyUnresolvedReferences -from deep_ep_cpp import Config, EventHandle +from deep_ep_cpp import EventHandle class EventOverlap: @@ -14,8 +16,7 @@ class EventOverlap: extra_tensors: an easier way to simulate PyTorch tensor `record_stream`, may be useful with CUDA graph. """ - def __init__(self, event: Optional[EventHandle] = None, - extra_tensors: Optional[Tuple[torch.Tensor]] = None) -> None: + def __init__(self, event: Optional[EventHandle] = None, extra_tensors: Optional[Tuple[torch.Tensor]] = None) -> None: """ Initialize the class. @@ -58,3 +59,43 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: """ if self.event is not None: self.event.current_stream_wait() + + +def check_nvlink_connections(group: dist.ProcessGroup): + """ + Check NVLink connection between every pair of GPUs. + + Arguments: + group: the communication group. + """ + # Check NVLink connection + # NOTES: some A100 PCIE GPUs only have pairwise NVLink connection, so that we can only use EP2 + # TODO: check all cases, all local-node GPUs in the group should be connected via NVLink + if 'PCIE' in torch.cuda.get_device_name(): + assert group.size() <= 2, 'PCIe GPUs only have pairwise NVLink connections' + + # noinspection PyUnresolvedReferences + import pynvml + pynvml.nvmlInit() + + # noinspection PyTypeChecker + devices = os.environ.get('CUDA_VISIBLE_DEVICES', '0,1,2,3,4,5,6,7').strip(',').split(',') + physical_device_idx = int(devices[torch.cuda.current_device()]) + physical_device_indices = [ + 0, + ] * group.size() + dist.all_gather_object(physical_device_indices, physical_device_idx, group) + + # Check whether they are all connected via NVLink + # Reference: https://github.com/vllm-project/vllm/blob/b8e809a057765c574726a6077fd124db5077ce1f/vllm/platforms/cuda.py#L438 + handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_indices] + for i, handle in enumerate(handles): + for j, peer_handle in enumerate(handles): + if i >= j: + continue + status = pynvml.nvmlDeviceGetP2PStatus(handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK) + assert status == pynvml.NVML_P2P_STATUS_OK,\ + f'GPU {physical_device_indices[i]} and GPU {physical_device_indices[j]} are not connected via NVLink' + + # Close NVML + pynvml.nvmlShutdown() diff --git a/setup.py b/setup.py index ac2bd30..b6a8111 100644 --- a/setup.py +++ b/setup.py @@ -7,6 +7,13 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension +# Wheel specific: the wheels only include the soname of the host library `libnvshmem_host.so.X` +def get_nvshmem_host_lib_name(base_dir): + path = Path(base_dir).joinpath('lib') + for file in path.rglob('libnvshmem_host.so.*'): + return file.name + raise ModuleNotFoundError('libnvshmem_host.so not found') + if __name__ == "__main__": # Add argument parser for handling --variant flag parser = argparse.ArgumentParser(description="DeepEP setup configuration") @@ -21,16 +28,39 @@ parser.add_argument("--verbose", action="store_true", help="Verbose build") parser.add_argument("--enable_timer", action="store_true", help="Enable timer to debug time out in internode") parser.add_argument("--rocm-disable-ctx", action="store_true", help="Disable workgroup context optimization in internode") - parser.add_argument("--disable-mpi", action="store_true", help="Disable MPI detection and configuration") # Get the arguments to be parsed and separate setuptools arguments args, unknown_args = parser.parse_known_args() variant = args.variant debug = args.debug rocm_disable_ctx = args.rocm_disable_ctx - disable_mpi = args.disable_mpi enable_timer = args.enable_timer + + if variant != "rocm": + disable_nvshmem = False + nvshmem_dir = os.getenv('NVSHMEM_DIR', None) + nvshmem_host_lib = 'libnvshmem_host.so' + if nvshmem_dir is None: + try: + nvshmem_dir = importlib.util.find_spec("nvidia.nvshmem").submodule_search_locations[0] + nvshmem_host_lib = get_nvshmem_host_lib_name(nvshmem_dir) + import nvidia.nvshmem as nvshmem # noqa: F401 + except (ModuleNotFoundError, AttributeError, IndexError): + print( + 'Warning: `NVSHMEM_DIR` is not specified, and the NVSHMEM module is not installed. All internode and low-latency features are disabled\n' + ) + disable_nvshmem = True + else: + disable_nvshmem = False + + if not disable_nvshmem: + assert os.path.exists(nvshmem_dir), f'The specified NVSHMEM directory does not exist: {nvshmem_dir}' + + #else: + # disable_nvshmem = False + + # Reset sys.argv for setuptools to avoid conflicts sys.argv = [sys.argv[0]] + unknown_args @@ -56,11 +86,9 @@ ), f"Failed to find {shmem_variant_name}" print(f"{shmem_variant_name} directory: {shmem_dir}") - ompi_dir = None - if variant == "rocm" and not disable_mpi: + if variant == "rocm": # Attempt to auto-detect OpenMPI installation directory if OMPI_DIR not set. # The first existing candidate containing bin/mpicc will be used. - print("MPI detection enabled for ROCm variant") ompi_dir_env = os.getenv("OMPI_DIR", "").strip() candidate_dirs = [ ompi_dir_env if ompi_dir_env else None, @@ -72,41 +100,30 @@ "/usr/local/ompi", "/usr/local/openmpi", ] + ompi_dir = None for d in candidate_dirs: if not d: continue mpicc_path = os.path.join(d, "bin", "mpicc") if os.path.exists(d) and os.path.exists(mpicc_path): ompi_dir = d - break - assert ompi_dir is not None, ( - f"Failed to find OpenMPI installation. " - f"Searched: {', '.join([d for d in candidate_dirs if d])}. " - f"Set OMPI_DIR environment variable or use --disable-mpi flag." - ) + break + if ompi_dir is None: + # Fallback to root (will trigger the assert below) + ompi_dir = "/" print(f"Detected OpenMPI directory: {ompi_dir}") - elif variant == "rocm" and disable_mpi: - print("MPI detection disabled for ROCm variant") - elif variant == "cuda" and not disable_mpi: - print("MPI detection enabled for CUDA variant") - else: - print("MPI detection disabled for CUDA variant") + assert os.path.exists(ompi_dir), f"Failed to find OMPI: {ompi_dir}" # TODO: currently, we only support Hopper architecture, we may add Ampere support later if variant == "rocm": - arch = os.getenv("PYTORCH_ROCM_ARCH") - allowed_arch = {"gfx942", "gfx950"} - if arch not in allowed_arch: - raise EnvironmentError( - f"Invalid PYTORCH_ROCM_ARCH='{arch}'. " - f"Use one of: {', '.join(sorted(allowed_arch))}.") + os.environ["PYTORCH_ROCM_ARCH"] = os.getenv("PYTORCH_ROCM_ARCH", "gfx942") elif variant == "cuda": os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0" optimization_flag = "-O0" if debug else "-O3" debug_symbol_flags = ["-g", "-ggdb"] if debug else [] define_macros = ( - ["-DUSE_ROCM=1", "-fgpu-rdc",] if variant == "rocm" else [] + ["-DUSE_ROCM=1", "-DDISABLE_SM90_FEATURES=1", "-fgpu-rdc",] if variant == "rocm" else [] ) if enable_timer: define_macros.append("-DENABLE_TIMER") @@ -138,19 +155,20 @@ nvcc_flags = [f"{optimization_flag}"] + debug_symbol_flags + define_macros include_dirs = ["csrc/", f"{shmem_dir}/include"] - if variant == "rocm" and ompi_dir is not None: + if variant == "rocm": include_dirs.append(f"{ompi_dir}/include") sources = [ "csrc/deep_ep.cpp", "csrc/kernels/runtime.cu", + 'csrc/kernels/layout.cu', "csrc/kernels/intranode.cu", "csrc/kernels/internode.cu", "csrc/kernels/internode_ll.cu", ] library_dirs = [f"{shmem_dir}/lib"] - if variant == "rocm" and ompi_dir is not None: + if variant == "rocm": library_dirs.append(f"{ompi_dir}/lib") # Disable aggressive PTX instructions @@ -158,6 +176,13 @@ cxx_flags.append("-DDISABLE_AGGRESSIVE_PTX_INSTRS") nvcc_flags.append("-DDISABLE_AGGRESSIVE_PTX_INSTRS") + + # Bits of `topk_idx.dtype`, choices are 32 and 64 + if "TOPK_IDX_BITS" in os.environ: + topk_idx_bits = int(os.environ['TOPK_IDX_BITS']) + cxx_flags.append(f'-DTOPK_IDX_BITS={topk_idx_bits}') + nvcc_flags.append(f'-DTOPK_IDX_BITS={topk_idx_bits}') + shmem_lib_name = "nvshmem" if variant == "cuda" else "rocshmem" # Disable DLTO (default by PyTorch) nvcc_dlink = ["-dlink", f"-L{shmem_dir}/lib", f"-l{shmem_lib_name}"] @@ -172,15 +197,10 @@ "-lamdhip64", "-lhsa-runtime64", "-libverbs", + f"-l:libmpi.so", + f"-Wl,-rpath,{ompi_dir}/lib", ] ) - if not disable_mpi: - extra_link_args.extend( - [ - f"-l:libmpi.so", - f"-Wl,-rpath,{ompi_dir}/lib", - ] - ) extra_compile_args = { "cxx": cxx_flags, @@ -189,6 +209,17 @@ if variant == "cuda": extra_compile_args["nvcc_dlink"] = nvcc_dlink + + # Summary + print('Build summary:') + print(f' > Sources: {sources}') + print(f' > Includes: {include_dirs}') + print(f' > Libraries: {library_dirs}') + print(f' > Compilation flags: {extra_compile_args}') + print(f' > Link flags: {extra_link_args}') + print(f' > NVSHMEM path: {shmem_dir}') + print() + # noinspection PyBroadException try: cmd = ["git", "rev-parse", "--short", "HEAD"] @@ -198,7 +229,7 @@ setuptools.setup( name="deep_ep", - version="1.0.0" + revision, + version="1.2.1" + revision, packages=setuptools.find_packages(include=["deep_ep"]), ext_modules=[ CUDAExtension( diff --git a/tests/test_internode.py b/tests/test_internode.py index 5f45d0f..c03d344 100644 --- a/tests/test_internode.py +++ b/tests/test_internode.py @@ -1,3 +1,4 @@ +import argparse import os import time import torch @@ -5,16 +6,27 @@ # noinspection PyUnresolvedReferences import deep_ep -from utils import init_dist, bench, calc_diff, create_grouped_scores, inplace_unique, per_token_cast_to_fp8, per_token_cast_back +from utils import init_dist, bench, bench_kineto, calc_diff, create_grouped_scores, inplace_unique, per_token_cast_to_fp8, per_token_cast_back, hash_tensor # Test compatibility with low latency functions import test_low_latency -import argparse -def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: int, num_nodes: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup): +# noinspection PyShadowingNames +def test_main(args: argparse.Namespace, + num_sms: int, + local_rank: int, + num_local_ranks: int, + num_ranks: int, + num_nodes: int, + rank: int, + buffer: deep_ep.Buffer, + group: dist.ProcessGroup, + skip_benchmark: bool = False): # Settings - num_tokens, hidden, num_topk_groups, num_topk, num_experts = 4096, 7168, min(num_nodes, 4), 8, (256 // num_ranks) * num_ranks + num_tokens, hidden = args.num_tokens, args.hidden + num_topk_groups, num_topk, num_experts = args.num_topk_groups, args.num_topk, args.num_experts + assert num_experts % num_ranks == 0 and num_local_ranks == 8 if local_rank == 0: print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk_groups={num_topk_groups}, num_topk={num_topk}', flush=True) @@ -23,19 +35,24 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') x_e4m3 = per_token_cast_to_fp8(x) + x_pure_rand_e4m3 = per_token_cast_to_fp8(x_pure_rand) + x_e4m3 = (x_e4m3[0], x_e4m3[1].T.contiguous().T) scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 group_scores = scores.view(num_tokens, num_nodes, -1).amax(dim=-1) group_idx = torch.topk(group_scores, k=num_topk_groups, dim=-1, sorted=False).indices masked_scores = create_grouped_scores(scores, group_idx, num_nodes) topk_idx = torch.topk(masked_scores, num_topk, dim=-1, largest=True, sorted=False)[1] + topk_idx = topk_idx.to(deep_ep.topk_idx_t) topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda') rank_idx = topk_idx // (num_experts // num_ranks) + rank_idx = rank_idx.to(torch.int64) rank_idx.masked_fill_(topk_idx == -1, -1) inplace_unique(rank_idx, num_ranks) rdma_rank_idx = rank_idx // num_local_ranks rdma_rank_idx.masked_fill_(rank_idx == -1, -1) inplace_unique(rdma_rank_idx, num_nodes) + hash_value = 0 # RDMA dispatch counts rdma_idx = topk_idx // (num_experts // num_nodes) @@ -77,12 +94,12 @@ def test_main(num_sms: int, local_rank: int, num_local_ranks: int, num_ranks: in t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0] if local_rank == 0: print(f'[layout] Kernel performance: {t * 1000:.3f} ms', flush=True) - print() + print('', flush=True) group.barrier() time.sleep(1) # Config - rdma_buffer_size, nvl_buffer_size = 128, (720 if num_ranks in (144, 160) else 512) + rdma_buffer_size, nvl_buffer_size = 128, (720 if num_ranks in (24, 48, 96, 144, 160) else 512) config = deep_ep.Config(num_sms, 8, nvl_buffer_size, 16, rdma_buffer_size) # Test dispatch @@ -99,37 +116,59 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): for async_mode in (False, True): for current_x in (x_pure_rand, x, x_e4m3): for with_topk in (False, True): + is_rand = current_x is x_pure_rand or current_x is x_pure_rand_e4m3 if local_rank == 0: - print(f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', flush=True, end='') - dispatch_args = {'x': current_x, 'num_tokens_per_rank': num_tokens_per_rank, 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank, 'is_token_in_rank': is_token_in_rank, - 'num_tokens_per_expert': num_tokens_per_expert, 'config': config, 'async_finish': async_mode} + print( + f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', + flush=True, + end='') + dispatch_args = { + 'x': current_x, + 'num_tokens_per_rank': num_tokens_per_rank, + 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank, + 'is_token_in_rank': is_token_in_rank, + 'num_tokens_per_expert': num_tokens_per_expert, + 'config': config, + 'async_finish': async_mode + } if with_topk: - dispatch_args.update({'topk_idx': topk_idx, 'topk_weights': topk_weights_pure_rand if current_x is x_pure_rand else topk_weights}) + dispatch_args.update({'topk_idx': topk_idx, 'topk_weights': topk_weights_pure_rand if is_rand else topk_weights}) if previous_mode: dispatch_args.update({'previous_event': buffer.capture()}) - recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch(**dispatch_args) + recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch( + **dispatch_args) event.current_stream_wait() if async_mode else () + + if current_x is x_pure_rand or current_x is x: + hash_value += hash_tensor(recv_x) + else: + hash_value += hash_tensor(recv_x[0]) + hash_value += hash_tensor(recv_x[1]) + recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x # Checks recv_gbl_rank_prefix_sum = handle[-4] - assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(0), f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}' + assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(0), \ + f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}' assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list - if current_x is not x_pure_rand: + if not is_rand: check_data(recv_x, recv_gbl_rank_prefix_sum) if with_topk: # Check `topk_idx` - assert (recv_topk_idx.eq(-1) | ((recv_topk_idx >= 0) & (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel() + assert (recv_topk_idx.eq(-1) | + ((recv_topk_idx >= 0) & + (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel() for i, count in enumerate(recv_num_tokens_per_expert_list): assert recv_topk_idx.eq(i).sum().item() == count # Check `topk_weights` - if current_x is not x_pure_rand: - recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)] + if not is_rand: + recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax( + dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)] check_data(recv_topk_weights, recv_gbl_rank_prefix_sum) # Test cached dispatch (must without top-k staffs) - # NOTES: handle must be refreshed if not with_topk: dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode} if previous_mode: @@ -137,10 +176,13 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): recv_x, _, _, _, _, event = buffer.dispatch(**dispatch_args) event.current_stream_wait() if async_mode else () recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x - if current_x is not x_pure_rand: + if not is_rand: check_data(recv_x, recv_gbl_rank_prefix_sum) # Test combine + bias_0 = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + bias_1 = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + combine_args = {'x': recv_x, 'bias': (bias_0, bias_1), 'handle': handle, 'config': config, 'async_finish': async_mode} combine_args = {'x': recv_x, 'handle': handle, 'config': config, 'async_finish': async_mode} if with_topk: combine_args.update({'topk_weights': recv_topk_weights}) @@ -148,14 +190,17 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): combine_args.update({'previous_event': buffer.capture()}) combined_x, combined_topk_weights, event = buffer.combine(**combine_args) event.current_stream_wait() if async_mode else () - check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1) - ref_x = x_pure_rand if current_x is x_pure_rand else x - assert calc_diff(check_x, ref_x) < 5e-6 + check_x = (combined_x.float() - bias_0.float() - bias_1.float()) / is_token_in_rank.sum(dim=1).unsqueeze(1) + ref_x = x_pure_rand if is_rand else x + assert calc_diff(check_x, ref_x) < 5e-4 if current_x is x_pure_rand_e4m3 else 5e-6 if with_topk: - check_topk_weights = combined_topk_weights if (current_x is x_pure_rand) else (combined_topk_weights / is_token_in_rank.sum(dim=1).unsqueeze(1)) - ref_topk_weights = topk_weights_pure_rand if current_x is x_pure_rand else topk_weights + check_topk_weights = combined_topk_weights if is_rand else (combined_topk_weights / + is_token_in_rank.sum(dim=1).unsqueeze(1)) + ref_topk_weights = topk_weights_pure_rand if is_rand else topk_weights assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9 + hash_value += hash_tensor(recv_x) + # For later tuning dispatch_bf16_rdma_send_bytes = num_rdma_token_sent * hidden * 2 dispatch_bf16_nvl_recv_bytes = recv_x.numel() * 2 @@ -165,7 +210,10 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): if local_rank == 0: print(' passed', flush=True) if local_rank == 0: - print() + print('', flush=True) + + if skip_benchmark: + return hash_value # Tune dispatch performance best_dispatch_results = None @@ -174,18 +222,29 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): best_time, best_results = 1e10, None rdma_send_bytes = (dispatch_bf16_rdma_send_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_rdma_send_bytes nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes - for nvl_chunk_size in range(4, 33, 4): + for nvl_chunk_size in range(4, 45, 4): for rdma_chunk_size in range(4, 33, 4): config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size) tune_args = {'x': current_x, 'handle': handle, 'config': config} - t = bench(lambda: buffer.dispatch(**tune_args))[0] + t, notify_t = bench_kineto( + lambda: buffer.dispatch(**tune_args), # noqa: B023 + ('dispatch', 'notify'), + suppress_kineto_output=True) if t < best_time: - best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size) + best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size, notify_t) if local_rank == 0: - print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {rdma_send_bytes / 1e9 / t:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ') + print( + f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: ' + f'{notify_t * 1e6:.0f} + {t * 1e6:.0f} us, ' + f'{rdma_send_bytes / 1e9 / t:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL) ', + flush=True) if local_rank == 0: - print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {rdma_send_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)') - print() + print( + f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: ' + f'{best_results[3] * 1e6:.0f} + {best_time * 1e6:.0f} us, ' + f'{rdma_send_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL)', + flush=True) + print('', flush=True) if isinstance(current_x, tuple): # Gather FP8 the best config from rank 0 @@ -193,66 +252,125 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): all_best_fp8_results_list = [torch.zeros_like(best_dispatch_results) for _ in range(torch.distributed.get_world_size())] dist.all_gather(all_best_fp8_results_list, best_dispatch_results, group=group) best_dispatch_results = all_best_fp8_results_list[0].tolist() - dispatch_config = deep_ep.Config(best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size, best_dispatch_results[2], rdma_buffer_size) + dispatch_config = deep_ep.Config(best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size, best_dispatch_results[2], + rdma_buffer_size) - dispatch_args = {'x': x, 'num_tokens_per_rank': num_tokens_per_rank, 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank, - 'is_token_in_rank': is_token_in_rank, 'num_tokens_per_expert': num_tokens_per_expert, - 'config': dispatch_config if dispatch_config is not None else config} + dispatch_args = { + 'x': x, + 'num_tokens_per_rank': num_tokens_per_rank, + 'num_tokens_per_rdma_rank': num_tokens_per_rdma_rank, + 'is_token_in_rank': is_token_in_rank, + 'num_tokens_per_expert': num_tokens_per_expert, + 'config': dispatch_config if dispatch_config is not None else config + } recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args) # Tune combine performance best_time, best_results = 1e10, None - for nvl_chunk_size in range(1, 5, 1): - # TODO: Sort out the assertation for 16 nodes - upper_bound = 29 if num_ranks == 128 else 33 - for rdma_chunk_size in range(8, upper_bound, 4): + for nvl_chunk_size in range(1, 8, 1): + for rdma_chunk_size in range(12 if num_nodes == 2 else 8, 33, 4): config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size, rdma_chunk_size, rdma_buffer_size) tune_args = {'x': recv_x, 'handle': handle, 'config': config} - t = bench(lambda: buffer.combine(**tune_args))[0] + t, notify_t = bench_kineto( + lambda: buffer.combine(**tune_args), # noqa: B023 + ('combine', 'notify'), + suppress_kineto_output=True) if local_rank == 0: - print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: {combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ') + print( + f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}, RDMA chunk {rdma_chunk_size}: ' + f'{notify_t * 1e6:.0f} + {t * 1e6:.0f} us, ' + f'{combine_bf16_rdma_recv_bytes / 1e9 / t:.2f} GB/s (RDMA), ' + f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL) ', + flush=True) if t < best_time: - best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size) + best_time, best_results = t, (num_sms, nvl_chunk_size, rdma_chunk_size, notify_t) if local_rank == 0: - print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}: {combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)') - print() + print( + f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}, RDMA chunk {best_results[2]}, ' + f'{best_results[3] * 1e6:.2f} + {best_time * 1e6:.2f} us, ' + f'{combine_bf16_rdma_recv_bytes / 1e9 / best_time:.2f} GB/s (RDMA), {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL)', + flush=True) + print('', flush=True) + return hash_value -# noinspection PyUnboundLocalVariable -def test_loop(local_rank: int, num_local_ranks: int, backend: str): - num_nodes = int(os.getenv('WORLD_SIZE', 2)) - rank, num_ranks, group = init_dist(local_rank, num_local_ranks, backend=backend) - test_ll_compatibility = False - if test_ll_compatibility: +# noinspection PyUnboundLocalVariable,PyShadowingNames +def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): + num_nodes = int(os.getenv('WORLD_SIZE', 1)) + rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + if args.test_ll_compatibility: ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 - buffer = deep_ep.Buffer(group, int(1e9), int(1e9), low_latency_mode=test_ll_compatibility, - num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1)) + num_sms = 24 + num_qps_per_rank = max(num_sms, ll_num_experts // num_ranks if args.test_ll_compatibility else 0) + + buffer = deep_ep.Buffer(group, + int(2e9), + int(1e9), + low_latency_mode=args.test_ll_compatibility, + num_qps_per_rank=num_qps_per_rank, + explicitly_destroy=True) assert num_local_ranks == 8 and num_ranks > 8 - torch.manual_seed(rank) - for i in (32, ): - test_main(i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group) + for seed in range(int(1e9)): if local_rank == 0: - print() + print(f'Testing with seed {seed} ...', flush=True) + torch.manual_seed(rank + seed) + ref_hash = 0 + for i in (num_sms, ): + ref_hash += test_main(args, i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group, + args.pressure_test_mode == 1) + if local_rank == 0: + print('', flush=True) + if args.pressure_test_mode == 0: + break + + if local_rank == 0: + print(f'{ref_hash=}') + print('', flush=True) + + for _ in range(20): + torch.manual_seed(rank + seed) + current_hash = 0 + for i in (num_sms, ): + current_hash += test_main(args, i, local_rank, num_local_ranks, num_ranks, num_nodes, rank, buffer, group, + args.pressure_test_mode == 1) + if local_rank == 0: + print('', flush=True) + assert current_hash == ref_hash # Test compatibility with low latency functions - if test_ll_compatibility: + if args.test_ll_compatibility: buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts) test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1) + # Destroy the buffer runtime and communication group + buffer.destroy() + dist.barrier() + dist.destroy_process_group() + if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Test internode communication') - parser.add_argument('--backend', type=str, choices=['mpi', 'nccl'], default='nccl', - help='Backend for distributed communication (mpi or nccl)') + parser = argparse.ArgumentParser(description='Test internode EP kernels') + parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)') + parser.add_argument('--num-tokens', type=int, default=4096, help='Number of tokens (default: 4096)') + parser.add_argument('--hidden', type=int, default=7168, help='Hidden dimension size (default: 7168)') + parser.add_argument('--num-topk-groups', type=int, default=None, help='Number of top-k groups (default: `min(num_nodes, 4)`)') + parser.add_argument('--num-topk', type=int, default=8, help='Number of top-k experts (default: 8)') + parser.add_argument( + '--pressure-test-mode', + type=int, + default=0, + help='Pressure test mode. 0: don\'t do pressure test, 1: do pressure test without benchmarks, 2: do pressure test with benchmarks') + parser.add_argument('--num-experts', type=int, default=256, help='Number of experts (default: 256') + parser.add_argument('--test-ll-compatibility', action='store_true', help='whether to test compatibility with low-latency kernels') args = parser.parse_args() - num_processes = 8 - if args.backend == 'mpi': - dist.init_process_group(backend='mpi') - rank = dist.get_rank() - local_rank = rank % num_processes - test_loop(local_rank=local_rank, num_local_ranks=num_processes, backend='mpi') - else: - torch.multiprocessing.spawn(test_loop, args=(num_processes, 'nccl'), nprocs=num_processes) + + # Set default `num_topk_groups` if not provided + if args.num_topk_groups is None: + num_nodes = int(os.getenv('WORLD_SIZE', 1)) + args.num_topk_groups = min(num_nodes, 4) + + num_processes = args.num_processes + torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes) diff --git a/tests/test_intranode.py b/tests/test_intranode.py index 68a95d8..53dda47 100644 --- a/tests/test_intranode.py +++ b/tests/test_intranode.py @@ -1,4 +1,4 @@ -import os +import argparse import time import torch import torch.distributed as dist @@ -11,9 +11,13 @@ import test_low_latency -def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: deep_ep.Buffer, group: dist.ProcessGroup): +# noinspection PyShadowingNames +def test_main(args: argparse.Namespace, num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: deep_ep.Buffer, + group: dist.ProcessGroup): # Settings - num_tokens, hidden, num_topk, num_experts = 4096, 7168, 8, (256 // num_ranks) * num_ranks + num_tokens, hidden = args.num_tokens, args.hidden + num_topk, num_experts = args.num_topk, args.num_experts + assert num_experts % num_ranks == 0 if local_rank == 0: print(f'[config] num_tokens={num_tokens}, hidden={hidden}, num_topk={num_topk}', flush=True) @@ -21,12 +25,15 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: # Random data x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * rank x_pure_rand = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') - x_e4m3 = per_token_cast_to_fp8(x) + x_e4m3 = per_token_cast_to_fp8(x) if deep_ep.Buffer.is_sm90_compiled() else None + x_e4m3 = (x_e4m3[0], x_e4m3[1].T.contiguous().T) if x_e4m3 is not None else None scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False)[1] + topk_idx = topk_idx.to(deep_ep.topk_idx_t) topk_weights = torch.ones((num_tokens, num_topk), dtype=torch.float32, device='cuda') * rank topk_weights_pure_rand = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda') rank_idx = topk_idx // (num_experts // num_ranks) + rank_idx = rank_idx.to(torch.int64) rank_idx.masked_fill_(topk_idx == -1, -1) inplace_unique(rank_idx, num_ranks) @@ -60,7 +67,7 @@ def test_main(num_sms: int, local_rank: int, num_ranks: int, rank: int, buffer: t = bench(lambda: buffer.get_dispatch_layout(topk_idx, num_experts))[0] if local_rank == 0: print(f'[layout] Kernel performance: {t * 1000:.3f} ms', flush=True) - print() + print('', flush=True) group.barrier() time.sleep(1) @@ -80,39 +87,74 @@ def check_data(check_x, rank_prefix_matrix): for previous_mode in (False, True): for async_mode in (False, True): - for current_x in (x_pure_rand, x, x_e4m3): + for current_x in filter(lambda elem: elem is not None, (x_pure_rand, x, x_e4m3)): for with_topk in (False, True): if local_rank == 0: - print(f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', flush=True, end='') - dispatch_args = {'x': current_x, 'num_tokens_per_rank': num_tokens_per_rank, 'is_token_in_rank': is_token_in_rank, - 'num_tokens_per_expert': num_tokens_per_expert, 'config': config, 'async_finish': async_mode} + print( + f'[testing] Running with {"FP8" if isinstance(current_x, tuple) else "BF16"}, {"with" if with_topk else "without"} top-k (async={async_mode}, previous={previous_mode}) ...', + flush=True, + end='') + dispatch_args = { + 'x': current_x, + 'num_tokens_per_rank': num_tokens_per_rank, + 'is_token_in_rank': is_token_in_rank, + 'num_tokens_per_expert': num_tokens_per_expert, + 'config': config, + 'async_finish': async_mode + } if with_topk: - dispatch_args.update({'topk_idx': topk_idx, 'topk_weights': topk_weights_pure_rand if current_x is x_pure_rand else topk_weights}) + dispatch_args.update({ + 'topk_idx': topk_idx, + 'topk_weights': topk_weights_pure_rand if current_x is x_pure_rand else topk_weights + }) if previous_mode: dispatch_args.update({'previous_event': buffer.capture()}) - recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch(**dispatch_args) + recv_x, recv_topk_idx, recv_topk_weights, recv_num_tokens_per_expert_list, handle, event = buffer.dispatch( + **dispatch_args) event.current_stream_wait() if async_mode else () recv_x = per_token_cast_back(*recv_x) if isinstance(recv_x, tuple) else recv_x # Checks rank_prefix_matrix = handle[0] - assert gbl_num_tokens_per_rank[rank].item() == recv_x.size(0), f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}' + assert gbl_num_tokens_per_rank[rank].item() == recv_x.size( + 0), f'{gbl_num_tokens_per_rank[rank].item()} != {recv_x.size(0)}' assert gbl_num_tokens_per_expert.view(num_ranks, -1)[rank].tolist() == recv_num_tokens_per_expert_list if current_x is not x_pure_rand: check_data(recv_x, rank_prefix_matrix) + recv_topk_weights_clone = None if with_topk: # Check `topk_idx` - assert (recv_topk_idx.eq(-1) | ((recv_topk_idx >= 0) & (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel() + assert (recv_topk_idx.eq(-1) | + ((recv_topk_idx >= 0) & + (recv_topk_idx < (num_experts // num_ranks)))).sum().item() == recv_topk_idx.numel() for i, count in enumerate(recv_num_tokens_per_expert_list): assert recv_topk_idx.eq(i).sum().item() == count # Check `topk_weights` + recv_topk_weights_clone = recv_topk_weights.clone() if current_x is not x_pure_rand: - recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax(dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)] + recv_topk_weights[recv_topk_idx.eq(-1)] = recv_topk_weights.amax( + dim=1, keepdim=True).expand_as(recv_topk_weights)[recv_topk_idx.eq(-1)] check_data(recv_topk_weights, rank_prefix_matrix) + # Test `num_worst_tokens != 0` + if with_topk: + num_worst_tokens = num_tokens * num_ranks + dispatch_args.update({'num_worst_tokens': num_worst_tokens}) + recv_worst_x, recv_worst_topk_idx, recv_worst_topk_weights, empty_list, _, event = buffer.dispatch(**dispatch_args) + event.current_stream_wait() if async_mode else () + recv_worst_x = per_token_cast_back(*recv_worst_x) if isinstance(recv_worst_x, tuple) else recv_worst_x + assert len(empty_list) == 0 + assert num_worst_tokens == recv_worst_x.size(0) + assert num_worst_tokens == recv_worst_topk_idx.size(0) + assert num_worst_tokens == recv_worst_topk_weights.size(0) + assert torch.equal(recv_x, recv_worst_x[:recv_x.size(0)]) + assert torch.equal(recv_topk_idx, recv_worst_topk_idx[:recv_x.size(0)]) + assert torch.equal(recv_topk_weights_clone, recv_worst_topk_weights[:recv_x.size(0)]) + #TODO check why overflow area is not all -1. + #assert torch.all(recv_worst_topk_idx[recv_x.size(0):] == -1).item() + # Test cached dispatch (must without top-k staffs) - # NOTES: handle must be refreshed if not with_topk: dispatch_args = {'x': current_x, 'handle': handle, 'config': config, 'async_finish': async_mode} if previous_mode: @@ -128,14 +170,16 @@ def check_data(check_x, rank_prefix_matrix): if with_topk: combine_args.update({'topk_weights': recv_topk_weights}) if previous_mode: - dispatch_args.update({'previous_event': buffer.capture()}) + combine_args.update({'previous_event': buffer.capture()}) combined_x, combined_topk_weights, event = buffer.combine(**combine_args) event.current_stream_wait() if async_mode else () check_x = combined_x.float() / is_token_in_rank.sum(dim=1).unsqueeze(1) ref_x = x_pure_rand if current_x is x_pure_rand else x assert calc_diff(check_x, ref_x) < 5e-6 if with_topk: - check_topk_weights = combined_topk_weights if (current_x is x_pure_rand) else (combined_topk_weights / is_token_in_rank.sum(dim=1).unsqueeze(1)) + check_topk_weights = combined_topk_weights if (current_x + is x_pure_rand) else (combined_topk_weights / + is_token_in_rank.sum(dim=1).unsqueeze(1)) ref_topk_weights = topk_weights_pure_rand if current_x is x_pure_rand else topk_weights assert calc_diff(check_topk_weights, ref_topk_weights) < 1e-9 @@ -146,79 +190,123 @@ def check_data(check_x, rank_prefix_matrix): if local_rank == 0: print(' passed', flush=True) if local_rank == 0: - print() + print('', flush=True) # Tune dispatch performance best_dispatch_results = None fp8_factor = (1 + 4 / 128) / 2 - for current_x in (x_e4m3, x): + for current_x in filter(lambda elem: elem is not None, (x_e4m3, x)): best_time, best_results = 1e10, None nvl_recv_bytes = (dispatch_bf16_nvl_recv_bytes * fp8_factor) if isinstance(current_x, tuple) else dispatch_bf16_nvl_recv_bytes - for nvl_chunk_size in range(4, 150, 4): - config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size) + for nvl_chunk_size in tuple(range(4, 33, 2)) + (0, ): + if nvl_chunk_size > 0: + config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size) + else: + # Test default config as well + deep_ep.Buffer.set_num_sms(num_sms) + config = deep_ep.Buffer.get_dispatch_config(num_ranks) tune_args = {'x': current_x, 'handle': handle, 'config': config} - t = bench(lambda: buffer.dispatch(**tune_args))[0] - if t < best_time: + t = bench(lambda: buffer.dispatch(**tune_args))[0] # noqa: B023 + if t < best_time and nvl_chunk_size > 0: best_time, best_results = t, (num_sms, nvl_chunk_size) if local_rank == 0: - print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL), time {t * 1000 * 1000:.2f} us', flush=True) + print( + f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: ' + f'{nvl_recv_bytes / 1e9 / t:.2f} GB/s (NVL), {t * 1e6:.2f} us', + flush=True) if local_rank == 0: - print(f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL), time {best_time * 1000 * 1000:.2f} us', flush=True) - print() + print( + f'[tuning] Best dispatch ({"FP8" if isinstance(current_x, tuple) else "BF16"}): SMs {best_results[0]}, NVL chunk {best_results[1]}, {nvl_recv_bytes / 1e9 / best_time:.2f} GB/s (NVL), t: {best_time * 1e6:.2f} us', + flush=True) + print('', flush=True) - if isinstance(current_x, tuple): - # Gather FP8 the best config from rank 0 + # Gather the best config from rank 0 and the first test setting + if best_dispatch_results is None: best_dispatch_results = torch.tensor([best_results[0], best_results[1]], dtype=torch.int32, device='cuda') all_best_fp8_results_list = [torch.zeros_like(best_dispatch_results) for _ in range(torch.distributed.get_world_size())] dist.all_gather(all_best_fp8_results_list, best_dispatch_results, group=group) best_dispatch_results = all_best_fp8_results_list[0].tolist() dispatch_config = deep_ep.Config(best_dispatch_results[0], best_dispatch_results[1], nvl_buffer_size) - dispatch_args = {'x': x, 'num_tokens_per_rank': num_tokens_per_rank, - 'is_token_in_rank': is_token_in_rank, 'num_tokens_per_expert': num_tokens_per_expert, - 'config': dispatch_config if dispatch_config is not None else config} + dispatch_args = { + 'x': x, + 'num_tokens_per_rank': num_tokens_per_rank, + 'is_token_in_rank': is_token_in_rank, + 'num_tokens_per_expert': num_tokens_per_expert, + 'config': dispatch_config if dispatch_config is not None else config + } recv_x, _, _, _, handle, _ = buffer.dispatch(**dispatch_args) # Tune combine performance best_time, best_results = 1e10, None - for nvl_chunk_size in range(1, 35, 1): - config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size) + for nvl_chunk_size in tuple(range(1, 17, 1)) + (0, ): + if nvl_chunk_size > 0: + config = deep_ep.Config(num_sms, nvl_chunk_size, nvl_buffer_size) + else: + # Test default config as well + deep_ep.Buffer.set_num_sms(num_sms) + config = deep_ep.Buffer.get_combine_config(num_ranks) tune_args = {'x': recv_x, 'handle': handle, 'config': config} - t = bench(lambda: buffer.combine(**tune_args))[0] + t = bench(lambda: buffer.combine(**tune_args))[0] # noqa: B023 if local_rank == 0: - print(f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size}: {combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL), time {t * 1000 * 1000:.2f} us', flush=True) - if t < best_time: + print( + f'[tuning] SMs {num_sms}, NVL chunk {nvl_chunk_size if nvl_chunk_size else "default"}: ' + f'{combine_bf16_nvl_send_bytes / 1e9 / t:.2f} GB/s (NVL), {t * 1e6:.2f} us', + flush=True) + if t < best_time and nvl_chunk_size > 0: best_time, best_results = t, (num_sms, nvl_chunk_size) if local_rank == 0: - print(f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL), time {best_time * 1000 * 1000:.2f} us', flush=True) - print() + print( + f'[tuning] Best combine: SMs {best_results[0]}, NVL chunk {best_results[1]}: {combine_bf16_nvl_send_bytes / 1e9 / best_time:.2f} GB/s (NVL), t: {best_time * 1e6:.2f} us', + flush=True) + print('', flush=True) -# noinspection PyUnboundLocalVariable -def test_loop(local_rank: int, num_local_ranks: int): +# noinspection PyUnboundLocalVariable,PyShadowingNames +def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): rank, num_ranks, group = init_dist(local_rank, num_local_ranks) test_ll_compatibility, num_rdma_bytes = False, 0 if test_ll_compatibility: ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(ll_num_tokens, ll_hidden, num_ranks, ll_num_experts) - buffer = deep_ep.Buffer(group, int(1e9), num_rdma_bytes, low_latency_mode=test_ll_compatibility, - num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1)) + buffer = deep_ep.Buffer(group, + int(2e9), + num_rdma_bytes, + low_latency_mode=test_ll_compatibility, + num_qps_per_rank=(ll_num_experts // num_ranks if test_ll_compatibility else 1), + explicitly_destroy=True, + allow_mnnvl=args.allow_mnnvl, + use_fabric=args.use_fabric) torch.manual_seed(rank) - for i in (64, ): - test_main(i, local_rank, num_ranks, rank, buffer, group) + for i in (24, ): + test_main(args, i, local_rank, num_ranks, rank, buffer, group) if local_rank == 0: - print() + print('', flush=True) # Test compatibility with low latency functions if test_ll_compatibility: buffer.clean_low_latency_buffer(ll_num_tokens, ll_hidden, ll_num_experts) test_low_latency.test_main(ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk, rank, num_ranks, group, buffer, seed=1) - dist.destroy_process_group(group) + # Destroy the buffer runtime and communication group + buffer.destroy() + dist.barrier() + dist.destroy_process_group() + if __name__ == '__main__': - num_processes = 8 - torch.multiprocessing.spawn(test_loop, args=(num_processes, ), nprocs=num_processes) + parser = argparse.ArgumentParser(description='Test intranode EP kernels') + parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)') + parser.add_argument('--num-tokens', type=int, default=4096, help='Number of tokens (default: 4096)') + parser.add_argument('--hidden', type=int, default=7168, help='Hidden dimension size (default: 7168)') + parser.add_argument('--num-topk', type=int, default=8, help='Number of top-k experts (default: 8)') + parser.add_argument('--num-experts', type=int, default=256, help='Number of experts (default: 256)') + parser.add_argument('--allow-mnnvl', action="store_true", help='Enable MNNVL support') + parser.add_argument('--use-fabric', action="store_true", help='Enable fabric mode') + args = parser.parse_args() + + num_processes = args.num_processes + torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes) diff --git a/tests/test_low_latency.py b/tests/test_low_latency.py index d6887f7..4521599 100644 --- a/tests/test_low_latency.py +++ b/tests/test_low_latency.py @@ -1,104 +1,195 @@ +import argparse import random import torch import torch.distributed as dist from functools import partial +from typing import Literal, Set import deep_ep from utils import init_dist, bench, bench_kineto, calc_diff, hash_tensor, per_token_cast_back -def test_main(num_tokens: int, hidden: int, num_experts: int, num_topk: int, - rank: int, num_ranks: int, group: dist.ProcessGroup, buffer: deep_ep.Buffer, seed: int = 0): +def simulate_failure_and_skip(rank: int, api: Literal["dispatch", "combine", "clean"], expected_masked_ranks: Set[int]): + # Simulates rank failure when the rank first calls the corresponding communication API + failed_api_ranks = { + # API -> rank to fail (rank fails when it first calls the corresponding communication API) + 'dispatch': 1, + 'combine': 3, + 'clean': 5 + } + if rank in expected_masked_ranks: + # Rank already failed + return True + if api in failed_api_ranks.keys(): + expected_masked_ranks.add(failed_api_ranks[api]) + if failed_api_ranks[api] == rank: + print(f"Rank {rank} failed when first calling {api} communication API, exit...", flush=True) + return True + return False + + +def query_mask_buffer_and_check(api: Literal["dispatch", "combine", "clean"], buffer: deep_ep.Buffer, mask_status: torch.Tensor, + expected_masked_ranks: Set[int]): + buffer.low_latency_query_mask_buffer(mask_status) + assert set(mask_status.nonzero().squeeze(-1).tolist()) == expected_masked_ranks + + +def test_main(num_tokens: int, + hidden: int, + num_experts: int, + num_topk: int, + rank: int, + num_ranks: int, + group: dist.ProcessGroup, + buffer: deep_ep.Buffer, + use_logfmt: bool = False, + shrink_test: bool = False, + seed: int = 0): torch.manual_seed(seed + rank) random.seed(seed + rank) assert num_experts % num_ranks == 0 num_local_experts = num_experts // num_ranks - # NOTES: the integers greater than 256 exceeds the BF16 precision limit + # NOTES: the integers greater than 256 exceed the BF16 precision limit rank_offset = 128 assert num_ranks - rank_offset < 257, 'Too many ranks (exceeding test precision limit)' x = torch.ones((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * (rank - rank_offset) x[:, -128:] = torch.arange(num_tokens, device='cuda').to(torch.bfloat16).view(-1, 1) + x_list = [x] + for _ in range(4 if use_logfmt else 0): + # NOTES: make more LogFMT casts and also with some BF16 + x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.5 * random.random()) + # NOTES: the last one is for performance testing + # Most of the values in the perf case is lower than the threshold, casting most channels + x_list.append(torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') * 0.1) + scores = torch.randn((num_tokens, num_experts), dtype=torch.float32, device='cuda').abs() + 1 topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=True)[1] + topk_idx = topk_idx.to(deep_ep.topk_idx_t) topk_weights = torch.randn((num_tokens, num_topk), dtype=torch.float32, device='cuda').abs() + # Randomly mask some positions - for i in range(10): + for _ in range(10): topk_idx[random.randint(0, num_tokens - 1), random.randint(0, num_topk - 1)] = -1 + all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda') + dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group) + + # For failure simulation and shrink testing + mask_status = torch.zeros((num_ranks, ), dtype=torch.int, device='cuda') + expected_masked_ranks = set() + # Check dispatch correctness do_check = True hash_value, num_times = 0, 0 + for current_x in x_list: + for return_recv_hook in (False, True): + for dispatch_use_fp8 in (False, True): + for round_scale in (False, True) if dispatch_use_fp8 else (False, ): + for use_ue8m0 in (False,) if round_scale else (False, ): + if shrink_test and simulate_failure_and_skip(rank, "dispatch", expected_masked_ranks): + break + num_times += 1 + for _ in range((num_times % 2) + 1): + cumulative_local_expert_recv_stats = torch.zeros((num_local_experts, ), dtype=torch.int, device='cuda') + packed_recv_x, packed_recv_count, handle, event, hook = \ + buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts, + use_fp8=dispatch_use_fp8, round_scale=round_scale, use_ue8m0=use_ue8m0, + cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats, + async_finish=not return_recv_hook, return_recv_hook=return_recv_hook) + hook() if return_recv_hook else event.current_stream_wait() + if shrink_test: + query_mask_buffer_and_check("dispatch", buffer, mask_status, expected_masked_ranks) + packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x + simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \ + if dispatch_use_fp8 else packed_recv_x.clone() + for i in range(num_local_experts if do_check else 0): + expert_id = rank * num_local_experts + i + recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i] + recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i] - for return_recv_hook in (False, True): - for dispatch_use_fp8 in (False, True): - num_times += 1 - for i in range((num_times) + 1): - packed_recv_x, packed_recv_count, handle, event, hook = \ - buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, use_fp8=dispatch_use_fp8, - async_finish=not return_recv_hook, return_recv_hook=return_recv_hook) - hook() if return_recv_hook else event.current_stream_wait() - packed_recv_x = (packed_recv_x[0], packed_recv_x[1].contiguous()) if dispatch_use_fp8 else packed_recv_x - simulated_gemm_x = per_token_cast_back(packed_recv_x[0].view(-1, hidden), packed_recv_x[1].view(-1, hidden // 128)).view(packed_recv_x[0].shape) \ - if dispatch_use_fp8 else packed_recv_x.clone() - #print(f"rank{rank}: packed_recv_x[0]\n{packed_recv_x[0].cpu()}\n") - #print(f"rank{rank}: packed_recv_x[1]\n{packed_recv_x[1].cpu()}\n") - #print(f"simulated_gemm_x{simulated_gemm_x.cpu()}") - all_topk_idx = torch.empty((num_ranks, num_tokens, num_topk), dtype=topk_idx.dtype, device='cuda') - dist.all_gather_into_tensor(all_topk_idx, topk_idx, group=group) - for i in range(num_local_experts if do_check else 0): - expert_id = rank * num_local_experts + i - recv_x = per_token_cast_back(packed_recv_x[0][i], packed_recv_x[1][i]) if dispatch_use_fp8 else packed_recv_x[i] - recv_count, recv_src_info, recv_layout_range = packed_recv_count[i], handle[0][i], handle[1][i] - - # Check expert indices - int_mask = (2 ** 32) - 1 - num_valid_tokens = recv_count.item() - assert num_valid_tokens == (recv_layout_range & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()' - assert num_valid_tokens == (all_topk_idx == expert_id).sum().item(), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum().item()}' - - # Check received data - recv_x = recv_x[:num_valid_tokens] - recv_x_amin = recv_x[:, :-128].amin(dim=-1) - recv_src_info = recv_src_info[:num_valid_tokens] - assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1)) - assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0 - for j in range(num_ranks): - begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item() - assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item() - assert (recv_x[begin_idx:begin_idx + count][:-128] - j).sum().item() == 0 - if dispatch_use_fp8: - hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens]) - hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens]) - else: - hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens]) - - # Check combine correctness - for zero_copy in (False,True): - if zero_copy: - buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x - out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') - combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, - async_finish=not return_recv_hook, - return_recv_hook=return_recv_hook, out=out) - hook() if return_recv_hook else event.current_stream_wait() - if do_check: - diff = calc_diff(x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x) - assert torch.isnan(combined_x).sum().item() == 0 - assert diff < 1e-5, f'Error: diff={diff}' - hash_value ^= hash_tensor(combined_x) - - def create_test_cast_with_outliers(num_outliers): - tmp = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') - tmp /= tmp.abs().amax(dim=1).view(-1, 1) - assert tmp.abs().amax().item() <= 1 - - # Create some amax outliers - for i in range(num_outliers): - tmp[random.randint(0, num_tokens - 1)] *= 1e3 - return tmp + # Check expert indices + int_mask = (2**32) - 1 + num_valid_tokens = recv_count.item() + # cumulative_local_expert_recv_stats not currently enabled. + #assert cumulative_local_expert_recv_stats[i].item( + #) == num_valid_tokens, f'{cumulative_local_expert_recv_stats[i].item()} != {num_valid_tokens}' + assert num_valid_tokens == ( + recv_layout_range + & int_mask).sum().item(), f'{num_valid_tokens} != {recv_layout_range & int_mask}.sum().item()' + assert num_valid_tokens == (all_topk_idx == expert_id).sum(dim=[1, 2])[mask_status == 0].sum().item( + ), f'{num_valid_tokens} != {(all_topk_idx == expert_id).sum(dim=[1, 2])[mask_status==0].sum().item()}' + + if num_valid_tokens == 0: + continue + # Check received data + if current_x is x: + recv_x = recv_x[:num_valid_tokens] + recv_x_amin = recv_x[:, :-128].amin(dim=-1) + recv_src_info = recv_src_info[:num_valid_tokens] + assert torch.equal(recv_x_amin, recv_x[:, :-128].amax(dim=-1)) + if round_scale: + assert calc_diff(recv_x[:, -1], recv_src_info.view(-1)) < 0.007 + else: + assert (recv_x[:, -128:] - recv_src_info.view(-1, 1) % num_tokens).sum().item() == 0 + for j in range(num_ranks): + if shrink_test and mask_status[j]: + continue + begin_idx, count = (recv_layout_range[j] >> 32).item(), (recv_layout_range[j] & int_mask).item() + if not round_scale: + assert (recv_x_amin == j - rank_offset).sum().item() == (all_topk_idx[j] == expert_id).sum().item() + assert (recv_x[begin_idx:begin_idx + count, :-128] - j + rank_offset).sum().item() == 0 + if dispatch_use_fp8: + hash_value ^= hash_tensor(packed_recv_x[0][i, :num_valid_tokens]) + hash_value ^= hash_tensor(packed_recv_x[1][i, :num_valid_tokens]) + else: + hash_value ^= hash_tensor(packed_recv_x[i, :num_valid_tokens]) + + # Check combine correctness + if shrink_test and simulate_failure_and_skip(rank, "combine", expected_masked_ranks): + break + for zero_copy in (False, ) if use_logfmt else (False, True): + if zero_copy: + buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x + out = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, + topk_idx, + topk_weights, + handle, + use_logfmt=use_logfmt, + async_finish=not return_recv_hook, + zero_copy=zero_copy, + return_recv_hook=return_recv_hook, + out=out) + hook() if return_recv_hook else event.current_stream_wait() + if shrink_test: + query_mask_buffer_and_check("combine", buffer, mask_status, expected_masked_ranks) + if do_check: + if shrink_test: + owner_by_expert = (torch.arange(num_experts, device='cuda') // num_local_experts) + fail_owner_mask = (mask_status == 1).index_select(0, owner_by_expert) + valid_topk_idx = topk_idx >= 0 + failed_topk_idx = torch.zeros_like(topk_idx, device='cuda', dtype=torch.bool) + failed_topk_idx[valid_topk_idx] = fail_owner_mask.index_select(0, topk_idx[valid_topk_idx]) + topk_idx[failed_topk_idx] = -1 + diff = calc_diff(current_x * topk_weights.masked_fill(topk_idx == -1, 0).sum(dim=1).view(-1, 1), combined_x) + assert torch.isnan(combined_x).sum().item() == 0 + if not round_scale: + assert diff < (9e-4 if dispatch_use_fp8 else 1e-5), f'Error: {diff=}, {dispatch_use_fp8=}, {zero_copy=}' + hash_value ^= hash_tensor(combined_x) + + # Clean buffer API + if shrink_test: + if simulate_failure_and_skip(rank, "clean", expected_masked_ranks): + break + + buffer.clean_low_latency_buffer(num_tokens, hidden, num_experts) + query_mask_buffer_and_check("clean", buffer, mask_status, expected_masked_ranks) + + if shrink_test: + return # noinspection PyShadowingNames def large_gemm_with_hook(hook): @@ -108,24 +199,28 @@ def large_gemm_with_hook(hook): hook() # noinspection PyShadowingNames - def test_func(zero_copy: bool, return_recv_hook: bool): + def test_func(return_recv_hook: bool): recv_x, recv_count, handle, event, hook = \ - buffer.low_latency_dispatch(x, topk_idx, num_tokens, num_experts, use_fp8=True, - async_finish=False, return_recv_hook=return_recv_hook) + buffer.low_latency_dispatch(current_x, topk_idx, num_tokens, num_experts, + cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats, + use_fp8=True, async_finish=False, return_recv_hook=return_recv_hook) large_gemm_with_hook(hook) if return_recv_hook else None - if zero_copy: - buffer.get_next_low_latency_combine_buffer(handle)[:, :, :] = simulated_gemm_x - combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, topk_idx, topk_weights, handle, - zero_copy=zero_copy, return_recv_hook=return_recv_hook) + combined_x, event, hook = buffer.low_latency_combine(simulated_gemm_x, + topk_idx, + topk_weights, + handle, + use_logfmt=use_logfmt, + return_recv_hook=return_recv_hook) large_gemm_with_hook(hook) if return_recv_hook else None # Calculate bandwidth num_fp8_bytes, num_bf16_bytes = (hidden + hidden / 128 * 4 + 16), hidden * 2 + num_logfmt10_bytes = hidden * 10 / 8 + hidden / 128 * 4 num_dispatch_comm_bytes, num_combine_comm_bytes = 0, 0 for i in range(num_tokens): num_selections = (topk_idx[i] != -1).sum().item() num_dispatch_comm_bytes += num_fp8_bytes * num_selections - num_combine_comm_bytes += num_bf16_bytes * num_selections + num_combine_comm_bytes += (num_logfmt10_bytes if use_logfmt else num_bf16_bytes) * num_selections # Dispatch + combine testing avg_t, min_t, max_t = bench(partial(test_func, zero_copy=False, return_recv_hook=False)) @@ -135,43 +230,100 @@ def test_func(zero_copy: bool, return_recv_hook: bool): # Separate profiling for return_recv_hook in (False, True): group.barrier() - - dispatch_t, combine_t = bench_kineto(partial(test_func, zero_copy=False, return_recv_hook=return_recv_hook), - kernel_names=('dispatch', 'combine'), barrier_comm_profiling=True, - suppress_kineto_output=True) + dispatch_t, combine_t = bench_kineto(partial(test_func, return_recv_hook=return_recv_hook), + kernel_names=('dispatch', 'combine'), + barrier_comm_profiling=True, + suppress_kineto_output=True, + num_kernels_per_period=2 if return_recv_hook else 1) if not return_recv_hook: - print(f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | ' - f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us') + print( + f'[rank {rank}] Dispatch bandwidth: {num_dispatch_comm_bytes / 1e9 / dispatch_t:.2f} GB/s, avg_t={dispatch_t * 1e6:.2f} us | ' + f'Combine bandwidth: {num_combine_comm_bytes / 1e9 / combine_t:.2f} GB/s, avg_t={combine_t * 1e6:.2f} us', + flush=True) else: - print(f'[rank {rank}] Dispatch send/recv time: {dispatch_t * 2 * 1e6:.2f} us | ' - f'Combine send/recv time: {combine_t * 2 * 1e6:.2f} us') - + print( + f'[rank {rank}] Dispatch send/recv time: {dispatch_t[0] * 1e6:.2f} + {dispatch_t[1] * 1e6:.2f} us | ' + f'Combine send/recv time: {combine_t[0] * 1e6:.2f} + {combine_t[1] * 1e6:.2f} us', + flush=True) return hash_value -# noinspection PyUnboundLocalVariable -def test_loop(local_rank: int, num_local_ranks: int): +# noinspection PyUnboundLocalVariable,PyShadowingNames +def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): rank, num_ranks, group = init_dist(local_rank, num_local_ranks) - # The default setting of deepEP upstream is below: - num_tokens, hidden, num_topk, num_experts = 128, 7168, 8, 288 + num_tokens, hidden = args.num_tokens, args.hidden + num_topk, num_experts = args.num_topk, args.num_experts num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(num_tokens, hidden, num_ranks, num_experts) if local_rank == 0: print(f'Allocating buffer size: {num_rdma_bytes / 1e6} MB ...', flush=True) - buffer = deep_ep.Buffer(group, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True, - num_qps_per_rank=num_experts // num_ranks) - test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=1) + buffer = deep_ep.Buffer(group, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=True, + num_qps_per_rank=num_experts // num_ranks, + allow_nvlink_for_low_latency_mode=not args.disable_nvlink, + explicitly_destroy=True, + allow_mnnvl=args.allow_mnnvl, + enable_shrink=args.shrink_test) + test_main(num_tokens, + hidden, + num_experts, + num_topk, + rank, + num_ranks, + group, + buffer, + use_logfmt=args.use_logfmt, + shrink_test=args.shrink_test, + seed=1) - do_pressure_test = False + do_pressure_test = args.pressure_test for seed in range(int(1e9) if do_pressure_test else 0): if local_rank == 0: print(f'Testing with seed {seed} ...', flush=True) - ref_hash = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed) - for i in range(20): - assert test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed) == ref_hash, f'Error: seed={seed}' + ref_hash = test_main(num_tokens, + hidden, + num_experts, + num_topk, + rank, + num_ranks, + group, + buffer, + use_logfmt=args.use_logfmt, + seed=seed) + for _ in range(20): + assert test_main(num_tokens, + hidden, + num_experts, + num_topk, + rank, + num_ranks, + group, + buffer, + use_logfmt=args.use_logfmt, + seed=seed) == ref_hash, f'Error: seed={seed}' + + # Destroy the buffer runtime and communication group + buffer.destroy() + dist.barrier() + dist.destroy_process_group() if __name__ == '__main__': # TODO: you may modify NUMA binding for less CPU overhead - num_processes = 8 - torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes) + # TODO: buggy with `num_tokens=512` + parser = argparse.ArgumentParser(description='Test low-latency EP kernels') + parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)') + parser.add_argument('--num-tokens', type=int, default=128, help='Number of tokens (default: 128)') + parser.add_argument('--hidden', type=int, default=7168, help='Hidden dimension size (default: 7168)') + parser.add_argument('--num-topk', type=int, default=8, help='Number of top-k experts (default: 8)') + parser.add_argument('--num-experts', type=int, default=288, help='Number of experts (default: 288)') + parser.add_argument('--allow-mnnvl', action="store_true", help='Allow MNNVL for communication') + parser.add_argument('--disable-nvlink', action='store_true', help='Whether to disable NVLink for testing') + parser.add_argument('--use-logfmt', action='store_true', help='Whether to test LogFMT combine') + parser.add_argument("--pressure-test", action='store_true', help='Whether to do pressure test') + parser.add_argument("--shrink-test", action='store_true', help='Whether to simulate failure and test shrink mode') + args = parser.parse_args() + + num_processes = args.num_processes + torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes) diff --git a/tests/utils.py b/tests/utils.py index d980c5c..94f4f02 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,26 +1,34 @@ +import inspect +import json +import tempfile +from pathlib import Path + +import numpy as np import os import sys -import numpy as np import torch import torch.distributed as dist -from typing import Optional +from typing import Optional, Union -def init_dist(local_rank: int, num_local_ranks: int, backend: str = 'nccl'): +def init_dist(local_rank: int, num_local_ranks: int): # NOTES: you may rewrite this function with your own cluster settings - if backend == 'nccl': - ip = os.getenv('MASTER_ADDR', '127.0.0.1') - port = int(os.getenv('MASTER_PORT', '8361')) - node_rank = int(os.getenv('RANK', 0)) + ip = os.getenv('MASTER_ADDR', '127.0.0.1') + port = int(os.getenv('MASTER_PORT', '8361')) num_nodes = int(os.getenv('WORLD_SIZE', 1)) - assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8 - if backend == 'nccl': - dist.init_process_group( - backend='nccl', - init_method=f'tcp://{ip}:{port}', - world_size=num_nodes * num_local_ranks, - rank=node_rank * num_local_ranks + local_rank - ) + node_rank = int(os.getenv('RANK', 0)) + + sig = inspect.signature(dist.init_process_group) + params = { + 'backend': 'nccl', + 'init_method': f'tcp://{ip}:{port}', + 'world_size': num_nodes * num_local_ranks, + 'rank': node_rank * num_local_ranks + local_rank, + } + if 'device_id' in sig.parameters: + # noinspection PyTypeChecker + params['device_id'] = torch.device(f'cuda:{local_rank}') + dist.init_process_group(**params) torch.set_default_dtype(torch.bfloat16) torch.set_default_device('cuda') torch.cuda.set_device(local_rank) @@ -35,18 +43,35 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): return (1 - sim).item() +def align_up(x, y): + return (x + y - 1) // y * y + + def per_token_cast_to_fp8(x: torch.Tensor): - assert x.dim() == 2 and x.size(1) % 128 == 0 + assert x.dim() == 2 m, n = x.shape - x_view = x.view(m, -1, 128) - x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) - return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) + aligned_n = align_up(n, 128) + x_padded = torch.nn.functional.pad(x, (0, aligned_n - n), mode='constant', value=0) + x_padded_view = x_padded.view(m, -1, 128) + x_amax = x_padded_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + return (x_padded_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view( + m, aligned_n)[:, :n].contiguous(), (x_amax / 448.0).view(m, -1) def per_token_cast_back(x_fp8: torch.Tensor, x_scales: torch.Tensor): - x_fp32 = x_fp8.to(torch.float32).view(x_fp8.size(0), -1, 128) + if x_fp8.numel() == 0: + return x_fp8.to(torch.bfloat16) + + assert x_fp8.dim() == 2 + m, n = x_fp8.shape + aligned_n = align_up(n, 128) + x_fp8_padded = torch.nn.functional.pad(x_fp8, (0, aligned_n - n), mode='constant', value=0) + if x_scales.dtype == torch.int: + x_scales = x_scales.view(dtype=torch.uint8).to(torch.int) << 23 + x_scales = x_scales.view(dtype=torch.float) + x_fp32_padded = x_fp8_padded.to(torch.float32).view(x_fp8.size(0), -1, 128) x_scales = x_scales.view(x_fp8.size(0), -1, 1) - return (x_fp32 * x_scales).view(x_fp8.shape).to(torch.bfloat16) + return (x_fp32_padded * x_scales).view(x_fp8_padded.shape).to(torch.bfloat16)[:, :n].contiguous() def inplace_unique(x: torch.Tensor, num_slots: int): @@ -72,7 +97,7 @@ def create_grouped_scores(scores: torch.Tensor, group_idx: torch.Tensor, num_gro return (scores * mask).view(num_tokens, num_experts) -def bench(fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None): +def bench(fn, num_warmups: int = 50, num_tests: int = 50, post_fn=None): # Flush L2 cache with 256 MB data torch.cuda.synchronize() cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') @@ -101,6 +126,7 @@ def bench(fn, num_warmups: int = 20, num_tests: int = 30, post_fn=None): class empty_suppress: + def __enter__(self): return self @@ -109,6 +135,7 @@ def __exit__(self, *_): class suppress_stdout_stderr: + def __enter__(self): self.outnull_file = open(os.devnull, 'w') self.errnull_file = open(os.devnull, 'w') @@ -148,9 +175,9 @@ def bench_kineto(fn, kernel_names, num_tests: int = 100, suppress_kineto_output: # Profile suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress with suppress(): - schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) + schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1) with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) as prof: - for i in range(2): + for _ in range(2): # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead if barrier_comm_profiling: lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda') @@ -159,11 +186,12 @@ def bench_kineto(fn, kernel_names, num_tests: int = 100, suppress_kineto_output: dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda')) for _ in range(num_tests): fn() + torch.cuda.synchronize() prof.step() # Parse the profiling table - assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) - is_tupled = isinstance(kernel_names, tuple) + assert isinstance(kernel_names, (str, tuple)) + is_tuple = isinstance(kernel_names, tuple) prof_lines = prof.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names assert all([isinstance(name, str) for name in kernel_names]) @@ -174,20 +202,36 @@ def bench_kineto(fn, kernel_names, num_tests: int = 100, suppress_kineto_output: if trace_path is not None: prof.export_chrome_trace(trace_path) - # Return average kernel times + # Return average kernel durations units = {'ms': 1e3, 'us': 1e6} - kernel_times = [] + kernel_durations = [] for name in kernel_names: for line in prof_lines: if name in line: time_str = line.split()[-2] for unit, scale in units.items(): if unit in time_str: - kernel_times.append(float(time_str.replace(unit, '')) / scale) + kernel_durations.append(float(time_str.replace(unit, '')) / scale) break break - return tuple(kernel_times) if is_tupled else kernel_times[0] + + # Expand the kernels by periods + if num_kernels_per_period > 1: + with tempfile.NamedTemporaryFile(suffix='.json') as tmp: + prof.export_chrome_trace(tmp.name) + profile_data = json.loads(Path(tmp.name).read_text()) + + for i, kernel_name in enumerate(kernel_names): + events = [event for event in profile_data['traceEvents'] if f'::{kernel_name}' in event['name']] + events = sorted(events, key=lambda event: event['ts']) + durations = [event['dur'] / 1e6 for event in events] + assert len(durations) % num_kernels_per_period == 0 + num_kernel_patterns = len(durations) // num_kernels_per_period + kernel_durations[i] = [sum(durations[j::num_kernels_per_period]) / num_kernel_patterns for j in range(num_kernels_per_period)] + + # Return execution durations + return kernel_durations if is_tuple else kernel_durations[0] def hash_tensor(t: torch.Tensor): - return t.view(torch.int64).sum().item() + return t.view(torch.int).sum().item() diff --git a/third-party/README.md b/third-party/README.md index 505efc8..39ad467 100644 --- a/third-party/README.md +++ b/third-party/README.md @@ -23,12 +23,6 @@ MPI_ROOT=$BUILD_DIR/ompi ../rocSHMEM/scripts/build_configs/gda_mlx5 --fresh \ -DUSE_IPC=ON \ -DGDA_BNXT=ON -# To build rocSHMEM with MPI disabled, please add this flag -DUSE_EXTERNAL_MPI=OFF -MPI_ROOT=$BUILD_DIR/ompi ../rocSHMEM/scripts/build_configs/gda_mlx5 --fresh \ - -DUSE_IPC=ON \ - -DGDA_BNXT=ON - -DUSE_EXTERNAL_MPI=OFF - # You may pass additional arguments to Cmake, # e.g., -DBUILD_LOCAL_GPU_TARGET_ONLY=ON ``` From 96f6f6140446bb65e2c53f5b63bea2d94a8ad595 Mon Sep 17 00:00:00 2001 From: Richard Chamberlain Date: Fri, 7 Nov 2025 04:22:06 -0800 Subject: [PATCH 13/22] Adding support for disabling MPI. --- README.md | 10 ++++++++- setup.py | 49 ++++++++++++++++++++++++++++++----------- tests/test_internode.py | 12 ++++++++-- tests/utils.py | 38 ++++++++++++++++++-------------- 4 files changed, 77 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 4dd1e64..d38de3c 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,8 @@ DeepEP (AMD version) depends on [rocSHMEM](https://github.com/ROCm/rocSHMEM). Pl git clone https://github.com/ROCm/DeepEP cd DeepEP -# export OMPI dir in the next command (e.g., it's $BUILD_DIR/ompi in third-party/README.md) +# To use DeepEP with MPI, please proceed with these commands +# Export OMPI dir in the next command (e.g., it's $BUILD_DIR/ompi in third-party/README.md) export OMPI_DIR= python3 setup.py --variant rocm build develop --user @@ -39,11 +40,18 @@ python3 setup.py --variant rocm build develop --user # Then install DeepEP using this command python3 setup.py --variant rocm --disable-mpi build develop --user + +# To use DeepEP without MPI, please make sure rocSHMEM was built with this flag -DUSE_EXTERNAL_MPI=OFF +# Then install DeepEP using this command +python3 setup.py --variant rocm --disable-mpi build develop + # Run test cases # NOTES: you may modify the `init_dist` function in `tests/utils.py` # according to your own cluster settings, and launch into multiple nodes python3 tests/test_intranode.py python3 tests/test_internode.py +# Set the required ROCSHMEM heap size (for example, for DeepSeek models) +export ROCSHMEM_HEAP_SIZE=2147483648 python3 tests/test_low_latency.py ``` diff --git a/setup.py b/setup.py index b6a8111..6645bd6 100644 --- a/setup.py +++ b/setup.py @@ -28,12 +28,14 @@ def get_nvshmem_host_lib_name(base_dir): parser.add_argument("--verbose", action="store_true", help="Verbose build") parser.add_argument("--enable_timer", action="store_true", help="Enable timer to debug time out in internode") parser.add_argument("--rocm-disable-ctx", action="store_true", help="Disable workgroup context optimization in internode") + parser.add_argument("--disable-mpi", action="store_true", help="Disable MPI detection and configuration") # Get the arguments to be parsed and separate setuptools arguments args, unknown_args = parser.parse_known_args() variant = args.variant debug = args.debug rocm_disable_ctx = args.rocm_disable_ctx + disable_mpi = args.disable_mpi enable_timer = args.enable_timer @@ -57,8 +59,8 @@ def get_nvshmem_host_lib_name(base_dir): if not disable_nvshmem: assert os.path.exists(nvshmem_dir), f'The specified NVSHMEM directory does not exist: {nvshmem_dir}' - #else: - # disable_nvshmem = False + else: + disable_nvshmem = False # Reset sys.argv for setuptools to avoid conflicts @@ -86,7 +88,8 @@ def get_nvshmem_host_lib_name(base_dir): ), f"Failed to find {shmem_variant_name}" print(f"{shmem_variant_name} directory: {shmem_dir}") - if variant == "rocm": + ompi_dir = None + if variant == "rocm" and not disable_mpi: # Attempt to auto-detect OpenMPI installation directory if OMPI_DIR not set. # The first existing candidate containing bin/mpicc will be used. ompi_dir_env = os.getenv("OMPI_DIR", "").strip() @@ -107,16 +110,30 @@ def get_nvshmem_host_lib_name(base_dir): mpicc_path = os.path.join(d, "bin", "mpicc") if os.path.exists(d) and os.path.exists(mpicc_path): ompi_dir = d - break - if ompi_dir is None: - # Fallback to root (will trigger the assert below) - ompi_dir = "/" + break + + assert ompi_dir is not None, ( + f"Failed to find OpenMPI installation. " + f"Searched: {', '.join([d for d in candidate_dirs if d])}. " + f"Set OMPI_DIR environment variable or use --disable-mpi flag." + ) print(f"Detected OpenMPI directory: {ompi_dir}") - assert os.path.exists(ompi_dir), f"Failed to find OMPI: {ompi_dir}" + elif variant == "rocm" and disable_mpi: + print("MPI detection disabled for ROCm variant") + elif variant == "cuda" and not disable_mpi: + print("MPI detection enabled for CUDA variant") + else: + print("MPI detection disabled for CUDA variant") + # TODO: currently, we only support Hopper architecture, we may add Ampere support later if variant == "rocm": - os.environ["PYTORCH_ROCM_ARCH"] = os.getenv("PYTORCH_ROCM_ARCH", "gfx942") + arch = os.getenv("PYTORCH_ROCM_ARCH") + allowed_arch = {"gfx942", "gfx950"} + if arch not in allowed_arch: + raise EnvironmentError( + f"Invalid PYTORCH_ROCM_ARCH='{arch}'. " + f"Use one of: {', '.join(sorted(allowed_arch))}.") elif variant == "cuda": os.environ["TORCH_CUDA_ARCH_LIST"] = "9.0" @@ -155,7 +172,7 @@ def get_nvshmem_host_lib_name(base_dir): nvcc_flags = [f"{optimization_flag}"] + debug_symbol_flags + define_macros include_dirs = ["csrc/", f"{shmem_dir}/include"] - if variant == "rocm": + if variant == "rocm" and ompi_dir is not None: include_dirs.append(f"{ompi_dir}/include") sources = [ @@ -168,7 +185,7 @@ def get_nvshmem_host_lib_name(base_dir): ] library_dirs = [f"{shmem_dir}/lib"] - if variant == "rocm": + if variant == "rocm" and ompi_dir is not None: library_dirs.append(f"{ompi_dir}/lib") # Disable aggressive PTX instructions @@ -197,10 +214,15 @@ def get_nvshmem_host_lib_name(base_dir): "-lamdhip64", "-lhsa-runtime64", "-libverbs", - f"-l:libmpi.so", - f"-Wl,-rpath,{ompi_dir}/lib", ] ) + if not disable_mpi: + extra_link_args.extend( + [ + f"-l:libmpi.so", + f"-Wl,-rpath,{ompi_dir}/lib", + ] + ) extra_compile_args = { "cxx": cxx_flags, @@ -218,6 +240,7 @@ def get_nvshmem_host_lib_name(base_dir): print(f' > Compilation flags: {extra_compile_args}') print(f' > Link flags: {extra_link_args}') print(f' > NVSHMEM path: {shmem_dir}') + print(f' > Disable MPI: {disable_mpi}') print() # noinspection PyBroadException diff --git a/tests/test_internode.py b/tests/test_internode.py index c03d344..3315b3d 100644 --- a/tests/test_internode.py +++ b/tests/test_internode.py @@ -298,7 +298,7 @@ def check_data(check_x, recv_gbl_rank_prefix_sum): # noinspection PyUnboundLocalVariable,PyShadowingNames def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): num_nodes = int(os.getenv('WORLD_SIZE', 1)) - rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + rank, num_ranks, group = init_dist(local_rank, num_local_ranks, backend=args.backend) if args.test_ll_compatibility: ll_num_tokens, ll_hidden, ll_num_experts, ll_num_topk = 16, 5120, 256, 9 @@ -353,6 +353,7 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): if __name__ == '__main__': parser = argparse.ArgumentParser(description='Test internode EP kernels') + parser.add_argument('--backend', type=str, choices=['mpi', 'nccl'], default='nccl',help='Backend for distributed communication (mpi or nccl)') parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)') parser.add_argument('--num-tokens', type=int, default=4096, help='Number of tokens (default: 4096)') parser.add_argument('--hidden', type=int, default=7168, help='Hidden dimension size (default: 7168)') @@ -373,4 +374,11 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): args.num_topk_groups = min(num_nodes, 4) num_processes = args.num_processes - torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes) + if args.backend == 'mpi': + dist.init_process_group(backend='mpi') + rank = dist.get_rank() + local_rank = rank % num_processes + test_loop(local_rank=local_rank, num_local_ranks=num_processes, args=args) + else: + torch.multiprocessing.spawn(test_loop, args=(num_processes, args), nprocs=num_processes) + diff --git a/tests/utils.py b/tests/utils.py index 94f4f02..df49cec 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -11,24 +11,30 @@ from typing import Optional, Union -def init_dist(local_rank: int, num_local_ranks: int): +def init_dist(local_rank: int, num_local_ranks: int, backend: str = 'nccl'): # NOTES: you may rewrite this function with your own cluster settings - ip = os.getenv('MASTER_ADDR', '127.0.0.1') - port = int(os.getenv('MASTER_PORT', '8361')) + if backend == 'nccl': + ip = os.getenv('MASTER_ADDR', '127.0.0.1') + port = int(os.getenv('MASTER_PORT', '8361')) + node_rank = int(os.getenv('RANK', 0)) + num_nodes = int(os.getenv('WORLD_SIZE', 1)) - node_rank = int(os.getenv('RANK', 0)) - - sig = inspect.signature(dist.init_process_group) - params = { - 'backend': 'nccl', - 'init_method': f'tcp://{ip}:{port}', - 'world_size': num_nodes * num_local_ranks, - 'rank': node_rank * num_local_ranks + local_rank, - } - if 'device_id' in sig.parameters: - # noinspection PyTypeChecker - params['device_id'] = torch.device(f'cuda:{local_rank}') - dist.init_process_group(**params) + + assert (num_local_ranks < 8 and num_nodes == 1) or num_local_ranks == 8 + + if backend == 'nccl': + sig = inspect.signature(dist.init_process_group) + params = { + 'backend': backend, + 'init_method': f'tcp://{ip}:{port}', + 'world_size': num_nodes * num_local_ranks, + 'rank': node_rank * num_local_ranks + local_rank, + } + if 'device_id' in sig.parameters: + # noinspection PyTypeChecker + params['device_id'] = torch.device(f'cuda:{local_rank}') + dist.init_process_group(**params) + torch.set_default_dtype(torch.bfloat16) torch.set_default_device('cuda') torch.cuda.set_device(local_rank) From 98705ff48e412944b39982d6ac808336a0636dec Mon Sep 17 00:00:00 2001 From: Richard Chamberlain Date: Fri, 7 Nov 2025 04:24:38 -0800 Subject: [PATCH 14/22] Restored readme code. --- third-party/README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/third-party/README.md b/third-party/README.md index 39ad467..505efc8 100644 --- a/third-party/README.md +++ b/third-party/README.md @@ -23,6 +23,12 @@ MPI_ROOT=$BUILD_DIR/ompi ../rocSHMEM/scripts/build_configs/gda_mlx5 --fresh \ -DUSE_IPC=ON \ -DGDA_BNXT=ON +# To build rocSHMEM with MPI disabled, please add this flag -DUSE_EXTERNAL_MPI=OFF +MPI_ROOT=$BUILD_DIR/ompi ../rocSHMEM/scripts/build_configs/gda_mlx5 --fresh \ + -DUSE_IPC=ON \ + -DGDA_BNXT=ON + -DUSE_EXTERNAL_MPI=OFF + # You may pass additional arguments to Cmake, # e.g., -DBUILD_LOCAL_GPU_TARGET_ONLY=ON ``` From 89661b527dbb5dd1be7d9d0c551a1bf822cacc03 Mon Sep 17 00:00:00 2001 From: Richard Chamberlain Date: Fri, 7 Nov 2025 08:58:07 -0800 Subject: [PATCH 15/22] Adding missing kernel file layout.cu --- csrc/kernels/layout.cu | 153 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 153 insertions(+) create mode 100644 csrc/kernels/layout.cu diff --git a/csrc/kernels/layout.cu b/csrc/kernels/layout.cu new file mode 100644 index 0000000..c3a16ae --- /dev/null +++ b/csrc/kernels/layout.cu @@ -0,0 +1,153 @@ +#include "configs.cuh" +#include "exception.cuh" +#include "launch.cuh" + +namespace deep_ep { + +namespace layout { + +template +__global__ void get_dispatch_layout(const topk_idx_t* topk_idx, + int* num_tokens_per_rank, + int* num_tokens_per_rdma_rank, + int* num_tokens_per_expert, + bool* is_token_in_rank, + int num_tokens, + int num_topk, + int num_ranks, + int num_experts) { + auto sm_id = static_cast(blockIdx.x); + auto thread_id = static_cast(threadIdx.x); + + // Count expert statistics + __shared__ int num_tokens_per_expert_per_thread[kNumThreads][kNumExpertsPerSM]; + int expert_begin_idx = sm_id * kNumExpertsPerSM, expert_end_idx = min(expert_begin_idx + kNumExpertsPerSM, num_experts); + if (expert_begin_idx < expert_end_idx) { + // Per-thread count + #pragma unroll + for (int i = 0; i < kNumExpertsPerSM; ++i) + num_tokens_per_expert_per_thread[thread_id][i] = 0; + #pragma unroll + for (int i = thread_id; i < num_tokens; i += kNumThreads) { + auto shifted_topk_idx = topk_idx + i * num_topk; + #pragma unroll + for (int j = 0, expert_idx; j < num_topk; ++j) { + expert_idx = static_cast(shifted_topk_idx[j]); + if (expert_begin_idx <= expert_idx and expert_idx < expert_end_idx) + ++num_tokens_per_expert_per_thread[thread_id][expert_idx - expert_begin_idx]; + } + } + __syncthreads(); + + // Sum up + EP_STATIC_ASSERT(kNumExpertsPerSM <= kNumThreads, "Too many experts per SM"); + if (expert_begin_idx + thread_id < expert_end_idx) { + int sum = 0; + #pragma unroll + for (int i = 0; i < kNumThreads; ++i) + sum += num_tokens_per_expert_per_thread[i][thread_id]; + num_tokens_per_expert[expert_begin_idx + thread_id] = sum; + } + return; + } + + if (num_tokens_per_rdma_rank != nullptr) + EP_DEVICE_ASSERT(num_ranks % NUM_MAX_NVL_PEERS == 0 and num_ranks > NUM_MAX_NVL_PEERS); + + // Count rank statistics + constexpr int kNumRDMARanksPerSM = kNumRanksPerSM / NUM_MAX_NVL_PEERS; + __shared__ int num_tokens_per_rank_per_thread[kNumThreads][kNumRanksPerSM]; + __shared__ int num_tokens_per_rdma_rank_per_thread[kNumThreads][kNumRDMARanksPerSM]; + auto sm_begin = (num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM; + int rank_begin_idx = (sm_id - sm_begin) * kNumRanksPerSM, rank_end_idx = min(rank_begin_idx + kNumRanksPerSM, num_ranks); + int rdma_rank_begin_idx = rank_begin_idx / NUM_MAX_NVL_PEERS, rdma_rank_end_idx = rank_end_idx / NUM_MAX_NVL_PEERS; + if (rank_begin_idx < rank_end_idx) { + const auto num_expert_per_rank = num_experts / num_ranks; + auto expert_begin = rank_begin_idx * num_expert_per_rank; + auto expert_end = rank_end_idx * num_expert_per_rank; + + // Per-thread count + #pragma unroll + for (int i = 0; i < kNumRanksPerSM; ++i) + num_tokens_per_rank_per_thread[thread_id][i] = 0; + #pragma unroll + for (int i = 0; i < kNumRDMARanksPerSM; ++i) + num_tokens_per_rdma_rank_per_thread[thread_id][i] = 0; + #pragma unroll + for (int i = thread_id; i < num_tokens; i += kNumThreads) { + auto shifted_topk_idx = topk_idx + i * num_topk; + int is_in_rank[kNumRanksPerSM] = {0}, is_in_rdma_rank[kNumRDMARanksPerSM] = {0}; + #pragma unroll + for (int j = 0, expert_idx, rank_idx; j < num_topk; ++j) { + expert_idx = static_cast(shifted_topk_idx[j]); + if (expert_begin <= expert_idx and expert_idx < expert_end) { + // Count single rank + rank_idx = expert_idx / num_expert_per_rank - rank_begin_idx; + is_in_rank[rank_idx]++, is_in_rdma_rank[rank_idx / NUM_MAX_NVL_PEERS]++; + } + } + + auto shifted_is_token_in_rank = is_token_in_rank + i * num_ranks; + #pragma unroll + for (int j = 0; j + rank_begin_idx < rank_end_idx; ++j) { + shifted_is_token_in_rank[j + rank_begin_idx] = (is_in_rank[j] > 0); + num_tokens_per_rank_per_thread[thread_id][j] += (is_in_rank[j] > 0); + } + + #pragma unroll + for (int j = 0; j + rdma_rank_begin_idx < rdma_rank_end_idx; ++j) + num_tokens_per_rdma_rank_per_thread[thread_id][j] += (is_in_rdma_rank[j] > 0); + } + __syncthreads(); + + // Sum up + EP_STATIC_ASSERT(kNumRanksPerSM <= kNumThreads, "Too many ranks per SM"); + if (rank_begin_idx + thread_id < rank_end_idx) { + int sum = 0; + #pragma unroll + for (int i = 0; i < kNumThreads; ++i) + sum += num_tokens_per_rank_per_thread[i][thread_id]; + num_tokens_per_rank[rank_begin_idx + thread_id] = sum; + } + + if (num_tokens_per_rdma_rank != nullptr and rdma_rank_begin_idx + thread_id < rdma_rank_end_idx) { + int sum = 0; + #pragma unroll + for (int i = 0; i < kNumThreads; ++i) + sum += num_tokens_per_rdma_rank_per_thread[i][thread_id]; + num_tokens_per_rdma_rank[rdma_rank_begin_idx + thread_id] = sum; + } + } +} + +void get_dispatch_layout(const topk_idx_t* topk_idx, + int* num_tokens_per_rank, + int* num_tokens_per_rdma_rank, + int* num_tokens_per_expert, + bool* is_token_in_rank, + int num_tokens, + int num_topk, + int num_ranks, + int num_experts, + cudaStream_t stream) { + constexpr int kNumThreads = 256, kNumExpertsPerSM = 4, kNumRanksPerSM = 8; + int num_sms = ((num_experts + kNumExpertsPerSM - 1) / kNumExpertsPerSM) + (num_ranks + kNumRanksPerSM - 1) / kNumRanksPerSM; + EP_STATIC_ASSERT(kNumRanksPerSM % NUM_MAX_NVL_PEERS == 0, "Invalid number of ranks per SM"); + + SETUP_LAUNCH_CONFIG(num_sms, kNumThreads, stream); + LAUNCH_KERNEL(&cfg, + (get_dispatch_layout), + topk_idx, + num_tokens_per_rank, + num_tokens_per_rdma_rank, + num_tokens_per_expert, + is_token_in_rank, + num_tokens, + num_topk, + num_ranks, + num_experts); +} + +} // namespace layout + +} // namespace deep_ep From 5003cfb7151b89ac89c6842d90fb439710288e36 Mon Sep 17 00:00:00 2001 From: RichardChamberlain1 Date: Fri, 7 Nov 2025 18:34:29 +0000 Subject: [PATCH 16/22] Update internode_ll.cu --- csrc/kernels/internode_ll.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index 66ce3c6..abd3232 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -193,7 +193,12 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, for (int j = 0;j < kNumElemsPerRead;j += 2) { float2 fp32x2 = {fp32_values[j] * scale, fp32_values[j + 1] * scale}; #ifdef USE_ROCM +#if defined(__gfx942__) fp8x2_values[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E4M3_FNUZ); +#endif +#if defined(__gfx950__) + fp8x2_values[j / 2] = __hip_cvt_float2_to_fp8x2(fp32x2, __HIP_SATFINITE, __HIP_E4M3); +#endif #else fp8x2_values[j / 2] = __nv_cvt_float2_to_fp8x2(fp32x2, __NV_SATFINITE, __NV_E4M3); #endif @@ -232,11 +237,6 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, internode::shmem_ctx_schar_put_nbi_warp(ctx, reinterpret_cast(dst_ptr), reinterpret_cast(src_ptr), num_bytes_per_msg, dst_rank); #endif reinterpret_cast(dst_ptr), reinterpret_cast(src_ptr), num_bytes_per_msg, dst_rank); -#if defined(ROCM_DISABLE_CTX) - internode::shmem_fence(); -#else - internode::shmem_ctx_quiet(ctx); -#endif #else nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx); #endif From d1f9414e6635e3d784a220125482460631a62c4c Mon Sep 17 00:00:00 2001 From: RichardChamberlain1 Date: Mon, 10 Nov 2025 12:19:14 +0000 Subject: [PATCH 17/22] Fix gfx950 FP8 datatypes --- csrc/deep_ep.cpp | 6 +++++- csrc/deep_ep.hpp | 3 +++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index 57ffd33..aa2b60e 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -57,6 +57,10 @@ Buffer::Buffer(int rank, // Get device info cudaDeviceProp device_prop = {}; CUDA_CHECK(cudaGetDeviceProperties(&device_prop, device_id)); +#ifdef USE_ROCM + sscanf(device_prop.gcnArchName, "gfx%d", &gfx); + EP_HOST_ASSERT(gfx >= 942); +#endif num_device_sms = device_prop.multiProcessorCount; // Number of per-channel bytes cannot be large @@ -1526,7 +1530,7 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, // Allocate packed tensors #ifdef USE_ROCM auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, - x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fnuz : torch::kBFloat16)); + x.options().dtype(use_fp8 ? (gfx == 942 ? torch::kFloat8_e4m3fnuz : torch::kFloat8_e4m3fn) : torch::kBFloat16)); #else auto packed_recv_x = torch::empty({num_local_experts, num_ranks * num_max_dispatch_tokens_per_rank, hidden}, x.options().dtype(use_fp8 ? torch::kFloat8_e4m3fn : torch::kBFloat16)); diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index 90fc163..94d8f0a 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -47,6 +47,9 @@ struct Buffer { // Device info and communication int device_id; +#ifdef USE_ROCM + int gfx; +#endif int num_device_sms; int rank, rdma_rank, nvl_rank; int num_ranks, num_rdma_ranks, num_nvl_ranks; From 68c1015ac5c67289ecb42027aaf1979e604bdbf3 Mon Sep 17 00:00:00 2001 From: RichardChamberlain1 Date: Mon, 10 Nov 2025 12:36:53 +0000 Subject: [PATCH 18/22] Address review comments --- csrc/deep_ep.cpp | 1 + csrc/deep_ep.hpp | 2 -- csrc/kernels/intranode.cu | 3 --- csrc/kernels/runtime.cu | 9 --------- setup.py | 4 ++++ 5 files changed, 5 insertions(+), 14 deletions(-) diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index aa2b60e..cc9cee0 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -784,6 +784,7 @@ std::tuple, std::optional>({bias_0, bias_1}); void* bias_ptrs[2] = {nullptr, nullptr}; for (int i = 0; i < 2; ++i) diff --git a/csrc/deep_ep.hpp b/csrc/deep_ep.hpp index 94d8f0a..307db20 100644 --- a/csrc/deep_ep.hpp +++ b/csrc/deep_ep.hpp @@ -170,8 +170,6 @@ struct Buffer { std::tuple, std::optional> intranode_combine( const torch::Tensor& x, const std::optional& topk_weights, - //const std::optional& bias_0, - //const std::optional& bias_1, const torch::Tensor& src_idx, const torch::Tensor& rank_prefix_matrix, const torch::Tensor& channel_prefix_matrix, diff --git a/csrc/kernels/intranode.cu b/csrc/kernels/intranode.cu index 3ef7e35..7b36cc2 100644 --- a/csrc/kernels/intranode.cu +++ b/csrc/kernels/intranode.cu @@ -37,9 +37,6 @@ notify_dispatch(const int* num_tokens_per_rank, int* moe_recv_counter_mapped, // - `per_expert_buffer[rank][i, j]` means the number of tokens from rank i to local expert j int num_experts_per_rank = num_experts / kNumRanks; if (thread_id < kNumRanks) { - //#pragma unroll - //for (int i = 0; i < kNumRanks; ++ i) - // per_rank_buffer[rank * kNumRanks + i] = num_tokens_per_rank[i]; per_rank_buffer[rank * kNumRanks + thread_id] = num_tokens_per_rank[thread_id]; #pragma unroll for (int i = 0; i < num_experts_per_rank; ++ i) diff --git a/csrc/kernels/runtime.cu b/csrc/kernels/runtime.cu index 7fa6031..606e433 100644 --- a/csrc/kernels/runtime.cu +++ b/csrc/kernels/runtime.cu @@ -16,15 +16,6 @@ __global__ void barrier(int** task_fifo_ptrs, int head, int rank) { barrier_device(task_fifo_ptrs, head, rank); } -/*void barrier(int** task_fifo_ptrs, int head, int rank, int num_ranks, cudaStream_t stream) { -#define BARRIER_LAUNCH_CASE(ranks) \ - LAUNCH_KERNEL(&cfg, barrier, task_fifo_ptrs, head, rank); \ - break - - SETUP_LAUNCH_CONFIG(1, kWarpSize, stream); - SWITCH_RANKS(BARRIER_LAUNCH_CASE); -#undef BARRIER_LAUNCH_CASE -}*/ void barrier(int** barrier_signal_ptrs, int rank, int num_ranks, cudaStream_t stream, int head = 0) { #define BARRIER_LAUNCH_CASE(ranks) \ diff --git a/setup.py b/setup.py index 6645bd6..3e34858 100644 --- a/setup.py +++ b/setup.py @@ -197,6 +197,10 @@ def get_nvshmem_host_lib_name(base_dir): # Bits of `topk_idx.dtype`, choices are 32 and 64 if "TOPK_IDX_BITS" in os.environ: topk_idx_bits = int(os.environ['TOPK_IDX_BITS']) + assert topk_idx_bits in (32, 64), ( + f"Invalid TOPK_IDX_BITS={topk_idx_bits}. " + "Must be either 32 or 64." + ) cxx_flags.append(f'-DTOPK_IDX_BITS={topk_idx_bits}') nvcc_flags.append(f'-DTOPK_IDX_BITS={topk_idx_bits}') From 4f628e92914e6d63623ff143ac8231677fc51647 Mon Sep 17 00:00:00 2001 From: RichardChamberlain1 Date: Mon, 10 Nov 2025 12:45:57 +0000 Subject: [PATCH 19/22] Update utils.cuh Removed unused definition. --- csrc/kernels/utils.cuh | 4 ---- 1 file changed, 4 deletions(-) diff --git a/csrc/kernels/utils.cuh b/csrc/kernels/utils.cuh index 3410e0c..37de81d 100644 --- a/csrc/kernels/utils.cuh +++ b/csrc/kernels/utils.cuh @@ -2,10 +2,6 @@ #include "exception.cuh" -#ifdef USE_ROCM -#define syncthreads() __syncthreads() -#endif - #define UNROLLED_WARP_COPY(UNROLL_FACTOR, LANE_ID, N, DST, SRC, LD_FUNC, ST_FUNC) \ { \ constexpr int kLoopStride = kWarpSize * (UNROLL_FACTOR); \ From 231d01dc586da586be071d9d7a6a3c2b473db80e Mon Sep 17 00:00:00 2001 From: Richard Chamberlain Date: Thu, 13 Nov 2025 02:45:23 -0800 Subject: [PATCH 20/22] Removed broken buffer cleanup code --- csrc/deep_ep.cpp | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index cc9cee0..0ccf04d 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -131,28 +131,8 @@ Buffer::~Buffer() noexcept(false) { printf("WARNING: destroy() was not called before DeepEP buffer destruction, which can leak resources.\n"); fflush(stdout); } - - // Free NVSHMEM - if (num_rdma_bytes > 0) { - CUDA_CHECK(cudaDeviceSynchronize()); - internode::barrier(); - internode::free(rdma_buffer_ptr); - internode::finalize(); - } - - // Free cuBLAS handle, workspace and MoE counter - CUDA_CHECK(cudaFree(workspace)); - CUDA_CHECK(cudaFree(dispatch_global_atomic_counter)); - CUDA_CHECK(cudaFree(combine_global_atomic_counter)); - CUDA_CHECK(cudaFreeHost(const_cast(moe_recv_counter))); - - // Free chunked mode staffs - CUDA_CHECK(cudaFreeHost(const_cast(moe_recv_expert_counter))); } -void Buffer::move_fifo_slots(int num_slots) { - head = (head + num_ranks * num_slots) % NUM_MAX_FIFO_SLOTS; -} bool Buffer::is_available() const { return available; From e4a8886c8c3b269af5b42b3e9ebb58c9b34adf9f Mon Sep 17 00:00:00 2001 From: RichardChamberlain1 Date: Wed, 19 Nov 2025 18:28:59 +0000 Subject: [PATCH 21/22] Update shmem_wrapper.cuh --- csrc/kernels/shmem_wrapper.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/kernels/shmem_wrapper.cuh b/csrc/kernels/shmem_wrapper.cuh index 1e90717..275798e 100644 --- a/csrc/kernels/shmem_wrapper.cuh +++ b/csrc/kernels/shmem_wrapper.cuh @@ -65,7 +65,7 @@ static inline const auto &shmem_ibgda_amo_nonfetch_add = #if !defined(ROCM_DISABLE_CTX) using shmem_ctx_t = rocshmem::rocshmem_ctx_t; static inline const auto &shmem_wg_ctx_create = [] __device__(rocshmem::rocshmem_ctx_t *ctx) { - return rocshmem::rocshmem_wg_ctx_create(0, ctx); + return rocshmem::rocshmem_wg_ctx_create(ctx); }; static inline const auto &shmem_wg_ctx_destroy = rocshmem::rocshmem_wg_ctx_destroy; From 304504321813224d97da3476c243d843a9fb7c7d Mon Sep 17 00:00:00 2001 From: RichardChamberlain1 Date: Thu, 20 Nov 2025 16:36:26 +0000 Subject: [PATCH 22/22] Update shmem_wrapper.cuh --- csrc/kernels/shmem_wrapper.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/kernels/shmem_wrapper.cuh b/csrc/kernels/shmem_wrapper.cuh index 275798e..3a6d653 100644 --- a/csrc/kernels/shmem_wrapper.cuh +++ b/csrc/kernels/shmem_wrapper.cuh @@ -65,7 +65,7 @@ static inline const auto &shmem_ibgda_amo_nonfetch_add = #if !defined(ROCM_DISABLE_CTX) using shmem_ctx_t = rocshmem::rocshmem_ctx_t; static inline const auto &shmem_wg_ctx_create = [] __device__(rocshmem::rocshmem_ctx_t *ctx) { - return rocshmem::rocshmem_wg_ctx_create(ctx); + return rocshmem::rocshmem_wg_ctx_create(0,ctx); }; static inline const auto &shmem_wg_ctx_destroy = rocshmem::rocshmem_wg_ctx_destroy;