diff --git a/examples/rl/environments/math/math_agent.py b/examples/rl/environments/math/math_agent.py index 027ef242285..e7e113e2bf5 100644 --- a/examples/rl/environments/math/math_agent.py +++ b/examples/rl/environments/math/math_agent.py @@ -117,6 +117,24 @@ def compute_score(self, response: str, golden: dict, golden_key: str = "answer") # Did not format the answer correctly return self.negative_reward + # def make_prefix(self, problem_key: str = "problem", **kwargs) -> str: + # """Take a string math problem and return the prompt. Supports requesting tagged or boxed answers. Supports chat mode prompts.""" + # if self.answer_format == "boxed": + # answer_format = "Please reason step by step and provide your answer between \\boxed{} tags, for example \\boxed{20\\sqrt{3}}." + # elif self.answer_format == "tagged": + # answer_format = "Please reason step by step and provide your answer between tags, for example 20\\sqrt{3} . Do not include an = sign." + # else: + # raise ValueError(f"Invalid answer format: {self.answer_format}") + + # if self.chat_mode: + # prefix = f"""{kwargs[problem_key]}\n{answer_format}""" + # else: + # prefix = f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. + # The question will be a word math problem. Show your work in tags. + # {answer_format} + # User: {kwargs[problem_key]} + # {self.assistant_suffix}""" + # return prefix def make_prefix(self, problem_key: str = "problem", **kwargs) -> str: """Take a string math problem and return the prompt. Supports requesting tagged or boxed answers. Supports chat mode prompts.""" if self.answer_format == "boxed": @@ -126,6 +144,11 @@ def make_prefix(self, problem_key: str = "problem", **kwargs) -> str: else: raise ValueError(f"Invalid answer format: {self.answer_format}") - prefix = f"""{kwargs[problem_key]}\n{answer_format}""" + # prefix = f"""{kwargs[problem_key]}\n{answer_format}""" + prefix = f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. + The question will be a word math problem. Show your work in tags. + {answer_format} + User: {kwargs[problem_key]} + Assistant:""" return prefix diff --git a/examples/rl/model_configs/common.sh b/examples/rl/model_configs/common.sh index 6198708c1bf..a84029b9b6c 100644 --- a/examples/rl/model_configs/common.sh +++ b/examples/rl/model_configs/common.sh @@ -16,14 +16,11 @@ COMMON_OPTIONS="\ --transformer-impl transformer_engine \ --${PRECISION:-bf16} \ --te-rng-tracker \ - --rl-offload-optimizer-during-inference \ - --inference-dynamic-batching-buffer-size-gb 20 \ --data-parallel-random-init \ --attention-backend flash \ --timing-log-level 1 \ --log-timers-to-tensorboard \ --save-retain-interval 160 \ - --inference-dynamic-batching-num-cuda-graphs 1 \ --inference-dynamic-batching-unified-memory-level 1 \ --adam-beta1 0.9 \ --adam-beta2 ${ADAM_BETA2:-0.95} \ diff --git a/examples/rl/model_configs/nemotron6_3b_moe.sh b/examples/rl/model_configs/nemotron6_3b_moe.sh index 19891ad7b8b..87444552ddd 100644 --- a/examples/rl/model_configs/nemotron6_3b_moe.sh +++ b/examples/rl/model_configs/nemotron6_3b_moe.sh @@ -9,39 +9,21 @@ echo "Using Nemotron6 3B MOE model checkpoint" SCRIPT_PATH="${BASH_SOURCE[0]}" source $(dirname $SCRIPT_PATH)/common.sh -# In all cases, one can override those values. -# However, running without envs will give you some -# good perf out of the box for established envs. -if [ "$(basename "$ENV_CONFIG")" = "dapo.yaml" ]; then - echo "Using DAPO environment config" - GRPO_CLAMP_EPS_LOWER=${GRPO_CLAMP_EPS_LOWER:-0.2} - GRPO_CLAMP_EPS_UPPER=${GRPO_CLAMP_EPS_UPPER:-0.28} - MAX_INFERENCE_BS=${MAX_INFERENCE_BS:-32} - GRPO_GROUP_SIZE=${GRPO_GROUP_SIZE:-16} - GRPO_PROMPTS_PER_STEP=${GRPO_PROMPTS_PER_STEP:-64} - GRPO_ITERATIONS=${GRPO_ITERATIONS:-1} - GRPO_KL_BETA=${GRPO_KL_BETA:-"0.0"} - TRAINING_BATCH_SIZE=${TRAINING_BATCH_SIZE:-1024} - MICRO_BATCH_SIZE=${MICRO_BATCH_SIZE:-1} - MAX_SEQ_LENGTH=${MAX_SEQ_LENGTH:-11999} - EXIT_INTERVAL=${EXIT_INTERVAL:-20} - CHKPT_SAVE_INTERVAL=${CHKPT_SAVE_INTERVAL:-20} -else - # Some default values if config is unsupported. - echo "Undected environment config, using default values" - GRPO_CLAMP_EPS_LOWER=${GRPO_CLAMP_EPS_LOWER:-0.2} - GRPO_CLAMP_EPS_UPPER=${GRPO_CLAMP_EPS_UPPER:-0.28} - MAX_INFERENCE_BS=${MAX_INFERENCE_BS:-64} - GRPO_GROUP_SIZE=${GRPO_GROUP_SIZE:-2} - GRPO_PROMPTS_PER_STEP=${GRPO_PROMPTS_PER_STEP:-16} - GRPO_ITERATIONS=${GRPO_ITERATIONS:-1} - GRPO_KL_BETA=${GRPO_KL_BETA:-"0.0"} - TRAINING_BATCH_SIZE=${TRAINING_BATCH_SIZE:-32} - MICRO_BATCH_SIZE=${MICRO_BATCH_SIZE:-1} - MAX_SEQ_LENGTH=${MAX_SEQ_LENGTH:-1024} - EXIT_INTERVAL=${EXIT_INTERVAL:-20} - CHKPT_SAVE_INTERVAL=${CHKPT_SAVE_INTERVAL:-20} -fi + +echo "Undected environment config, using default values" +GRPO_CLAMP_EPS_LOWER=${GRPO_CLAMP_EPS_LOWER:-0.2} +GRPO_CLAMP_EPS_UPPER=${GRPO_CLAMP_EPS_UPPER:-0.28} +MAX_INFERENCE_BS=${MAX_INFERENCE_BS:-64} +GRPO_GROUP_SIZE=${GRPO_GROUP_SIZE:-16} +GRPO_PROMPTS_PER_STEP=${GRPO_PROMPTS_PER_STEP:-64} +GRPO_ITERATIONS=${GRPO_ITERATIONS:-1} +GRPO_KL_BETA=${GRPO_KL_BETA:-"0.0"} +TRAINING_BATCH_SIZE=${TRAINING_BATCH_SIZE:-1024} +MICRO_BATCH_SIZE=${MICRO_BATCH_SIZE:-1} +MAX_SEQ_LENGTH=${MAX_SEQ_LENGTH:-8192} +EXIT_INTERVAL=${EXIT_INTERVAL:-15} +CHKPT_SAVE_INTERVAL=${CHKPT_SAVE_INTERVAL:-20} + ENV_DEPENDENT="\ --micro-batch-size $MICRO_BATCH_SIZE \ @@ -56,14 +38,19 @@ ENV_DEPENDENT="\ MODEL_OPTIONS="\ --rl-skip-bos-token \ - --no-rl-use-sequence-packing \ + --rl-use-sequence-packing \ --rl-partial-rollouts \ --rl-offload-optimizer-during-inference \ --moe-pad-experts-for-cuda-graph-inference \ - --inference-dynamic-batching-max-tokens 8192 \ + --inference-dynamic-batching-num-cuda-graphs 4 \ --inference-dynamic-batching-max-requests 128 \ - --inference-dynamic-batching-num-cuda-graphs 2 \ - --decode-only-cuda-graphs \ + --inference-dynamic-batching-paused-buffer-size-gb 5 \ + --inference-dynamic-batching-buffer-size-gb 5 \ + --inference-dynamic-batching-unified-memory-level 1 \ + --rl-training-cuda-graphs \ + --empty-unused-memory-level 0 \ + --rl-parallel-generation-tasks 128 \ + --inference-dynamic-batching-cuda-graph-mixed-prefill-count 0 \ --cuda-graph-impl local \ --cuda-graph-scope full \ --use-checkpoint-args \ @@ -118,4 +105,13 @@ MODEL_OPTIONS="\ --lr-warmup-samples 640 \ --lr-warmup-init 0.3e-7 \ --no-load-optim \ - --no-load-rng " + --no-load-rng \ + --moe-permute-fusion \ + --eval-interval 1000 \ + --timing-log-level 2 \ + " + # --inference-dynamic-batching-max-tokens 8192 \ +# --rl-training-cuda-graphs \ + # --rl-training-cuda-graphs \ +# --empty-unused-memory-level 0 \ # try with the default value (=2) + # --inference-logging-step-interval 100 \ diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 7dec7a14bea..ea9e474fde2 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -25,6 +25,7 @@ MaxSequenceLengthOverflowError, TokenOverflowError, ) +from megatron.core.inference.inference_flops import InferenceFLOPsCalculator from megatron.core.inference.data_parallel_inference_coordinator import ( DataParallelInferenceCoordinator, ) @@ -188,6 +189,21 @@ def __init__(self, controller: TextGenerationController, context: DynamicInferen ) self.cuda_graph_impl = model_config.cuda_graph_impl self.cuda_graph_scope = model_config.cuda_graph_scope + + # Initialize inference FLOPs calculator and GPU peak for MFU reporting. + self.flops_calculator = None + self.gpu_peak_tflops = 0.0 + self.cumulative_inference_flops = 0.0 + self.cumulative_inference_time = 0.0 + try: + from megatron.training.global_vars import get_args + from megatron.training.gpu_peak_flops import get_gpu_peak_tflops + args = get_args() + self.flops_calculator = InferenceFLOPsCalculator.from_args(args) + self.gpu_peak_tflops = get_gpu_peak_tflops() + except Exception as e: + logging.warning(f"Could not initialize inference FLOPs calculator: {e}") + # Initialize engine. self.reset() @@ -1487,6 +1503,33 @@ async def async_bookkeep( self.socket_for_receiving_requests.send(payload) range_pop() + # Compute inference FLOPs for this step. + step_flops_info = None + if self.flops_calculator is not None: + batch_dims = self.context.batch_dimensions + decode_tokens = batch_dims.decode_req_count if batch_dims else 0 + prefill_reqs = batch_dims.prefill_req_count if batch_dims else 0 + total_tokens = batch_dims.token_count if batch_dims else 0 + prefill_tokens = total_tokens - decode_tokens + + step_flops_info = self.flops_calculator.compute_step_flops( + decode_tokens=decode_tokens, + prefill_tokens=prefill_tokens, + total_tokens=total_tokens, + active_blocks=context_state["total_active_used_blocks"], + active_reqs=context_state["total_request_count"] - context_state["paused_request_count"], + num_prefill_reqs=prefill_reqs, + ) + self.cumulative_inference_flops += step_flops_info['total_flops'] + self.cumulative_inference_time += step_time + try: + from megatron.training.mfu_tracker import get_mfu_tracker + get_mfu_tracker().add_inference_flops( + step_flops_info['total_flops'], step_time, tokens=total_tokens + ) + except Exception: + pass + # Log KV cache utilization stats to W&B if context_state["kv_stats"] is not None: # Prepare metrics dictionary with all stats @@ -1499,6 +1542,29 @@ async def async_bookkeep( 'inference/waiting_queue_len': int(len(self.waiting_request_ids)), 'inference/total_requests_dict_size': int(len(self.requests)), } + + batch_dims = self.context.batch_dimensions + total_tokens = batch_dims.token_count if batch_dims else 0 + if step_time > 0 and total_tokens > 0: + metrics['inference/tokens_per_sec_per_gpu'] = float(total_tokens / step_time) + + if step_flops_info is not None: + step_tflops = step_flops_info['total_flops'] / 1e12 + step_throughput = step_tflops / step_time if step_time > 0 else 0 + metrics['inference/step_flops_tflop'] = float(step_tflops) + metrics['inference/throughput_tflops_per_gpu'] = float(step_throughput) + metrics['inference/t_avg'] = float(step_flops_info['t_avg']) + metrics['inference/cumulative_flops_tflop'] = float(self.cumulative_inference_flops / 1e12) + if self.gpu_peak_tflops > 0: + mfu = step_throughput / self.gpu_peak_tflops * 100.0 + cumulative_throughput = ( + (self.cumulative_inference_flops / 1e12) / self.cumulative_inference_time + if self.cumulative_inference_time > 0 else 0 + ) + cumulative_mfu = cumulative_throughput / self.gpu_peak_tflops * 100.0 + metrics['inference/mfu_percent'] = float(mfu) + metrics['inference/cumulative_mfu_percent'] = float(cumulative_mfu) + # Add KV stats with inference/ prefix # Convert utilization metrics from 0-1 range to 0-100 percentage range for better visualization for key, value in context_state["kv_stats"].items(): @@ -1557,6 +1623,18 @@ async def async_bookkeep( mem["reserved_bytes.all.current"] / (1024**3), ) ) + batch_dims = self.context.batch_dimensions + total_tokens = batch_dims.token_count if batch_dims else 0 + if step_time > 0 and total_tokens > 0: + toks_per_sec_per_gpu = total_tokens / step_time + output_str += f" toks/s/GPU: {toks_per_sec_per_gpu:.0f}," + if step_flops_info is not None: + step_tflops = step_flops_info['total_flops'] / 1e12 + step_throughput = step_tflops / step_time if step_time > 0 else 0 + output_str += f" {step_throughput:.1f} TFLOP/s/GPU" + if self.gpu_peak_tflops > 0: + mfu = step_throughput / self.gpu_peak_tflops * 100.0 + output_str += f", MFU: {mfu:.1f}%" if context_state["is_decode_only"]: output_str = f"\033[94m{output_str}\033[0m" logging.info(output_str) diff --git a/megatron/rl/rl_utils.py b/megatron/rl/rl_utils.py index 3d818b5b0cd..1e8a423c098 100644 --- a/megatron/rl/rl_utils.py +++ b/megatron/rl/rl_utils.py @@ -487,7 +487,7 @@ def get_environment_rollouts( nvtx_range = get_nvtx_range() if args.rl_offload_optimizer_during_inference: - with nvtx_range("offload-optimizer-state-and-grad-buffers-during-inference"): + with nvtx_range("offload-optimizer-before-inference", time=True): if not args.rl_training_cuda_graphs: model[0].offload_grad_buffers() else: @@ -496,6 +496,7 @@ def get_environment_rollouts( ) optimizer.offload_to_cpu() + # If we have separate training and inference models we to refit weights from the training model to the inference model. has_separate_inference_model = inference_model is not None if has_separate_inference_model: @@ -519,7 +520,7 @@ def get_environment_rollouts( pg_size = get_pg_size(inference_pg_collection.ep) assert (n_prompts % pg_size == 0), f"{n_prompts=} must be divisible by {pg_size=}" - with nvtx_range("rollout-collection"): + with nvtx_range("rollout-collection", time=True): loop = get_asyncio_loop() with megatron_rl_inference_mode( inference_model, @@ -530,7 +531,7 @@ def get_environment_rollouts( increment_staleness_on_suspend=True, ) as inference_interface: - with nvtx_range("inference-setup"): + with nvtx_range("inference-setup", time=True): # Asyncronously run inference and rollout collection rollout_generator = get_rollout_generator( args, inference_interface, n_prompts, samples_per_group @@ -538,7 +539,7 @@ def get_environment_rollouts( # NOTE(jbarker): we need to double check this when using PP>1 rank = torch.distributed.get_rank() - with nvtx_range("collect-rollouts"): + with nvtx_range("collect-rollouts", time=True): if rank == 0: log_single_rank( logger, @@ -563,14 +564,14 @@ def get_environment_rollouts( # Just set up space to collect the rollouts rollouts = [[None for _ in range(samples_per_group)] for _ in range(n_prompts)] - with nvtx_range("sync-rollouts"): + with nvtx_range("sync-rollouts", time=True): # Wait for Rollouts to be collected # TODO(jbarker): double check why this isn't causing rank 0 memory allocations torch.distributed.broadcast_object_list(rollouts, src=0) logger.debug(f"Got rollouts on rank {rank}") if args.rl_offload_optimizer_during_inference: - with nvtx_range("restore-optimizer-state-and-grad-buffers-after-inference"): + with nvtx_range("onload-optimizer-after-inference", time=True): model[0].restore_grad_buffers() optimizer.restore_from_cpu() @@ -1282,8 +1283,8 @@ def prepare_data_for_update( model = model[0] dtype = torch.bfloat16 if args.bf16 else (torch.float16 if args.fp16 else torch.float32) - with nvtx_range("prepare-data-for-update"): - with nvtx_range("compute-group-stats"): + with nvtx_range("prepare-data-for-update", time=True): + with nvtx_range("compute-group-stats", time=True): group_stats = compute_group_stats(rollouts, tokenizer, args.seq_length) # TODO(vitalyk): why do we need global_advantages here? go inside packing advantages = global_advantages = torch.tensor(group_stats.advantages, dtype=dtype).cuda() @@ -1323,7 +1324,7 @@ def prepare_data_for_update( # First we calculate them on a global level and then we split and recalculate on a local level. # Sequence packing and reporting needs it global but non-packing wants it local. - with nvtx_range("prepare_trajectories"): + with nvtx_range("prepare-trajectories", time=True): trajs, generation_masks, inference_logprobs = prepare_trajectories( rollouts, tokenizer, args.seq_length, sequence_packing, args.rl_skip_bos_token ) @@ -1331,7 +1332,7 @@ def prepare_data_for_update( packing_context = None # Build trajectories based on sequence packing or standard processing if sequence_packing: - with nvtx_range("sequence_packing", time=True): + with nvtx_range("sequence-packing", time=True): runtime_state.packing_context = packing_context = pack_all_trajectories( trajs, generation_masks, @@ -1351,7 +1352,7 @@ def prepare_data_for_update( logprobs_batch_size = 1 else: # Always compute standard masks for the original data (we'll need them later) - with nvtx_range("get_ltor_masks_and_position_ids"): + with nvtx_range("get-ltor-masks-and-position-ids", time=True): _, original_loss_mask, original_position_ids = get_ltor_masks_and_position_ids( trajs, tokenizer.eod, @@ -1370,7 +1371,7 @@ def prepare_data_for_update( ) logprobs_batch_size = args.micro_batch_size - with torch.no_grad(), nvtx_range("compute_logprobs", time=True): + with torch.no_grad(), nvtx_range("compute-logprobs", time=True): # Before we can update the model, we need to get the logprobs for the \pi_{old} model. # Wrap forward_backward_func for Full iteration CUDA graph @@ -1388,7 +1389,7 @@ def prepare_data_for_update( pg_collection = get_attr_wrapped_model(model, "pg_collection") pp_group = pg_collection.pp - with torch.no_grad(), nvtx_range("compute_old_logprobs", time=True): + with torch.no_grad(), nvtx_range("compute-old-logprobs", time=True): old_logprobs = compute_logprobs_batch( model=model, data_loader=data_loader, @@ -1403,7 +1404,7 @@ def prepare_data_for_update( is_correction=args.rl_inference_logprobs_is_correction, ) - with torch.no_grad(), nvtx_range("compute_ref_logprobs", time=True): + with torch.no_grad(), nvtx_range("compute-ref-logprobs", time=True): # We need to load the ref model state dict and compute the logprobs for the ref model cur_st_dict = { k: (v.cpu() if v is not None else v) for k, v in model.state_dict().items() @@ -1432,7 +1433,7 @@ def prepare_data_for_update( if sequence_packing: - with nvtx_range("pack_logprobs", time=True): + with nvtx_range("pack-logprobs", time=True): # Store logprobs on gpu in packing context # Since PackingContext is a dataclass, we add these as new attributes packing_context.old_logprobs = old_logprobs.cuda() @@ -1460,7 +1461,7 @@ def prepare_data_for_update( packing_context.packed_inference_logprobs = packed_inference_logprobs.cuda() # Only mark as having inference logprobs for IS correction if enabled packing_context.has_inference_logprobs = args.rl_inference_logprobs_is_correction - with nvtx_range("create_dataloader"): + with nvtx_range("create-dataloader", time=True): # @vitalyk: This function also reconfigures the data loader to count the # global_batch_size in the bins frame of reference. # I think it will be a better design if we split the data loader creating and logic @@ -1477,7 +1478,7 @@ def prepare_data_for_update( ) loader = get_microbatch_dataloader(len(packing_context.packed_trajs), args.micro_batch_size) else: - with nvtx_range("align_inference_logprobs", time=True): + with nvtx_range("align-inference-logprobs", time=True): if inference_logprobs is not None: inference_logprobs = align_unpacked_inference_logprobs( inference_logprobs=inference_logprobs, @@ -1490,7 +1491,7 @@ def prepare_data_for_update( # Nullify logprobs if not used in IS correction, if not args.rl_inference_logprobs_is_correction: inference_logprobs = None - with nvtx_range("create_dataloader"): + with nvtx_range("create-dataloader", time=True): # Because of multiturn, our batch sizes for non-sequence packed trajectories are not fixed anymore. # As in sequence packing above, we need to reconfigure it too. runtime_state.packing_context = None @@ -1851,7 +1852,7 @@ def megatron_rl_inference_mode( with torch.no_grad(): if offload_optimizer_during_inference: - with nvtx_range("offload-optimizer-state-and-grad-buffers-before-inference"): + with nvtx_range("offload-optimizer-before-inference", time=True): if not args.rl_training_cuda_graphs: # Offload grad buffers from the training model (if separate inference model is used) # or from the inference model (if they're the same model) @@ -1872,7 +1873,7 @@ def megatron_rl_inference_mode( logger.debug(f"[{dist.get_rank()}] Entered inference mode") yield inference_interface - with nvtx_range("suspend-engine"): + with nvtx_range("suspend-engine", time=True): loop.run_until_complete(inference_interface.suspend()) if increment_staleness_on_suspend: inference_interface.increment_staleness() @@ -1886,7 +1887,7 @@ def megatron_rl_inference_mode( _maybe_prefetch_separate_inference_model_weights(model_core, to_cpu=True) if offload_optimizer_during_inference: - with nvtx_range("onload-optimizer-state-and-grad-buffers-after-inference"): + with nvtx_range("onload-optimizer-after-inference", time=True): # Restore grad buffers to the training model (if separate inference model is used) # or to the inference model (if they're the same model) model_for_grad_offload = training_model if training_model is not None else model diff --git a/megatron/rl/sequence_packing_utils.py b/megatron/rl/sequence_packing_utils.py index b641ecd85d0..ddbcdeffa2f 100644 --- a/megatron/rl/sequence_packing_utils.py +++ b/megatron/rl/sequence_packing_utils.py @@ -51,7 +51,6 @@ class PackingContext: original_trajs: All trajectories before packing packed_trajs: Packed trajectories tensor [num_bins, bin_size] packed_position_ids: Position IDs for packed sequences [num_bins, bin_size] - packed_attention_mask: Attention mask for packed sequences [num_bins, 1, bin_size, bin_size] packed_loss_mask: Loss mask for packed sequences [num_bins, bin_size] original_inference_logprobs: Inference logprobs for all sequences before packing (optional) bin_advantages: List of advantage tensors for each bin @@ -64,7 +63,6 @@ class PackingContext: original_trajs: torch.Tensor packed_trajs: torch.Tensor packed_position_ids: torch.Tensor - packed_attention_mask: torch.Tensor packed_loss_mask: torch.Tensor original_inference_logprobs: Optional[torch.Tensor] = None bin_advantages: List[torch.Tensor] = field(default_factory=list) @@ -314,9 +312,8 @@ def create_empty_bins( packed_trajs : torch.Tensor, packed_position_ids : torch.Tensor, packed_loss_mask : torch.Tensor, - packed_attention_mask : torch.Tensor, tokenizer, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[Dict[str, Any]]]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[Dict[str, Any]]]: """Create empty bins for padding to ensure all ranks have the same number of bins. Args: @@ -325,11 +322,10 @@ def create_empty_bins( packed_trajs: Packed trajectories tensor (for dtype/device reference) packed_position_ids: Packed position IDs tensor (for dtype/device reference) packed_loss_mask: Packed loss mask tensor (for dtype/device reference) - packed_attention_mask: Packed attention mask tensor (can be None) tokenizer: Tokenizer for pad token Returns: - Tuple of (empty_trajs, empty_position_ids, empty_loss_mask, empty_attention_mask, empty_packing_info_entries) + Tuple of (empty_trajs, empty_position_ids, empty_loss_mask, empty_packing_info_entries) """ device = packed_trajs.device @@ -337,7 +333,6 @@ def create_empty_bins( empty_bins = [] empty_position_ids_list = [] empty_loss_mask_list = [] - empty_attention_mask_list = [] empty_packing_info_entries = [] for i in range(num_empty_bins): @@ -355,14 +350,6 @@ def create_empty_bins( empty_loss = torch.zeros(1, bin_size, dtype=packed_loss_mask.dtype, device=device) empty_loss_mask_list.append(empty_loss) - # Zero attention mask if needed - if packed_attention_mask is not None: - # Attention mask is always 4D: [num_bins, 1, bin_size, bin_size] - empty_attn = torch.zeros( - 1, 1, bin_size, bin_size, dtype=packed_attention_mask.dtype, device=device - ) - empty_attention_mask_list.append(empty_attn) - # Empty packing info entries empty_packing_info_entries.append( { @@ -376,22 +363,15 @@ def create_empty_bins( empty_trajs = torch.cat(empty_bins, dim=0) empty_position_ids = torch.cat(empty_position_ids_list, dim=0) empty_loss_mask = torch.cat(empty_loss_mask_list, dim=0) - empty_attention_mask = ( - torch.cat(empty_attention_mask_list, dim=0) - if packed_attention_mask is not None - else None - ) else: empty_trajs = None empty_position_ids = None empty_loss_mask = None - empty_attention_mask = None return ( empty_trajs, empty_position_ids, empty_loss_mask, - empty_attention_mask, empty_packing_info_entries, ) @@ -706,9 +686,6 @@ def pack_sequences( position_ids = torch.zeros( (num_bins, self.bin_size), dtype=torch.long, device=device, requires_grad=False ) - attention_mask = torch.zeros( - (num_bins, 1, self.bin_size, self.bin_size), dtype=torch.bool, device=device - ) loss_mask = torch.zeros((num_bins, self.bin_size), dtype=torch.float, device=device) # Track packing information for unpacking later @@ -739,12 +716,6 @@ def pack_sequences( len(seq), device=device, requires_grad=False ) - # Causal attention mask within each sequence - seq_len = end - start - attention_mask[bin_idx, 0, start:end, start:end] = torch.tril( - torch.ones(seq_len, seq_len, dtype=torch.bool, device=device) - ) - # Loss mask (excluding padding) loss_mask[bin_idx, start:end] = 1.0 @@ -759,12 +730,6 @@ def pack_sequences( seq_starts.append(current_pos) seq_starts_dict[bin_idx] = seq_starts - # Note: We'll store the actual padded length later when we know it - # (it depends on the original trajectories passed to pack_sequences) - - # Invert attention mask, before inversion: (True = attend, False = mask) - attention_mask.bitwise_not_() - # Create the PackingInfo dataclass packing_info = PackingInfo( bin_seq_indices=bin_seq_indices, @@ -793,15 +758,14 @@ def pack_sequences( ) log_single_rank(logger, logging.DEBUG, f" - First 20 bins: {seq_per_bin[:20]}") - return packed_sequences, position_ids, attention_mask, loss_mask, packing_info + return packed_sequences, position_ids, loss_mask, packing_info def distribute_packed_bins( packed_trajs: torch.Tensor, packed_position_ids: torch.Tensor, - packed_attention_mask: torch.Tensor, packed_loss_mask: torch.Tensor, packing_info: PackingInfo, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, PackingInfo]: +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, PackingInfo]: """Distribute packed bins across the data parallel ranks.""" rank = mpu.get_data_parallel_rank() world_size = mpu.get_data_parallel_world_size() @@ -838,7 +802,6 @@ def distribute_packed_bins( # Extract this rank's bins my_packed_trajs = [] my_packed_position_ids = [] - my_packed_attention_mask = [] my_packed_loss_mask = [] my_bin_seq_indices = [] my_seq_starts = {} @@ -848,8 +811,6 @@ def distribute_packed_bins( for new_idx, old_idx in enumerate(my_bin_indices): my_packed_trajs.append(packed_trajs[old_idx]) my_packed_position_ids.append(packed_position_ids[old_idx]) - if packed_attention_mask is not None: - my_packed_attention_mask.append(packed_attention_mask[old_idx]) my_packed_loss_mask.append(packed_loss_mask[old_idx]) my_bin_seq_indices.append(packing_info.bin_seq_indices[old_idx]) my_seq_starts[new_idx] = packing_info.seq_starts[old_idx] @@ -875,9 +836,6 @@ def distribute_packed_bins( device=packed_position_ids.device, ) ) - packed_attention_mask = ( - torch.stack(my_packed_attention_mask) if my_packed_attention_mask else None - ) packed_loss_mask = ( torch.stack(my_packed_loss_mask) if my_packed_loss_mask @@ -935,7 +893,6 @@ def distribute_packed_bins( empty_trajs, empty_position_ids, empty_loss_mask, - empty_attention_mask, empty_packing_entries, ) = create_empty_bins( num_empty_bins, @@ -943,7 +900,6 @@ def distribute_packed_bins( packed_trajs, packed_position_ids, packed_loss_mask, - packed_attention_mask, tokenizer, ) @@ -954,18 +910,13 @@ def distribute_packed_bins( ) packed_loss_mask = torch.cat([packed_loss_mask, empty_loss_mask], dim=0) - if packed_attention_mask is not None and empty_attention_mask is not None: - packed_attention_mask = torch.cat( - [packed_attention_mask, empty_attention_mask], dim=0 - ) - # Add empty entries to packing_info for i, entry in enumerate(empty_packing_entries): bin_idx = current_bins + i new_packing_info.bin_seq_indices.append(entry['bin_seq_indices']) new_packing_info.seq_starts[bin_idx] = entry['seq_starts'] - return packed_trajs, packed_position_ids, packed_attention_mask, packed_loss_mask, new_packing_info + return packed_trajs, packed_position_ids, packed_loss_mask, new_packing_info def pack_all_trajectories(trajs, generation_masks, inference_logprobs, global_advantages, bin_size, max_sequences_per_bin, packing_algo): @@ -998,7 +949,6 @@ def _gather(data): ( packed_trajs, packed_position_ids, - packed_attention_mask, packed_loss_mask, packing_info, ) = packer.pack_sequences(trajs, generation_masks) @@ -1008,13 +958,11 @@ def _gather(data): ( packed_trajs, packed_position_ids, - packed_attention_mask, packed_loss_mask, packing_info, ) = distribute_packed_bins( packed_trajs, packed_position_ids, - packed_attention_mask, packed_loss_mask, packing_info, ) @@ -1051,7 +999,6 @@ def _gather(data): original_trajs=trajs, packed_trajs=packed_trajs, packed_position_ids=packed_position_ids, - packed_attention_mask=packed_attention_mask, packed_loss_mask=packed_loss_mask, original_inference_logprobs=inference_logprobs, bin_advantages=bin_advantages, diff --git a/megatron/training/training.py b/megatron/training/training.py index 8d0bfff3b3f..3a005347f21 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -1923,14 +1923,17 @@ def training_log( ]) # Add timers from RL loop if needed. if getattr(args, 'perform_rl_step', False): - timers_to_log.extend(['rollout-collection', 'inference-setup', 'collect-rollouts', 'postrollout-gc-collect', - 'sync-rollouts', 'prepare-data-for-update', 'compute-group-stats', - 'prepare-trajectories', 'get-ltor-masks-and-position-ids', 'create-logprobs-dataloader', - 'compute-logprobs', 'compute-ref-logprobs', 'compute-prob-stats', - 'prepare-advantages', 'create-dataloader', 'log-wandb-tb', - 'offload-optimizer-before-inference', 'onload-kv-cache-before-inference', - 'wait-for-decode-only', 'build-cuda-graphs', 'suspend-engine', - 'offload-kv-cache-after-inference', 'onload-optimizer-after-inference']) + timers_to_log.extend([ + 'rollout-collection', 'inference-setup', 'collect-rollouts', + 'sync-rollouts', 'prepare-data-for-update', 'compute-group-stats', + 'prepare-trajectories', 'get-ltor-masks-and-position-ids', + 'sequence-packing', + 'compute-logprobs', 'compute-old-logprobs', 'compute-ref-logprobs', + 'pack-logprobs', 'align-inference-logprobs', + 'create-dataloader', 'log-wandb-tb', + 'offload-optimizer-before-inference', 'onload-optimizer-after-inference', + 'suspend-engine', + ]) # Calculate batch size. batch_size = args.micro_batch_size * args.data_parallel_size * get_num_microbatches() @@ -2114,11 +2117,126 @@ def training_log( ) if args.log_throughput: log_string += f' throughput per GPU (TFLOP/s/GPU): {throughput:.1f} |' + + tokens_this_iter = batch_size * args.seq_length + + # Compute and log MFU (Model FLOPs Utilization) + if not hasattr(args, '_gpu_peak_tflops'): + try: + from megatron.training.gpu_peak_flops import get_gpu_peak_tflops + args._gpu_peak_tflops = get_gpu_peak_tflops() + except Exception: + args._gpu_peak_tflops = 0.0 + + training_mfu = 0.0 + inference_mfu = 0.0 + total_mfu = 0.0 + has_tracker = False + iter_inference_tokens = 0 + iter_inference_time = 0.0 + iter_logprob_time = 0.0 + training_only_time = elapsed_time_per_iteration + training_flops = 0.0 + iter_inference_flops = 0.0 + effective_tokens = tokens_this_iter + + # Read compute-logprobs time from the existing Megatron timer + try: + iter_logprob_time = ( + timers('compute-logprobs').elapsed(reset=False, barrier=False) + / total_iterations + ) + except Exception: + pass + + if args._gpu_peak_tflops > 0: + try: + from megatron.training.mfu_tracker import get_mfu_tracker + tracker = get_mfu_tracker() + training_flops = num_floating_point_operations(args, batch_size) + iter_inference_time = tracker.get_iter_inference_time() + iter_inference_flops = tracker.get_iter_inference_flops() + iter_inference_tokens = tracker.get_iter_inference_tokens() + real_training_tokens = tracker.get_iter_real_training_tokens() + if real_training_tokens > 0: + effective_tokens = real_training_tokens + training_only_time = max( + elapsed_time_per_iteration - iter_inference_time - iter_logprob_time, 1e-6 + ) + tracker.add_training_flops( + training_flops, training_only_time, tokens=effective_tokens + ) + tracker.reset_iter() + has_tracker = True + except Exception: + has_tracker = False + + training_mfu = throughput / args._gpu_peak_tflops * 100.0 + + ws = args.world_size + + # Per-iteration toks/s/GPU breakdown (uses real tokens when seq packing is active) + train_tps = effective_tokens / (training_only_time * ws) if training_only_time > 0 else 0.0 + inf_tps = iter_inference_tokens / (iter_inference_time * ws) if iter_inference_time > 0 else 0.0 + total_tps = (effective_tokens + iter_inference_tokens) / (elapsed_time_per_iteration * ws) + e2e_tps = effective_tokens / (elapsed_time_per_iteration * ws) + + if has_tracker: + log_string += ( + f' toks/s/GPU: train {train_tps:.0f}' + f', infer {inf_tps:.0f}' + f', total {total_tps:.0f}' + f', e2e {e2e_tps:.0f} |' + ) + + # Per-iteration MFU breakdown + if args._gpu_peak_tflops > 0: + log_string += f' MFU: train {training_mfu:.1f}%' + if has_tracker: + if iter_inference_time > 0: + inference_mfu = ( + iter_inference_flops / (iter_inference_time * ws) + / 1e12 / args._gpu_peak_tflops * 100.0 + ) + total_mfu = ( + (training_flops + iter_inference_flops) + / (elapsed_time_per_iteration * ws) + / 1e12 / args._gpu_peak_tflops * 100.0 + ) + log_string += ( + f', infer {inference_mfu:.1f}%' + f', total {total_mfu:.1f}%' + ) + log_string += ' |' + if args.log_timers_to_tensorboard: if writer: writer.add_scalar('throughput', throughput, iteration) + writer.add_scalar('toks_per_sec_per_gpu/e2e', e2e_tps, iteration) + if has_tracker: + writer.add_scalar('toks_per_sec_per_gpu/training', train_tps, iteration) + writer.add_scalar('toks_per_sec_per_gpu/inference', inf_tps, iteration) + writer.add_scalar('toks_per_sec_per_gpu/total', total_tps, iteration) + if args._gpu_peak_tflops > 0: + writer.add_scalar('mfu/training_percent', training_mfu, iteration) + if has_tracker: + writer.add_scalar('mfu/inference_percent', inference_mfu, iteration) + writer.add_scalar('mfu/total_percent', total_mfu, iteration) if wandb_writer: - wandb_writer.log({'throughput': throughput}, iteration) + wandb_log = { + 'throughput': throughput, + 'toks_per_sec_per_gpu/e2e': e2e_tps, + } + if has_tracker: + wandb_log['toks_per_sec_per_gpu/training'] = train_tps + wandb_log['toks_per_sec_per_gpu/inference'] = inf_tps + wandb_log['toks_per_sec_per_gpu/total'] = total_tps + if args._gpu_peak_tflops > 0: + wandb_log['mfu/training_percent'] = training_mfu + if has_tracker: + wandb_log['mfu/inference_percent'] = inference_mfu + wandb_log['mfu/total_percent'] = total_mfu + wandb_writer.log(wandb_log, iteration) if args.log_energy: energy = (energy_monitor.lap() / total_iterations) / args.world_size power = energy / elapsed_time_per_iteration diff --git a/tests/unit_tests/rl/test_sequence_packing_utils.py b/tests/unit_tests/rl/test_sequence_packing_utils.py index 06e63adf217..75a7981457d 100644 --- a/tests/unit_tests/rl/test_sequence_packing_utils.py +++ b/tests/unit_tests/rl/test_sequence_packing_utils.py @@ -98,13 +98,12 @@ def test_sequence_packing_basic(): rewards = torch.tensor([1.0, 2.0, 3.0, 4.0]) sequences_tensor = torch.stack(sequences) - packed_trajs, packed_position_ids, packed_attention_mask, packed_loss_mask, packing_info = ( + packed_trajs, packed_position_ids, packed_loss_mask, packing_info = ( packer.pack_sequences(sequences_tensor, generation_masks) ) assert packed_trajs is not None assert packed_position_ids is not None - assert packed_attention_mask is not None assert packed_loss_mask is not None assert packing_info is not None @@ -140,7 +139,7 @@ def test_sequence_packing_with_generation_masks(): ) padded_sequences_tensor = torch.stack(padded_sequences) - packed_trajs, packed_position_ids, packed_attention_mask, packed_loss_mask, packing_info = ( + packed_trajs, packed_position_ids, packed_loss_mask, packing_info = ( packer.pack_sequences(padded_sequences_tensor, generation_masks) ) @@ -162,16 +161,14 @@ def test_sequence_packing_empty_bins(): ) packed_position_ids = torch.tensor([[0, 1, 2, 3, 0, 0, 0, 0]]) packed_loss_mask = torch.tensor([[1, 1, 1, 1, 0, 0, 0, 0]], dtype=torch.float) - packed_attention_mask = torch.ones(1, bin_size, bin_size) - empty_trajs, empty_position_ids, empty_loss_mask, empty_attention_mask, empty_packing_info = ( + empty_trajs, empty_position_ids, empty_loss_mask, empty_packing_info = ( sequence_packing_utils.create_empty_bins( num_empty_bins=num_empty_bins, bin_size=bin_size, packed_trajs=packed_trajs, packed_position_ids=packed_position_ids, packed_loss_mask=packed_loss_mask, - packed_attention_mask=packed_attention_mask, tokenizer=tokenizer, ) ) @@ -220,7 +217,7 @@ def test_sequence_packing_integration(): ] sequences_tensor = torch.stack(sequences) - packed_trajs, packed_position_ids, packed_attention_mask, packed_loss_mask, packing_info = ( + packed_trajs, packed_position_ids, packed_loss_mask, packing_info = ( packer.pack_sequences(sequences_tensor, generation_masks) )