From 0aa16f34e89247ae34041b7505a00b20a31b21eb Mon Sep 17 00:00:00 2001 From: qiyuw Date: Tue, 17 Feb 2026 11:46:54 -0800 Subject: [PATCH 1/8] fix memory overheads --- .../distributed/distributed_data_parallel.py | 8 ++++++- .../core/distributed/param_and_grad_buffer.py | 23 +++++++++++++++++-- 2 files changed, 28 insertions(+), 3 deletions(-) mode change 100644 => 100755 megatron/core/distributed/distributed_data_parallel.py mode change 100644 => 100755 megatron/core/distributed/param_and_grad_buffer.py diff --git a/megatron/core/distributed/distributed_data_parallel.py b/megatron/core/distributed/distributed_data_parallel.py old mode 100644 new mode 100755 index 55179ff3024..89b65741e05 --- a/megatron/core/distributed/distributed_data_parallel.py +++ b/megatron/core/distributed/distributed_data_parallel.py @@ -7,7 +7,7 @@ import torch from ..config_logger import has_config_logger_enabled, log_config_to_disk -from ..fp8_utils import is_float8tensor, post_all_gather_processing +from ..fp8_utils import is_float8tensor, is_mxfp8tensor, post_all_gather_processing from ..process_groups_config import ProcessGroupCollection from ..transformer.cuda_graphs import is_graph_capturing from ..transformer.transformer_config import TransformerConfig @@ -488,9 +488,13 @@ def start_param_sync(self, *unused, force_sync: bool = False, force_dispatch: bo # The paramaters are cast from bf16 to MXFP8 during copy. # In the case of "overlap_param_gather=True", the param copy is done # in "finish_param_sync" stage after zeroing the shared gardient buffers. + is_bf16_weight_group = False if self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag: for bucket in bucket_group.buckets: for param in bucket.params: + if not is_mxfp8tensor(param) and not is_float8tensor(param): + is_bf16_weight_group = True + break param_start, param_end = bucket.param_to_index[param] param_slice = bucket.param_data.view(-1)[param_start:param_end] param.data.copy_(param_slice.view(param.data.shape)) @@ -500,6 +504,8 @@ def start_param_sync(self, *unused, force_sync: bool = False, force_dispatch: bo # buffer may correspond to multiple param buffers. If we zero out the entire # grad buffer, it would clear the data of those param buffers that have not # yet completed AG. + if is_bf16_weight_group: + break bucket.param_data.zero_() else: fp8_params = [] diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py old mode 100644 new mode 100755 index 088374fbf13..301046506d3 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -343,9 +343,14 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): # For the mxfp8_param with "reuse_grad_buf_for_mxfp8_param_ag=True", # we need to copy the param_data from the shared_param/grad_buffer to param.data # after the param all-gather. + is_bf16_weight_group = False if self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag: for bucket in self.buckets: for param in bucket.params: + # bf16 weights are already mapped to param.data + if not is_mxfp8tensor(param) and not is_float8tensor(param): + is_bf16_weight_group = True + break param_start, param_end = bucket.param_to_index[param] param_slice = bucket.param_data.view(-1)[param_start:param_end] param.data.copy_(param_slice.view(param.data.shape)) @@ -354,6 +359,8 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): # We cannot zero out the entire grad buffer because one grad buffer may # correspond to multiple param buffers. If we zero out the entire grad buffer, # it would clear the data of those param buffers that have not yet completed AG. + if is_bf16_weight_group: + break bucket.param_data.zero_() else: fp8_params = [] @@ -820,8 +827,20 @@ def _does_param_require_new_bucket(param): cur_bucket_id = 0 for param in params[::-1]: param_start_index, param_end_index, bucket_id = self.param_index_map[param] - # For MXFP8 param: we only need to map weight gradients to the buffer. - if not self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag: + # For MXFP8 param: we only need to map bf16 weights (layernorm, embedding, etc) to the buffer. + if self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag: + if not is_mxfp8tensor(param) and not is_float8tensor(param): + if self.param_data is not None: + new_param_data = self._get( + param.data.shape, param_start_index, buffer_type=BufferType.PARAM + ) + old_param_data = param.data + param.data = new_param_data + assert old_param_data._base is None + # Copy tensor values (from initialization or checkpoint). + param.data.detach().copy_(old_param_data) + del old_param_data + else: # Assign param.data to appropriate segment of self.param_data. if self.param_data is not None: new_param_data = self._get( From 57bcc6a50ec62425b1d8fbd51566a26d8ac1396a Mon Sep 17 00:00:00 2001 From: qiyuw Date: Tue, 17 Feb 2026 11:48:00 -0800 Subject: [PATCH 2/8] file permission --- megatron/core/distributed/distributed_data_parallel.py | 0 megatron/core/distributed/param_and_grad_buffer.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) mode change 100755 => 100644 megatron/core/distributed/distributed_data_parallel.py mode change 100755 => 100644 megatron/core/distributed/param_and_grad_buffer.py diff --git a/megatron/core/distributed/distributed_data_parallel.py b/megatron/core/distributed/distributed_data_parallel.py old mode 100755 new mode 100644 diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py old mode 100755 new mode 100644 From 04ddeb2eda5d52223a0fb8074c890a6c508f93ae Mon Sep 17 00:00:00 2001 From: qiyuw Date: Wed, 18 Feb 2026 17:03:55 +0000 Subject: [PATCH 3/8] lint --- megatron/core/distributed/param_and_grad_buffer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index 301046506d3..467e442ab67 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -347,7 +347,8 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): if self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag: for bucket in self.buckets: for param in bucket.params: - # bf16 weights are already mapped to param.data + # Skip copying since bf16 weights in the mxfp8 model + # are already mapped to param.data. if not is_mxfp8tensor(param) and not is_float8tensor(param): is_bf16_weight_group = True break @@ -827,7 +828,8 @@ def _does_param_require_new_bucket(param): cur_bucket_id = 0 for param in params[::-1]: param_start_index, param_end_index, bucket_id = self.param_index_map[param] - # For MXFP8 param: we only need to map bf16 weights (layernorm, embedding, etc) to the buffer. + # For MXFP8 param: + # we only need to map bf16 weights (layernorm, embedding, etc) to the buffer. if self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag: if not is_mxfp8tensor(param) and not is_float8tensor(param): if self.param_data is not None: From c67d0980f9cdf5f6480181ff96ff5aa95f17da5c Mon Sep 17 00:00:00 2001 From: qiyuw Date: Wed, 18 Feb 2026 17:06:15 +0000 Subject: [PATCH 4/8] add more comments --- megatron/core/distributed/distributed_data_parallel.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/megatron/core/distributed/distributed_data_parallel.py b/megatron/core/distributed/distributed_data_parallel.py index 89b65741e05..fe6f29af40e 100644 --- a/megatron/core/distributed/distributed_data_parallel.py +++ b/megatron/core/distributed/distributed_data_parallel.py @@ -492,6 +492,8 @@ def start_param_sync(self, *unused, force_sync: bool = False, force_dispatch: bo if self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag: for bucket in bucket_group.buckets: for param in bucket.params: + # Skip copying since bf16 weights in the mxfp8 model + # are already mapped to param.data. if not is_mxfp8tensor(param) and not is_float8tensor(param): is_bf16_weight_group = True break From 36548acf01a4d250e15ed48da778defc496bc767 Mon Sep 17 00:00:00 2001 From: qiyuw Date: Wed, 25 Feb 2026 19:52:46 +0000 Subject: [PATCH 5/8] address comments Signed-off-by: qiyuw --- .../distributed/distributed_data_parallel.py | 12 ++++----- .../core/distributed/param_and_grad_buffer.py | 27 +++++-------------- 2 files changed, 13 insertions(+), 26 deletions(-) diff --git a/megatron/core/distributed/distributed_data_parallel.py b/megatron/core/distributed/distributed_data_parallel.py index fe6f29af40e..7f3c0993bd8 100644 --- a/megatron/core/distributed/distributed_data_parallel.py +++ b/megatron/core/distributed/distributed_data_parallel.py @@ -7,7 +7,7 @@ import torch from ..config_logger import has_config_logger_enabled, log_config_to_disk -from ..fp8_utils import is_float8tensor, is_mxfp8tensor, post_all_gather_processing +from ..fp8_utils import is_float8tensor, post_all_gather_processing from ..process_groups_config import ProcessGroupCollection from ..transformer.cuda_graphs import is_graph_capturing from ..transformer.transformer_config import TransformerConfig @@ -488,26 +488,26 @@ def start_param_sync(self, *unused, force_sync: bool = False, force_dispatch: bo # The paramaters are cast from bf16 to MXFP8 during copy. # In the case of "overlap_param_gather=True", the param copy is done # in "finish_param_sync" stage after zeroing the shared gardient buffers. - is_bf16_weight_group = False if self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag: for bucket in bucket_group.buckets: + is_bf16_weight_bucket = False for param in bucket.params: # Skip copying since bf16 weights in the mxfp8 model # are already mapped to param.data. - if not is_mxfp8tensor(param) and not is_float8tensor(param): - is_bf16_weight_group = True + if not is_float8tensor(param): + is_bf16_weight_bucket = True break param_start, param_end = bucket.param_to_index[param] param_slice = bucket.param_data.view(-1)[param_start:param_end] param.data.copy_(param_slice.view(param.data.shape)) + if is_bf16_weight_bucket: + continue # All-gathered params are not needed after being copied to param.data. # Zero out the param buffer (shared with grad buffer) for gradient # accumulation. We cannot zero out the entire grad buffer because one grad # buffer may correspond to multiple param buffers. If we zero out the entire # grad buffer, it would clear the data of those param buffers that have not # yet completed AG. - if is_bf16_weight_group: - break bucket.param_data.zero_() else: fp8_params = [] diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index 467e442ab67..ed4a4f7bf30 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -342,26 +342,26 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): # For the mxfp8_param with "reuse_grad_buf_for_mxfp8_param_ag=True", # we need to copy the param_data from the shared_param/grad_buffer to param.data - # after the param all-gather. - is_bf16_weight_group = False + # after the param all-gather. if self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag: for bucket in self.buckets: + is_bf16_weight_bucket = False for param in bucket.params: # Skip copying since bf16 weights in the mxfp8 model # are already mapped to param.data. - if not is_mxfp8tensor(param) and not is_float8tensor(param): - is_bf16_weight_group = True + if not is_float8tensor(param): + is_bf16_weight_bucket = True break param_start, param_end = bucket.param_to_index[param] param_slice = bucket.param_data.view(-1)[param_start:param_end] param.data.copy_(param_slice.view(param.data.shape)) + if is_bf16_weight_bucket: + continue # All-gathered params are not needed after being copied to param.data. # Zero out the param buffer (shared with grad buffer) for gradient accumulation. # We cannot zero out the entire grad buffer because one grad buffer may # correspond to multiple param buffers. If we zero out the entire grad buffer, # it would clear the data of those param buffers that have not yet completed AG. - if is_bf16_weight_group: - break bucket.param_data.zero_() else: fp8_params = [] @@ -830,20 +830,7 @@ def _does_param_require_new_bucket(param): param_start_index, param_end_index, bucket_id = self.param_index_map[param] # For MXFP8 param: # we only need to map bf16 weights (layernorm, embedding, etc) to the buffer. - if self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag: - if not is_mxfp8tensor(param) and not is_float8tensor(param): - if self.param_data is not None: - new_param_data = self._get( - param.data.shape, param_start_index, buffer_type=BufferType.PARAM - ) - old_param_data = param.data - param.data = new_param_data - assert old_param_data._base is None - # Copy tensor values (from initialization or checkpoint). - param.data.detach().copy_(old_param_data) - del old_param_data - else: - # Assign param.data to appropriate segment of self.param_data. + if not self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag or not is_mxfp8tensor(param): if self.param_data is not None: new_param_data = self._get( param.data.shape, param_start_index, buffer_type=BufferType.PARAM From f9e95ee5fb23b7ed67448ac0a7291a7d3cb5df43 Mon Sep 17 00:00:00 2001 From: Kunlun Li <94586211+kunlunl@users.noreply.github.com> Date: Thu, 26 Feb 2026 11:16:19 +0800 Subject: [PATCH 6/8] Delete extra spaces --- megatron/core/distributed/param_and_grad_buffer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index ed4a4f7bf30..9b74a8d7291 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -342,7 +342,7 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): # For the mxfp8_param with "reuse_grad_buf_for_mxfp8_param_ag=True", # we need to copy the param_data from the shared_param/grad_buffer to param.data - # after the param all-gather. + # after the param all-gather. if self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag: for bucket in self.buckets: is_bf16_weight_bucket = False From 6656cb2286a2f476ae564f440ec5a12d583ccfaf Mon Sep 17 00:00:00 2001 From: qiyuw Date: Wed, 4 Mar 2026 06:17:07 +0000 Subject: [PATCH 7/8] extend unit test to check if bf16 weights are mapped into buffer or not Signed-off-by: qiyuw --- tests/unit_tests/test_fp8_param.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/tests/unit_tests/test_fp8_param.py b/tests/unit_tests/test_fp8_param.py index 361698f7127..1d1bbf2830f 100644 --- a/tests/unit_tests/test_fp8_param.py +++ b/tests/unit_tests/test_fp8_param.py @@ -11,7 +11,7 @@ from megatron.core.distributed import DistributedDataParallel as DDP from megatron.core.enums import ModelType -from megatron.core.fp8_utils import is_float8tensor +from megatron.core.fp8_utils import is_float8tensor, is_mxfp8tensor from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.num_microbatches_calculator import destroy_num_microbatches_calculator @@ -246,6 +246,31 @@ def _run_test_helper( if fp8_param_gather: assert num_fp8_params == 4 * fp8_layers + # Verify that bf16 params (embedding, LN, etc.) in the MXFP8 model are mapped + # to the param buffer (shared with grad buffer) rather than allocated separately. + if args.reuse_grad_buf_for_mxfp8_param_ag: + for buffer in gpt_model[0].buffers: + if buffer.param_data is None: + continue + buf_start = buffer.param_data.data_ptr() + buf_end = ( + buf_start + + buffer.param_data.numel() * buffer.param_data.element_size() + ) + for param in buffer.param_to_bucket: + if is_mxfp8tensor(param): + # MXFP8 params keep their own quantized storage. + assert not (buf_start <= param.data.data_ptr() < buf_end), ( + "MXFP8 param should not be mapped to the param buffer" + ) + else: + # BF16 params should be views into the param buffer + # (no double allocation). + assert buf_start <= param.data.data_ptr() < buf_end, ( + "BF16 param should be a view into the param buffer " + "(no separate allocation)" + ) + loss_list = [] for i in range(100): From 77a03712f3f0200c803e998008a6bcd746416020 Mon Sep 17 00:00:00 2001 From: qiyuw Date: Wed, 4 Mar 2026 08:55:43 -0800 Subject: [PATCH 8/8] fix lint Signed-off-by: qiyuw --- tests/unit_tests/test_fp8_param.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/tests/unit_tests/test_fp8_param.py b/tests/unit_tests/test_fp8_param.py index 1d1bbf2830f..34b504e21de 100644 --- a/tests/unit_tests/test_fp8_param.py +++ b/tests/unit_tests/test_fp8_param.py @@ -253,16 +253,13 @@ def _run_test_helper( if buffer.param_data is None: continue buf_start = buffer.param_data.data_ptr() - buf_end = ( - buf_start - + buffer.param_data.numel() * buffer.param_data.element_size() - ) + buf_end = buf_start + buffer.param_data.numel() * buffer.param_data.element_size() for param in buffer.param_to_bucket: if is_mxfp8tensor(param): # MXFP8 params keep their own quantized storage. - assert not (buf_start <= param.data.data_ptr() < buf_end), ( - "MXFP8 param should not be mapped to the param buffer" - ) + assert not ( + buf_start <= param.data.data_ptr() < buf_end + ), "MXFP8 param should not be mapped to the param buffer" else: # BF16 params should be views into the param buffer # (no double allocation).