From c67de430a784eef14b004548e30ec2b1b5442a71 Mon Sep 17 00:00:00 2001 From: wrx Date: Tue, 27 Jan 2026 17:10:53 +0800 Subject: [PATCH] Fix SM120 scaled-mm beta_ptr on device --- .../csrc/gemm/mxfp4_scaled_mm_kernels_sm120.cu | 13 +++++++++---- .../gemm/mxfp6_mxfp8_scaled_mm_kernels_sm120.cu | 13 +++++++++---- .../csrc/gemm/mxfp8_scaled_mm_kernels_sm120.cu | 13 +++++++++---- .../csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu | 13 +++++++++---- 4 files changed, 36 insertions(+), 16 deletions(-) diff --git a/lightx2v_kernel/csrc/gemm/mxfp4_scaled_mm_kernels_sm120.cu b/lightx2v_kernel/csrc/gemm/mxfp4_scaled_mm_kernels_sm120.cu index 18d03912a..87ab31f08 100644 --- a/lightx2v_kernel/csrc/gemm/mxfp4_scaled_mm_kernels_sm120.cu +++ b/lightx2v_kernel/csrc/gemm/mxfp4_scaled_mm_kernels_sm120.cu @@ -154,8 +154,8 @@ typename Mxfp4GemmSm120::Gemm::Arguments args_from_options_mxp4_mxfp4( stride_D}}; auto& fusion_args = arguments.epilogue.thread; fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); - static const float beta_zero = 0.0f; - fusion_args.beta_ptr = &beta_zero; + // static const float beta_zero = 0.0f; + // fusion_args.beta_ptr = &beta_zero; fusion_args.bias_ptr = static_cast(bias->data_ptr()); fusion_args.dBias = StrideBias{}; return arguments; @@ -180,8 +180,8 @@ typename Mxfp4GemmSm120::Gemm::Arguments args_from_options_mxp4_mxfp4( stride_D}}; auto& fusion_args = arguments.epilogue.thread; fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); - static const float beta_zero = 0.0f; - fusion_args.beta_ptr = &beta_zero; + // static const float beta_zero = 0.0f; + // fusion_args.beta_ptr = &beta_zero; return arguments; } } @@ -202,6 +202,11 @@ void runGemmMxfp4Sm120( typename Mxfp4GemmSm120::Gemm gemm; auto arguments = args_from_options_mxp4_mxfp4(D, A, B, A_sf, B_sf, alpha, bias, m, n, k); + auto beta_dev = torch::zeros({1}, torch::TensorOptions() + .dtype(torch::kFloat32) + .device(A.device())); + arguments.epilogue.thread.beta_ptr = + static_cast(beta_dev.data_ptr()); size_t workspace_size = Mxfp4GemmSm120::Gemm::get_workspace_size(arguments); auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); auto workspace = torch::empty(workspace_size, workspace_options); diff --git a/lightx2v_kernel/csrc/gemm/mxfp6_mxfp8_scaled_mm_kernels_sm120.cu b/lightx2v_kernel/csrc/gemm/mxfp6_mxfp8_scaled_mm_kernels_sm120.cu index 7b2e178ec..9041f3583 100644 --- a/lightx2v_kernel/csrc/gemm/mxfp6_mxfp8_scaled_mm_kernels_sm120.cu +++ b/lightx2v_kernel/csrc/gemm/mxfp6_mxfp8_scaled_mm_kernels_sm120.cu @@ -154,8 +154,8 @@ typename Mxfp6Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp6_mxfp8( stride_D}}; auto& fusion_args = arguments.epilogue.thread; fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); - static const float beta_zero = 0.0f; - fusion_args.beta_ptr = &beta_zero; + // static const float beta_zero = 0.0f; + // fusion_args.beta_ptr = &beta_zero; fusion_args.bias_ptr = static_cast(bias->data_ptr()); fusion_args.dBias = StrideBias{}; return arguments; @@ -180,8 +180,8 @@ typename Mxfp6Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp6_mxfp8( stride_D}}; auto& fusion_args = arguments.epilogue.thread; fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); - static const float beta_zero = 0.0f; - fusion_args.beta_ptr = &beta_zero; + // static const float beta_zero = 0.0f; + // fusion_args.beta_ptr = &beta_zero; return arguments; } } @@ -202,6 +202,11 @@ void runGemmMxfp6Mxfp8Sm120( typename Mxfp6Mxfp8GemmSm120::Gemm gemm; auto arguments = args_from_options_mxfp6_mxfp8(D, A, B, A_sf, B_sf, alpha, bias, m, n, k); + auto beta_dev = torch::zeros({1}, torch::TensorOptions() + .dtype(torch::kFloat32) + .device(A.device())); + arguments.epilogue.thread.beta_ptr = + static_cast(beta_dev.data_ptr()); size_t workspace_size = Mxfp6Mxfp8GemmSm120::Gemm::get_workspace_size(arguments); auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); auto workspace = torch::empty(workspace_size, workspace_options); diff --git a/lightx2v_kernel/csrc/gemm/mxfp8_scaled_mm_kernels_sm120.cu b/lightx2v_kernel/csrc/gemm/mxfp8_scaled_mm_kernels_sm120.cu index f3a1558b6..8414295d0 100644 --- a/lightx2v_kernel/csrc/gemm/mxfp8_scaled_mm_kernels_sm120.cu +++ b/lightx2v_kernel/csrc/gemm/mxfp8_scaled_mm_kernels_sm120.cu @@ -154,8 +154,8 @@ typename Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp8( stride_D}}; auto& fusion_args = arguments.epilogue.thread; fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); - static const float beta_zero = 0.0f; - fusion_args.beta_ptr = &beta_zero; + // static const float beta_zero = 0.0f; + // fusion_args.beta_ptr = &beta_zero; fusion_args.bias_ptr = static_cast(bias->data_ptr()); fusion_args.dBias = StrideBias{}; return arguments; @@ -180,8 +180,8 @@ typename Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp8( stride_D}}; auto& fusion_args = arguments.epilogue.thread; fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); - static const float beta_zero = 0.0f; - fusion_args.beta_ptr = &beta_zero; + // static const float beta_zero = 0.0f; + // fusion_args.beta_ptr = &beta_zero; return arguments; } } @@ -202,6 +202,11 @@ void runGemmMxfp8Sm120( typename Mxfp8GemmSm120::Gemm gemm; auto arguments = args_from_options_mxfp8(D, A, B, A_sf, B_sf, alpha, bias, m, n, k); + auto beta_dev = torch::zeros({1}, torch::TensorOptions() + .dtype(torch::kFloat32) + .device(A.device())); + arguments.epilogue.thread.beta_ptr = + static_cast(beta_dev.data_ptr()); size_t workspace_size = Mxfp8GemmSm120::Gemm::get_workspace_size(arguments); auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); auto workspace = torch::empty(workspace_size, workspace_options); diff --git a/lightx2v_kernel/csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu b/lightx2v_kernel/csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu index 8dd2838a2..12d1adebd 100644 --- a/lightx2v_kernel/csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu +++ b/lightx2v_kernel/csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu @@ -154,8 +154,8 @@ typename Fp4GemmSm120::Gemm::Arguments args_from_options_nvfp4_nvfp4( stride_D}}; auto& fusion_args = arguments.epilogue.thread; fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); - static const float beta_zero = 0.0f; - fusion_args.beta_ptr = &beta_zero; + // static const float beta_zero = 0.0f; + // fusion_args.beta_ptr = &beta_zero; fusion_args.bias_ptr = static_cast(bias->data_ptr()); fusion_args.dBias = StrideBias{}; return arguments; @@ -180,8 +180,8 @@ typename Fp4GemmSm120::Gemm::Arguments args_from_options_nvfp4_nvfp4( stride_D}}; auto& fusion_args = arguments.epilogue.thread; fusion_args.alpha_ptr = static_cast(alpha.data_ptr()); - static const float beta_zero = 0.0f; - fusion_args.beta_ptr = &beta_zero; + // static const float beta_zero = 0.0f; + // fusion_args.beta_ptr = &beta_zero; return arguments; } } @@ -202,6 +202,11 @@ void runGemmNvfp4Sm120( typename Fp4GemmSm120::Gemm gemm; auto arguments = args_from_options_nvfp4_nvfp4(D, A, B, A_sf, B_sf, alpha, bias, m, n, k); + auto beta_dev = torch::zeros({1}, torch::TensorOptions() + .dtype(torch::kFloat32) + .device(A.device())); + arguments.epilogue.thread.beta_ptr = + static_cast(beta_dev.data_ptr()); size_t workspace_size = Fp4GemmSm120::Gemm::get_workspace_size(arguments); auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); auto workspace = torch::empty(workspace_size, workspace_options);