Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions lightx2v_kernel/csrc/gemm/mxfp4_scaled_mm_kernels_sm120.cu
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ typename Mxfp4GemmSm120::Gemm::Arguments args_from_options_mxp4_mxfp4(
stride_D}};
auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr());
static const float beta_zero = 0.0f;
fusion_args.beta_ptr = &beta_zero;
// static const float beta_zero = 0.0f;
// fusion_args.beta_ptr = &beta_zero;
Comment on lines +157 to +158
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This commented-out code is no longer needed and should be removed to improve code clarity.

fusion_args.bias_ptr = static_cast<Mxfp4GemmSm120::Gemm::ElementC const*>(bias->data_ptr());
fusion_args.dBias = StrideBias{};
return arguments;
Expand All @@ -180,8 +180,8 @@ typename Mxfp4GemmSm120::Gemm::Arguments args_from_options_mxp4_mxfp4(
stride_D}};
auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr());
static const float beta_zero = 0.0f;
fusion_args.beta_ptr = &beta_zero;
// static const float beta_zero = 0.0f;
// fusion_args.beta_ptr = &beta_zero;
Comment on lines +183 to +184
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This commented-out code is no longer needed and should be removed to improve code clarity.

return arguments;
}
}
Expand All @@ -202,6 +202,11 @@ void runGemmMxfp4Sm120(
typename Mxfp4GemmSm120::Gemm gemm;

auto arguments = args_from_options_mxp4_mxfp4(D, A, B, A_sf, B_sf, alpha, bias, m, n, k);
auto beta_dev = torch::zeros({1}, torch::TensorOptions()
.dtype(torch::kFloat32)
.device(A.device()));
arguments.epilogue.thread.beta_ptr =
static_cast<float const*>(beta_dev.data_ptr());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The beta_dev tensor is a local variable, and its lifetime is tied to this function. When the function returns, beta_dev is destroyed, and its underlying memory can be freed. However, the gemm.run call launches a kernel asynchronously. This creates a race condition where the kernel might access freed memory, leading to a critical use-after-free bug.

To fix this, you should inform the PyTorch caching allocator that the memory is being used by the stream, so it won't be reclaimed until the stream operations are complete. You can do this using c10::cuda::CUDACachingAllocator::recordStream. You will also need to include <c10/cuda/CUDACachingAllocator.h>.

      static_cast<float const*>(beta_dev.data_ptr());
  c10::cuda::CUDACachingAllocator::recordStream(beta_dev.storage().data_ptr(), stream);

size_t workspace_size = Mxfp4GemmSm120::Gemm::get_workspace_size(arguments);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
auto workspace = torch::empty(workspace_size, workspace_options);
Expand Down
13 changes: 9 additions & 4 deletions lightx2v_kernel/csrc/gemm/mxfp6_mxfp8_scaled_mm_kernels_sm120.cu
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ typename Mxfp6Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp6_mxfp8(
stride_D}};
auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr());
static const float beta_zero = 0.0f;
fusion_args.beta_ptr = &beta_zero;
// static const float beta_zero = 0.0f;
// fusion_args.beta_ptr = &beta_zero;
Comment on lines +157 to +158
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This commented-out code is no longer needed and should be removed to improve code clarity.

fusion_args.bias_ptr = static_cast<Mxfp6Mxfp8GemmSm120::Gemm::ElementC const*>(bias->data_ptr());
fusion_args.dBias = StrideBias{};
return arguments;
Expand All @@ -180,8 +180,8 @@ typename Mxfp6Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp6_mxfp8(
stride_D}};
auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr());
static const float beta_zero = 0.0f;
fusion_args.beta_ptr = &beta_zero;
// static const float beta_zero = 0.0f;
// fusion_args.beta_ptr = &beta_zero;
Comment on lines +183 to +184
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This commented-out code is no longer needed and should be removed to improve code clarity.

return arguments;
}
}
Expand All @@ -202,6 +202,11 @@ void runGemmMxfp6Mxfp8Sm120(
typename Mxfp6Mxfp8GemmSm120::Gemm gemm;

auto arguments = args_from_options_mxfp6_mxfp8(D, A, B, A_sf, B_sf, alpha, bias, m, n, k);
auto beta_dev = torch::zeros({1}, torch::TensorOptions()
.dtype(torch::kFloat32)
.device(A.device()));
arguments.epilogue.thread.beta_ptr =
static_cast<float const*>(beta_dev.data_ptr());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The beta_dev tensor is a local variable, and its lifetime is tied to this function. When the function returns, beta_dev is destroyed, and its underlying memory can be freed. However, the gemm.run call launches a kernel asynchronously. This creates a race condition where the kernel might access freed memory, leading to a critical use-after-free bug.

To fix this, you should inform the PyTorch caching allocator that the memory is being used by the stream, so it won't be reclaimed until the stream operations are complete. You can do this using c10::cuda::CUDACachingAllocator::recordStream. You will also need to include <c10/cuda/CUDACachingAllocator.h>.

      static_cast<float const*>(beta_dev.data_ptr());
  c10::cuda::CUDACachingAllocator::recordStream(beta_dev.storage().data_ptr(), stream);

size_t workspace_size = Mxfp6Mxfp8GemmSm120::Gemm::get_workspace_size(arguments);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
auto workspace = torch::empty(workspace_size, workspace_options);
Expand Down
13 changes: 9 additions & 4 deletions lightx2v_kernel/csrc/gemm/mxfp8_scaled_mm_kernels_sm120.cu
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ typename Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp8(
stride_D}};
auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr());
static const float beta_zero = 0.0f;
fusion_args.beta_ptr = &beta_zero;
// static const float beta_zero = 0.0f;
// fusion_args.beta_ptr = &beta_zero;
Comment on lines +157 to +158
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This commented-out code is no longer needed and should be removed to improve code clarity.

