diff --git a/megatron/core/distributed/distributed_data_parallel.py b/megatron/core/distributed/distributed_data_parallel.py index 55179ff3024..7f3c0993bd8 100644 --- a/megatron/core/distributed/distributed_data_parallel.py +++ b/megatron/core/distributed/distributed_data_parallel.py @@ -490,10 +490,18 @@ def start_param_sync(self, *unused, force_sync: bool = False, force_dispatch: bo # in "finish_param_sync" stage after zeroing the shared gardient buffers. if self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag: for bucket in bucket_group.buckets: + is_bf16_weight_bucket = False for param in bucket.params: + # Skip copying since bf16 weights in the mxfp8 model + # are already mapped to param.data. + if not is_float8tensor(param): + is_bf16_weight_bucket = True + break param_start, param_end = bucket.param_to_index[param] param_slice = bucket.param_data.view(-1)[param_start:param_end] param.data.copy_(param_slice.view(param.data.shape)) + if is_bf16_weight_bucket: + continue # All-gathered params are not needed after being copied to param.data. # Zero out the param buffer (shared with grad buffer) for gradient # accumulation. We cannot zero out the entire grad buffer because one grad diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index 088374fbf13..9b74a8d7291 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -345,10 +345,18 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False): # after the param all-gather. if self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag: for bucket in self.buckets: + is_bf16_weight_bucket = False for param in bucket.params: + # Skip copying since bf16 weights in the mxfp8 model + # are already mapped to param.data. + if not is_float8tensor(param): + is_bf16_weight_bucket = True + break param_start, param_end = bucket.param_to_index[param] param_slice = bucket.param_data.view(-1)[param_start:param_end] param.data.copy_(param_slice.view(param.data.shape)) + if is_bf16_weight_bucket: + continue # All-gathered params are not needed after being copied to param.data. # Zero out the param buffer (shared with grad buffer) for gradient accumulation. # We cannot zero out the entire grad buffer because one grad buffer may @@ -820,9 +828,9 @@ def _does_param_require_new_bucket(param): cur_bucket_id = 0 for param in params[::-1]: param_start_index, param_end_index, bucket_id = self.param_index_map[param] - # For MXFP8 param: we only need to map weight gradients to the buffer. - if not self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag: - # Assign param.data to appropriate segment of self.param_data. + # For MXFP8 param: + # we only need to map bf16 weights (layernorm, embedding, etc) to the buffer. + if not self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag or not is_mxfp8tensor(param): if self.param_data is not None: new_param_data = self._get( param.data.shape, param_start_index, buffer_type=BufferType.PARAM diff --git a/tests/unit_tests/test_fp8_param.py b/tests/unit_tests/test_fp8_param.py index 361698f7127..34b504e21de 100644 --- a/tests/unit_tests/test_fp8_param.py +++ b/tests/unit_tests/test_fp8_param.py @@ -11,7 +11,7 @@ from megatron.core.distributed import DistributedDataParallel as DDP from megatron.core.enums import ModelType -from megatron.core.fp8_utils import is_float8tensor +from megatron.core.fp8_utils import is_float8tensor, is_mxfp8tensor from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.num_microbatches_calculator import destroy_num_microbatches_calculator @@ -246,6 +246,28 @@ def _run_test_helper( if fp8_param_gather: assert num_fp8_params == 4 * fp8_layers + # Verify that bf16 params (embedding, LN, etc.) in the MXFP8 model are mapped + # to the param buffer (shared with grad buffer) rather than allocated separately. + if args.reuse_grad_buf_for_mxfp8_param_ag: + for buffer in gpt_model[0].buffers: + if buffer.param_data is None: + continue + buf_start = buffer.param_data.data_ptr() + buf_end = buf_start + buffer.param_data.numel() * buffer.param_data.element_size() + for param in buffer.param_to_bucket: + if is_mxfp8tensor(param): + # MXFP8 params keep their own quantized storage. + assert not ( + buf_start <= param.data.data_ptr() < buf_end + ), "MXFP8 param should not be mapped to the param buffer" + else: + # BF16 params should be views into the param buffer + # (no double allocation). + assert buf_start <= param.data.data_ptr() < buf_end, ( + "BF16 param should be a view into the param buffer " + "(no separate allocation)" + ) + loss_list = [] for i in range(100):