Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions megatron/core/distributed/distributed_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,10 +490,18 @@ def start_param_sync(self, *unused, force_sync: bool = False, force_dispatch: bo
# in "finish_param_sync" stage after zeroing the shared gardient buffers.
if self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag:
for bucket in bucket_group.buckets:
is_bf16_weight_bucket = False
for param in bucket.params:
# Skip copying since bf16 weights in the mxfp8 model
# are already mapped to param.data.
if not is_float8tensor(param):
is_bf16_weight_bucket = True
break
param_start, param_end = bucket.param_to_index[param]
param_slice = bucket.param_data.view(-1)[param_start:param_end]
param.data.copy_(param_slice.view(param.data.shape))
if is_bf16_weight_bucket:
continue
# All-gathered params are not needed after being copied to param.data.
# Zero out the param buffer (shared with grad buffer) for gradient
# accumulation. We cannot zero out the entire grad buffer because one grad
Expand Down
14 changes: 11 additions & 3 deletions megatron/core/distributed/param_and_grad_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,10 +345,18 @@ def finish_param_sync(self, skip_next_bucket_dispatch: bool = False):
# after the param all-gather.
if self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag:
for bucket in self.buckets:
is_bf16_weight_bucket = False
for param in bucket.params:
# Skip copying since bf16 weights in the mxfp8 model
# are already mapped to param.data.
if not is_float8tensor(param):
is_bf16_weight_bucket = True
break
param_start, param_end = bucket.param_to_index[param]
param_slice = bucket.param_data.view(-1)[param_start:param_end]
param.data.copy_(param_slice.view(param.data.shape))
if is_bf16_weight_bucket:
continue
# All-gathered params are not needed after being copied to param.data.
# Zero out the param buffer (shared with grad buffer) for gradient accumulation.
# We cannot zero out the entire grad buffer because one grad buffer may
Expand Down Expand Up @@ -820,9 +828,9 @@ def _does_param_require_new_bucket(param):
cur_bucket_id = 0
for param in params[::-1]:
param_start_index, param_end_index, bucket_id = self.param_index_map[param]
# For MXFP8 param: we only need to map weight gradients to the buffer.
if not self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag:
# Assign param.data to appropriate segment of self.param_data.
# For MXFP8 param:
# we only need to map bf16 weights (layernorm, embedding, etc) to the buffer.
if not self.ddp_config.reuse_grad_buf_for_mxfp8_param_ag or not is_mxfp8tensor(param):
if self.param_data is not None:
new_param_data = self._get(
param.data.shape, param_start_index, buffer_type=BufferType.PARAM
Expand Down
24 changes: 23 additions & 1 deletion tests/unit_tests/test_fp8_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from megatron.core.distributed import DistributedDataParallel as DDP
from megatron.core.enums import ModelType
from megatron.core.fp8_utils import is_float8tensor
from megatron.core.fp8_utils import is_float8tensor, is_mxfp8tensor
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec
from megatron.core.models.gpt.gpt_model import GPTModel
from megatron.core.num_microbatches_calculator import destroy_num_microbatches_calculator
Expand Down Expand Up @@ -246,6 +246,28 @@ def _run_test_helper(
if fp8_param_gather:
assert num_fp8_params == 4 * fp8_layers

# Verify that bf16 params (embedding, LN, etc.) in the MXFP8 model are mapped
# to the param buffer (shared with grad buffer) rather than allocated separately.
if args.reuse_grad_buf_for_mxfp8_param_ag:
for buffer in gpt_model[0].buffers:
if buffer.param_data is None:
continue
buf_start = buffer.param_data.data_ptr()
buf_end = buf_start + buffer.param_data.numel() * buffer.param_data.element_size()
for param in buffer.param_to_bucket:
if is_mxfp8tensor(param):
# MXFP8 params keep their own quantized storage.
assert not (
buf_start <= param.data.data_ptr() < buf_end
), "MXFP8 param should not be mapped to the param buffer"
else:
# BF16 params should be views into the param buffer
# (no double allocation).
assert buf_start <= param.data.data_ptr() < buf_end, (
"BF16 param should be a view into the param buffer "
"(no separate allocation)"
)

loss_list = []

for i in range(100):
Expand Down
Loading