fusion_args.bias_ptr = static_cast<Mxfp8GemmSm120::Gemm::ElementC const*>(bias->data_ptr());
fusion_args.dBias = StrideBias{};
return arguments;
Expand All @@ -180,8 +180,8 @@ typename Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp8(
stride_D}};
auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr());
static const float beta_zero = 0.0f;
fusion_args.beta_ptr = &beta_zero;
// static const float beta_zero = 0.0f;
// fusion_args.beta_ptr = &beta_zero;
Comment on lines +183 to +184
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This commented-out code is no longer needed and should be removed to improve code clarity.

return arguments;
}
}
Expand All @@ -202,6 +202,11 @@ void runGemmMxfp8Sm120(
typename Mxfp8GemmSm120::Gemm gemm;

auto arguments = args_from_options_mxfp8(D, A, B, A_sf, B_sf, alpha, bias, m, n, k);
auto beta_dev = torch::zeros({1}, torch::TensorOptions()
.dtype(torch::kFloat32)
.device(A.device()));
arguments.epilogue.thread.beta_ptr =
static_cast<float const*>(beta_dev.data_ptr());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The beta_dev tensor is a local variable, and its lifetime is tied to this function. When the function returns, beta_dev is destroyed, and its underlying memory can be freed. However, the gemm.run call launches a kernel asynchronously. This creates a race condition where the kernel might access freed memory, leading to a critical use-after-free bug.

To fix this, you should inform the PyTorch caching allocator that the memory is being used by the stream, so it won't be reclaimed until the stream operations are complete. You can do this using c10::cuda::CUDACachingAllocator::recordStream. You will also need to include <c10/cuda/CUDACachingAllocator.h>.

      static_cast<float const*>(beta_dev.data_ptr());
  c10::cuda::CUDACachingAllocator::recordStream(beta_dev.storage().data_ptr(), stream);

size_t workspace_size = Mxfp8GemmSm120::Gemm::get_workspace_size(arguments);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
auto workspace = torch::empty(workspace_size, workspace_options);
Expand Down
13 changes: 9 additions & 4 deletions lightx2v_kernel/csrc/gemm/nvfp4_scaled_mm_kernels_sm120.cu
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ typename Fp4GemmSm120::Gemm::Arguments args_from_options_nvfp4_nvfp4(
stride_D}};
auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr());
static const float beta_zero = 0.0f;
fusion_args.beta_ptr = &beta_zero;
// static const float beta_zero = 0.0f;
// fusion_args.beta_ptr = &beta_zero;
Comment on lines +157 to +158
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This commented-out code is no longer needed and should be removed to improve code clarity.

fusion_args.bias_ptr = static_cast<Fp4GemmSm120::Gemm::ElementC const*>(bias->data_ptr());
fusion_args.dBias = StrideBias{};
return arguments;
Expand All @@ -180,8 +180,8 @@ typename Fp4GemmSm120::Gemm::Arguments args_from_options_nvfp4_nvfp4(
stride_D}};
auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr());
static const float beta_zero = 0.0f;
fusion_args.beta_ptr = &beta_zero;
// static const float beta_zero = 0.0f;
// fusion_args.beta_ptr = &beta_zero;
Comment on lines +183 to +184
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This commented-out code is no longer needed and should be removed to improve code clarity.

return arguments;
}
}
Expand All @@ -202,6 +202,11 @@ void runGemmNvfp4Sm120(
typename Fp4GemmSm120::Gemm gemm;

auto arguments = args_from_options_nvfp4_nvfp4(D, A, B, A_sf, B_sf, alpha, bias, m, n, k);
auto beta_dev = torch::zeros({1}, torch::TensorOptions()
.dtype(torch::kFloat32)
.device(A.device()));
arguments.epilogue.thread.beta_ptr =
static_cast<float const*>(beta_dev.data_ptr());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The beta_dev tensor is a local variable, and its lifetime is tied to this function. When the function returns, beta_dev is destroyed, and its underlying memory can be freed. However, the gemm.run call launches a kernel asynchronously. This creates a race condition where the kernel might access freed memory, leading to a critical use-after-free bug.

To fix this, you should inform the PyTorch caching allocator that the memory is being used by the stream, so it won't be reclaimed until the stream operations are complete. You can do this using c10::cuda::CUDACachingAllocator::recordStream. You will also need to include <c10/cuda/CUDACachingAllocator.h>.

      static_cast<float const*>(beta_dev.data_ptr());
  c10::cuda::CUDACachingAllocator::recordStream(beta_dev.storage().data_ptr(), stream);

size_t workspace_size = Fp4GemmSm120::Gemm::get_workspace_size(arguments);
auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
auto workspace = torch::empty(workspace_size, workspace_options);
Expand Down