-
Notifications
You must be signed in to change notification settings - Fork 148
Fix SM120 scaled-mm beta_ptr on device #830
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -154,8 +154,8 @@ typename Mxfp4GemmSm120::Gemm::Arguments args_from_options_mxp4_mxfp4( | |
| stride_D}}; | ||
| auto& fusion_args = arguments.epilogue.thread; | ||
| fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr()); | ||
| static const float beta_zero = 0.0f; | ||
| fusion_args.beta_ptr = &beta_zero; | ||
| // static const float beta_zero = 0.0f; | ||
| // fusion_args.beta_ptr = &beta_zero; | ||
| fusion_args.bias_ptr = static_cast<Mxfp4GemmSm120::Gemm::ElementC const*>(bias->data_ptr()); | ||
| fusion_args.dBias = StrideBias{}; | ||
| return arguments; | ||
|
|
@@ -180,8 +180,8 @@ typename Mxfp4GemmSm120::Gemm::Arguments args_from_options_mxp4_mxfp4( | |
| stride_D}}; | ||
| auto& fusion_args = arguments.epilogue.thread; | ||
| fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr()); | ||
| static const float beta_zero = 0.0f; | ||
| fusion_args.beta_ptr = &beta_zero; | ||
| // static const float beta_zero = 0.0f; | ||
| // fusion_args.beta_ptr = &beta_zero; | ||
|
Comment on lines
+183
to
+184
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| return arguments; | ||
| } | ||
| } | ||
|
|
@@ -202,6 +202,11 @@ void runGemmMxfp4Sm120( | |
| typename Mxfp4GemmSm120::Gemm gemm; | ||
|
|
||
| auto arguments = args_from_options_mxp4_mxfp4(D, A, B, A_sf, B_sf, alpha, bias, m, n, k); | ||
| auto beta_dev = torch::zeros({1}, torch::TensorOptions() | ||
| .dtype(torch::kFloat32) | ||
| .device(A.device())); | ||
| arguments.epilogue.thread.beta_ptr = | ||
| static_cast<float const*>(beta_dev.data_ptr()); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The To fix this, you should inform the PyTorch caching allocator that the memory is being used by the stream, so it won't be reclaimed until the stream operations are complete. You can do this using |
||
| size_t workspace_size = Mxfp4GemmSm120::Gemm::get_workspace_size(arguments); | ||
| auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); | ||
| auto workspace = torch::empty(workspace_size, workspace_options); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -154,8 +154,8 @@ typename Mxfp6Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp6_mxfp8( | |
| stride_D}}; | ||
| auto& fusion_args = arguments.epilogue.thread; | ||
| fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr()); | ||
| static const float beta_zero = 0.0f; | ||
| fusion_args.beta_ptr = &beta_zero; | ||
| // static const float beta_zero = 0.0f; | ||
| // fusion_args.beta_ptr = &beta_zero; | ||
|
Comment on lines
+157
to
+158
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| fusion_args.bias_ptr = static_cast<Mxfp6Mxfp8GemmSm120::Gemm::ElementC const*>(bias->data_ptr()); | ||
| fusion_args.dBias = StrideBias{}; | ||
| return arguments; | ||
|
|
@@ -180,8 +180,8 @@ typename Mxfp6Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp6_mxfp8( | |
| stride_D}}; | ||
| auto& fusion_args = arguments.epilogue.thread; | ||
| fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr()); | ||
| static const float beta_zero = 0.0f; | ||
| fusion_args.beta_ptr = &beta_zero; | ||
| // static const float beta_zero = 0.0f; | ||
| // fusion_args.beta_ptr = &beta_zero; | ||
|
Comment on lines
+183
to
+184
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| return arguments; | ||
| } | ||
| } | ||
|
|
@@ -202,6 +202,11 @@ void runGemmMxfp6Mxfp8Sm120( | |
| typename Mxfp6Mxfp8GemmSm120::Gemm gemm; | ||
|
|
||
| auto arguments = args_from_options_mxfp6_mxfp8(D, A, B, A_sf, B_sf, alpha, bias, m, n, k); | ||
| auto beta_dev = torch::zeros({1}, torch::TensorOptions() | ||
| .dtype(torch::kFloat32) | ||
| .device(A.device())); | ||
| arguments.epilogue.thread.beta_ptr = | ||
| static_cast<float const*>(beta_dev.data_ptr()); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The To fix this, you should inform the PyTorch caching allocator that the memory is being used by the stream, so it won't be reclaimed until the stream operations are complete. You can do this using |
||
| size_t workspace_size = Mxfp6Mxfp8GemmSm120::Gemm::get_workspace_size(arguments); | ||
| auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); | ||
| auto workspace = torch::empty(workspace_size, workspace_options); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -154,8 +154,8 @@ typename Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp8( | |
| stride_D}}; | ||
| auto& fusion_args = arguments.epilogue.thread; | ||
| fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr()); | ||
| static const float beta_zero = 0.0f; | ||
| fusion_args.beta_ptr = &beta_zero; | ||
| // static const float beta_zero = 0.0f; | ||
| // fusion_args.beta_ptr = &beta_zero; | ||
|
Comment on lines
+157
to
+158
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| fusion_args.bias_ptr = static_cast<Mxfp8GemmSm120::Gemm::ElementC const*>(bias->data_ptr()); | ||
| fusion_args.dBias = StrideBias{}; | ||
| return arguments; | ||
|
|
@@ -180,8 +180,8 @@ typename Mxfp8GemmSm120::Gemm::Arguments args_from_options_mxfp8( | |
| stride_D}}; | ||
| auto& fusion_args = arguments.epilogue.thread; | ||
| fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr()); | ||
| static const float beta_zero = 0.0f; | ||
| fusion_args.beta_ptr = &beta_zero; | ||
| // static const float beta_zero = 0.0f; | ||
| // fusion_args.beta_ptr = &beta_zero; | ||
|
Comment on lines
+183
to
+184
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| return arguments; | ||
| } | ||
| } | ||
|
|
@@ -202,6 +202,11 @@ void runGemmMxfp8Sm120( | |
| typename Mxfp8GemmSm120::Gemm gemm; | ||
|
|
||
| auto arguments = args_from_options_mxfp8(D, A, B, A_sf, B_sf, alpha, bias, m, n, k); | ||
| auto beta_dev = torch::zeros({1}, torch::TensorOptions() | ||
| .dtype(torch::kFloat32) | ||
| .device(A.device())); | ||
| arguments.epilogue.thread.beta_ptr = | ||
| static_cast<float const*>(beta_dev.data_ptr()); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The To fix this, you should inform the PyTorch caching allocator that the memory is being used by the stream, so it won't be reclaimed until the stream operations are complete. You can do this using |
||
| size_t workspace_size = Mxfp8GemmSm120::Gemm::get_workspace_size(arguments); | ||
| auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); | ||
| auto workspace = torch::empty(workspace_size, workspace_options); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -154,8 +154,8 @@ typename Fp4GemmSm120::Gemm::Arguments args_from_options_nvfp4_nvfp4( | |
| stride_D}}; | ||
| auto& fusion_args = arguments.epilogue.thread; | ||
| fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr()); | ||
| static const float beta_zero = 0.0f; | ||
| fusion_args.beta_ptr = &beta_zero; | ||
| // static const float beta_zero = 0.0f; | ||
| // fusion_args.beta_ptr = &beta_zero; | ||
|
Comment on lines
+157
to
+158
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| fusion_args.bias_ptr = static_cast<Fp4GemmSm120::Gemm::ElementC const*>(bias->data_ptr()); | ||
| fusion_args.dBias = StrideBias{}; | ||
| return arguments; | ||
|
|
@@ -180,8 +180,8 @@ typename Fp4GemmSm120::Gemm::Arguments args_from_options_nvfp4_nvfp4( | |
| stride_D}}; | ||
| auto& fusion_args = arguments.epilogue.thread; | ||
| fusion_args.alpha_ptr = static_cast<float const*>(alpha.data_ptr()); | ||
| static const float beta_zero = 0.0f; | ||
| fusion_args.beta_ptr = &beta_zero; | ||
| // static const float beta_zero = 0.0f; | ||
| // fusion_args.beta_ptr = &beta_zero; | ||
|
Comment on lines
+183
to
+184
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| return arguments; | ||
| } | ||
| } | ||
|
|
@@ -202,6 +202,11 @@ void runGemmNvfp4Sm120( | |
| typename Fp4GemmSm120::Gemm gemm; | ||
|
|
||
| auto arguments = args_from_options_nvfp4_nvfp4(D, A, B, A_sf, B_sf, alpha, bias, m, n, k); | ||
| auto beta_dev = torch::zeros({1}, torch::TensorOptions() | ||
| .dtype(torch::kFloat32) | ||
| .device(A.device())); | ||
| arguments.epilogue.thread.beta_ptr = | ||
| static_cast<float const*>(beta_dev.data_ptr()); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The To fix this, you should inform the PyTorch caching allocator that the memory is being used by the stream, so it won't be reclaimed until the stream operations are complete. You can do this using |
||
| size_t workspace_size = Fp4GemmSm120::Gemm::get_workspace_size(arguments); | ||
| auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(A.device()); | ||
| auto workspace = torch::empty(workspace_size, workspace_options); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This commented-out code is no longer needed and should be removed to improve code clarity.