diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu index 9cdc77d..f3cc227 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu @@ -8,7 +8,7 @@ #include #include // low latency+RocSHMEM has issue with CTX. -#define ROCM_DISABLE_CTX +//#define ROCM_DISABLE_CTX namespace cg = cooperative_groups; using namespace rocshmem; @@ -19,9 +19,9 @@ namespace internode_ll { __device__ void grid_barrier(int* global_counter, int num_blocks) { volatile int ret; __syncthreads(); - __threadfence(); - if (threadIdx.x == 0 ) { - ret = __hip_atomic_fetch_add( &global_counter[0], 1, + if (threadIdx.x == 0 ) { + __threadfence(); + ret = __hip_atomic_fetch_add( &global_counter[0], 1, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT); } __syncthreads(); @@ -73,7 +73,7 @@ void clean_low_latency_buffer(int64_t* clean_0, int num_clean_int_0, clean_0, num_clean_int_0, clean_1, num_clean_int_1); } -template +template __global__ __launch_bounds__(kNumWarpGroups * kNumWarpsPerGroup * kWarpSize, 1) void dispatch(void* packed_recv_x, float* packed_recv_x_scales, int* packed_recv_src_info, int64_t* packed_recv_layout_range, @@ -90,7 +90,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, const auto thread_id = static_cast(threadIdx.x); const auto warp_id = thread_id / kWarpSize, lane_id = get_lane_id(); const auto num_sms = static_cast(gridDim.x); - const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; + constexpr auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; const auto num_local_experts = num_experts / num_ranks; const auto warp_group_id = warp_id / kNumWarpsPerGroup; const auto sub_warp_id = warp_id % kNumWarpsPerGroup; @@ -98,7 +98,8 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, #if !defined(ROCM_DISABLE_CTX) __shared__ internode::shmem_ctx_t ctx; - internode::shmem_wg_ctx_create(&ctx); + if constexpr (multinode) + internode::shmem_wg_ctx_create(&ctx); #endif // FP8 staffs @@ -120,18 +121,6 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, const size_t num_int4_per_msg = num_bytes_per_msg / sizeof(int4); EP_DEVICE_ASSERT(num_bytes_per_msg % sizeof(int4) == 0); -#ifdef USE_ROCM - // 16 is the max possible number of warps in AMD GPUs - constexpr int kMaxNumWarps = 1024 / kWarpSize; - constexpr int num_sync_large_iteration = kMaxNumWarps ; - __shared__ volatile uint8_t sync_large_warp_counters[num_sync_large_iteration]; - - #pragma unroll - for (int i = thread_id; i < num_sync_large_iteration; i += blockDim.x) { - sync_large_warp_counters[i] = 0; - } - __syncthreads(); -#endif // Sending phase if ((phases & LOW_LATENCY_SEND_PHASE) == 0) goto LOW_LATENCY_DISPATCH_RECV; @@ -146,7 +135,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, constexpr int kNumElemsPerRead = sizeof(int4) / sizeof(gpu_bfloat16_t); EP_DEVICE_ASSERT(kHidden % kNumElemsPerRead == 0); EP_STATIC_ASSERT(kNumElemsPerRead * kWarpSize % kNumPerChannels == 0, "Invalid vectorization"); - const auto num_threads = (num_warps - 1) * kWarpSize; + constexpr int num_threads = kNumWarpGroups * kNumWarpsPerGroup * kWarpSize; const size_t hidden_bf16_int4 = kHidden / kNumElemsPerRead; for (int token_idx = sm_id; token_idx < num_tokens; token_idx += num_sms) { @@ -165,7 +154,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, // Read auto int4_value = __ldg(x_int4 + i); - if (kUseFP8) { + if constexpr(kUseFP8) { // Calculate local amax auto bf16_values = reinterpret_cast(&int4_value); float fp32_values[kNumElemsPerRead]; @@ -233,20 +222,24 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, if (dst_rank != rank) { #ifdef USE_ROCM + if constexpr (!multinode) { + internode::shmemx_int8_put_nbi_warp(reinterpret_cast(dst_ptr), reinterpret_cast(src_ptr), num_bytes_per_msg, dst_rank); + } else { + #if defined(ROCM_DISABLE_CTX) - internode::shmemx_int8_put_nbi_warp( -#else - internode::shmem_ctx_schar_put_nbi_warp(ctx, + internode::shmemx_int8_put_nbi_warp(reinterpret_cast(dst_ptr), reinterpret_cast(src_ptr), num_bytes_per_msg, dst_rank); +#else //DISABLE_CTX + internode::shmem_ctx_schar_put_nbi_warp(ctx, reinterpret_cast(dst_ptr), reinterpret_cast(src_ptr), num_bytes_per_msg, dst_rank); #endif - reinterpret_cast(dst_ptr), reinterpret_cast(src_ptr), num_bytes_per_msg, dst_rank); -#else + } +#else //USE_ROCM nvshmemi_ibgda_put_nbi_warp(dst_ptr, src_ptr, num_bytes_per_msg, dst_rank, dst_expert_local_idx, lane_id, slot_idx); #endif } else { // NOTES: only 2 load iterations for 7K hidden with 8 unrolls const auto* src_int4_ptr = reinterpret_cast(src_ptr); const auto* dst_int4_ptr = reinterpret_cast(dst_ptr); - UNROLLED_WARP_COPY(8, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); + UNROLLED_WARP_COPY(4, lane_id, num_int4_per_msg, dst_int4_ptr, src_int4_ptr, ld_nc_global, st_na_global); } // Increase counter after finishing @@ -265,12 +258,12 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, #pragma unroll for (int i = lane_id; i < num_next_clean_int; i += kWarpSize) next_clean[i] = 0; - - // Notify before executing `int_p` - syncwarp(); - #pragma unroll - for (int i = lane_id; i < num_experts; i += kWarpSize) - atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG); + if constexpr (multinode){ + syncwarp(); + #pragma unroll + for (int i = lane_id; i < num_experts; i += kWarpSize) + atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG); + } } // This SM should be responsible for some destination experts, read `topk_idx` for them int expert_count[kNumWarpGroups] = {0}; @@ -278,7 +271,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, const auto expert_end_idx = min(expert_begin_idx + kNumWarpGroups, num_experts); // Per lane count - #pragma unroll 8 + #pragma unroll 2 for (int i = lane_id; i < num_tokens * num_topk; i += kWarpSize) { auto idx = static_cast(__ldg(topk_idx + i)); if (idx >= expert_begin_idx and idx < expert_end_idx) @@ -286,24 +279,29 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, } // Warp reduce - #pragma unroll + #pragma unroll 2 for (int i = expert_begin_idx; i < expert_end_idx; ++ i) { auto sum = warp_reduce_sum(expert_count[i - expert_begin_idx]); if (lane_id == 0) { shared_num_tokens_sent_per_expert[i - expert_begin_idx] = sum; - atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum); + if constexpr (multinode){ + atomic_add_release_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum); + }else{ + atomic_add_relaxed_global(atomic_finish_counter_per_expert + i, FINISHED_SUM_TAG - sum); + } } } } - if (thread_id == 0 and num_ranks > 8){ -#if defined(ROCM_DISABLE_CTX) - internode::shmem_fence(); -#else - internode::shmem_ctx_quiet(ctx); -#endif - } - //revert sync_large_warp_counters to 0 for next sync +// if constexpr (multinode){ +// if (thread_id == 0 ){ +//#if defined(ROCM_DISABLE_CTX) +// internode::shmem_fence(); +//#else +// internode::shmem_ctx_quiet(ctx); +//#endif +// } +// } __syncthreads(); // Issue count sends @@ -313,21 +311,36 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, const auto num_tokens_sent = shared_num_tokens_sent_per_expert[responsible_expert_idx - sm_id * kNumWarpGroups]; // Wait local sends issued and send expert counts - while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2); + if constexpr(multinode){ + while (ld_acquire_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG * 2); + }else{ + while (ld_volatile_global(atomic_finish_counter_per_expert + responsible_expert_idx) != FINISHED_SUM_TAG); + } if (dst_rank != rank) { #ifdef USE_ROCM + if constexpr (!multinode){ + rocshmem::rocshmem_long_p(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank); + }else{ #if defined(ROCM_DISABLE_CTX) - internode::shmem_long_atomic_add( rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank); + internode::shmem_long_atomic_add( rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank); #else - internode::shmem_ctx_long_atomic_add(ctx, rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank); + internode::shmem_ctx_long_atomic_add(ctx, rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank); #endif + } #else //CUDA nvshmemi_ibgda_amo_nonfetch_add(rdma_recv_count + dst_expert_local_idx * num_ranks + rank, -num_tokens_sent - 1, dst_rank, dst_expert_local_idx); #endif } else { st_na_release(reinterpret_cast(rdma_recv_count + dst_expert_local_idx * num_ranks + rank), -num_tokens_sent - 1); } - + if constexpr (multinode){ +#if defined(ROCM_DISABLE_CTX) + internode::shmem_fence(); +#else + internode::shmem_ctx_quiet(ctx); +#endif + } + // Clean workspace for next use atomic_counter_per_expert[responsible_expert_idx] = 0; atomic_finish_counter_per_expert[responsible_expert_idx] = 0; @@ -336,13 +349,17 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, if (dst_rank == 0) packed_recv_count[dst_expert_local_idx] = 0; } - syncwarp(); + //syncwarp(); // Receiving phase LOW_LATENCY_DISPATCH_RECV: - if ((phases & LOW_LATENCY_RECV_PHASE) == 0) + if ((phases & LOW_LATENCY_RECV_PHASE) == 0){ +#if !defined(ROCM_DISABLE_CTX) + if constexpr (multinode) + internode::shmem_wg_ctx_destroy(&ctx); +#endif return; - + } // For send-and-recv kernels, we need a grid sync for making `packed_recv_count` visible if (phases & LOW_LATENCY_SEND_PHASE){ grid_barrier(global_atomic_counter, num_sms); @@ -367,8 +384,21 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, // NOTES: using sub-warp 1 to overlap with sub-warp 0 int num_recv_tokens, recv_token_begin_idx; EP_STATIC_ASSERT(kNumWarpsPerGroup > 1, "Requires more than one warp per group"); - if (sub_warp_id == 1 and lane_id == 0) { - while ((num_recv_tokens = ld_acquire_global(reinterpret_cast(rdma_recv_count + local_expert_idx * num_ranks + src_rank))) == 0); + if (sub_warp_id == 0 and lane_id == 0) { + auto start_time = clock64(); + if constexpr (multinode){ + while ((num_recv_tokens = ld_acquire_global(reinterpret_cast(rdma_recv_count + local_expert_idx * num_ranks + src_rank))) == 0){ + if ((clock64() - start_time) >= NUM_TIMEOUT_CYCLES){ + printf("dispatch recieve time out \n"); + } + } + }else{ + while ((num_recv_tokens = ld_volatile_global(reinterpret_cast(rdma_recv_count + local_expert_idx * num_ranks + src_rank))) == 0){ + if ((clock64() - start_time) >= NUM_TIMEOUT_CYCLES){ + printf("dispatch recieve single node time out \n"); + } + } + } num_recv_tokens = -num_recv_tokens - 1; recv_token_begin_idx = atomicAdd(packed_recv_count + local_expert_idx, num_recv_tokens); shared_num_recv_tokens[warp_group_id] = num_recv_tokens; @@ -376,14 +406,7 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, recv_range[src_rank] = pack2(num_recv_tokens, recv_token_begin_idx); } #ifdef USE_ROCM - // no needs to reset because there is no iteration - if (lane_id == 0){ - volatile int ret = __hip_atomic_fetch_add( - &sync_large_warp_counters[warp_group_id], 1, - __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP); - } - syncwarp(); - while (sync_large_warp_counters[warp_group_id] < (kNumWarpsPerGroup)); + __syncthreads(); #else asm volatile("bar.sync %0, %1;" :: "r"(warp_group_id + 2), "r"(kNumWarpsPerGroup * 32)); #endif @@ -403,10 +426,10 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, // NOTES: only 2 load iterations for 7K hidden with 7 unrolls const auto src_data = reinterpret_cast(reinterpret_cast(src_src_idx) + sizeof(int4)); const auto dst_data = recv_x_int4 + (recv_token_begin_idx + i) * hidden_int4; - UNROLLED_WARP_COPY(7, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global); + UNROLLED_WARP_COPY(8, lane_id, hidden_int4, dst_data, src_data, ld_nc_global, st_na_global); // Copy scales - if (kUseFP8) { + if constexpr(kUseFP8) { const auto src_scales = reinterpret_cast(reinterpret_cast(src_data) + hidden_bytes); const auto dst_scales = reinterpret_cast(recv_x_scales + recv_token_begin_idx + i); const auto scale_stride = num_ranks * num_max_dispatch_tokens_per_rank; @@ -418,7 +441,8 @@ dispatch(void* packed_recv_x, float* packed_recv_x_scales, } } #if !defined(ROCM_DISABLE_CTX) - internode::shmem_wg_ctx_destroy(&ctx); + if constexpr (multinode) + internode::shmem_wg_ctx_destroy(&ctx); #endif } @@ -434,14 +458,14 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, void* workspace, cudaStream_t stream, int phases) { #ifdef USE_ROCM - constexpr int kNumWarpsPerGroup = 5; + constexpr int kNumWarpsPerGroup = 8; constexpr int kNumWarpGroups = 2; #else constexpr int kNumWarpsPerGroup = 10; constexpr int kNumWarpGroups = 3; #endif constexpr int kNumMaxTopK = 9; - EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections"); + // EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections"); const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; const auto num_sms = cell_div(num_experts, kNumWarpGroups); @@ -451,10 +475,17 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, auto atomic_counter_per_expert = reinterpret_cast(workspace); auto atomic_finish_counter_per_expert = atomic_counter_per_expert + num_experts; EP_HOST_ASSERT(num_experts * sizeof(int) * 2 <= NUM_WORKSPACE_BYTES); - + + bool multinode = (num_ranks > 8); #define DISPATCH_LAUNCH_CASE(hidden) { \ -auto dispatch_func = use_fp8 ? dispatch