From 2f1e0497d1bb43236a33f3dc079ae0dceced37f1 Mon Sep 17 00:00:00 2001 From: amirakb89 Date: Fri, 7 Nov 2025 05:56:35 +0000 Subject: [PATCH 1/4] opt grid barrier --- csrc/kernels/internode_ll.cu | 66 +++++++++++++++++++----------------- 1 file changed, 35 insertions(+), 31 deletions(-) diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index 9cdc77d..723fb6b 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -19,9 +19,10 @@ namespace internode_ll { __device__ void grid_barrier(int* global_counter, int num_blocks) { volatile int ret; __syncthreads(); - __threadfence(); - if (threadIdx.x == 0 ) { - ret = __hip_atomic_fetch_add( &global_counter[0], 1, + if (threadIdx.x == 0 ) { + __threadfence(); + + ret = __hip_atomic_fetch_add( &global_counter[0], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); } __syncthreads(); @@ -120,18 +121,18 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4); EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0); -#ifdef USE_ROCM - // 16 is the max possible number of warps in AMD GPUs - constexpr int kMaxNumWarps = 1024 / kWarpSize; - constexpr int num_sync_large_iteration = kMaxNumWarps ; - __shared__ volatile uint8_t sync_large_warp_counters[num_sync_large_iteration]; - - #pragma unroll - for (int i = thread_id; i < num_sync_large_iteration; i += blockDim.x) { - sync_large_warp_counters[i] = 0; - } - __syncthreads(); -#endif +//#ifdef USE_ROCM +// // 16 is the max possible number of warps in AMD GPUs +// constexpr int kMaxNumWarps = 1024 / kWarpSize; +// constexpr int num_sync_large_iteration = kMaxNumWarps ; +// __shared__ volatile uint8_t sync_large_warp_counters[num_sync_large_iteration]; +// +// #pragma unroll +// for (int i = thread_id; i < num_sync_large_iteration; i += blockDim.x) { +// sync_large_warp_counters[i] = 0; +// } +// __syncthreads(); +//#endif // Sending phase if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_DISPATCH_RECV; @@ -165,7 +166,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, // Read auto int4_value = __ldg(x_int4 + i); - if (kUseFP8) { + if constexpr(kUseFP8) { // Calculate local amax auto bf16_values = reinterpret_cast(&int4_value); float fp32_values[kNumElemsPerRead]; @@ -267,10 +268,10 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, next_clean[i] = 0; // Notify before executing `int_p` - syncwarp(); - #pragma unroll - for (int i = lane_id; i < num_experts; i += kWarpSize) - atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG); + //syncwarp(); + //#pragma unroll + //for (int i = lane_id; i < num_experts; i += kWarpSize) + // atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG); } // This SM should be responsible for some destination experts, read `topk_idx` for them int expert_count[kNumWarpGroups] = {0}; @@ -278,7 +279,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, const auto expert_end_idx = min(expert_begin_idx + kNumWarpGroups, num_experts); // Per lane count - #pragma unroll 8 + #pragma unroll 2 for (int i = lane_id; i < num_tokens * num_topk; i += kWarpSize) { auto idx = static_cast(__ldg(topk_idx + i)); if (idx >= expert_begin_idx and idx < expert_end_idx) @@ -296,6 +297,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, } } +#if 0 if (thread_id == 0 and num_ranks > 8){ #if defined(ROCM_DISABLE_CTX) internode::shmem_fence(); @@ -303,6 +305,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, internode::shmem_ctx_quiet(ctx); #endif } +#endif //revert sync_large_warp_counters to 0 for next sync __syncthreads(); @@ -313,7 +316,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * kNumWarpGroups]; // Wait local sends issued and send expert counts - while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2); + while (ld_volatile_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG); if (dst_rank != rank) { #ifdef USE_ROCM #if defined(ROCM_DISABLE_CTX) @@ -377,13 +380,14 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, } #ifdef USE_ROCM // no needs to reset because there is no iteration - if (lane_id == 0){ - volatile int ret = __hip_atomic_fetch_add( - &sync_large_warp_counters[warp_group_id], 1, - __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP); - } - syncwarp(); - while (sync_large_warp_counters[warp_group_id] < (kNumWarpsPerGroup)); + // if (lane_id == 0){ + // volatile int ret = __hip_atomic_fetch_add( + // &sync_large_warp_counters[warp_group_id], 1, + // __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP); + // } + // syncwarp(); + // while (sync_large_warp_counters[warp_group_id] < (kNumWarpsPerGroup)); + __syncthreads(); #else asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(kNumWarpsPerGroup * 32)); #endif @@ -406,7 +410,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global); // Copy scales - if (kUseFP8) { + if constexpr(kUseFP8) { const auto src_scales = reinterpret_cast(reinterpret_cast(src_data) + hidden_bytes); const auto dst_scales = reinterpret_cast(recv_x_scales + recv_token_begin_idx + i); const auto scale_stride = num_ranks * num_max_dispatch_tokens_per_rank; @@ -728,4 +732,4 @@ LAUNCH_KERNEL_NON_COOPERATIVE(&cfg, combine_func, \ } // namespace internode_ll -} // namespace deep_ep \ No newline at end of file +} // namespace deep_ep From 6fcf7169a5a200def4c39eda01ab6f60a41bcc97 Mon Sep 17 00:00:00 2001 From: amirakb89 Date: Sat, 8 Nov 2025 00:48:37 +0000 Subject: [PATCH 2/4] change groupsize combine --- csrc/kernels/internode_ll.cu | 84 +++++++++++++++++++----------------- csrc/kernels/utils.cuh | 16 ++++--- tests/test_low_latency.py | 2 +- 3 files changed, 55 insertions(+), 47 deletions(-) diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index 723fb6b..9a806ac 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -16,10 +16,11 @@ namespace deep_ep { namespace internode_ll { -__device__ void grid_barrier(int* global_counter, int num_blocks) { +__device__ void grid_barrier(int* global_counter, int num_blocks, int do_fence=1) { volatile int ret; __syncthreads(); if (threadIdx.x == 0 ) { + if (do_fence) __threadfence(); ret = __hip_atomic_fetch_add( &global_counter[0], 1, @@ -511,21 +512,21 @@ combine(void* combined_x, // Message package // BF16 mode: always use BF16 for hidden data (ignoring the extra flag slot) constexpr size_t num_bytes_per_slot = sizeof(int4) + kHidden * sizeof(gpu_bfloat16_t); - EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization"); - __syncthreads(); -#ifdef USE_ROCM - // 16 is the max possible number of warps in AMD GPUs - constexpr int kMaxNumWarps = 1024 / kWarpSize; - __shared__ volatile int sync_large_warp_counters[kMaxNumWarps]; - if (threadIdx.x==0){ - // printf("combine"); - #pragma unroll - for (int i = 0; i < kMaxNumWarps; ++i) { - sync_large_warp_counters[i] = 0; - } - } - __syncthreads(); -#endif + // EP_STATIC_ASSERT(num_bytes_per_slot % sizeof(int4) == 0, "Invalid vectorization"); +// __syncthreads(); +// #ifdef USE_ROCM +// // 16 is the max possible number of warps in AMD GPUs +// constexpr int kMaxNumWarps = 1024 / kWarpSize; +// __shared__ volatile int sync_large_warp_counters[kMaxNumWarps]; +// if (threadIdx.x==0){ +// // printf("combine"); +// #pragma unroll +// for (int i = 0; i < kMaxNumWarps; ++i) { +// sync_large_warp_counters[i] = 0; +// } +// } +// __syncthreads(); +// #endif // Sending phase if ((phases & LOW_LATENCY_SEND_PHASE) == 0) @@ -540,7 +541,7 @@ combine(void* combined_x, // Notify before executing `int_p` syncwarp(); if (lane_id == 0) - atomic_add_release_global(atomic_clean_flag, num_experts); + atomic_add_relaxed_global(atomic_clean_flag, num_experts); } // Issue IBGDA sends @@ -571,11 +572,11 @@ combine(void* combined_x, const auto dst_ptr = reinterpret_cast(rdma_recv_x) + (global_expert_idx * num_max_dispatch_tokens_per_rank + src_idx) * num_bytes_per_slot + sizeof(int4); if (dst_rank == rank) { const auto dst_int4_ptr = reinterpret_cast(dst_ptr); - UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global); + UNROLLED_WARP_COPY(4, lane_id, hidden_bf16_int4, dst_int4_ptr, x_int4, ld_nc_global, st_na_global); } else { const auto buf_int4_ptr = reinterpret_cast(buf_ptr); if (not zero_copy) - UNROLLED_WARP_COPY(7, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global); + UNROLLED_WARP_COPY(4, lane_id, hidden_bf16_int4, buf_int4_ptr, x_int4, ld_nc_global, st_na_global); //nvshmemi_ibgda_put_nbi_warp(dst_ptr, buf_ptr, hidden * sizeof(gpu_bfloat16_t), dst_rank, local_expert_idx, lane_id, token_idx - offset); #if defined(ROCM_DISABLE_CTX) @@ -584,6 +585,7 @@ combine(void* combined_x, internode::shmem_ctx_schar_put_nbi_warp(ctx, #endif reinterpret_cast(dst_ptr), reinterpret_cast(buf_ptr), hidden * sizeof(gpu_bfloat16_t), dst_rank); +#if 0 if (num_ranks > 8){ #if defined(ROCM_DISABLE_CTX) internode::shmem_fence(); @@ -591,24 +593,26 @@ combine(void* combined_x, internode::shmem_ctx_quiet(ctx); #endif } +#endif //0 } } // Put finishing flag - EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group"); + // EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group"); #ifdef USE_ROCM - if (lane_id == 0){ - volatile int ret = __hip_atomic_fetch_add( - &sync_large_warp_counters[warp_group_id], 1, - __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP); - } - syncwarp(); - while (sync_large_warp_counters[warp_group_id] < (kNumWarpsPerGroup)); + // if (lane_id == 0){ + // volatile int ret = __hip_atomic_fetch_add( + // &sync_large_warp_counters[warp_group_id], 1, + // __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP); + // } + // syncwarp(); + // while (sync_large_warp_counters[warp_group_id] < (kNumWarpsPerGroup)); + __syncthreads(); #else asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 1), "r"(kNumWarpsPerGroup * 32)); #endif - if (sub_warp_id == 1 and lane_id == 0) { - while (ld_acquire_global(atomic_clean_flag) == 0); + if (sub_warp_id == 0 and lane_id == 0) { + while (ld_volatile_global(atomic_clean_flag) == 0); if (dst_rank != rank) { #ifdef USE_ROCM #if defined(ROCM_DISABLE_CTX) @@ -622,9 +626,9 @@ combine(void* combined_x, } else { st_na_release(reinterpret_cast(rdma_recv_flag + global_expert_idx), 1); } - atomic_add_release_global(atomic_clean_flag, -1); + atomic_add_relaxed_global(atomic_clean_flag, -1); } - syncwarp(); + //syncwarp(); } // Receiving phase @@ -634,16 +638,16 @@ combine(void* combined_x, // Wait all ranks to arrive and notify PCIe usage if (responsible_expert_idx < num_experts) { - EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group"); + // EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Invalid number of warps per group"); if (sub_warp_id == 0 and lane_id == 0){ - while (ld_acquire_global(reinterpret_cast(rdma_recv_flag + responsible_expert_idx)) == 0); + while (ld_volatile_global(reinterpret_cast(rdma_recv_flag + responsible_expert_idx)) == 0); } } - grid_barrier(global_atomic_counter, num_sms); + grid_barrier(global_atomic_counter, num_sms, 0); // Reduce tokens with FP8 cast - EP_DEVICE_ASSERT(num_topk <= kWarpSize and hidden_bf16_int4 <= num_threads); - EP_STATIC_ASSERT(kHidden % (kWarpSize * kNumElemsPerInt4) == 0, "Invalid vectorization"); + // EP_DEVICE_ASSERT(num_topk <= kWarpSize and hidden_bf16_int4 <= num_threads); + // EP_STATIC_ASSERT(kHidden % (kWarpSize * kNumElemsPerInt4) == 0, "Invalid vectorization"); if (thread_id < hidden_bf16_int4) { for (int token_idx = sm_id; token_idx < num_combined_tokens; token_idx += num_sms) { // Read top-k indices and weights @@ -665,7 +669,7 @@ combine(void* combined_x, // Reduce auto x_vec = ld_nc_global(reinterpret_cast(rdma_buffer_row) + thread_id); const auto x_bf16 = reinterpret_cast(&x_vec); - #pragma unroll + #pragma unroll 4 for (int j = 0; j < kNumElemsPerInt4; ++ j) combined_values[j] += static_cast(x_bf16[j]) * reg_topk_weights[i]; } @@ -673,7 +677,7 @@ combine(void* combined_x, // Write results int4& combined_int4 = *reinterpret_cast(combined_values); auto combined_bf16 = reinterpret_cast(&combined_values); - #pragma unroll + #pragma unroll 4 for (int j = 0; j < kNumElemsPerInt4; ++ j) combined_bf16[j] = static_cast(combined_values[j]); (reinterpret_cast(combined_x) + token_idx * hidden_bf16_int4)[thread_id] = combined_int4; @@ -695,8 +699,8 @@ void combine(void* combined_x, void* workspace, cudaStream_t stream, int phases, bool zero_copy) { #ifdef USE_ROCM - constexpr int kNumWarpsPerGroup = 4; - constexpr int kNumWarpGroups = 4; + constexpr int kNumWarpsPerGroup = 8; + constexpr int kNumWarpGroups = 2; #else constexpr int kNumWarpsPerGroup = 10; constexpr int kNumWarpGroups = 3; diff --git a/csrc/kernels/utils.cuh b/csrc/kernels/utils.cuh index 69182d0..da22ef4 100644 --- a/csrc/kernels/utils.cuh +++ b/csrc/kernels/utils.cuh @@ -232,18 +232,22 @@ __device__ __forceinline__ int ld_acquire_global(const int *ptr) { return ret; } //not used -__device__ __forceinline__ int atomic_add_release_sys_global(const int* ptr, int value) { + +__device__ __forceinline__ int atomic_add_release_global(const int* ptr, int value) { int ret; -#ifndef USE_ROCM - asm volatile("atom.add.release.sys.global.s32 %0, [%1], %2;" : "=r"(ret) : "l"(ptr), "r"(value)); +#ifdef USE_ROCM + ret = __hip_atomic_fetch_add(const_cast (ptr), value, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_AGENT); +#else + asm volatile("atom.add.release.gpu.global.s32 %0, [%1], %2;" : "=r"(ret) : "l"(ptr), "r"(value)); #endif - return ret; +return ret; } + //inter -__device__ __forceinline__ int atomic_add_release_global(const int* ptr, int value) { +__device__ __forceinline__ int atomic_add_relaxed_global(const int* ptr, int value) { int ret; #ifdef USE_ROCM - ret = __hip_atomic_fetch_add(const_cast (ptr), value, __ATOMIC_RELEASE, __HIP_MEMORY_SCOPE_AGENT); + ret = __hip_atomic_fetch_add(const_cast (ptr), value, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); #else asm volatile("atom.add.release.gpu.global.s32 %0, [%1], %2;" : "=r"(ret) : "l"(ptr), "r"(value)); #endif diff --git a/tests/test_low_latency.py b/tests/test_low_latency.py index 0c65024..527c51f 100644 --- a/tests/test_low_latency.py +++ b/tests/test_low_latency.py @@ -173,5 +173,5 @@ def test_loop(local_rank: int, num_local_ranks: int): if __name__ == '__main__': # TODO: you may modify NUMA binding for less CPU overhead - num_processes = 8 + num_processes = 1 torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes) From 603b03cf294a4cc48dab2731b09f95f69883cefa Mon Sep 17 00:00:00 2001 From: amirakb89 Date: Mon, 10 Nov 2025 23:41:55 +0000 Subject: [PATCH 3/4] opt dispatch --- csrc/deep_ep.cpp | 8 +++++++- csrc/kernels/internode_ll.cu | 32 ++++++++++++++++++-------------- csrc/kernels/utils.cuh | 9 +++++++++ tests/test_low_latency.py | 8 ++++---- tests/utils.py | 2 +- 5 files changed, 39 insertions(+), 20 deletions(-) diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index aacd712..7472c0d 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -1098,6 +1098,10 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i // Kernel launch auto next_clean_meta = next_buffer.clean_meta(); + + if (not return_recv_hook) + stream_wait(compute_stream, launch_stream); + internode::barrier(); auto launcher = [=](int phases) { internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr, packed_recv_src_info.data_ptr(), packed_recv_layout_range.data_ptr(), @@ -1195,9 +1199,11 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id } else { combined_x = torch::empty({num_combined_tokens, hidden}, x.options()); } - // Kernel launch auto next_clean_meta = next_buffer.clean_meta(); + if (not return_recv_hook) + stream_wait(compute_stream, launch_stream); + internode::barrier(); auto launcher = [=](int phases) { internode_ll::combine(combined_x.data_ptr(), buffer.combine_rdma_recv_data_buffer, buffer.combine_rdma_recv_flag_buffer, diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index 9a806ac..d86b1fe 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -92,7 +92,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, const auto thread_id = static_cast(threadIdx.x); const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id(); const auto num_sms = static_cast(gridDim.x); - const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; + constexpr auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; const auto num_local_experts = num_experts / num_ranks; const auto warp_group_id = warp_id / kNumWarpsPerGroup; const auto sub_warp_id = warp_id % kNumWarpsPerGroup; @@ -148,7 +148,8 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(gpu_bfloat16_t); EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0); EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize % kNumPerChannels == 0, "Invalid vectorization"); - const auto num_threads = (num_warps - 1) * kWarpSize; + //const auto num_threads = (num_warps - 1) * kWarpSize; + constexpr int num_threads = kNumWarpGroups * kNumWarpsPerGroup * kWarpSize; const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead; for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) { @@ -248,7 +249,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, // NOTES: only 2 load iterations for 7K hidden with 8 unrolls const auto* src_int4_ptr = reinterpret_cast(src_ptr); const auto* dst_int4_ptr = reinterpret_cast(dst_ptr); - UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); + UNROLLED_WARP_COPY(4, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); } // Increase counter after finishing @@ -288,12 +289,12 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, } // Warp reduce - #pragma unroll + #pragma unroll 2 for (int i = expert_begin_idx; i < expert_end_idx; ++ i) { auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]); if (lane_id == 0) { shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum; - atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum); + atomic_add_relaxed_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum); } } } @@ -321,7 +322,9 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, if (dst_rank != rank) { #ifdef USE_ROCM #if defined(ROCM_DISABLE_CTX) - internode::shmem_long_atomic_add( rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank); + //internode::shmem_long_atomic_add( rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank); + rocshmem::rocshmem_long_p( rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank); + #else internode::shmem_ctx_long_atomic_add(ctx, rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank); #endif @@ -340,7 +343,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, if (dst_rank == 0) packed_recv_count[dst_expert_local_idx] = 0; } - syncwarp(); + //syncwarp(); // Receiving phase LOW_LATENCY_DISPATCH_RECV: @@ -371,8 +374,8 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, // NOTES: using sub-warp 1 to overlap with sub-warp 0 int num_recv_tokens, recv_token_begin_idx; EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group"); - if (sub_warp_id == 1 and lane_id == 0) { - while ((num_recv_tokens = ld_acquire_global(reinterpret_cast(rdma_recv_count + local_expert_idx * num_ranks + src_rank))) == 0); + if (sub_warp_id == 0 and lane_id == 0) { + while ((num_recv_tokens = ld_volatile_global(reinterpret_cast(rdma_recv_count + local_expert_idx * num_ranks + src_rank))) == 0); num_recv_tokens = -num_recv_tokens - 1; recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens); shared_num_recv_tokens[warp_group_id] = num_recv_tokens; @@ -408,7 +411,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, // NOTES: only 2 load iterations for 7K hidden with 7 unrolls const auto src_data = reinterpret_cast(reinterpret_cast(src_src_idx) + sizeof(int4)); const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4; - UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global); + UNROLLED_WARP_COPY(8, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global); // Copy scales if constexpr(kUseFP8) { @@ -439,14 +442,14 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, void* workspace, cudaStream_t stream, int phases) { #ifdef USE_ROCM - constexpr int kNumWarpsPerGroup = 5; + constexpr int kNumWarpsPerGroup = 8; constexpr int kNumWarpGroups = 2; #else constexpr int kNumWarpsPerGroup = 10; constexpr int kNumWarpGroups = 3; #endif constexpr int kNumMaxTopK = 9; - EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections"); + // EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections"); const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; const auto num_sms = cell_div(num_experts, kNumWarpGroups); @@ -616,7 +619,8 @@ combine(void* combined_x, if (dst_rank != rank) { #ifdef USE_ROCM #if defined(ROCM_DISABLE_CTX) - internode::shmem_long_atomic_add(rdma_recv_flag + global_expert_idx, 1, dst_rank); + //internode::shmem_long_atomic_add(rdma_recv_flag + global_expert_idx, 1, dst_rank); + rocshmem::rocshmem_long_p(rdma_recv_flag + global_expert_idx, 1, dst_rank); #else internode::shmem_ctx_long_atomic_add(ctx, rdma_recv_flag + global_expert_idx, 1, dst_rank); #endif @@ -643,7 +647,7 @@ combine(void* combined_x, while (ld_volatile_global(reinterpret_cast(rdma_recv_flag + responsible_expert_idx)) == 0); } } - grid_barrier(global_atomic_counter, num_sms, 0); + grid_barrier(global_atomic_counter, num_sms, 1); // Reduce tokens with FP8 cast // EP_DEVICE_ASSERT(num_topk <= kWarpSize and hidden_bf16_int4 <= num_threads); diff --git a/csrc/kernels/utils.cuh b/csrc/kernels/utils.cuh index da22ef4..10fd2dc 100644 --- a/csrc/kernels/utils.cuh +++ b/csrc/kernels/utils.cuh @@ -231,6 +231,15 @@ __device__ __forceinline__ int ld_acquire_global(const int *ptr) { #endif return ret; } +__device__ __forceinline__ int ld_acquire_global(const int64_t *ptr) { + int64_t ret; +#ifdef USE_ROCM + ret = __hip_atomic_load(ptr, __ATOMIC_ACQUIRE, __HIP_MEMORY_SCOPE_AGENT); +#else + asm volatile("ld.acquire.gpu.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); +#endif + return ret; +} //not used __device__ __forceinline__ int atomic_add_release_global(const int* ptr, int value) { diff --git a/tests/test_low_latency.py b/tests/test_low_latency.py index 527c51f..91dc09a 100644 --- a/tests/test_low_latency.py +++ b/tests/test_low_latency.py @@ -162,16 +162,16 @@ def test_loop(local_rank: int, num_local_ranks: int): num_qps_per_rank=num_experts // num_ranks) test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=1) - do_pressure_test = False - for seed in range(int(1e9) if do_pressure_test else 0): + do_pressure_test = True + for seed in range(int(10) if do_pressure_test else 0): if local_rank == 0: print(f'Testing with seed {seed} ...', flush=True) ref_hash = test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed) - for i in range(20): + for i in range(2): assert test_main(num_tokens, hidden, num_experts, num_topk, rank, num_ranks, group, buffer, seed=seed) == ref_hash, f'Error: seed={seed}' if __name__ == '__main__': # TODO: you may modify NUMA binding for less CPU overhead - num_processes = 1 + num_processes = 8 torch.multiprocessing.spawn(test_loop, args=(num_processes,), nprocs=num_processes) diff --git a/tests/utils.py b/tests/utils.py index 7665889..d980c5c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -143,7 +143,7 @@ def __exit__(self, *_): self.errnull_file.close() -def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False, +def bench_kineto(fn, kernel_names, num_tests: int = 100, suppress_kineto_output: bool = False, trace_path: Optional[str] = None, barrier_comm_profiling: bool = False): # Profile suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress From bd16740a3c826de18944ff6828e94411b45b2935 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 17 Nov 2025 21:28:07 +0000 Subject: [PATCH 4/4] Enable CTX for multinode --- csrc/deep_ep.cpp | 10 +- csrc/kernels/internode_ll.cu | 171 +++++++++++++++++------------------ csrc/kernels/utils.cuh | 11 ++- tests/test_low_latency.py | 8 +- 4 files changed, 97 insertions(+), 103 deletions(-) diff --git a/csrc/deep_ep.cpp b/csrc/deep_ep.cpp index 7472c0d..02a55cb 100644 --- a/csrc/deep_ep.cpp +++ b/csrc/deep_ep.cpp @@ -1098,10 +1098,6 @@ Buffer::low_latency_dispatch(const torch::Tensor& x, const torch::Tensor& topk_i // Kernel launch auto next_clean_meta = next_buffer.clean_meta(); - - if (not return_recv_hook) - stream_wait(compute_stream, launch_stream); - internode::barrier(); auto launcher = [=](int phases) { internode_ll::dispatch(packed_recv_x.data_ptr(), packed_recv_x_scales_ptr, packed_recv_src_info.data_ptr(), packed_recv_layout_range.data_ptr(), @@ -1199,11 +1195,9 @@ Buffer::low_latency_combine(const torch::Tensor& x, const torch::Tensor& topk_id } else { combined_x = torch::empty({num_combined_tokens, hidden}, x.options()); } + // Kernel launch auto next_clean_meta = next_buffer.clean_meta(); - if (not return_recv_hook) - stream_wait(compute_stream, launch_stream); - internode::barrier(); auto launcher = [=](int phases) { internode_ll::combine(combined_x.data_ptr(), buffer.combine_rdma_recv_data_buffer, buffer.combine_rdma_recv_flag_buffer, @@ -1318,4 +1312,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .def("low_latency_dispatch", &deep_ep::Buffer::low_latency_dispatch) .def("low_latency_combine", &deep_ep::Buffer::low_latency_combine) .def("get_next_low_latency_combine_buffer", &deep_ep::Buffer::get_next_low_latency_combine_buffer); -} +} \ No newline at end of file diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index d86b1fe..59a6c57 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -16,13 +16,11 @@ namespace deep_ep { namespace internode_ll { -__device__ void grid_barrier(int* global_counter, int num_blocks, int do_fence=1) { +__device__ void grid_barrier(int* global_counter, int num_blocks) { volatile int ret; __syncthreads(); if (threadIdx.x == 0 ) { - if (do_fence) __threadfence(); - ret = __hip_atomic_fetch_add( &global_counter[0], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); } @@ -75,7 +73,7 @@ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, clean_0, num_clean_int_0, clean_1, num_clean_int_1); } -template +template __global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * kWarpSize, 1) void dispatch(void* packed_recv_x, float* packed_recv_x_scales, int* packed_recv_src_info, int64_t* packed_recv_layout_range, @@ -100,7 +98,9 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, #if !defined(ROCM_DISABLE_CTX) __shared__ internode::shmem_ctx_t ctx; - internode::shmem_wg_ctx_create(&ctx); + if constexpr (multinode) { + EP_DEVICE_ASSERT(internode::shmem_wg_ctx_create(&ctx) == 0); + } #endif // FP8 staffs @@ -122,18 +122,6 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4); EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0); -//#ifdef USE_ROCM -// // 16 is the max possible number of warps in AMD GPUs -// constexpr int kMaxNumWarps = 1024 / kWarpSize; -// constexpr int num_sync_large_iteration = kMaxNumWarps ; -// __shared__ volatile uint8_t sync_large_warp_counters[num_sync_large_iteration]; -// -// #pragma unroll -// for (int i = thread_id; i < num_sync_large_iteration; i += blockDim.x) { -// sync_large_warp_counters[i] = 0; -// } -// __syncthreads(); -//#endif // Sending phase if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_DISPATCH_RECV; @@ -148,7 +136,6 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(gpu_bfloat16_t); EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0); EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize % kNumPerChannels == 0, "Invalid vectorization"); - //const auto num_threads = (num_warps - 1) * kWarpSize; constexpr int num_threads = kNumWarpGroups * kNumWarpsPerGroup * kWarpSize; const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead; @@ -239,7 +226,8 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, #if defined(ROCM_DISABLE_CTX) internode::shmemx_int8_put_nbi_warp( #else - internode::shmem_ctx_schar_put_nbi_warp(ctx, + if constexpr (multinode) + internode::shmem_ctx_schar_put_nbi_warp(ctx, #endif reinterpret_cast(dst_ptr), reinterpret_cast(src_ptr), num_bytes_per_msg, dst_rank); #else @@ -268,12 +256,6 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, #pragma unroll for (int i = lane_id; i < num_next_clean_int; i += kWarpSize) next_clean[i] = 0; - - // Notify before executing `int_p` - //syncwarp(); - //#pragma unroll - //for (int i = lane_id; i < num_experts; i += kWarpSize) - // atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG); } // This SM should be responsible for some destination experts, read `topk_idx` for them int expert_count[kNumWarpGroups] = {0}; @@ -299,16 +281,16 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, } } -#if 0 - if (thread_id == 0 and num_ranks > 8){ + if constexpr (multinode){ + if (thread_id == 0 ){ #if defined(ROCM_DISABLE_CTX) internode::shmem_fence(); #else - internode::shmem_ctx_quiet(ctx); + if constexpr (multinode) + internode::shmem_ctx_quiet(ctx); #endif + } } -#endif - //revert sync_large_warp_counters to 0 for next sync __syncthreads(); // Issue count sends @@ -322,11 +304,15 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, if (dst_rank != rank) { #ifdef USE_ROCM #if defined(ROCM_DISABLE_CTX) - //internode::shmem_long_atomic_add( rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank); - rocshmem::rocshmem_long_p( rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank); - + if constexpr (multinode){ // does CTX depend on multinode? + internode::shmem_long_atomic_add( rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank); + }else{ + rocshmem::rocshmem_long_p(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank); + } #else - internode::shmem_ctx_long_atomic_add(ctx, rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank); + if constexpr (multinode){ + internode::shmem_ctx_long_atomic_add(ctx, rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank); + } #endif #else //CUDA nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx); @@ -343,11 +329,14 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, if (dst_rank == 0) packed_recv_count[dst_expert_local_idx] = 0; } - //syncwarp(); // Receiving phase LOW_LATENCY_DISPATCH_RECV: if ((phases & LOW_LATENCY_RECV_PHASE) == 0) + if constexpr (multinode) + #if defined(USE_ROCM) && !defined(ROCM_DISABLE_CTX) && defined(multinode) + internode::shmem_wg_ctx_destroy(&ctx); + #endif return; // For send-and-recv kernels, we need a grid sync for making `packed_recv_count` visible @@ -375,7 +364,11 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, int num_recv_tokens, recv_token_begin_idx; EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group"); if (sub_warp_id == 0 and lane_id == 0) { - while ((num_recv_tokens = ld_volatile_global(reinterpret_cast(rdma_recv_count + local_expert_idx * num_ranks + src_rank))) == 0); + if constexpr (multinode){ + while ((num_recv_tokens = ld_acquire_global(reinterpret_cast(rdma_recv_count + local_expert_idx * num_ranks + src_rank))) == 0); + }else{ + while ((num_recv_tokens = ld_volatile_global(reinterpret_cast(rdma_recv_count + local_expert_idx * num_ranks + src_rank))) == 0); + } num_recv_tokens = -num_recv_tokens - 1; recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens); shared_num_recv_tokens[warp_group_id] = num_recv_tokens; @@ -383,15 +376,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, recv_range[src_rank] = pack2(num_recv_tokens, recv_token_begin_idx); } #ifdef USE_ROCM - // no needs to reset because there is no iteration - // if (lane_id == 0){ - // volatile int ret = __hip_atomic_fetch_add( - // &sync_large_warp_counters[warp_group_id], 1, - // __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP); - // } - // syncwarp(); - // while (sync_large_warp_counters[warp_group_id] < (kNumWarpsPerGroup)); - __syncthreads(); + __syncthreads(); #else asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(kNumWarpsPerGroup * 32)); #endif @@ -426,7 +411,9 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, } } #if !defined(ROCM_DISABLE_CTX) - internode::shmem_wg_ctx_destroy(&ctx); + if constexpr (multinode) { + internode::shmem_wg_ctx_destroy(&ctx); + } #endif } @@ -459,10 +446,17 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, auto atomic_counter_per_expert = reinterpret_cast(workspace); auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts; EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES); - + + bool multinode = (num_ranks > 8); #define DISPATCH_LAUNCH_CASE(hidden) { \ -auto dispatch_func = use_fp8 ? dispatch