diff --git a/examples/rl/environments/math/math_agent.py b/examples/rl/environments/math/math_agent.py
index 027ef242285..e7e113e2bf5 100644
--- a/examples/rl/environments/math/math_agent.py
+++ b/examples/rl/environments/math/math_agent.py
@@ -117,6 +117,24 @@ def compute_score(self, response: str, golden: dict, golden_key: str = "answer")
# Did not format the answer correctly
return self.negative_reward
+ # def make_prefix(self, problem_key: str = "problem", **kwargs) -> str:
+ # """Take a string math problem and return the prompt. Supports requesting tagged or boxed answers. Supports chat mode prompts."""
+ # if self.answer_format == "boxed":
+ # answer_format = "Please reason step by step and provide your answer between \\boxed{} tags, for example \\boxed{20\\sqrt{3}}."
+ # elif self.answer_format == "tagged":
+ # answer_format = "Please reason step by step and provide your answer between tags, for example 20\\sqrt{3} . Do not include an = sign."
+ # else:
+ # raise ValueError(f"Invalid answer format: {self.answer_format}")
+
+ # if self.chat_mode:
+ # prefix = f"""{kwargs[problem_key]}\n{answer_format}"""
+ # else:
+ # prefix = f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.
+ # The question will be a word math problem. Show your work in tags.
+ # {answer_format}
+ # User: {kwargs[problem_key]}
+ # {self.assistant_suffix}"""
+ # return prefix
def make_prefix(self, problem_key: str = "problem", **kwargs) -> str:
"""Take a string math problem and return the prompt. Supports requesting tagged or boxed answers. Supports chat mode prompts."""
if self.answer_format == "boxed":
@@ -126,6 +144,11 @@ def make_prefix(self, problem_key: str = "problem", **kwargs) -> str:
else:
raise ValueError(f"Invalid answer format: {self.answer_format}")
- prefix = f"""{kwargs[problem_key]}\n{answer_format}"""
+ # prefix = f"""{kwargs[problem_key]}\n{answer_format}"""
+ prefix = f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.
+ The question will be a word math problem. Show your work in tags.
+ {answer_format}
+ User: {kwargs[problem_key]}
+ Assistant:"""
return prefix
diff --git a/examples/rl/model_configs/common.sh b/examples/rl/model_configs/common.sh
index 6198708c1bf..a84029b9b6c 100644
--- a/examples/rl/model_configs/common.sh
+++ b/examples/rl/model_configs/common.sh
@@ -16,14 +16,11 @@ COMMON_OPTIONS="\
--transformer-impl transformer_engine \
--${PRECISION:-bf16} \
--te-rng-tracker \
- --rl-offload-optimizer-during-inference \
- --inference-dynamic-batching-buffer-size-gb 20 \
--data-parallel-random-init \
--attention-backend flash \
--timing-log-level 1 \
--log-timers-to-tensorboard \
--save-retain-interval 160 \
- --inference-dynamic-batching-num-cuda-graphs 1 \
--inference-dynamic-batching-unified-memory-level 1 \
--adam-beta1 0.9 \
--adam-beta2 ${ADAM_BETA2:-0.95} \
diff --git a/examples/rl/model_configs/nemotron6_3b_moe.sh b/examples/rl/model_configs/nemotron6_3b_moe.sh
index 19891ad7b8b..87444552ddd 100644
--- a/examples/rl/model_configs/nemotron6_3b_moe.sh
+++ b/examples/rl/model_configs/nemotron6_3b_moe.sh
@@ -9,39 +9,21 @@ echo "Using Nemotron6 3B MOE model checkpoint"
SCRIPT_PATH="${BASH_SOURCE[0]}"
source $(dirname $SCRIPT_PATH)/common.sh
-# In all cases, one can override those values.
-# However, running without envs will give you some
-# good perf out of the box for established envs.
-if [ "$(basename "$ENV_CONFIG")" = "dapo.yaml" ]; then
- echo "Using DAPO environment config"
- GRPO_CLAMP_EPS_LOWER=${GRPO_CLAMP_EPS_LOWER:-0.2}
- GRPO_CLAMP_EPS_UPPER=${GRPO_CLAMP_EPS_UPPER:-0.28}
- MAX_INFERENCE_BS=${MAX_INFERENCE_BS:-32}
- GRPO_GROUP_SIZE=${GRPO_GROUP_SIZE:-16}
- GRPO_PROMPTS_PER_STEP=${GRPO_PROMPTS_PER_STEP:-64}
- GRPO_ITERATIONS=${GRPO_ITERATIONS:-1}
- GRPO_KL_BETA=${GRPO_KL_BETA:-"0.0"}
- TRAINING_BATCH_SIZE=${TRAINING_BATCH_SIZE:-1024}
- MICRO_BATCH_SIZE=${MICRO_BATCH_SIZE:-1}
- MAX_SEQ_LENGTH=${MAX_SEQ_LENGTH:-11999}
- EXIT_INTERVAL=${EXIT_INTERVAL:-20}
- CHKPT_SAVE_INTERVAL=${CHKPT_SAVE_INTERVAL:-20}
-else
- # Some default values if config is unsupported.
- echo "Undected environment config, using default values"
- GRPO_CLAMP_EPS_LOWER=${GRPO_CLAMP_EPS_LOWER:-0.2}
- GRPO_CLAMP_EPS_UPPER=${GRPO_CLAMP_EPS_UPPER:-0.28}
- MAX_INFERENCE_BS=${MAX_INFERENCE_BS:-64}
- GRPO_GROUP_SIZE=${GRPO_GROUP_SIZE:-2}
- GRPO_PROMPTS_PER_STEP=${GRPO_PROMPTS_PER_STEP:-16}
- GRPO_ITERATIONS=${GRPO_ITERATIONS:-1}
- GRPO_KL_BETA=${GRPO_KL_BETA:-"0.0"}
- TRAINING_BATCH_SIZE=${TRAINING_BATCH_SIZE:-32}
- MICRO_BATCH_SIZE=${MICRO_BATCH_SIZE:-1}
- MAX_SEQ_LENGTH=${MAX_SEQ_LENGTH:-1024}
- EXIT_INTERVAL=${EXIT_INTERVAL:-20}
- CHKPT_SAVE_INTERVAL=${CHKPT_SAVE_INTERVAL:-20}
-fi
+
+echo "Undected environment config, using default values"
+GRPO_CLAMP_EPS_LOWER=${GRPO_CLAMP_EPS_LOWER:-0.2}
+GRPO_CLAMP_EPS_UPPER=${GRPO_CLAMP_EPS_UPPER:-0.28}
+MAX_INFERENCE_BS=${MAX_INFERENCE_BS:-64}
+GRPO_GROUP_SIZE=${GRPO_GROUP_SIZE:-16}
+GRPO_PROMPTS_PER_STEP=${GRPO_PROMPTS_PER_STEP:-64}
+GRPO_ITERATIONS=${GRPO_ITERATIONS:-1}
+GRPO_KL_BETA=${GRPO_KL_BETA:-"0.0"}
+TRAINING_BATCH_SIZE=${TRAINING_BATCH_SIZE:-1024}
+MICRO_BATCH_SIZE=${MICRO_BATCH_SIZE:-1}
+MAX_SEQ_LENGTH=${MAX_SEQ_LENGTH:-8192}
+EXIT_INTERVAL=${EXIT_INTERVAL:-15}
+CHKPT_SAVE_INTERVAL=${CHKPT_SAVE_INTERVAL:-20}
+
ENV_DEPENDENT="\
--micro-batch-size $MICRO_BATCH_SIZE \
@@ -56,14 +38,19 @@ ENV_DEPENDENT="\
MODEL_OPTIONS="\
--rl-skip-bos-token \
- --no-rl-use-sequence-packing \
+ --rl-use-sequence-packing \
--rl-partial-rollouts \
--rl-offload-optimizer-during-inference \
--moe-pad-experts-for-cuda-graph-inference \
- --inference-dynamic-batching-max-tokens 8192 \
+ --inference-dynamic-batching-num-cuda-graphs 4 \
--inference-dynamic-batching-max-requests 128 \
- --inference-dynamic-batching-num-cuda-graphs 2 \
- --decode-only-cuda-graphs \
+ --inference-dynamic-batching-paused-buffer-size-gb 5 \
+ --inference-dynamic-batching-buffer-size-gb 5 \
+ --inference-dynamic-batching-unified-memory-level 1 \
+ --rl-training-cuda-graphs \
+ --empty-unused-memory-level 0 \
+ --rl-parallel-generation-tasks 128 \
+ --inference-dynamic-batching-cuda-graph-mixed-prefill-count 0 \
--cuda-graph-impl local \
--cuda-graph-scope full \
--use-checkpoint-args \
@@ -118,4 +105,13 @@ MODEL_OPTIONS="\
--lr-warmup-samples 640 \
--lr-warmup-init 0.3e-7 \
--no-load-optim \
- --no-load-rng "
+ --no-load-rng \
+ --moe-permute-fusion \
+ --eval-interval 1000 \
+ --timing-log-level 2 \
+ "
+ # --inference-dynamic-batching-max-tokens 8192 \
+# --rl-training-cuda-graphs \
+ # --rl-training-cuda-graphs \
+# --empty-unused-memory-level 0 \ # try with the default value (=2)
+ # --inference-logging-step-interval 100 \
diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py
index 7dec7a14bea..ea9e474fde2 100644
--- a/megatron/core/inference/engines/dynamic_engine.py
+++ b/megatron/core/inference/engines/dynamic_engine.py
@@ -25,6 +25,7 @@
MaxSequenceLengthOverflowError,
TokenOverflowError,
)
+from megatron.core.inference.inference_flops import InferenceFLOPsCalculator
from megatron.core.inference.data_parallel_inference_coordinator import (
DataParallelInferenceCoordinator,
)
@@ -188,6 +189,21 @@ def __init__(self, controller: TextGenerationController, context: DynamicInferen
)
self.cuda_graph_impl = model_config.cuda_graph_impl
self.cuda_graph_scope = model_config.cuda_graph_scope
+
+ # Initialize inference FLOPs calculator and GPU peak for MFU reporting.
+ self.flops_calculator = None
+ self.gpu_peak_tflops = 0.0
+ self.cumulative_inference_flops = 0.0
+ self.cumulative_inference_time = 0.0
+ try:
+ from megatron.training.global_vars import get_args
+ from megatron.training.gpu_peak_flops import get_gpu_peak_tflops
+ args = get_args()
+ self.flops_calculator = InferenceFLOPsCalculator.from_args(args)
+ self.gpu_peak_tflops = get_gpu_peak_tflops()
+ except Exception as e:
+ logging.warning(f"Could not initialize inference FLOPs calculator: {e}")
+
# Initialize engine.
self.reset()
@@ -1487,6 +1503,33 @@ async def async_bookkeep(
self.socket_for_receiving_requests.send(payload)
range_pop()
+ # Compute inference FLOPs for this step.
+ step_flops_info = None
+ if self.flops_calculator is not None:
+ batch_dims = self.context.batch_dimensions
+ decode_tokens = batch_dims.decode_req_count if batch_dims else 0
+ prefill_reqs = batch_dims.prefill_req_count if batch_dims else 0
+ total_tokens = batch_dims.token_count if batch_dims else 0
+ prefill_tokens = total_tokens - decode_tokens
+
+ step_flops_info = self.flops_calculator.compute_step_flops(
+ decode_tokens=decode_tokens,
+ prefill_tokens=prefill_tokens,
+ total_tokens=total_tokens,
+ active_blocks=context_state["total_active_used_blocks"],
+ active_reqs=context_state["total_request_count"] - context_state["paused_request_count"],
+ num_prefill_reqs=prefill_reqs,
+ )
+ self.cumulative_inference_flops += step_flops_info['total_flops']
+ self.cumulative_inference_time += step_time
+ try:
+ from megatron.training.mfu_tracker import get_mfu_tracker
+ get_mfu_tracker().add_inference_flops(
+ step_flops_info['total_flops'], step_time, tokens=total_tokens
+ )
+ except Exception:
+ pass
+
# Log KV cache utilization stats to W&B
if context_state["kv_stats"] is not None:
# Prepare metrics dictionary with all stats
@@ -1499,6 +1542,29 @@ async def async_bookkeep(
'inference/waiting_queue_len': int(len(self.waiting_request_ids)),
'inference/total_requests_dict_size': int(len(self.requests)),
}
+
+ batch_dims = self.context.batch_dimensions
+ total_tokens = batch_dims.token_count if batch_dims else 0
+ if step_time > 0 and total_tokens > 0:
+ metrics['inference/tokens_per_sec_per_gpu'] = float(total_tokens / step_time)
+
+ if step_flops_info is not None:
+ step_tflops = step_flops_info['total_flops'] / 1e12
+ step_throughput = step_tflops / step_time if step_time > 0 else 0
+ metrics['inference/step_flops_tflop'] = float(step_tflops)
+ metrics['inference/throughput_tflops_per_gpu'] = float(step_throughput)
+ metrics['inference/t_avg'] = float(step_flops_info['t_avg'])
+ metrics['inference/cumulative_flops_tflop'] = float(self.cumulative_inference_flops / 1e12)
+ if self.gpu_peak_tflops > 0:
+ mfu = step_throughput / self.gpu_peak_tflops * 100.0
+ cumulative_throughput = (
+ (self.cumulative_inference_flops / 1e12) / self.cumulative_inference_time
+ if self.cumulative_inference_time > 0 else 0
+ )
+ cumulative_mfu = cumulative_throughput / self.gpu_peak_tflops * 100.0
+ metrics['inference/mfu_percent'] = float(mfu)
+ metrics['inference/cumulative_mfu_percent'] = float(cumulative_mfu)
+
# Add KV stats with inference/ prefix
# Convert utilization metrics from 0-1 range to 0-100 percentage range for better visualization
for key, value in context_state["kv_stats"].items():
@@ -1557,6 +1623,18 @@ async def async_bookkeep(
mem["reserved_bytes.all.current"] / (1024**3),
)
)
+ batch_dims = self.context.batch_dimensions
+ total_tokens = batch_dims.token_count if batch_dims else 0
+ if step_time > 0 and total_tokens > 0:
+ toks_per_sec_per_gpu = total_tokens / step_time
+ output_str += f" toks/s/GPU: {toks_per_sec_per_gpu:.0f},"
+ if step_flops_info is not None:
+ step_tflops = step_flops_info['total_flops'] / 1e12
+ step_throughput = step_tflops / step_time if step_time > 0 else 0
+ output_str += f" {step_throughput:.1f} TFLOP/s/GPU"
+ if self.gpu_peak_tflops > 0:
+ mfu = step_throughput / self.gpu_peak_tflops * 100.0
+ output_str += f", MFU: {mfu:.1f}%"
if context_state["is_decode_only"]:
output_str = f"\033[94m{output_str}\033[0m"
logging.info(output_str)
diff --git a/megatron/rl/rl_utils.py b/megatron/rl/rl_utils.py
index 3d818b5b0cd..1e8a423c098 100644
--- a/megatron/rl/rl_utils.py
+++ b/megatron/rl/rl_utils.py
@@ -487,7 +487,7 @@ def get_environment_rollouts(
nvtx_range = get_nvtx_range()
if args.rl_offload_optimizer_during_inference:
- with nvtx_range("offload-optimizer-state-and-grad-buffers-during-inference"):
+ with nvtx_range("offload-optimizer-before-inference", time=True):
if not args.rl_training_cuda_graphs:
model[0].offload_grad_buffers()
else:
@@ -496,6 +496,7 @@ def get_environment_rollouts(
)
optimizer.offload_to_cpu()
+
# If we have separate training and inference models we to refit weights from the training model to the inference model.
has_separate_inference_model = inference_model is not None
if has_separate_inference_model:
@@ -519,7 +520,7 @@ def get_environment_rollouts(
pg_size = get_pg_size(inference_pg_collection.ep)
assert (n_prompts % pg_size == 0), f"{n_prompts=} must be divisible by {pg_size=}"
- with nvtx_range("rollout-collection"):
+ with nvtx_range("rollout-collection", time=True):
loop = get_asyncio_loop()
with megatron_rl_inference_mode(
inference_model,
@@ -530,7 +531,7 @@ def get_environment_rollouts(
increment_staleness_on_suspend=True,
) as inference_interface:
- with nvtx_range("inference-setup"):
+ with nvtx_range("inference-setup", time=True):
# Asyncronously run inference and rollout collection
rollout_generator = get_rollout_generator(
args, inference_interface, n_prompts, samples_per_group
@@ -538,7 +539,7 @@ def get_environment_rollouts(
# NOTE(jbarker): we need to double check this when using PP>1
rank = torch.distributed.get_rank()
- with nvtx_range("collect-rollouts"):
+ with nvtx_range("collect-rollouts", time=True):
if rank == 0:
log_single_rank(
logger,
@@ -563,14 +564,14 @@ def get_environment_rollouts(
# Just set up space to collect the rollouts
rollouts = [[None for _ in range(samples_per_group)] for _ in range(n_prompts)]
- with nvtx_range("sync-rollouts"):
+ with nvtx_range("sync-rollouts", time=True):
# Wait for Rollouts to be collected
# TODO(jbarker): double check why this isn't causing rank 0 memory allocations
torch.distributed.broadcast_object_list(rollouts, src=0)
logger.debug(f"Got rollouts on rank {rank}")
if args.rl_offload_optimizer_during_inference:
- with nvtx_range("restore-optimizer-state-and-grad-buffers-after-inference"):
+ with nvtx_range("onload-optimizer-after-inference", time=True):
model[0].restore_grad_buffers()
optimizer.restore_from_cpu()
@@ -1282,8 +1283,8 @@ def prepare_data_for_update(
model = model[0]
dtype = torch.bfloat16 if args.bf16 else (torch.float16 if args.fp16 else torch.float32)
- with nvtx_range("prepare-data-for-update"):
- with nvtx_range("compute-group-stats"):
+ with nvtx_range("prepare-data-for-update", time=True):
+ with nvtx_range("compute-group-stats", time=True):
group_stats = compute_group_stats(rollouts, tokenizer, args.seq_length)
# TODO(vitalyk): why do we need global_advantages here? go inside packing
advantages = global_advantages = torch.tensor(group_stats.advantages, dtype=dtype).cuda()
@@ -1323,7 +1324,7 @@ def prepare_data_for_update(
# First we calculate them on a global level and then we split and recalculate on a local level.
# Sequence packing and reporting needs it global but non-packing wants it local.
- with nvtx_range("prepare_trajectories"):
+ with nvtx_range("prepare-trajectories", time=True):
trajs, generation_masks, inference_logprobs = prepare_trajectories(
rollouts, tokenizer, args.seq_length, sequence_packing, args.rl_skip_bos_token
)
@@ -1331,7 +1332,7 @@ def prepare_data_for_update(
packing_context = None
# Build trajectories based on sequence packing or standard processing
if sequence_packing:
- with nvtx_range("sequence_packing", time=True):
+ with nvtx_range("sequence-packing", time=True):
runtime_state.packing_context = packing_context = pack_all_trajectories(
trajs,
generation_masks,
@@ -1351,7 +1352,7 @@ def prepare_data_for_update(
logprobs_batch_size = 1
else:
# Always compute standard masks for the original data (we'll need them later)
- with nvtx_range("get_ltor_masks_and_position_ids"):
+ with nvtx_range("get-ltor-masks-and-position-ids", time=True):
_, original_loss_mask, original_position_ids = get_ltor_masks_and_position_ids(
trajs,
tokenizer.eod,
@@ -1370,7 +1371,7 @@ def prepare_data_for_update(
)
logprobs_batch_size = args.micro_batch_size
- with torch.no_grad(), nvtx_range("compute_logprobs", time=True):
+ with torch.no_grad(), nvtx_range("compute-logprobs", time=True):
# Before we can update the model, we need to get the logprobs for the \pi_{old} model.
# Wrap forward_backward_func for Full iteration CUDA graph
@@ -1388,7 +1389,7 @@ def prepare_data_for_update(
pg_collection = get_attr_wrapped_model(model, "pg_collection")
pp_group = pg_collection.pp
- with torch.no_grad(), nvtx_range("compute_old_logprobs", time=True):
+ with torch.no_grad(), nvtx_range("compute-old-logprobs", time=True):
old_logprobs = compute_logprobs_batch(
model=model,
data_loader=data_loader,
@@ -1403,7 +1404,7 @@ def prepare_data_for_update(
is_correction=args.rl_inference_logprobs_is_correction,
)
- with torch.no_grad(), nvtx_range("compute_ref_logprobs", time=True):
+ with torch.no_grad(), nvtx_range("compute-ref-logprobs", time=True):
# We need to load the ref model state dict and compute the logprobs for the ref model
cur_st_dict = {
k: (v.cpu() if v is not None else v) for k, v in model.state_dict().items()
@@ -1432,7 +1433,7 @@ def prepare_data_for_update(
if sequence_packing:
- with nvtx_range("pack_logprobs", time=True):
+ with nvtx_range("pack-logprobs", time=True):
# Store logprobs on gpu in packing context
# Since PackingContext is a dataclass, we add these as new attributes
packing_context.old_logprobs = old_logprobs.cuda()
@@ -1460,7 +1461,7 @@ def prepare_data_for_update(
packing_context.packed_inference_logprobs = packed_inference_logprobs.cuda()
# Only mark as having inference logprobs for IS correction if enabled
packing_context.has_inference_logprobs = args.rl_inference_logprobs_is_correction
- with nvtx_range("create_dataloader"):
+ with nvtx_range("create-dataloader", time=True):
# @vitalyk: This function also reconfigures the data loader to count the
# global_batch_size in the bins frame of reference.
# I think it will be a better design if we split the data loader creating and logic
@@ -1477,7 +1478,7 @@ def prepare_data_for_update(
)
loader = get_microbatch_dataloader(len(packing_context.packed_trajs), args.micro_batch_size)
else:
- with nvtx_range("align_inference_logprobs", time=True):
+ with nvtx_range("align-inference-logprobs", time=True):
if inference_logprobs is not None:
inference_logprobs = align_unpacked_inference_logprobs(
inference_logprobs=inference_logprobs,
@@ -1490,7 +1491,7 @@ def prepare_data_for_update(
# Nullify logprobs if not used in IS correction,
if not args.rl_inference_logprobs_is_correction:
inference_logprobs = None
- with nvtx_range("create_dataloader"):
+ with nvtx_range("create-dataloader", time=True):
# Because of multiturn, our batch sizes for non-sequence packed trajectories are not fixed anymore.
# As in sequence packing above, we need to reconfigure it too.
runtime_state.packing_context = None
@@ -1851,7 +1852,7 @@ def megatron_rl_inference_mode(
with torch.no_grad():
if offload_optimizer_during_inference:
- with nvtx_range("offload-optimizer-state-and-grad-buffers-before-inference"):
+ with nvtx_range("offload-optimizer-before-inference", time=True):
if not args.rl_training_cuda_graphs:
# Offload grad buffers from the training model (if separate inference model is used)
# or from the inference model (if they're the same model)
@@ -1872,7 +1873,7 @@ def megatron_rl_inference_mode(
logger.debug(f"[{dist.get_rank()}] Entered inference mode")
yield inference_interface
- with nvtx_range("suspend-engine"):
+ with nvtx_range("suspend-engine", time=True):
loop.run_until_complete(inference_interface.suspend())
if increment_staleness_on_suspend:
inference_interface.increment_staleness()
@@ -1886,7 +1887,7 @@ def megatron_rl_inference_mode(
_maybe_prefetch_separate_inference_model_weights(model_core, to_cpu=True)
if offload_optimizer_during_inference:
- with nvtx_range("onload-optimizer-state-and-grad-buffers-after-inference"):
+ with nvtx_range("onload-optimizer-after-inference", time=True):
# Restore grad buffers to the training model (if separate inference model is used)
# or to the inference model (if they're the same model)
model_for_grad_offload = training_model if training_model is not None else model
diff --git a/megatron/rl/sequence_packing_utils.py b/megatron/rl/sequence_packing_utils.py
index b641ecd85d0..ddbcdeffa2f 100644
--- a/megatron/rl/sequence_packing_utils.py
+++ b/megatron/rl/sequence_packing_utils.py
@@ -51,7 +51,6 @@ class PackingContext:
original_trajs: All trajectories before packing
packed_trajs: Packed trajectories tensor [num_bins, bin_size]
packed_position_ids: Position IDs for packed sequences [num_bins, bin_size]
- packed_attention_mask: Attention mask for packed sequences [num_bins, 1, bin_size, bin_size]
packed_loss_mask: Loss mask for packed sequences [num_bins, bin_size]
original_inference_logprobs: Inference logprobs for all sequences before packing (optional)
bin_advantages: List of advantage tensors for each bin
@@ -64,7 +63,6 @@ class PackingContext:
original_trajs: torch.Tensor
packed_trajs: torch.Tensor
packed_position_ids: torch.Tensor
- packed_attention_mask: torch.Tensor
packed_loss_mask: torch.Tensor
original_inference_logprobs: Optional[torch.Tensor] = None
bin_advantages: List[torch.Tensor] = field(default_factory=list)
@@ -314,9 +312,8 @@ def create_empty_bins(
packed_trajs : torch.Tensor,
packed_position_ids : torch.Tensor,
packed_loss_mask : torch.Tensor,
- packed_attention_mask : torch.Tensor,
tokenizer,
-) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[Dict[str, Any]]]:
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[Dict[str, Any]]]:
"""Create empty bins for padding to ensure all ranks have the same number of bins.
Args:
@@ -325,11 +322,10 @@ def create_empty_bins(
packed_trajs: Packed trajectories tensor (for dtype/device reference)
packed_position_ids: Packed position IDs tensor (for dtype/device reference)
packed_loss_mask: Packed loss mask tensor (for dtype/device reference)
- packed_attention_mask: Packed attention mask tensor (can be None)
tokenizer: Tokenizer for pad token
Returns:
- Tuple of (empty_trajs, empty_position_ids, empty_loss_mask, empty_attention_mask, empty_packing_info_entries)
+ Tuple of (empty_trajs, empty_position_ids, empty_loss_mask, empty_packing_info_entries)
"""
device = packed_trajs.device
@@ -337,7 +333,6 @@ def create_empty_bins(
empty_bins = []
empty_position_ids_list = []
empty_loss_mask_list = []
- empty_attention_mask_list = []
empty_packing_info_entries = []
for i in range(num_empty_bins):
@@ -355,14 +350,6 @@ def create_empty_bins(
empty_loss = torch.zeros(1, bin_size, dtype=packed_loss_mask.dtype, device=device)
empty_loss_mask_list.append(empty_loss)
- # Zero attention mask if needed
- if packed_attention_mask is not None:
- # Attention mask is always 4D: [num_bins, 1, bin_size, bin_size]
- empty_attn = torch.zeros(
- 1, 1, bin_size, bin_size, dtype=packed_attention_mask.dtype, device=device
- )
- empty_attention_mask_list.append(empty_attn)
-
# Empty packing info entries
empty_packing_info_entries.append(
{
@@ -376,22 +363,15 @@ def create_empty_bins(
empty_trajs = torch.cat(empty_bins, dim=0)
empty_position_ids = torch.cat(empty_position_ids_list, dim=0)
empty_loss_mask = torch.cat(empty_loss_mask_list, dim=0)
- empty_attention_mask = (
- torch.cat(empty_attention_mask_list, dim=0)
- if packed_attention_mask is not None
- else None
- )
else:
empty_trajs = None
empty_position_ids = None
empty_loss_mask = None
- empty_attention_mask = None
return (
empty_trajs,
empty_position_ids,
empty_loss_mask,
- empty_attention_mask,
empty_packing_info_entries,
)
@@ -706,9 +686,6 @@ def pack_sequences(
position_ids = torch.zeros(
(num_bins, self.bin_size), dtype=torch.long, device=device, requires_grad=False
)
- attention_mask = torch.zeros(
- (num_bins, 1, self.bin_size, self.bin_size), dtype=torch.bool, device=device
- )
loss_mask = torch.zeros((num_bins, self.bin_size), dtype=torch.float, device=device)
# Track packing information for unpacking later
@@ -739,12 +716,6 @@ def pack_sequences(
len(seq), device=device, requires_grad=False
)
- # Causal attention mask within each sequence
- seq_len = end - start
- attention_mask[bin_idx, 0, start:end, start:end] = torch.tril(
- torch.ones(seq_len, seq_len, dtype=torch.bool, device=device)
- )
-
# Loss mask (excluding padding)
loss_mask[bin_idx, start:end] = 1.0
@@ -759,12 +730,6 @@ def pack_sequences(
seq_starts.append(current_pos)
seq_starts_dict[bin_idx] = seq_starts
- # Note: We'll store the actual padded length later when we know it
- # (it depends on the original trajectories passed to pack_sequences)
-
- # Invert attention mask, before inversion: (True = attend, False = mask)
- attention_mask.bitwise_not_()
-
# Create the PackingInfo dataclass
packing_info = PackingInfo(
bin_seq_indices=bin_seq_indices,
@@ -793,15 +758,14 @@ def pack_sequences(
)
log_single_rank(logger, logging.DEBUG, f" - First 20 bins: {seq_per_bin[:20]}")
- return packed_sequences, position_ids, attention_mask, loss_mask, packing_info
+ return packed_sequences, position_ids, loss_mask, packing_info
def distribute_packed_bins(
packed_trajs: torch.Tensor,
packed_position_ids: torch.Tensor,
- packed_attention_mask: torch.Tensor,
packed_loss_mask: torch.Tensor,
packing_info: PackingInfo,
-) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, PackingInfo]:
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, PackingInfo]:
"""Distribute packed bins across the data parallel ranks."""
rank = mpu.get_data_parallel_rank()
world_size = mpu.get_data_parallel_world_size()
@@ -838,7 +802,6 @@ def distribute_packed_bins(
# Extract this rank's bins
my_packed_trajs = []
my_packed_position_ids = []
- my_packed_attention_mask = []
my_packed_loss_mask = []
my_bin_seq_indices = []
my_seq_starts = {}
@@ -848,8 +811,6 @@ def distribute_packed_bins(
for new_idx, old_idx in enumerate(my_bin_indices):
my_packed_trajs.append(packed_trajs[old_idx])
my_packed_position_ids.append(packed_position_ids[old_idx])
- if packed_attention_mask is not None:
- my_packed_attention_mask.append(packed_attention_mask[old_idx])
my_packed_loss_mask.append(packed_loss_mask[old_idx])
my_bin_seq_indices.append(packing_info.bin_seq_indices[old_idx])
my_seq_starts[new_idx] = packing_info.seq_starts[old_idx]
@@ -875,9 +836,6 @@ def distribute_packed_bins(
device=packed_position_ids.device,
)
)
- packed_attention_mask = (
- torch.stack(my_packed_attention_mask) if my_packed_attention_mask else None
- )
packed_loss_mask = (
torch.stack(my_packed_loss_mask)
if my_packed_loss_mask
@@ -935,7 +893,6 @@ def distribute_packed_bins(
empty_trajs,
empty_position_ids,
empty_loss_mask,
- empty_attention_mask,
empty_packing_entries,
) = create_empty_bins(
num_empty_bins,
@@ -943,7 +900,6 @@ def distribute_packed_bins(
packed_trajs,
packed_position_ids,
packed_loss_mask,
- packed_attention_mask,
tokenizer,
)
@@ -954,18 +910,13 @@ def distribute_packed_bins(
)
packed_loss_mask = torch.cat([packed_loss_mask, empty_loss_mask], dim=0)
- if packed_attention_mask is not None and empty_attention_mask is not None:
- packed_attention_mask = torch.cat(
- [packed_attention_mask, empty_attention_mask], dim=0
- )
-
# Add empty entries to packing_info
for i, entry in enumerate(empty_packing_entries):
bin_idx = current_bins + i
new_packing_info.bin_seq_indices.append(entry['bin_seq_indices'])
new_packing_info.seq_starts[bin_idx] = entry['seq_starts']
- return packed_trajs, packed_position_ids, packed_attention_mask, packed_loss_mask, new_packing_info
+ return packed_trajs, packed_position_ids, packed_loss_mask, new_packing_info
def pack_all_trajectories(trajs, generation_masks, inference_logprobs, global_advantages, bin_size, max_sequences_per_bin, packing_algo):
@@ -998,7 +949,6 @@ def _gather(data):
(
packed_trajs,
packed_position_ids,
- packed_attention_mask,
packed_loss_mask,
packing_info,
) = packer.pack_sequences(trajs, generation_masks)
@@ -1008,13 +958,11 @@ def _gather(data):
(
packed_trajs,
packed_position_ids,
- packed_attention_mask,
packed_loss_mask,
packing_info,
) = distribute_packed_bins(
packed_trajs,
packed_position_ids,
- packed_attention_mask,
packed_loss_mask,
packing_info,
)
@@ -1051,7 +999,6 @@ def _gather(data):
original_trajs=trajs,
packed_trajs=packed_trajs,
packed_position_ids=packed_position_ids,
- packed_attention_mask=packed_attention_mask,
packed_loss_mask=packed_loss_mask,
original_inference_logprobs=inference_logprobs,
bin_advantages=bin_advantages,
diff --git a/megatron/training/training.py b/megatron/training/training.py
index 8d0bfff3b3f..3a005347f21 100644
--- a/megatron/training/training.py
+++ b/megatron/training/training.py
@@ -1923,14 +1923,17 @@ def training_log(
])
# Add timers from RL loop if needed.
if getattr(args, 'perform_rl_step', False):
- timers_to_log.extend(['rollout-collection', 'inference-setup', 'collect-rollouts', 'postrollout-gc-collect',
- 'sync-rollouts', 'prepare-data-for-update', 'compute-group-stats',
- 'prepare-trajectories', 'get-ltor-masks-and-position-ids', 'create-logprobs-dataloader',
- 'compute-logprobs', 'compute-ref-logprobs', 'compute-prob-stats',
- 'prepare-advantages', 'create-dataloader', 'log-wandb-tb',
- 'offload-optimizer-before-inference', 'onload-kv-cache-before-inference',
- 'wait-for-decode-only', 'build-cuda-graphs', 'suspend-engine',
- 'offload-kv-cache-after-inference', 'onload-optimizer-after-inference'])
+ timers_to_log.extend([
+ 'rollout-collection', 'inference-setup', 'collect-rollouts',
+ 'sync-rollouts', 'prepare-data-for-update', 'compute-group-stats',
+ 'prepare-trajectories', 'get-ltor-masks-and-position-ids',
+ 'sequence-packing',
+ 'compute-logprobs', 'compute-old-logprobs', 'compute-ref-logprobs',
+ 'pack-logprobs', 'align-inference-logprobs',
+ 'create-dataloader', 'log-wandb-tb',
+ 'offload-optimizer-before-inference', 'onload-optimizer-after-inference',
+ 'suspend-engine',
+ ])
# Calculate batch size.
batch_size = args.micro_batch_size * args.data_parallel_size * get_num_microbatches()
@@ -2114,11 +2117,126 @@ def training_log(
)
if args.log_throughput:
log_string += f' throughput per GPU (TFLOP/s/GPU): {throughput:.1f} |'
+
+ tokens_this_iter = batch_size * args.seq_length
+
+ # Compute and log MFU (Model FLOPs Utilization)
+ if not hasattr(args, '_gpu_peak_tflops'):
+ try:
+ from megatron.training.gpu_peak_flops import get_gpu_peak_tflops
+ args._gpu_peak_tflops = get_gpu_peak_tflops()
+ except Exception:
+ args._gpu_peak_tflops = 0.0
+
+ training_mfu = 0.0
+ inference_mfu = 0.0
+ total_mfu = 0.0
+ has_tracker = False
+ iter_inference_tokens = 0
+ iter_inference_time = 0.0
+ iter_logprob_time = 0.0
+ training_only_time = elapsed_time_per_iteration
+ training_flops = 0.0
+ iter_inference_flops = 0.0
+ effective_tokens = tokens_this_iter
+
+ # Read compute-logprobs time from the existing Megatron timer
+ try:
+ iter_logprob_time = (
+ timers('compute-logprobs').elapsed(reset=False, barrier=False)
+ / total_iterations
+ )
+ except Exception:
+ pass
+
+ if args._gpu_peak_tflops > 0:
+ try:
+ from megatron.training.mfu_tracker import get_mfu_tracker
+ tracker = get_mfu_tracker()
+ training_flops = num_floating_point_operations(args, batch_size)
+ iter_inference_time = tracker.get_iter_inference_time()
+ iter_inference_flops = tracker.get_iter_inference_flops()
+ iter_inference_tokens = tracker.get_iter_inference_tokens()
+ real_training_tokens = tracker.get_iter_real_training_tokens()
+ if real_training_tokens > 0:
+ effective_tokens = real_training_tokens
+ training_only_time = max(
+ elapsed_time_per_iteration - iter_inference_time - iter_logprob_time, 1e-6
+ )
+ tracker.add_training_flops(
+ training_flops, training_only_time, tokens=effective_tokens
+ )
+ tracker.reset_iter()
+ has_tracker = True
+ except Exception:
+ has_tracker = False
+
+ training_mfu = throughput / args._gpu_peak_tflops * 100.0
+
+ ws = args.world_size
+
+ # Per-iteration toks/s/GPU breakdown (uses real tokens when seq packing is active)
+ train_tps = effective_tokens / (training_only_time * ws) if training_only_time > 0 else 0.0
+ inf_tps = iter_inference_tokens / (iter_inference_time * ws) if iter_inference_time > 0 else 0.0
+ total_tps = (effective_tokens + iter_inference_tokens) / (elapsed_time_per_iteration * ws)
+ e2e_tps = effective_tokens / (elapsed_time_per_iteration * ws)
+
+ if has_tracker:
+ log_string += (
+ f' toks/s/GPU: train {train_tps:.0f}'
+ f', infer {inf_tps:.0f}'
+ f', total {total_tps:.0f}'
+ f', e2e {e2e_tps:.0f} |'
+ )
+
+ # Per-iteration MFU breakdown
+ if args._gpu_peak_tflops > 0:
+ log_string += f' MFU: train {training_mfu:.1f}%'
+ if has_tracker:
+ if iter_inference_time > 0:
+ inference_mfu = (
+ iter_inference_flops / (iter_inference_time * ws)
+ / 1e12 / args._gpu_peak_tflops * 100.0
+ )
+ total_mfu = (
+ (training_flops + iter_inference_flops)
+ / (elapsed_time_per_iteration * ws)
+ / 1e12 / args._gpu_peak_tflops * 100.0
+ )
+ log_string += (
+ f', infer {inference_mfu:.1f}%'
+ f', total {total_mfu:.1f}%'
+ )
+ log_string += ' |'
+
if args.log_timers_to_tensorboard:
if writer:
writer.add_scalar('throughput', throughput, iteration)
+ writer.add_scalar('toks_per_sec_per_gpu/e2e', e2e_tps, iteration)
+ if has_tracker:
+ writer.add_scalar('toks_per_sec_per_gpu/training', train_tps, iteration)
+ writer.add_scalar('toks_per_sec_per_gpu/inference', inf_tps, iteration)
+ writer.add_scalar('toks_per_sec_per_gpu/total', total_tps, iteration)
+ if args._gpu_peak_tflops > 0:
+ writer.add_scalar('mfu/training_percent', training_mfu, iteration)
+ if has_tracker:
+ writer.add_scalar('mfu/inference_percent', inference_mfu, iteration)
+ writer.add_scalar('mfu/total_percent', total_mfu, iteration)
if wandb_writer:
- wandb_writer.log({'throughput': throughput}, iteration)
+ wandb_log = {
+ 'throughput': throughput,
+ 'toks_per_sec_per_gpu/e2e': e2e_tps,
+ }
+ if has_tracker:
+ wandb_log['toks_per_sec_per_gpu/training'] = train_tps
+ wandb_log['toks_per_sec_per_gpu/inference'] = inf_tps
+ wandb_log['toks_per_sec_per_gpu/total'] = total_tps
+ if args._gpu_peak_tflops > 0:
+ wandb_log['mfu/training_percent'] = training_mfu
+ if has_tracker:
+ wandb_log['mfu/inference_percent'] = inference_mfu
+ wandb_log['mfu/total_percent'] = total_mfu
+ wandb_writer.log(wandb_log, iteration)
if args.log_energy:
energy = (energy_monitor.lap() / total_iterations) / args.world_size
power = energy / elapsed_time_per_iteration
diff --git a/tests/unit_tests/rl/test_sequence_packing_utils.py b/tests/unit_tests/rl/test_sequence_packing_utils.py
index 06e63adf217..75a7981457d 100644
--- a/tests/unit_tests/rl/test_sequence_packing_utils.py
+++ b/tests/unit_tests/rl/test_sequence_packing_utils.py
@@ -98,13 +98,12 @@ def test_sequence_packing_basic():
rewards = torch.tensor([1.0, 2.0, 3.0, 4.0])
sequences_tensor = torch.stack(sequences)
- packed_trajs, packed_position_ids, packed_attention_mask, packed_loss_mask, packing_info = (
+ packed_trajs, packed_position_ids, packed_loss_mask, packing_info = (
packer.pack_sequences(sequences_tensor, generation_masks)
)
assert packed_trajs is not None
assert packed_position_ids is not None
- assert packed_attention_mask is not None
assert packed_loss_mask is not None
assert packing_info is not None
@@ -140,7 +139,7 @@ def test_sequence_packing_with_generation_masks():
)
padded_sequences_tensor = torch.stack(padded_sequences)
- packed_trajs, packed_position_ids, packed_attention_mask, packed_loss_mask, packing_info = (
+ packed_trajs, packed_position_ids, packed_loss_mask, packing_info = (
packer.pack_sequences(padded_sequences_tensor, generation_masks)
)
@@ -162,16 +161,14 @@ def test_sequence_packing_empty_bins():
)
packed_position_ids = torch.tensor([[0, 1, 2, 3, 0, 0, 0, 0]])
packed_loss_mask = torch.tensor([[1, 1, 1, 1, 0, 0, 0, 0]], dtype=torch.float)
- packed_attention_mask = torch.ones(1, bin_size, bin_size)
- empty_trajs, empty_position_ids, empty_loss_mask, empty_attention_mask, empty_packing_info = (
+ empty_trajs, empty_position_ids, empty_loss_mask, empty_packing_info = (
sequence_packing_utils.create_empty_bins(
num_empty_bins=num_empty_bins,
bin_size=bin_size,
packed_trajs=packed_trajs,
packed_position_ids=packed_position_ids,
packed_loss_mask=packed_loss_mask,
- packed_attention_mask=packed_attention_mask,
tokenizer=tokenizer,
)
)
@@ -220,7 +217,7 @@ def test_sequence_packing_integration():
]
sequences_tensor = torch.stack(sequences)
- packed_trajs, packed_position_ids, packed_attention_mask, packed_loss_mask, packing_info = (
+ packed_trajs, packed_position_ids, packed_loss_mask, packing_info = (
packer.pack_sequences(sequences_tensor, generation_masks)
)