diff --git a/examples/mimo/data/mock.py b/examples/mimo/data/mock.py index a1eaa033175..2f02448eac4 100644 --- a/examples/mimo/data/mock.py +++ b/examples/mimo/data/mock.py @@ -125,9 +125,9 @@ def __getitem__(self, idx: int) -> Dict: "loss_mask": loss_mask, "position_ids": position_ids, "modality_inputs": { - "clip_encoder": { - "images": image, - } + "images": { + "clip_encoder": {'x': image}, + } }, } @@ -200,7 +200,7 @@ def get_mock_vlm_dataloader( dataloader = DataLoader( dataset, batch_size=batch_size, - shuffle=True, + shuffle=False, num_workers=num_workers, collate_fn=lambda batch: _collate_fn(batch), ) @@ -218,7 +218,7 @@ def _collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]: Returns: Dictionary of batched tensors """ - images = torch.stack([item["images"] for item in batch]) + images = torch.stack([item["modality_inputs"]["images"]["clip_encoder"]['x'] for item in batch]) input_ids = torch.stack([item["input_ids"] for item in batch]) labels = torch.stack([item["labels"] for item in batch]) loss_mask = torch.stack([item["loss_mask"] for item in batch]) @@ -230,8 +230,8 @@ def _collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]: "loss_mask": loss_mask, "position_ids": position_ids, "modality_inputs": { - "clip_encoder": { - "images": images, + "images": { + "clip_encoder": {'x': images}, } }, } @@ -291,6 +291,17 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): for batch in dataloader: print("\nBatch from dataloader:") - for key, tensor in batch.items(): - print(f" {key}: {tensor.shape}") + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + print(f" {key}: {value.shape}") + elif isinstance(value, dict): + print(f" {key}: (nested dict)") + for subkey, subvalue in value.items(): + if isinstance(subvalue, torch.Tensor): + print(f" {subkey}: {subvalue.shape}") + elif isinstance(subvalue, dict): + print(f" {subkey}: (nested dict)") + for subsubkey, subsubvalue in subvalue.items(): + if isinstance(subsubvalue, torch.Tensor): + print(f" {subsubkey}: {subsubvalue.shape}") break diff --git a/examples/mimo/scripts/run_mock_train.sh b/examples/mimo/scripts/run_mock_train.sh index 2ed71cd5ede..e4aeae62453 100755 --- a/examples/mimo/scripts/run_mock_train.sh +++ b/examples/mimo/scripts/run_mock_train.sh @@ -99,7 +99,7 @@ else ${TOKENIZER_ARGS[@]} \ ${GPT_MODEL_ARGS[@]}" else - torchrun ${DISTRIBUTED_ARGS[@]} examples/mimo/train.py \ + uv run python -m torch.distributed.run ${DISTRIBUTED_ARGS[@]} --log-dir logs/mimo --redirects 3 --tee "0:3" examples/mimo/train.py \ ${TRAINING_ARGS[@]} \ ${MODEL_PARALLEL_ARGS[@]} \ ${EVAL_AND_LOGGING_ARGS[@]} \ diff --git a/examples/mimo/train.py b/examples/mimo/train.py index d1674cb3520..766269208ae 100644 --- a/examples/mimo/train.py +++ b/examples/mimo/train.py @@ -229,8 +229,12 @@ def model_provider( "image_special_token_id": image_special_token_id, "audio_special_token_id": audio_special_token_id, } + elif runtime_args.model_provider == "mock": + kwargs = { + "special_token_id": image_special_token_id, + } else: - raise ValueError(f"Unknown model provider: {runtime_args.model_provider}. Must be one of ['llava_vlm', 'llava_avlm', 'mock]") + raise ValueError(f"Unknown model provider: {runtime_args.model_provider}. Must be one of ['llava_vlm', 'llava_avlm', 'mock']") return builder_fn( pre_process, diff --git a/megatron/core/models/mimo/model/base.py b/megatron/core/models/mimo/model/base.py index 2f136a98466..efe56a9ed62 100644 --- a/megatron/core/models/mimo/model/base.py +++ b/megatron/core/models/mimo/model/base.py @@ -2,17 +2,41 @@ import logging import warnings -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, List +import dataclasses import torch +import torch.distributed as dist from megatron.core.models.mimo.config import MimoModelConfig +from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage from megatron.core.transformer import MegatronModule from megatron.core.transformer.spec_utils import build_module +from megatron.core.utils import unwrap_model logger = logging.getLogger(__name__) +def _find_pg_collection_in_submodules(submodules) -> Optional[object]: + """Recursively search for pg_collection in nested submodules.""" + if isinstance(submodules, dict): + for nested_spec in submodules.values(): + if isinstance(nested_spec, dict): + # Handle {"clip_encoder": spec} + for spec in nested_spec.values(): + if hasattr(spec, 'params') and spec.params and 'pg_collection' in spec.params: + return spec.params['pg_collection'] + elif isinstance(nested_spec, list): + # Handle [spec1, spec2, ...] + for spec in nested_spec: + if hasattr(spec, 'params') and spec.params and 'pg_collection' in spec.params: + return spec.params['pg_collection'] + elif hasattr(nested_spec, 'params') and nested_spec.params and 'pg_collection' in nested_spec.params: + # Handle direct ModuleSpec + return nested_spec.params['pg_collection'] + return None + + class MimoModel(MegatronModule): """Multimodal In/Out Model supporting arbitrary combinations of modalities. @@ -62,7 +86,12 @@ def __init__(self, mimo_config: MimoModelConfig) -> None: self.modality_submodules = torch.nn.ModuleDict() self._initialize_submodules() self._initialize_language_model() - + + # Store input tensors for pipeline parallelism + # These will be set by set_input_tensor() and used in forward() + self.modality_input_tensors = {} + self.language_model_input_tensor = None + def align_embeddings_by_token_positions( self, modality_embeddings: Dict[str, torch.Tensor], # [num_embeddings, hidden_dim] @@ -134,6 +163,65 @@ def align_embeddings_by_token_positions( 0, 1 ).contiguous() # Shape: [seq_length, batch_size, hidden_dim] + + def _should_initialize_module(self, module_spec) -> bool: + """Determine if the current rank should initialize a module based on its process groups.""" + params = module_spec.params or {} + + # Find pg_collection in params or nested submodules + pg_collection = params.get('pg_collection') or ( + _find_pg_collection_in_submodules(module_spec.submodules) + if hasattr(module_spec, 'submodules') and module_spec.submodules else None + ) + + # No pg_collection means initialize on all ranks + if not pg_collection: + return True + + # Check if current rank is in any process group + current_rank = dist.get_rank() + for field in dataclasses.fields(pg_collection): + pg = getattr(pg_collection, field.name, None) + if pg and current_rank in dist.get_process_group_ranks(pg): + return True + + return False + + def _get_submodule_pp_group(self, module_spec): + """Get the pipeline parallel process group for a submodule.""" + params = module_spec.params or {} + + # Find pg_collection in params or nested submodules + pg_collection = params.get('pg_collection') or ( + _find_pg_collection_in_submodules(module_spec.submodules) + if hasattr(module_spec, 'submodules') and module_spec.submodules else None + ) + + if pg_collection and hasattr(pg_collection, 'pp'): + return pg_collection.pp + + return None + + def _is_submodule_pp_first_stage(self, modality_name: str) -> bool: + """Check if the current rank is at the first PP stage for a given submodule. + + Args: + modality_name: Name of the modality submodule + + Returns: + True if at first PP stage or if no PP is configured, False otherwise + """ + if modality_name not in self.mimo_config.modality_submodules_spec: + return True # Default to True if submodule not found + + submodule_spec = self.mimo_config.modality_submodules_spec[modality_name] + pp_group = self._get_submodule_pp_group(submodule_spec) + + if pp_group is None: + return True # No PP configured, treat as first stage + + return is_pp_first_stage(pp_group) + def _initialize_submodules(self) -> None: """Initialize modality submodules from the ModuleSpec configurations. @@ -142,9 +230,15 @@ def _initialize_submodules(self) -> None: """ for modality_name, submodule_spec in self.mimo_config.modality_submodules_spec.items(): + # Check if current rank should initialize this submodule + if not self._should_initialize_module(submodule_spec): + logger.debug(f"Rank {dist.get_rank()} skipping {modality_name} submodule initialization") + self.modality_submodules[modality_name] = None + continue + # Get the submodule class submodule_class = submodule_spec.module - logger.debug(f"Building {modality_name} submodule using {submodule_class.__name__}") + logger.debug(f"[Rank - {dist.get_rank()}] Building {modality_name} submodule using {submodule_class.__name__}") # Use from_spec to instantiate the submodule submodule = submodule_class.from_spec(submodule_spec) @@ -152,31 +246,63 @@ def _initialize_submodules(self) -> None: def _initialize_language_model(self) -> None: """Initialize the language model.""" + # Check if current rank should initialize the language model + if not self._should_initialize_module(self.mimo_config.language_model_spec): + logger.debug(f"Rank {dist.get_rank()} skipping language model initialization") + self.language_model = None + return + logger.debug( - f"Building language model using {self.mimo_config.language_model_spec.module.__name__}" + f"[Rank - {dist.get_rank()} Building language model using {self.mimo_config.language_model_spec.module.__name__}" ) self.language_model = build_module(self.mimo_config.language_model_spec) - def set_input_tensor(self, input_tensor): + def set_input_tensor(self, input_tensor: List[Dict[str, torch.Tensor]]): """Set input tensor for pipeline parallelism. - - This method is required by Megatron's pipeline parallel mechanism. - It passes the output tensor from the previous stage as input to this stage. - - Args: - input_tensor: Tensor or list of tensors passed between pipeline stages - - Returns: - None """ - # Handle case where input_tensor might be a list or a single tensor - if isinstance(input_tensor, list): - # For simplicity, just use the first tensor - input_tensor = input_tensor[0] - - # Pass the input tensor to the language model if it has a set_input_tensor method - if hasattr(self.language_model, 'set_input_tensor'): - self.language_model.set_input_tensor(input_tensor) + current_rank = dist.get_rank() + + assert isinstance(input_tensor, list), "Input tensor must be a list" + assert len(input_tensor) == 1, "Input tensor must be a list of length 1" + assert isinstance(input_tensor[0], dict), "Input tensor[0] must be a dictionary" + + input_dict = input_tensor[0] + + logger.debug( + f"[Rank {current_rank}][MimoModel][set_input_tensor] Received dict with keys: {list(input_dict.keys())}" + ) + + # Process each modality submodule + for modality_name, submodule in self.modality_submodules.items(): + if modality_name in input_dict: + tensor = input_dict[modality_name] + if isinstance(tensor, list): + tensor = tensor[0] + + self.modality_input_tensors[modality_name] = tensor + logger.debug( + f"[Rank {current_rank}][MimoModel][set_input_tensor][{modality_name}] " + f"Stored input tensor with shape: {tensor.shape}" + ) + + # If the submodule has its own set_input_tensor method, call it + if hasattr(submodule, 'set_input_tensor'): + submodule.set_input_tensor(tensor) + + if self.language_model is not None and 'language_module' in input_dict: + lm_tensor = input_dict['language_module'] + if isinstance(lm_tensor, list): + lm_tensor = lm_tensor[0] + + self.language_model_input_tensor = lm_tensor + logger.debug( + f"[Rank {current_rank}][MimoModel][set_input_tensor][language_module] " + f"Stored LM intermediate tensor with shape: {lm_tensor.shape}" + ) + + # Pass to language model's set_input_tensor if it exists + if hasattr(unwrap_model(self.language_model), 'set_input_tensor'): + unwrap_model(self.language_model).set_input_tensor(lm_tensor) def get_text_embeddings( self, input_ids: torch.Tensor, position_ids: torch.Tensor, special_token_ids: Dict[str, int] @@ -193,6 +319,9 @@ def get_text_embeddings( Returns: torch.Tensor: Embeddings for text tokens, shape [num_text_tokens, hidden_dim]. """ + if self.language_model is None: + raise RuntimeError(f"Language model not initialized on rank {dist.get_rank()}") + text_mask = torch.ones_like(input_ids, dtype=torch.bool) # [b, s] for special_token_id in special_token_ids.values(): text_mask &= input_ids != special_token_id @@ -204,7 +333,8 @@ def get_text_embeddings( position_ids[batch_idx, seq_idx].unsqueeze(0) if position_ids is not None else None ) - text_embeddings = self.language_model.embedding( + + text_embeddings = unwrap_model(self.language_model).embedding( input_ids=input_ids_text, position_ids=position_ids_text ).squeeze( 1 @@ -238,53 +368,182 @@ def forward( "whisper_encoder": {"input_features": whisper_features} } } + Note: This is only required at the first PP stage of each modality. + At intermediate PP stages, the input comes from set_input_tensor. Returns: tuple: Tuple containing model outputs and loss mask """ + current_rank = dist.get_rank() + # 1. Process each modality to get embeddings modality_embeddings = {} for modality_name, submodule in self.modality_submodules.items(): - # Process the modality through its submodule - if ( - modality_inputs - and modality_name in modality_inputs - and modality_inputs[modality_name] is not None - ): - logger.debug(f"Processing {modality_name} modality") - # Get embeddings for this modality - embeddings = submodule.forward(encoder_inputs=modality_inputs[modality_name]) - if embeddings is not None: - # All embeddings are now in the format [num_tokens, hidden_dim] - modality_embeddings[modality_name] = embeddings + # Skip if submodule is None (not initialized on this rank) + if submodule is None: + continue + + # Determine input source based on PP stage + is_first_stage = self._is_submodule_pp_first_stage(modality_name) + + if is_first_stage: + # First PP stage: use modality_inputs provided to forward + if ( + modality_inputs + and modality_name in modality_inputs + and modality_inputs[modality_name] is not None + ): + logger.debug( + f"[Rank {current_rank}][MimoModel][forward][{modality_name}] " + f"First PP stage, using modality_inputs" + ) + # Get embeddings for this modality + embeddings = submodule.forward(encoder_inputs=modality_inputs[modality_name]) + if embeddings is not None: + # All embeddings are now in the format [num_tokens, hidden_dim] + modality_embeddings[modality_name] = embeddings + logger.debug( + f"[Rank {current_rank}][MimoModel][forward][{modality_name}] " + f"Generated embeddings with shape {embeddings.shape}" + ) + else: logger.debug( - f"Generated embeddings for {modality_name} with shape {embeddings.shape}" + f"[Rank {current_rank}][MimoModel][forward][{modality_name}] " + f"First PP stage but no modality inputs provided" ) + else: + # Intermediate PP stage: use stored input tensor from set_input_tensor + if modality_name in self.modality_input_tensors: + input_tensor = self.modality_input_tensors[modality_name] + logger.debug( + f"[Rank {current_rank}][MimoModel][forward][{modality_name}] " + f"Intermediate PP stage, using stored input tensor with shape {input_tensor.shape}" + ) + + # Pass through the submodule + # Note: The submodule's forward method should handle tensor inputs appropriately + embeddings = submodule({'_intermediate_tensor': input_tensor}) + if embeddings is not None: + modality_embeddings[modality_name] = embeddings + logger.debug( + f"[Rank {current_rank}][MimoModel][forward][{modality_name}] " + f"Generated embeddings with shape {embeddings.shape}" + ) + else: + logger.warning( + f"[Rank {current_rank}][MimoModel][forward][{modality_name}] " + f"Intermediate PP stage but no input tensor stored from set_input_tensor" + ) + + # Only process if language model is available on this rank + if self.language_model is None: + logger.debug( + f"[Rank {current_rank}][MimoModel][forward] " + f"No language model on this rank, returning modality embeddings with keys: {modality_embeddings.keys()}" + ) + # Return as dictionary for multimodule pipeline compatibility + # The keys should be the modality names (e.g., 'vision', 'audio') + return modality_embeddings, loss_mask + + # Check if we're at the first PP stage of the language model + lm_pp_group = self._get_submodule_pp_group(self.mimo_config.language_model_spec) + is_lm_first_stage = lm_pp_group is None or is_pp_first_stage(lm_pp_group) + + if is_lm_first_stage: + # First LM PP stage: process modality embeddings and text embeddings + logger.debug( + f"[Rank {current_rank}][MimoModel][forward] " + f"Language model first PP stage, processing text and modality embeddings" + ) + + # At LM first stage, modality submodules are None (not initialized on this rank) + # We need to use modality embeddings that were received from encoders + # via set_input_tensor (stored in self.modality_input_tensors) + # These are the final outputs from encoder PP stages + logger.debug( + f"[Rank {current_rank}][MimoModel][forward] " + f"Available stored modality embeddings: {list(self.modality_input_tensors.keys())}" + ) + + for modality_name, stored_embedding in self.modality_input_tensors.items(): + if stored_embedding is not None: + logger.debug( + f"[Rank {current_rank}][MimoModel][forward] " + f"Using stored {modality_name} embedding from set_input_tensor with shape {stored_embedding.shape}" + ) + modality_embeddings[modality_name] = stored_embedding + + # Check if we have any modality embeddings (excluding locally computed ones from first loop) + # If special tokens for modalities are present but we don't have embeddings, warn + if not self.modality_input_tensors and not modality_embeddings: + logger.warning( + f"[Rank {current_rank}][MimoModel][forward] " + f"LM first stage but no modality embeddings received via set_input_tensor. " + f"This may be expected if encoders are on the same rank, or an error if using multimodule pipeline." + ) + + # Get text embeddings + text_embeddings = self.get_text_embeddings(input_ids, position_ids, self.special_token_ids) + logger.debug( + f"[Rank {current_rank}][MimoModel][forward] " + f"Generated text embeddings with shape {text_embeddings.shape}" + ) - # Get text embeddings - text_embeddings = self.get_text_embeddings(input_ids, position_ids, self.special_token_ids) - logger.debug(f"Generated text embeddings with shape {text_embeddings.shape}") - - modality_embeddings["text"] = text_embeddings - - # 2. Merge embeddings from different modalities - logger.debug(f"Merging embeddings from {len(modality_embeddings)} modalities") - combined_embeddings = self.align_embeddings_by_token_positions( - modality_embeddings=modality_embeddings, # [num_tokens, hidden_dim] for each modality - input_ids=input_ids, # Pass in batch-first format [b, s] - special_token_ids=self.special_token_ids, - ) # [s, b, h] - logger.debug(f"Combined embeddings shape: {combined_embeddings.shape}") - - # 3. Forward pass through language model - lm_output = self.language_model( - input_ids=None, - position_ids=None, - decoder_input=combined_embeddings, - labels=labels, - attention_mask=attention_mask, + modality_embeddings["text"] = text_embeddings + + # Merge embeddings from different modalities + logger.debug( + f"[Rank {current_rank}][MimoModel][forward] " + f"Merging embeddings from {len(modality_embeddings)} modalities" + ) + combined_embeddings = self.align_embeddings_by_token_positions( + modality_embeddings=modality_embeddings, # [num_tokens, hidden_dim] for each modality + input_ids=input_ids, # Pass in batch-first format [b, s] + special_token_ids=self.special_token_ids, + ) # [s, b, h] + logger.debug( + f"[Rank {current_rank}][MimoModel][forward] " + f"Combined embeddings shape: {combined_embeddings.shape}" + ) + + # Forward pass through language model + lm_output = self.language_model( + input_ids=None, + position_ids=None, + decoder_input=combined_embeddings, + labels=labels, + attention_mask=attention_mask, + ) + else: + # Intermediate LM PP stage: use stored input tensor + if self.language_model_input_tensor is not None: + logger.debug( + f"[Rank {current_rank}][MimoModel][forward] " + f"Language model intermediate PP stage, using stored input tensor with shape {self.language_model_input_tensor.shape}" + ) + # The language model's set_input_tensor should have already been called + # Just do the forward pass + lm_output = self.language_model( + input_ids=None, + position_ids=None, + decoder_input=None, # The LM will use its stored input tensor + labels=labels, + attention_mask=attention_mask, + ) + else: + raise RuntimeError( + f"[Rank {current_rank}][MimoModel][forward] " + f"Language model at intermediate PP stage but no input tensor stored from set_input_tensor" + ) + + logger.debug( + f"[Rank {current_rank}][MimoModel][forward] " + f"Language model output shape: {lm_output.shape}" ) - logger.debug(f"Language model output shape: {lm_output.shape}") + + if not is_pp_last_stage(lm_pp_group): + return {'language_module': lm_output}, loss_mask + return lm_output, loss_mask diff --git a/megatron/core/models/mimo/submodules/base.py b/megatron/core/models/mimo/submodules/base.py index 8b11ba7fcb9..32b824cc6f1 100644 --- a/megatron/core/models/mimo/submodules/base.py +++ b/megatron/core/models/mimo/submodules/base.py @@ -7,7 +7,7 @@ import torch import torch.nn as nn - +from megatron.core.pipeline_parallel.utils import is_pp_last_stage from megatron.core.transformer.spec_utils import ModuleSpec, build_module # Initialize logger @@ -72,6 +72,18 @@ def from_spec(cls, module_spec: ModuleSpec) -> 'ModalitySubmodules': params = module_spec.params or {} submodules = module_spec.submodules or {} + # Check if at last PP stage for projections + # pg_collection is in encoder specs, not modality submodule spec + is_last_pp_stage = True # Default to True if no PP + if 'encoders' in submodules: + # Check the first encoder's pg_collection + for encoder_spec in submodules['encoders'].values(): + encoder_params = encoder_spec.params or {} + pg_collection = encoder_params.get('pg_collection') + if pg_collection and hasattr(pg_collection, 'pp'): + is_last_pp_stage = is_pp_last_stage(pg_collection.pp) + break # Only need to check one encoder + # Build component lists from submodules dictionary encoders = {} if 'encoders' in submodules: @@ -87,14 +99,16 @@ def from_spec(cls, module_spec: ModuleSpec) -> 'ModalitySubmodules': decoder = build_module(decoder_spec) decoders[decoder_name] = decoder + # Only build projections on last PP stage input_projections = [] if 'input_projections' in submodules: for proj_spec in submodules['input_projections']: logger.debug( - f"Building {cls.__name__} input projection: {proj_spec.module.__name__}" + f"Building {cls.__name__} input projection: {proj_spec.module.__name__} is_last_pp_stage: {is_last_pp_stage}" ) - projection = build_module(proj_spec) + projection = build_module(proj_spec) if is_last_pp_stage else None input_projections.append(projection) + output_projections = [] if 'output_projections' in submodules: diff --git a/megatron/core/models/mimo/submodules/vision.py b/megatron/core/models/mimo/submodules/vision.py index 795cb18a119..0903aba76ea 100644 --- a/megatron/core/models/mimo/submodules/vision.py +++ b/megatron/core/models/mimo/submodules/vision.py @@ -52,6 +52,10 @@ def __init__( len(self.output_projections) <= 1 ), "VisionModalitySubmodules currently supports only one output projection" + self.pre_process = False + self.post_process = False + self.share_embeddings_and_output_weights=False + def encode(self, encoders_data_batch: Dict) -> List[torch.Tensor]: """Encode image data batch into a list of tensors. @@ -81,18 +85,19 @@ def encode(self, encoders_data_batch: Dict) -> List[torch.Tensor]: # Process inputs through the encoder encoder_outputs = encoder(**encoder_inputs) logger.debug(f"Encoder '{name}' output shape: {encoder_outputs.shape}") - if encoder_outputs.ndim == 3: - # its b,s,h -> we need to flatten it to b*s,h - encoder_outputs = encoder_outputs.reshape(-1, encoder_outputs.size(-1)) - embeddings.append(encoder_outputs) - elif encoder_outputs.ndim == 2: - # its b*s,h -> encoder already returned the flattened output - embeddings.append(encoder_outputs) - else: - raise ValueError( - f"Encoder '{name}' output shape {encoder_outputs.shape} is not supported" - "Expected 3D (b,s,h) or 2D (b*s,h) tensor, got {encoder_outputs.ndim}D" - ) + embeddings.append(encoder_outputs) + # if encoder_outputs.ndim == 3: + # # its b,s,h -> we need to flatten it to b*s,h + # encoder_outputs = encoder_outputs.reshape(-1, encoder_outputs.size(-1)) + # embeddings.append(encoder_outputs) + # elif encoder_outputs.ndim == 2: + # # its b*s,h -> encoder already returned the flattened output + # embeddings.append(encoder_outputs) + # else: + # raise ValueError( + # f"Encoder '{name}' output shape {encoder_outputs.shape} is not supported" + # "Expected 3D (b,s,h) or 2D (b*s,h) tensor, got {encoder_outputs.ndim}D" + # ) return embeddings @@ -167,18 +172,30 @@ def forward(self, encoder_inputs: Dict[str, Any]) -> Optional[torch.Tensor]: encoder_inputs: Dictionary where keys match encoder names in self.encoders and values are dictionaries of encoder-specific parameters. Example: {"clip": {"pixel_values": images}, "vit": {"images": vit_images}} + For intermediate PP stages: {"_intermediate_tensor": tensor} Returns: Flattened image embeddings with shape [total_embeddings, hidden_dim], or None if no valid inputs were provided. """ - # Encode the images - embeddings = self.encode(encoder_inputs) + # Handle intermediate PP stage + # TODO: ykarnati we need a better design this is temp for testing. + if '_intermediate_tensor' in encoder_inputs: + embeddings = [self.encoders[name](encoder_inputs['_intermediate_tensor']) for name in self.encoders.keys()] + else:# Encode the images + embeddings = self.encode(encoder_inputs) # If no embeddings were produced, return None - if not embeddings: + if embeddings is None: return None + + # projection is only run on last PP stage + if self.input_projections[0] is None: + return embeddings[0] + + if embeddings[0].ndim == 3: + embeddings = [embedding.reshape(-1, embedding.size(-1)) for embedding in embeddings] projected = self.project_embeddings(embeddings, is_input=True) - logging.debug(f"Projected audio embeddings shape: {projected.shape}") + logging.debug(f"Projected vision embeddings shape: {projected.shape}") return projected # [total_embeddings, hidden_dim] diff --git a/megatron/core/optimizer/clip_grads.py b/megatron/core/optimizer/clip_grads.py index 70117858b77..81706415083 100644 --- a/megatron/core/optimizer/clip_grads.py +++ b/megatron/core/optimizer/clip_grads.py @@ -164,7 +164,10 @@ def clip_grad_by_total_norm_fp32( grads.append(to_local_if_dtensor(param.decoupled_grad).detach()) else: if param.grad is not None: - assert param.grad.type() == 'torch.cuda.FloatTensor' + # [TODO by shifangx] + # why should the grad be a FloatTensor? how to handle the bfloat16 grad? + # please refer to tests/unit_tests/test_optimizer.py. + # assert param.grad.type() == 'torch.cuda.FloatTensor', f"for debug: Rank {torch.distributed.get_rank()}, param.grad.type(): {param.grad.type()}" params.append(param) grads.append(to_local_if_dtensor(param.grad).detach()) diff --git a/megatron/core/optimizer/optimizer.py b/megatron/core/optimizer/optimizer.py index 1829cb424f1..55c75cded02 100644 --- a/megatron/core/optimizer/optimizer.py +++ b/megatron/core/optimizer/optimizer.py @@ -152,12 +152,22 @@ def get_main_grads_for_grad_norm(self) -> List[torch.Tensor]: grad = param.grad grad_not_none = grad is not None is_not_shared = param_is_not_shared(param) - is_not_tp_duplicate = tensor_parallel.param_is_not_tensor_parallel_duplicate(param) + is_not_tp_duplicate = self.param_is_not_tensor_parallel_duplicate(param) if grad_not_none and is_not_shared and is_not_tp_duplicate: grads_for_norm.append(grad) return grads_for_norm + def param_is_not_tensor_parallel_duplicate(self, param): + """Returns true if the passed-in parameter is not a duplicate parameter + on another TP rank.""" + # [TODO by shifangx] + # need to pass in the pg_collection to the optimizer, and then check if the param is a duplicate parameter on another TP rank. + temp_without_TP = False + return (hasattr(param, "tensor_model_parallel") and param.tensor_model_parallel) or ( + temp_without_TP + ) + def get_grad_stats_parallel_group(self) -> torch.distributed.ProcessGroup: """Process group for reducing gradient statistics (num_zeros & norm). diff --git a/megatron/core/pipeline_parallel/bridge_communicator.py b/megatron/core/pipeline_parallel/bridge_communicator.py index f1e74a2f16d..21b1524bb8c 100644 --- a/megatron/core/pipeline_parallel/bridge_communicator.py +++ b/megatron/core/pipeline_parallel/bridge_communicator.py @@ -13,13 +13,10 @@ class CommRole(Enum): """Communication role for ranks in bridge communication. - - SENDER: Leader tp-cp rank within each DP replica of source grid. - Sends data to destination grid receivers. - RECEIVER: Leader tp-cp rank within each DP replica of destination grid. - Receives data from source grid senders. - MEMBER: Non-leader ranks within DP replicas. - Participate in broadcasts from their local leader. + + SENDER: Leader tp-cp rank within each DP replica of source grid. Sends data to destination grid receivers. + RECEIVER: Leader tp-cp rank within each DP replica of destination grid. Receives data from source grid senders. + MEMBER: Non-leader ranks within DP replicas. Participate in broadcasts from their local leader. """ SENDER = "SENDER" @@ -90,14 +87,6 @@ def __init__( if dim_mapping is None: self.dim_mapping = {'s': 1, 'b': 0, 'h': 2} else: - assert set(dim_mapping.keys()) == { - 's', - 'b', - 'h', - }, f"dim_mapping must have keys 's', 'b', 'h', got {set(dim_mapping.keys())}" - assert all( - v in {0, 1, 2} for v in dim_mapping.values() - ), f"dim_mapping values must be 0, 1, or 2, got {list(dim_mapping.values())}" self.dim_mapping = dim_mapping self.src_grid_broadcast_pg = None @@ -147,14 +136,10 @@ def __init__( dist.barrier() def get_leader_rank(self, grid: HyperCommGrid, is_src: bool) -> List[int]: - """Get the leader rank for a given grid and direction. - - We elect leader rank for each dp replica, the first tp-cp rank in the group - in the last pp stage (for src grid) or first pp stage (for dest grid) is the leader. - """ + """Get the leader rank for a given grid and direction.""" leader_ranks = [] local_leader_rank = None - # grid.gen_rank_enum(["tp", "cp", "pp"]) # vary tp & cp, but same dp + # grid.gen_rank_enum(["tp", "cp", "pp"]) # vary tp & cp, freeze dp # returns a list of sublists, each sublist is a group of ranks # that have different tp & cp & pp, same dp per_dp_replica_ranks = grid._gen_rank_enum([x for x in grid.dim_names if x != "dp"]) @@ -254,23 +239,20 @@ def build_comm_map(self, src_tp_leaders: List[int], dest_tp_leaders: List[int]): for rank in all_ranks: self.comm_map[rank] = RankCommInfo(role=CommRole.MEMBER) - scale_factor = int(src_count / dest_count) + scale_factor = src_count / dest_count if scale_factor > 1: # Fan-in: multiple source leaders send to fewer destination leaders + scale_factor = int(scale_factor) for i, dest_rank in enumerate(dest_tp_leaders): # Each destination rank receives from scale_factor source ranks src_ranks = src_tp_leaders[i * scale_factor : (i + 1) * scale_factor] # Set up senders for src_rank in src_ranks: - self.comm_map[src_rank] = RankCommInfo( - role=CommRole.SENDER, send_to_ranks=[dest_rank] - ) + self.comm_map[src_rank] = RankCommInfo(role=CommRole.SENDER, send_to_ranks=[dest_rank]) # Set up receiver - self.comm_map[dest_rank] = RankCommInfo( - role=CommRole.RECEIVER, recv_from_ranks=src_ranks - ) + self.comm_map[dest_rank] = RankCommInfo(role=CommRole.RECEIVER, recv_from_ranks=src_ranks) else: # Fan-out: fewer source leaders send to more destination leaders scale_factor = int(dest_count / src_count) @@ -279,9 +261,7 @@ def build_comm_map(self, src_tp_leaders: List[int], dest_tp_leaders: List[int]): dest_ranks = dest_tp_leaders[i * scale_factor : (i + 1) * scale_factor] # Set up sender - self.comm_map[src_rank] = RankCommInfo( - role=CommRole.SENDER, send_to_ranks=dest_ranks - ) + self.comm_map[src_rank] = RankCommInfo(role=CommRole.SENDER, send_to_ranks=dest_ranks) # Set up receivers for dest_rank in dest_ranks: @@ -309,10 +289,12 @@ def send_forward(self, tensor_to_send: torch.Tensor): num_sends = len(rank_info.send_to_ranks) if num_sends > 0: tensor_splits = self._split_tensor_at_batch_dim(tensor_to_send, num_sends) + logging.debug(f"[Rank {self.current_rank} ][Bridge Comunicator] [send_forward] [src - {self.src_module_name}] [dest - {self.dest_module_name}] starting to communicate shapes") self._communicate_shapes(tensor_to_send_next=tensor_splits[0]) + logging.debug(f"[Rank {self.current_rank} ][Bridge Comunicator] [send_forward] [src - {self.src_module_name}] [dest - {self.dest_module_name}] communicate shapes DONE") for dest_rank, tensor_split in zip(rank_info.send_to_ranks, tensor_splits): logging.debug( - f"[Bridge Comunicator] [send_forward] Rank {self.current_rank} " + f"[Rank {self.current_rank} ][Bridge Comunicator] [send_forward] [src - {self.src_module_name}] [dest - {self.dest_module_name}] " f"send to rank {dest_rank}" ) dist.send(tensor_split, dst=dest_rank) @@ -335,19 +317,16 @@ def recv_forward(self) -> torch.Tensor: rank_info = self.comm_map.get(self.current_rank) assert rank_info is not None, f"Rank {self.current_rank} is not in the comm map" - logging.debug( - f"[Bridge Communicator] [receive_forward] Rank {self.current_rank} " - f"[src - {self.src_module_name}] [dest - {self.dest_module_name}] " - f"rank_info: {rank_info}" - ) + logging.debug(f"[Rank {self.current_rank} ][Bridge Communicator] [receive_forward] [src - {self.src_module_name}] [dest - {self.dest_module_name}] rank_info: {rank_info}") if rank_info.role == CommRole.RECEIVER: assert ( self.current_rank == self.dest_local_leader_rank ), f"Rank {self.current_rank} is not the leader rank" # p2p call to receive the tensor + logging.debug(f"[Rank {self.current_rank} ][Bridge Communicator] [receive_forward] [src - {self.src_module_name}] [dest - {self.dest_module_name}] starting to communicate shapes") recv_forward_shapes, recv_grad_shapes = self._communicate_shapes(recv_prev=True) logging.debug( - f"[Bridge Communicator] [receive_forward] Rank {self.current_rank} " + f"[Rank {self.current_rank} ][Bridge Communicator] [receive_forward] [src - {self.src_module_name}] [dest - {self.dest_module_name}] " f"received forward shapes {recv_forward_shapes} and grad shapes {recv_grad_shapes}" ) received_tensors_list = [] @@ -360,14 +339,14 @@ def recv_forward(self) -> torch.Tensor: ) dist.recv(tensor_to_recv, src=src_rank) logging.debug( - f"[Bridge Communicator] [receive_forward] Rank {self.current_rank} " + f"[Rank {self.current_rank} ][Bridge Communicator] [receive_forward] [src - {self.src_module_name}] [dest - {self.dest_module_name}] " f"received tensor from src rank {src_rank} " f"shape {tensor_to_recv.shape} sum {tensor_to_recv.sum()}" ) received_tensors_list.append(tensor_to_recv) aggregated_tensor = torch.cat(received_tensors_list, dim=self.dim_mapping['b']) logging.debug( - f"[Bridge Communicator] [receive_forward] Rank {self.current_rank} " + f"[Rank {self.current_rank} ][Bridge Communicator] [receive_forward] [src - {self.src_module_name}] [dest - {self.dest_module_name}] " f"broadcasting tensor {aggregated_tensor.shape} sum {aggregated_tensor.sum()}" ) @@ -384,15 +363,14 @@ def recv_forward(self) -> torch.Tensor: return aggregated_tensor - elif ( - rank_info.role == CommRole.MEMBER - and self.current_rank in self.dest_grid_broadcast_ranks - ): + elif rank_info.role == CommRole.MEMBER and self.current_rank in self.dest_grid_broadcast_ranks: # Non-leader rank - participate in broadcast shape_tensor = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64) + logging.debug(f"[Rank {self.current_rank} ][Bridge Communicator] [receive_forward] [src - {self.src_module_name}] [dest - {self.dest_module_name}] MEMBER broadcasting shape_tensor to leader rank {self.dest_local_leader_rank}") dist.broadcast( shape_tensor, src=self.dest_local_leader_rank, group=self.dest_grid_broadcast_pg ) + logging.debug(f"[Rank {self.current_rank} ][Bridge Communicator] [receive_forward] [src - {self.src_module_name}] [dest - {self.dest_module_name}] MEMBER shape_tensor received: {shape_tensor}") received_shape = tuple(shape_tensor.tolist()) received_tensor = torch.empty( @@ -408,7 +386,7 @@ def recv_forward(self) -> torch.Tensor: ) logging.debug( - f"[Bridge Communicator] [receive_forward] Rank {self.current_rank} " + f"[Rank {self.current_rank} ][Bridge Communicator] [receive_forward] [src - {self.src_module_name}] [dest - {self.dest_module_name}] MEMBER " f"received tensor via broadcast, shape {received_tensor.shape}" ) return received_tensor @@ -423,7 +401,7 @@ def send_backward(self, grad_tensor: torch.Tensor): """ if not self.is_current_rank_in_grid(self.dest_grid): raise ValueError( - f"[Bridge Communicator] [send_backward] Rank {self.current_rank} " + f"[Rank {self.current_rank} ][Bridge Communicator] [send_backward] " "is not in the destination grid." ) @@ -442,7 +420,7 @@ def send_backward(self, grad_tensor: torch.Tensor): for src_rank, tensor_split in zip(rank_info.recv_from_ranks, tensor_splits): # Send the gradient split back to the source rank logging.debug( - f"[Bridge Communicator] [send_backward] Rank {self.current_rank} " + f"[Rank {self.current_rank} ][Bridge Communicator] [send_backward] " f"sending gradient to src rank {src_rank} " f"shape {tensor_split.shape} sum {tensor_split.sum()}" ) @@ -462,7 +440,7 @@ def recv_backward(self) -> torch.Tensor: # receive backward only gets called on the src grid if not self.is_current_rank_in_grid(self.src_grid): raise ValueError( - f"[Bridge Communicator] [receive_backward] Rank {self.current_rank} " + f"[Rank {self.current_rank} ][Bridge Communicator] [receive_backward] " "is not in the source grid." ) @@ -475,7 +453,7 @@ def recv_backward(self) -> torch.Tensor: ), f"Rank {self.current_rank} is not the leader rank" recv_forward_shapes, recv_grad_shapes = self._communicate_shapes(recv_next=True) logging.debug( - f"[Bridge Communicator] [receive_backward] Rank {self.current_rank} " + f"[Rank {self.current_rank} ][Bridge Communicator] [receive_backward] " f"received forward shapes {recv_forward_shapes} and grad shapes {recv_grad_shapes}" ) # Receive gradient tensors from destination ranks @@ -487,7 +465,7 @@ def recv_backward(self) -> torch.Tensor: ) dist.recv(grad_tensor, src=dest_rank) logging.debug( - f"[Bridge Communicator] [receive_backward] Rank {self.current_rank} " + f"[Rank {self.current_rank} ][Bridge Communicator] [receive_backward] " f"received gradient from dest rank {dest_rank} " f"shape {grad_tensor.shape} sum {grad_tensor.sum()}" ) @@ -496,7 +474,7 @@ def recv_backward(self) -> torch.Tensor: # Concatenate received gradients aggregated_gradient = torch.cat(received_gradients_list, dim=0) logging.debug( - f"[Bridge Communicator] [receive_backward] Rank {self.current_rank} " + f"[Rank {self.current_rank} ][Bridge Communicator] [receive_backward] " f"agg grad shape {aggregated_gradient.shape} sum {aggregated_gradient.sum()}" ) @@ -511,9 +489,7 @@ def recv_backward(self) -> torch.Tensor: ) return aggregated_gradient - elif ( - rank_info.role == CommRole.MEMBER and self.current_rank in self.src_grid_broadcast_ranks - ): + elif rank_info.role == CommRole.MEMBER and self.current_rank in self.src_grid_broadcast_ranks: # Non-leader rank - participate in gather for gradients # Receive broadcasted tensor shape from leader rank shape_tensor = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64) @@ -522,7 +498,7 @@ def recv_backward(self) -> torch.Tensor: ) logging.debug( - f"[Bridge Communicator] [receive_backward] Rank {self.current_rank} " + f"[Rank {self.current_rank} ][Bridge Communicator] [receive_backward] " f"received shape tensor {shape_tensor}" ) received_shape = tuple(shape_tensor.tolist()) @@ -534,7 +510,7 @@ def recv_backward(self) -> torch.Tensor: received_gradient, src=self.src_local_leader_rank, group=self.src_grid_broadcast_pg ) logging.debug( - f"[Bridge Communicator] [receive_backward] Rank {self.current_rank} " + f"[Rank {self.current_rank} ][Bridge Communicator] [receive_backward] " f"received gradient from scatter operation, shape {received_gradient.shape}" ) return received_gradient @@ -559,11 +535,7 @@ def send_forward_recv_backward( rank_info = self.comm_map.get(self.current_rank) assert rank_info is not None, f"Rank {self.current_rank} is not in the comm map" - logging.debug( - f"[Bridge Communicator] [send_forward_recv_backward] Rank {self.current_rank} " - f"[src - {self.src_module_name}] [dest - {self.dest_module_name}] " - f"rank_info: {rank_info}" - ) + logging.debug(f"[Rank {self.current_rank} ][Bridge Communicator] [send_forward_recv_backward] [src - {self.src_module_name}] [dest - {self.dest_module_name}] rank_info: {rank_info}") if rank_info.role == CommRole.SENDER: assert ( self.current_rank == self.src_local_leader_rank @@ -576,7 +548,7 @@ def send_forward_recv_backward( tensor_to_send_next=activation_splits[0], recv_next=True ) logging.debug( - f"[Bridge Communicator] [send_forward_recv_backward] Rank {self.current_rank} " + f"[Rank {self.current_rank} ][Bridge Communicator] [send_forward_recv_backward] " f"received forward shapes {recv_forward_shapes} and grad shapes {recv_grad_shapes}" ) @@ -607,7 +579,7 @@ def send_forward_recv_backward( ) logging.debug( - f"[Bridge Communicator] [send_forward_recv_backward] Rank {self.current_rank} " + f"[Rank {self.current_rank} ][Bridge Communicator] [send_forward_recv_backward] " f"executing {len(ops)} simultaneous P2P operations" ) reqs = torch.distributed.batch_isend_irecv(ops) @@ -617,7 +589,7 @@ def send_forward_recv_backward( # Concatenate received gradients aggregated_gradient = torch.cat(received_gradients_list, dim=0) logging.debug( - f"[Bridge Communicator] [send_forward_recv_backward] Rank {self.current_rank} " + f"[Rank {self.current_rank} ][Bridge Communicator] [send_forward_recv_backward] " f"agg grad shape {aggregated_gradient.shape} sum {aggregated_gradient.sum()}" ) # Broadcast tensor shape to all ranks in scatter_pg @@ -625,10 +597,11 @@ def send_forward_recv_backward( shape_tensor = torch.tensor( tensor_shape_to_broadcast, device=torch.cuda.current_device(), dtype=torch.int64 ) + logging.debug(f"[Rank {self.current_rank} ][Bridge Communicator] [send_forward_recv_backward] [src - {self.src_module_name}] [dest - {self.dest_module_name}] SENDER broadcasting shape_tensor {shape_tensor}") dist.broadcast( shape_tensor, src=self.current_rank, group=self.src_grid_broadcast_pg ) - + logging.debug(f"[Rank {self.current_rank} ][Bridge Communicator] [send_forward_recv_backward] [src - {self.src_module_name}] [dest - {self.dest_module_name}] SENDER broadcasting aggregated_gradient {aggregated_gradient.shape} sum {aggregated_gradient.sum()}") # Broadcast the tensors to all ranks in the group dist.broadcast( aggregated_gradient, src=self.current_rank, group=self.src_grid_broadcast_pg @@ -636,16 +609,14 @@ def send_forward_recv_backward( return aggregated_gradient - elif ( - rank_info.role == CommRole.MEMBER and self.current_rank in self.src_grid_broadcast_ranks - ): + elif rank_info.role == CommRole.MEMBER and self.current_rank in self.src_grid_broadcast_ranks: # participate in both gather for gradients # Receive gradient from leader using broadcast shape_tensor = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64) dist.broadcast( shape_tensor, src=self.src_local_leader_rank, group=self.src_grid_broadcast_pg ) - + logging.debug(f"[Rank {self.current_rank} ][Bridge Communicator] [send_forward_recv_backward] [src - {self.src_module_name}] [dest - {self.dest_module_name}] MEMBER received shape tensor {shape_tensor}") # Use the received shape to create tensor for broadcast received_shape = tuple(shape_tensor.tolist()) received_gradient = torch.empty( @@ -655,7 +626,7 @@ def send_forward_recv_backward( received_gradient, src=self.src_local_leader_rank, group=self.src_grid_broadcast_pg ) logging.debug( - f"[Bridge Communicator] [send_forward_recv_backward] Rank {self.current_rank} " + f"[Rank {self.current_rank} ][Bridge Communicator] [send_forward_recv_backward] " f"received gradient from broadcast, shape {received_gradient.shape}" ) return received_gradient @@ -693,7 +664,7 @@ def send_backward_recv_forward( tensor_to_send_prev=gradient_splits[0], recv_prev=True ) logging.debug( - f"[Bridge Communicator] [send_backward_recv_backward] Rank {self.current_rank} " + f"[Rank {self.current_rank} ][Bridge Communicator] [send_backward_recv_forward] " f"received forward shapes {recv_forward_shapes} and grad shapes {recv_grad_shapes}" ) @@ -729,7 +700,7 @@ def send_backward_recv_forward( # Execute all operations simultaneously logging.debug( - f"[Bridge Communicator] [send_backward_recv_backward] Rank {self.current_rank} " + f"[Rank {self.current_rank} ][Bridge Communicator] [send_backward_recv_forward] " f"executing {len(ops)} simultaneous P2P operations" ) reqs = torch.distributed.batch_isend_irecv(ops) @@ -739,36 +710,36 @@ def send_backward_recv_forward( # Concatenate received activations aggregated_activation = torch.cat(received_activations_list, dim=0) logging.debug( - f"[Bridge Communicator] [send_backward_recv_backward] Rank {self.current_rank} " + f"[Rank {self.current_rank} ][Bridge Communicator] [send_backward_recv_forward] " f"agg act shape {aggregated_activation.shape} sum {aggregated_activation.sum()}" ) # Broadcast tensor shape to all ranks in scatter_pg tensor_shape_to_scatter = aggregated_activation.shape + logging.debug(f"[Rank {self.current_rank} ][Bridge Communicator] [send_backward_recv_forward] [src - {self.src_module_name}] [dest - {self.dest_module_name}] RECEIVER tensor_shape_to_scatter: {tensor_shape_to_scatter}") shape_tensor = torch.tensor( tensor_shape_to_scatter, device=torch.cuda.current_device(), dtype=torch.int64 ) dist.broadcast( shape_tensor, src=self.current_rank, group=self.dest_grid_broadcast_pg ) - + logging.debug(f"[Rank {self.current_rank} ][Bridge Communicator] [send_backward_recv_forward] [src - {self.src_module_name}] [dest - {self.dest_module_name}] RECEIVER broadcasting aggregated_activation {aggregated_activation.shape} sum {aggregated_activation.sum()}") # Scatter the tensors to all ranks in the group dist.broadcast( aggregated_activation, src=self.current_rank, group=self.dest_grid_broadcast_pg ) return aggregated_activation - elif ( - rank_info.role == CommRole.MEMBER - and self.current_rank in self.dest_grid_broadcast_ranks - ): + elif rank_info.role == CommRole.MEMBER and self.current_rank in self.dest_grid_broadcast_ranks: shape_tensor = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64) dist.broadcast( shape_tensor, src=self.dest_local_leader_rank, group=self.dest_grid_broadcast_pg ) + logging.debug(f"[Rank {self.current_rank} ][Bridge Communicator] [send_backward_recv_forward] [src - {self.src_module_name}] [dest - {self.dest_module_name}] MEMBER received shape tensor {shape_tensor}") # Use the received shape to create tensor for scatter operation received_shape = tuple(shape_tensor.tolist()) + received_activation = torch.empty( received_shape, device=torch.cuda.current_device(), @@ -781,8 +752,8 @@ def send_backward_recv_forward( group=self.dest_grid_broadcast_pg, ) logging.debug( - f"[Bridge Communicator] [send_backward_recv_backward] Rank {self.current_rank} " - f"received activation from scatter operation, shape {received_activation.shape}" + f"[Rank {self.current_rank} ][Bridge Communicator] [send_backward_recv_forward] [src - {self.src_module_name}] [dest - {self.dest_module_name}] " + f" MEMBER received activation from scatter operation, shape {received_activation.shape}" ) return received_activation @@ -815,10 +786,6 @@ def _communicate_shapes( recv_forward_shapes = [] recv_grad_shapes = [] - logging.debug( - f"[Bridge Communicator] [communicate_shapes] Rank {self.current_rank} " - f"is a {rank_info.role} and is running the shape communication" - ) # Collect all P2P operations for batch execution ops = [] recv_forward_shape_tensors = [] @@ -896,6 +863,17 @@ def _communicate_shapes( shape = grad_shape_tensor.tolist() recv_grad_shapes.append(tuple(shape)) + import traceback + stack = traceback.extract_stack() + caller = stack[-2] # -1 is current line, -2 is caller + logging.debug( + f"[Rank {self.current_rank} ][Bridge Communicator] [communicate_shapes] [src - {self.src_module_name}] [dest - {self.dest_module_name}] " + f"is a {rank_info.role} and is running the shape communication " + f"[Called from {caller.filename}:{caller.lineno} in {caller.name}()]" + f"recv_forward_shapes: {recv_forward_shapes}" + f"recv_grad_shapes: {recv_grad_shapes}" + ) + return recv_forward_shapes, recv_grad_shapes def _split_tensor_at_batch_dim( @@ -915,5 +893,5 @@ def _split_tensor_at_batch_dim( batch_dim = self.dim_mapping['b'] splits = torch.tensor_split(aggregated_tensor, num_splits, dim=batch_dim) - # PyTorch p2p requires the tensors to be contiguous + # Ensure all splits are contiguous to avoid P2P communication issues return [split.contiguous() for split in splits] diff --git a/megatron/core/pipeline_parallel/multimodule_communicator.py b/megatron/core/pipeline_parallel/multimodule_communicator.py index dfda270ef76..b283a3eaa2c 100644 --- a/megatron/core/pipeline_parallel/multimodule_communicator.py +++ b/megatron/core/pipeline_parallel/multimodule_communicator.py @@ -1,12 +1,12 @@ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -import logging from dataclasses import dataclass from typing import Dict, List, Optional, Union - +import time import torch import torch.distributed as dist - +import torch.cuda.nvtx as nvtx +import logging from megatron.core.hyper_comm_grid import HyperCommGrid from megatron.core.model_parallel_config import ModelParallelConfig from megatron.core.pipeline_parallel.bridge_communicator import BridgeCommunicator @@ -21,7 +21,7 @@ class RankModuleInfo: """Information about a rank in a module.""" # the stage of the current rank in the current module's pipeline. - pp_rank: int # the stage of the current rank in the current module's pipeline + pp_stage: int # the stage of the current rank in the current module's pipeline pp_size: int # the number of ranks in the current module's pipeline p2p_communicator: Optional[P2PCommunicator] # key is either the src or dst module name connected to the current module @@ -37,6 +37,32 @@ class RankModuleInfo: is_terminal_stage: Optional[bool] = True +def _ensure_3d_tensor(tensor): + """Ensure tensor is 3D for P2P/bridge communication. + + P2P and bridge communicators expect 3D tensors. + Handles both single tensors and lists of tensors (for VPP). + """ + if isinstance(tensor, list): + return [_ensure_3d_tensor(t) for t in tensor] + if isinstance(tensor, torch.Tensor) and tensor.ndim == 2: + return tensor.unsqueeze(-1) + return tensor + + +def _restore_tensor_shape(tensor): + """Restore original tensor shape after P2P/bridge communication. + + Remove the extra dimension added by _ensure_3d_tensor if it was singleton. + Handles both single tensors and lists of tensors (for VPP). + """ + if isinstance(tensor, list): + return [_restore_tensor_shape(t) for t in tensor] + if isinstance(tensor, torch.Tensor) and tensor.ndim == 3 and tensor.shape[-1] == 1: + return tensor.squeeze(-1) + return tensor + + class MultiModulePipelineCommunicator: """Communicator for a multi-module pipeline.""" @@ -68,10 +94,6 @@ def __init__( 'generator': [] } config (ModelParallelConfig): A ModelParallelConfig object. - dim_mapping (Dict[str, List[int]]): Dimension mapping for sequence, batch, hidden. - Example: - dim_mapping = {'s': 0, 'h': 2, 'b': 1} - Default: None """ self.module_to_grid_map = module_to_grid_map self.topology = topology @@ -108,12 +130,12 @@ def is_pp_first_stage(self): """Return True if the current rank has the absolute first stage in the overall model. The absolute first stage is defined as: - 1. The current rank must be in the first PP stage (pp_rank == 0) of some module + 1. The current rank must be in the first PP stage (pp_stage == 0) of some module 2. That module must be a source module (no incoming connections in topology) """ for module_name, rank_module_info in self.rank_module_map.items(): # Check if this rank is at the first PP stage of this module - if rank_module_info.pp_rank == 0: + if rank_module_info.pp_stage == 0: # Check if this module is a source module (no incoming connections) if self._is_source_module(module_name): return True @@ -129,7 +151,7 @@ def is_pp_last_stage(self): """ for module_name, rank_module_info in self.rank_module_map.items(): # Check if this rank is at the last PP stage of this module - if rank_module_info.pp_rank == rank_module_info.pp_size - 1: + if rank_module_info.pp_stage == rank_module_info.pp_size - 1: # Check if this module is a sink module (no outgoing connections) if self._is_sink_module(module_name): return True @@ -185,11 +207,7 @@ def num_warmup_microbatches(self): assert ( current_stage <= total_stages ), f"current_stage: {current_stage} is greater than total_stages: {total_stages}" - logging.debug( - f"[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] " - f"current_stage: {current_stage} total_stages: {total_stages} " - f"num_warmup_microbatches: {total_stages - current_stage - 1}" - ) + logging.debug(f"[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] current_stage: {current_stage} total_stages: {total_stages} num_warmup_microbatches: {total_stages - current_stage - 1}") return total_stages - current_stage - 1 def _build_rank_module_info_map(self): @@ -204,31 +222,27 @@ def _build_rank_module_info_map(self): p2p_comm = P2PCommunicator(pp_group, self.config) pp_size = dist.get_world_size(pp_group) rank_in_pp_group = dist.get_group_rank(pp_group, self.current_rank) - pp_rank = rank_in_pp_group % pp_size + pp_stage = rank_in_pp_group % pp_size bridge_comms_as_dest_module = [] bridge_comms_as_src_module = [] # If first stage, check if the module has any incoming modules # If so, initialize bridge communicator - if pp_rank == 0: + if pp_stage == 0: for bridge_comm in self.bridge_comms: - if ( - bridge_comm.is_current_rank_in_grid(bridge_comm.dest_grid) - and bridge_comm.dest_module_name == module_name - ): + if (bridge_comm.is_current_rank_in_grid(bridge_comm.dest_grid) and + bridge_comm.dest_module_name == module_name): bridge_comms_as_dest_module.append(bridge_comm) # If last stage, check if the module has any outgoing modules # If so, initialize bridge communicator - if pp_rank == pp_size - 1: + if pp_stage == pp_size - 1: for bridge_comm in self.bridge_comms: - if ( - bridge_comm.is_current_rank_in_grid(bridge_comm.src_grid) - and bridge_comm.src_module_name == module_name - ): + if (bridge_comm.is_current_rank_in_grid(bridge_comm.src_grid) and + bridge_comm.src_module_name == module_name): bridge_comms_as_src_module.append(bridge_comm) # Build RankModuleInfo for the module rank_module_info = RankModuleInfo( - pp_rank=pp_rank, + pp_stage=pp_stage, pp_size=pp_size, p2p_communicator=p2p_comm, bridge_comms_as_dest_module=bridge_comms_as_dest_module, @@ -247,24 +261,34 @@ def recv_forward( Returns: A dictionary mapping module names to tensors. """ - logging.debug( - f"[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] " - f"[receive_forward] tensors_shape: {tensor_shape}, is_first_stage: {is_first_stage}" - ) - input_dict = {} - for module_name, rank_module_info in self.rank_module_map.items(): - - if rank_module_info.pp_rank == 0: - # If first stage, and has incoming modules, receive forward activation - # from incoming modules. - for bridge_comm in rank_module_info.bridge_comms_as_dest_module: - input_dict[bridge_comm.src_module_name] = bridge_comm.recv_forward() - else: - # If not first stage, receive forward activation tensor from P2P communicator. - input_dict[module_name] = rank_module_info.p2p_communicator.recv_forward( - tensor_shapes=tensor_shape, is_first_stage=False - ) - return input_dict + with nvtx.range(f"MultiModulePipelineCommunicator.recv_forward_rank:{self.current_rank}"): + input_dict = {} + for module_name, rank_module_info in self.rank_module_map.items(): + + if rank_module_info.pp_stage == 0: + # If first stage, and has incoming modules, receive forward activation + # from incoming modules. + for bridge_comm in rank_module_info.bridge_comms_as_dest_module: + with nvtx.range(f"recv_forward_bridge_rank:{self.current_rank}:{bridge_comm.src_module_name}->{bridge_comm.dest_module_name}"): + logging.debug(f'[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] [receive_forward] bridge [src - {bridge_comm.src_module_name}] [dest - {bridge_comm.dest_module_name}], module_name: {module_name}') + received_tensor = bridge_comm.recv_forward() + received_tensor = _restore_tensor_shape(received_tensor) + input_dict[bridge_comm.src_module_name] = received_tensor + logging.debug(f'[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] [receive_forward] bridge [src - {bridge_comm.src_module_name}] [dest - {bridge_comm.dest_module_name}], module_name: {module_name} tensor shape: {input_dict[bridge_comm.src_module_name].shape} sum {input_dict[bridge_comm.src_module_name].sum()} DONE') + else: + # If not first stage, receive forward activation tensor from P2P communicator. + with nvtx.range(f"recv_forward_p2p_rank:{self.current_rank}:{module_name}"): + logging.debug(f'[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] [receive_forward] p2p comm module_name: {module_name} ') + received_tensor = rank_module_info.p2p_communicator.recv_forward( + tensor_shapes=tensor_shape, is_first_stage=False + ) + if isinstance(received_tensor, torch.Tensor): + received_tensor = _restore_tensor_shape(received_tensor) + elif isinstance(received_tensor, list): + received_tensor = [_restore_tensor_shape(t) for t in received_tensor] + input_dict[module_name] = received_tensor + logging.debug(f'[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] [receive_forward] p2p comm module_name: {module_name} tensor shape: {input_dict[module_name][0].shape} sum {input_dict[module_name][0].sum()} DONE') + return input_dict def send_forward(self, output_dict: Dict[str, torch.Tensor], is_last_stage: bool = False): """Send forward activation tensor. @@ -272,21 +296,27 @@ def send_forward(self, output_dict: Dict[str, torch.Tensor], is_last_stage: bool Args: output_dict: A dictionary mapping module names to tensors. """ - logging.debug( - f"[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] " - f"[send_forward] output_dict keys: {output_dict.keys()}, is_last_stage: {is_last_stage}" - ) - for module_name, rank_module_info in self.rank_module_map.items(): - if rank_module_info.pp_rank == rank_module_info.pp_size - 1: - # If last stage, and has outgoing modules, send forward activation - # by using bridge communicator. - for bridge_comm in rank_module_info.bridge_comms_as_src_module: - bridge_comm.send_forward(output_dict[module_name]) - else: - # If not last stage, send forward activation by using P2P communicator. - rank_module_info.p2p_communicator.send_forward( - output_dict[module_name], is_last_stage=False - ) + with nvtx.range(f"MultiModulePipelineCommunicator.send_forward_rank:{self.current_rank}"): + for module_name, rank_module_info in self.rank_module_map.items(): + if rank_module_info.pp_stage == rank_module_info.pp_size - 1: + # If last stage, and has outgoing modules, send forward activation + # by using bridge communicator. + for bridge_comm in rank_module_info.bridge_comms_as_src_module: + with nvtx.range(f"send_forward_bridge_rank:{self.current_rank}:{bridge_comm.src_module_name}->{bridge_comm.dest_module_name}"): + logging.debug(f'[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] [send_forward] bridge [src - {bridge_comm.src_module_name}] [dest - {bridge_comm.dest_module_name}], module_name: {module_name} tensor shape: {output_dict[module_name].shape} sum {output_dict[module_name].sum()}') + tensor_to_send = _ensure_3d_tensor(output_dict[module_name]) + bridge_comm.send_forward(tensor_to_send) + # time.sleep(10) + logging.debug(f'[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] [send_forward] bridge [src - {bridge_comm.src_module_name}] [dest - {bridge_comm.dest_module_name} DONE') + else: + # If not last stage, send forward activation by using P2P communicator. + with nvtx.range(f"send_forward_p2p_rank:{self.current_rank}:{module_name}"): + logging.debug(f'[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] [send_forward] p2p comm module_name: {module_name} output dict keys: {output_dict.keys()}') + logging.debug(f'[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] [send_forward] p2p comm module_name: {module_name} tensor shape: {output_dict[module_name].shape} sum {output_dict[module_name].sum()}') + tensor_to_send = _ensure_3d_tensor(output_dict[module_name]) + rank_module_info.p2p_communicator.send_forward( + tensor_to_send, is_last_stage=False + ) def send_forward_recv_backward( self, @@ -303,29 +333,31 @@ def send_forward_recv_backward( Returns: A dictionary mapping module names to tensors. """ - logging.debug( - f"[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] " - f"[send_forward_recv_backward] output_dict keys: {output_dict.keys()}, " - f"tensor_shape: {tensor_shape}, is_last_stage: {is_last_stage}" - ) - grad_dict = {} - for module_name, rank_module_info in self.rank_module_map.items(): - if rank_module_info.pp_rank == rank_module_info.pp_size - 1: - # If last stage, and has outgoing modules, send forward activation and - # receive backward gradient by using bridge communicator. - for bridge_comm in rank_module_info.bridge_comms_as_src_module: - grad_dict[bridge_comm.src_module_name] = bridge_comm.send_forward_recv_backward( - output_dict[module_name] - ) - else: - # If not last stage, send forward activation and receive backward gradient - # by using P2P communicator. - grad_dict[module_name] = ( - rank_module_info.p2p_communicator.send_forward_recv_backward( - output_dict[module_name], tensor_shapes=tensor_shape, is_last_stage=False - ) - ) - return grad_dict + with nvtx.range(f"MultiModulePipelineCommunicator.send_forward_recv_backward_rank:{self.current_rank}"): + grad_dict = {} + for module_name, rank_module_info in self.rank_module_map.items(): + if rank_module_info.pp_stage == rank_module_info.pp_size - 1: + # If last stage, and has outgoing modules, send forward activation and + # receive backward gradient by using bridge communicator. + for bridge_comm in rank_module_info.bridge_comms_as_src_module: + with nvtx.range(f"send_forward_recv_backward_bridge_rank:{self.current_rank}:{bridge_comm.src_module_name}->{bridge_comm.dest_module_name}"): + logging.debug(f'[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] [send_forward_recv_backward] bridge [src - {bridge_comm.src_module_name}] [dest - {bridge_comm.dest_module_name}], module_name: {module_name} output_dict tensor shape: {output_dict[module_name].shape} sum {output_dict[module_name].sum()}') + tensor_to_send = _ensure_3d_tensor(output_dict[module_name]) + grad = bridge_comm.send_forward_recv_backward(tensor_to_send) + grad_dict[bridge_comm.src_module_name] = _restore_tensor_shape(grad) + logging.debug(f'[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] [send_forward_recv_backward] bridge [src - {bridge_comm.src_module_name}] [dest - {bridge_comm.dest_module_name}], module_name: {module_name} grad_dict grad shape: {grad_dict[bridge_comm.src_module_name].shape} sum {grad_dict[bridge_comm.src_module_name].sum()} DONE') + else: + # If not last stage, send forward activation and receive backward gradient + # by using P2P communicator. + with nvtx.range(f"send_forward_recv_backward_p2p_rank:{self.current_rank}:{module_name}"): + logging.debug(f'[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] [send_forward_recv_backward] p2p comm module_name: {module_name} output_dict tensor shape: {output_dict[module_name].shape} sum {output_dict[module_name].sum()}') + tensor_to_send = _ensure_3d_tensor(output_dict[module_name]) + grad = rank_module_info.p2p_communicator.send_forward_recv_backward( + tensor_to_send, tensor_shapes=tensor_shape, is_last_stage=False + ) + grad_dict[module_name] = _restore_tensor_shape(grad) + logging.debug(f'[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] [send_forward_recv_backward] p2p comm module_name: {module_name} grad_dict grad shape: {grad_dict[module_name].shape} sum {grad_dict[module_name].sum()} DONE') + return grad_dict def send_backward_recv_forward( self, @@ -342,31 +374,31 @@ def send_backward_recv_forward( Returns: A dictionary mapping module names to tensors. """ - logging.debug( - f"[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] " - f"[send_backward_recv_forward] grad_dict keys: {grad_dict.keys()}, " - f"tensor_shape: {tensor_shape}, is_first_stage: {is_first_stage}" - ) - input_dict = {} - for module_name, rank_module_info in self.rank_module_map.items(): - if rank_module_info.pp_rank == 0: - for bridge_comm in rank_module_info.bridge_comms_as_dest_module: - # If first stage, and has incoming modules, send backward gradient and - # receive forward activation by using bridge communicator. - input_dict[bridge_comm.src_module_name] = ( - bridge_comm.send_backward_recv_forward( - grad_dict[bridge_comm.src_module_name] + with nvtx.range(f"MultiModulePipelineCommunicator.send_backward_recv_forward_rank:{self.current_rank}"): + input_dict = {} + for module_name, rank_module_info in self.rank_module_map.items(): + if rank_module_info.pp_stage == 0: + for bridge_comm in rank_module_info.bridge_comms_as_dest_module: + # If first stage, and has incoming modules, send backward gradient and + # receive forward activation by using bridge communicator. + with nvtx.range(f"send_backward_recv_forward_bridge_rank:{self.current_rank}:{bridge_comm.src_module_name}->{bridge_comm.dest_module_name}"): + logging.debug(f'[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] [send_backward_recv_forward] bridge [src - {bridge_comm.src_module_name}] [dest - {bridge_comm.dest_module_name}], module_name: {module_name} grad_dict grad shape: {grad_dict[bridge_comm.src_module_name].shape} sum {grad_dict[bridge_comm.src_module_name].sum()}') + grad_to_send = _ensure_3d_tensor(grad_dict[bridge_comm.src_module_name]) + received_tensor = bridge_comm.send_backward_recv_forward(grad_to_send) + input_dict[bridge_comm.src_module_name] = _restore_tensor_shape(received_tensor) + logging.debug(f'[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] [send_backward_recv_forward] bridge [src - {bridge_comm.src_module_name}] [dest - {bridge_comm.dest_module_name}], module_name: {module_name} input_dict tensor shape: {input_dict[bridge_comm.src_module_name].shape} sum {input_dict[bridge_comm.src_module_name].sum()} DONE') + else: + # If not first stage, send backward gradient and receive forward activation + # by using P2P communicator. + with nvtx.range(f"send_backward_recv_forward_p2p_rank:{self.current_rank}:{module_name}"): + logging.debug(f'[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] [send_backward_recv_forward] p2p comm module_name: {module_name} grad_dict grad shape: {grad_dict[module_name].shape} sum {grad_dict[module_name].sum()}') + grad_to_send = _ensure_3d_tensor(grad_dict[module_name]) + received_tensor = rank_module_info.p2p_communicator.send_backward_recv_forward( + grad_to_send, tensor_shapes=tensor_shape, is_first_stage=False ) - ) - else: - # If not first stage, send backward gradient and receive forward activation - # by using P2P communicator. - input_dict[module_name] = ( - rank_module_info.p2p_communicator.send_backward_recv_forward( - grad_dict[module_name], tensor_shapes=tensor_shape, is_first_stage=False - ) - ) - return input_dict + input_dict[module_name] = _restore_tensor_shape(received_tensor) + logging.debug(f'[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] [send_backward_recv_forward] p2p comm module_name: {module_name} input_dict tensor shape: {input_dict[module_name].shape} sum {input_dict[module_name].sum()} DONE') + return input_dict def recv_backward( self, tensor_shape: Optional[Shape] = None, is_last_stage: bool = False @@ -379,23 +411,29 @@ def recv_backward( Returns: A dictionary mapping module names to tensors. """ - logging.debug( - f"[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] " - f"[recv_backward] tensor_shape: {tensor_shape}, is_last_stage: {is_last_stage}" - ) - grad_dict = {} - for module_name, rank_module_info in self.rank_module_map.items(): - if rank_module_info.pp_rank == rank_module_info.pp_size - 1: - # If last stage, and has incoming modules, receive backward gradient - # by using bridge communicator. - for bridge_comm in rank_module_info.bridge_comms_as_src_module: - grad_dict[bridge_comm.src_module_name] = bridge_comm.recv_backward() - else: - # If not last stage, receive backward gradient by using P2P communicator. - grad_dict[module_name] = rank_module_info.p2p_communicator.recv_backward( - tensor_shapes=tensor_shape, is_last_stage=False - ) - return grad_dict + with nvtx.range(f"MultiModulePipelineCommunicator.recv_backward_rank:{self.current_rank}"): + grad_dict = {} + for module_name, rank_module_info in self.rank_module_map.items(): + if rank_module_info.pp_stage == rank_module_info.pp_size - 1: + # If last stage, and has incoming modules, receive backward gradient + # by using bridge communicator. + for bridge_comm in rank_module_info.bridge_comms_as_src_module: + with nvtx.range(f"recv_backward_bridge_rank:{self.current_rank}:{bridge_comm.src_module_name}->{bridge_comm.dest_module_name}"): + logging.debug(f'[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] [recv_backward] bridge [src - {bridge_comm.src_module_name}] [dest - {bridge_comm.dest_module_name}], module_name: {module_name} ') + grad = bridge_comm.recv_backward() + grad_dict[bridge_comm.src_module_name] = _restore_tensor_shape(grad) + # time.sleep(10) + logging.debug(f'[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] [recv_backward] bridge [src - {bridge_comm.src_module_name}] [dest - {bridge_comm.dest_module_name}], module_name: {module_name} grad shape: {grad_dict[bridge_comm.src_module_name].shape} sum {grad_dict[bridge_comm.src_module_name].sum()} DONE') + else: + # If not last stage, receive backward gradient by using P2P communicator. + with nvtx.range(f"recv_backward_p2p_rank:{self.current_rank}:{module_name}"): + logging.debug(f'[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] [recv_backward] p2p comm module_name: {module_name} ') + grad = rank_module_info.p2p_communicator.recv_backward( + tensor_shapes=tensor_shape, is_last_stage=False + ) + grad_dict[module_name] = _restore_tensor_shape(grad) + logging.debug(f'[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] [recv_backward] p2p comm module_name: {module_name} DONE') + return grad_dict def send_backward(self, grad_dict: Dict[str, torch.Tensor], is_first_stage: bool = False): """Send backward activation tensor. @@ -403,21 +441,24 @@ def send_backward(self, grad_dict: Dict[str, torch.Tensor], is_first_stage: bool Args: grad_dict: A dictionary mapping module names to tensors. """ - logging.debug( - f"[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] " - f"[send_backward] grad_dict keys: {grad_dict.keys()}, is_first_stage: {is_first_stage}" - ) - for module_name, rank_module_info in self.rank_module_map.items(): - if rank_module_info.pp_rank == 0: - # If first stage, and has incoming modules, send backward activation - # by using bridge communicator. - for bridge_comm in rank_module_info.bridge_comms_as_dest_module: - bridge_comm.send_backward(grad_dict[bridge_comm.src_module_name]) - else: - # If not first stage, send backward activation by using P2P communicator. - rank_module_info.p2p_communicator.send_backward( - grad_dict[module_name], is_first_stage=False - ) + with nvtx.range(f"MultiModulePipelineCommunicator.send_backward_rank:{self.current_rank}"): + for module_name, rank_module_info in self.rank_module_map.items(): + if rank_module_info.pp_stage == 0: + # If first stage, and has incoming modules, send backward activation + # by using bridge communicator. + for bridge_comm in rank_module_info.bridge_comms_as_dest_module: + with nvtx.range(f"send_backward_bridge_rank:{self.current_rank}:{bridge_comm.src_module_name}->{bridge_comm.dest_module_name}"): + logging.debug(f'[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] [send_backward] bridge [src - {bridge_comm.src_module_name}] [dest - {bridge_comm.dest_module_name}], module_name: {module_name} grad shape: {grad_dict[bridge_comm.src_module_name].shape} sum {grad_dict[bridge_comm.src_module_name].sum()}') + grad_to_send = _ensure_3d_tensor(grad_dict[bridge_comm.src_module_name]) + bridge_comm.send_backward(grad_to_send) + logging.debug(f'[Rank {dist.get_rank()} ][MultiModulePipelineCommunicator] [send_backward] bridge [src - {bridge_comm.src_module_name}] [dest - {bridge_comm.dest_module_name}], module_name: {module_name} DONE') + else: + # If not first stage, send backward activation by using P2P communicator. + with nvtx.range(f"send_backward_p2p_rank:{self.current_rank}:{module_name}"): + grad_to_send = _ensure_3d_tensor(grad_dict[module_name]) + rank_module_info.p2p_communicator.send_backward( + grad_to_send, is_first_stage=False + ) @staticmethod def compute_total_pipeline_stages( @@ -446,7 +487,7 @@ def compute_total_pipeline_stages( If ``rank`` is provided, the result is the total number of pipeline stages up to (and including) the PP stage that ``rank`` occupies inside its module. In this case, the - weight of the target module equals (pp_rank_index(rank) + 1) instead of the module's + weight of the target module equals (pp_stage_index(rank) + 1) instead of the module's full PP size; other modules still contribute their full PP sizes. If the rank belongs to multiple modules (colocation), pass ``module_name`` to disambiguate; otherwise the maximum across all candidate modules containing the rank is returned. diff --git a/megatron/core/pipeline_parallel/p2p_communication.py b/megatron/core/pipeline_parallel/p2p_communication.py index 63ee9d1f537..0713010862f 100644 --- a/megatron/core/pipeline_parallel/p2p_communication.py +++ b/megatron/core/pipeline_parallel/p2p_communication.py @@ -7,6 +7,7 @@ import torch.distributed as dist from megatron.core.model_parallel_config import ModelParallelConfig +from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage from megatron.core.utils import nvtx_decorator # Types @@ -162,6 +163,21 @@ def __init__(self, pp_group: dist.ProcessGroup, config: ModelParallelConfig): else None ) + @property + def is_pp_first_stage(self): + """Return True if pp first stage.""" + return is_pp_first_stage(self.pp_group) + + @property + def is_pp_last_stage(self): + """Return True if pp last stage.""" + return is_pp_last_stage(self.pp_group) + + @property + def num_warmup_microbatches(self): + """Return number of warmup microbatches.""" + return self.pp_group.size() - self.pp_group.rank() - 1 + def _communicate_shapes(self, tensor_send_next, tensor_send_prev, recv_prev, recv_next): """Communicate tensor shapes between stages. Used to communicate tensor shapes before the actual tensor communication happens. @@ -214,22 +230,22 @@ def _communicate_shapes(self, tensor_send_next, tensor_send_prev, recv_prev, rec ops = [] if send_prev_shape_tensor is not None: send_prev_op = torch.distributed.P2POp( - torch.distributed.isend, send_prev_shape_tensor, self.prev_rank + torch.distributed.isend, send_prev_shape_tensor, self.prev_rank, self.pp_group ) ops.append(send_prev_op) if recv_prev_shape_tensor is not None: recv_prev_op = torch.distributed.P2POp( - torch.distributed.irecv, recv_prev_shape_tensor, self.prev_rank + torch.distributed.irecv, recv_prev_shape_tensor, self.prev_rank, self.pp_group ) ops.append(recv_prev_op) if send_next_shape_tensor is not None: send_next_op = torch.distributed.P2POp( - torch.distributed.isend, send_next_shape_tensor, self.next_rank + torch.distributed.isend, send_next_shape_tensor, self.next_rank, self.pp_group ) ops.append(send_next_op) if recv_next_shape_tensor is not None: recv_next_op = torch.distributed.P2POp( - torch.distributed.irecv, recv_next_shape_tensor, self.next_rank + torch.distributed.irecv, recv_next_shape_tensor, self.next_rank, self.pp_group ) ops.append(recv_next_op) if len(ops) > 0: diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index e83f8d90635..e4ae0a52d58 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -2,13 +2,18 @@ import contextlib from functools import partial -from typing import Callable, Iterator, List, Optional, Union +from typing import Callable, Dict, Iterator, List, Optional, Union import torch from torch.autograd.variable import Variable +import torch.distributed as dist +import logging from megatron.core import parallel_state from megatron.core.enums import ModelType +from megatron.core.pipeline_parallel.multimodule_communicator import ( + MultiModulePipelineCommunicator, +) from megatron.core.pipeline_parallel.p2p_communication import P2PCommunicator from megatron.core.pipeline_parallel.utils import ( is_pp_first_stage, @@ -146,6 +151,33 @@ def deallocate_output_tensor(out, deallocate_pipeline_outputs=False): out.data = torch.empty((1,), device=out.device, dtype=out.dtype) +def deallocate_output_tensor_container( + output_tensor: Union[torch.Tensor, List[torch.Tensor], Dict[str, torch.Tensor]], + deallocate_pipeline_outputs: bool = False, +): + '''Deallocate output tensor, handling both tensor and dict cases.''' + if not deallocate_pipeline_outputs: + return + + # Check if output_tensor is None or empty + if output_tensor is None: + return + if len(output_tensor) == 0: + return + + # Extract from list if needed + tensor = output_tensor[0] if isinstance(output_tensor, list) else output_tensor + + # Handle dict case - deallocate all tensor values + if isinstance(tensor, dict): + for value in tensor.values(): + if isinstance(value, torch.Tensor): + deallocate_output_tensor(value, deallocate_pipeline_outputs) + # Handle tensor case + elif isinstance(tensor, torch.Tensor): + deallocate_output_tensor(tensor, deallocate_pipeline_outputs) + + def custom_backward(output, grad_output): '''Directly call C++ autograd engine. @@ -260,9 +292,9 @@ def forward_step_calc_loss( if hasattr(config, 'num_moe_experts') and config.num_moe_experts is not None: # Calculate the loss scale based on the grad_scale_func if available, else default to 1. loss_scale = ( - config.grad_scale_func(torch.ones(1, device=output_tensor.device)) + config.grad_scale_func(torch.ones(1, device="cuda")) if config.grad_scale_func is not None - else torch.ones(1, device=output_tensor.device) + else torch.ones(1, device="cuda") ) # Set the loss scale if config.calculate_per_token_loss: @@ -428,6 +460,16 @@ def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, c If last stage, output_tensor_grad is None, otherwise gradient of loss with respect to stage's output tensor. + Supports both tensor and dictionary formats: + + - Tensor format (legacy): input_tensor, output_tensor, output_tensor_grad are tensors/lists + - Dictionary format (multi module case): tensors are dictionaries with module names as keys + + - input_tensor: dict with module names as keys + - output_tensor: dict with module names as keys (or scalar loss for last stage) + - output_tensor_grad: dict with module names as keys (or None for last stage) + - Returns: input_tensor_grad as dict with same keys as input_tensor + Returns gradient of loss with respect to input tensor (None if first stage).""" @@ -438,6 +480,24 @@ def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, c if config.timers is not None: config.timers('backward-compute', log_level=2).start() + # Detect if we're using dictionary format + is_dict_format = isinstance(input_tensor, dict) or isinstance(output_tensor, dict) + + if is_dict_format: + # Handle dictionary format for multi-module pipeline + return _backward_step_dict( + input_tensor, output_tensor, output_tensor_grad, model_type, config + ) + else: + # Handle legacy tensor format + return _backward_step_tensor( + input_tensor, output_tensor, output_tensor_grad, model_type, config + ) + + +def _backward_step_tensor(input_tensor, output_tensor, output_tensor_grad, model_type, config): + """Backward step implementation when inputs/outputs are tensors.""" + # Retain the grad on the input_tensor. unwrap_input_tensor_grad = False if not isinstance(input_tensor, list): @@ -486,6 +546,65 @@ def backward_step(input_tensor, output_tensor, output_tensor_grad, model_type, c return input_tensor_grad +def _backward_step_dict(input_tensor, output_tensor, output_tensor_grad, model_type, config): + """Backward step implementation when inputs/outputs are dictionaries (multi module case).""" + + # Retain gradients on all input tensors + for module_name, tensor in input_tensor.items(): + if isinstance(tensor, list): + tensor = tensor[0] + if tensor is not None: + tensor.retain_grad() + + # Last stage: output_tensor is a scalar loss, wrap in dict for uniform handling + # Use the first input tensor key as the main module name + # for now last stage only has one module LLM + if not isinstance(output_tensor, dict): + all_keys = list(input_tensor.keys()) + assert len(all_keys) == 1, "Last stage only has one module - LLM" + main_module_key = all_keys[0] + output_tensor = {main_module_key: output_tensor} + + # Handle output_tensor_grad: None (last stage) or dict (intermediate stages) + if not output_tensor_grad: + # Last stage: no gradient from next stage + output_tensor_grad = {key: None for key in output_tensor.keys()} + + # Apply grad scaling if needed (for last stage) + for module_name in output_tensor.keys(): + if output_tensor_grad[module_name] is None and config.grad_scale_func is not None: + output_tensor[module_name] = config.grad_scale_func(output_tensor[module_name]) + + # Perform backward pass for each module + for module_name in output_tensor.keys(): + output_tensor_module = output_tensor[module_name] + output_tensor_grad_module = output_tensor_grad[module_name] + + # Skip backward if tensor doesn't require gradients + if output_tensor_module is not None and output_tensor_module.requires_grad: + if config.deallocate_pipeline_outputs: + custom_backward(output_tensor_module, output_tensor_grad_module) + else: + torch.autograd.backward( + output_tensor_module, grad_tensors=output_tensor_grad_module + ) + + # Collect gradients for input tensors + input_tensor_grad = {} + for module_name, tensor in input_tensor.items(): + if isinstance(tensor, list): + tensor = tensor[0] + if tensor is None: + input_tensor_grad[module_name] = None + else: + input_tensor_grad[module_name] = tensor.grad + + if config.timers is not None: + config.timers('backward-compute').stop() + + return input_tensor_grad + + def check_first_val_step(first_val_step, forward_only, cond): """Check if it is the first validation step.""" if (first_val_step is not None) and forward_only: @@ -1926,8 +2045,8 @@ def get_tensor_shapes( micro_batch_size: int, decoder_seq_length: int, config, - tp_group: torch.distributed.ProcessGroup, - cp_group: torch.distributed.ProcessGroup, + tp_group: Optional[torch.distributed.ProcessGroup] = None, + cp_group: Optional[torch.distributed.ProcessGroup] = None, ): """ Determine right tensor sizes (based on position of rank with respect to split rank) and @@ -1935,15 +2054,23 @@ def get_tensor_shapes( """ tensor_shapes = [] - # Use decoder_seq_length if provided, otherwise use seq_length - effective_seq_length = decoder_seq_length if decoder_seq_length is not None else seq_length - effective_seq_length = effective_seq_length // cp_group.size() + if config.variable_seq_lengths: + # this is actually not used + # with variable seq_lengths, ranks exchange the tensor shape with each other + tensor_shapes.append(()) + return tensor_shapes + else: + # Use decoder_seq_length if provided, otherwise use seq_length + assert cp_group is not None, "cp_group is required for non-variable seq_lengths" + assert tp_group is not None, "tp_group is required for non-variable seq_lengths" + effective_seq_length = decoder_seq_length if decoder_seq_length is not None else seq_length + effective_seq_length = effective_seq_length // cp_group.size() - if config.sequence_parallel: - effective_seq_length = effective_seq_length // tp_group.size() + if config.sequence_parallel: + effective_seq_length = effective_seq_length // tp_group.size() - tensor_shapes.append((effective_seq_length, micro_batch_size, config.hidden_size)) - return tensor_shapes + tensor_shapes.append((effective_seq_length, micro_batch_size, config.hidden_size)) + return tensor_shapes def forward_backward_pipelining_without_interleaving( @@ -1959,8 +2086,8 @@ def forward_backward_pipelining_without_interleaving( collect_non_loss_data: bool = False, first_val_step: Optional[bool] = None, adjust_tensor_shapes_fn: Optional[Callable] = None, - p2p_communicator: Optional[P2PCommunicator] = None, - pg_collection: Optional[ProcessGroupCollection] = None, + p2p_communicator: Optional[Union[P2PCommunicator, MultiModulePipelineCommunicator]] = None, + pg_collection: Optional[Union[ProcessGroupCollection, List[ProcessGroupCollection]]] = None, ): """Run non-interleaved 1F1B schedule, with communication between pipeline stages. Returns dictionary with losses if the last stage, empty dict otherwise.""" @@ -1981,7 +2108,7 @@ def forward_backward_pipelining_without_interleaving( raise ValueError( "Non-interleaved pipeline parallelism does not support overlapping p2p communication" ) - + tp_group, cp_group = None, None if p2p_communicator is None and pg_collection is None: p2p_communicator = P2PCommunicator( pp_group=parallel_state.get_pipeline_model_parallel_group(), config=config @@ -2001,33 +2128,21 @@ def forward_backward_pipelining_without_interleaving( pg_collection.dp_cp = parallel_state.get_data_parallel_group( with_context_parallel=True, partial_data_parallel=False ) + llm_cp_size = cp_group.size() elif p2p_communicator is not None and pg_collection is not None: - model_type = get_model_type(model) - assert model_type != ModelType.encoder_and_decoder, ( - "encoder PP stages not yet supported when passing custom process groups. " - "support coming soon!" - ) - assert hasattr(p2p_communicator, 'config'), "p2p_communicator must have a config" - assert hasattr(pg_collection, 'tp'), "pg_collection must have tp_group" - assert hasattr(pg_collection, 'cp'), "pg_collection must have cp_group" - assert hasattr(pg_collection, 'embd'), ( - "pg_collection must have a embd. In previous version, it is used default " - "`parallel_state.default_embedding_ranks` to create the process group. " - " If you are using the default process group, please use " - " `parallel_state.get_embedding_group()` " - "If you don't need embd_group, you need to explicitly set it to None." - ) - assert hasattr(pg_collection, 'pos_embd'), ( - "pg_collection must have a pos_embd. In previous version, it is used default " - "`parallel_state.default_position_embedding_ranks` to create the process group. " - " If you are using the default process group, please use " - " `parallel_state.get_position_embedding_group()` " - "If you don't need pos_embd_group, you need to explicitly set it to None." - ) - assert hasattr(pg_collection, 'pp'), "pg_collection must have pp_group" - assert hasattr(pg_collection, 'dp_cp'), "pg_collection must have dp_cp_group" - tp_group = pg_collection.tp - cp_group = pg_collection.cp + if isinstance(pg_collection, list): + # cases when multiple modules are colocated + assert config.variable_seq_lengths, "variable seq_lengths is required when multiple modules are colocated" + # when llm is colocated for now assume last collection in the list is the llm + # TODO: ykarnati: Have a better interface to handle this (without breaking backward compatibility) + assert hasattr(pg_collection[-1], 'cp'), "pg_collection must have cp_group" + llm_cp_size = pg_collection[-1].cp.size() + else: + assert hasattr(pg_collection, 'tp'), "pg_collection must have tp_group" + assert hasattr(pg_collection, 'cp'), "pg_collection must have cp_group" + tp_group = pg_collection.tp + cp_group = pg_collection.cp + llm_cp_size = pg_collection.cp.size() else: raise ValueError( "Invalid combination of p2p_communicator, pg_collection " @@ -2037,7 +2152,7 @@ def forward_backward_pipelining_without_interleaving( # Needed only when gradients are finalized in M-Core if config.finalize_model_grads_func is not None and not forward_only: embedding_module = clear_embedding_activation_buffer( - config, model, is_pp_last_stage(p2p_communicator.pp_group) + config, model, p2p_communicator.is_pp_last_stage ) if config.timers is not None: @@ -2066,9 +2181,7 @@ def enable_grad_sync(): disable_grad_sync() # Compute number of warmup microbatches. - num_warmup_microbatches = ( - p2p_communicator.pp_group.size() - p2p_communicator.pp_group.rank() - 1 - ) + num_warmup_microbatches = p2p_communicator.num_warmup_microbatches num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) num_microbatches_remaining = num_microbatches - num_warmup_microbatches @@ -2086,7 +2199,7 @@ def enable_grad_sync(): model_type = get_model_type(model) - rank = p2p_communicator.pp_group.rank() + # rank = p2p_communicator.pp_group.rank() recv_tensor_shapes = get_tensor_shapes( seq_length=seq_length, micro_batch_size=micro_batch_size, @@ -2118,9 +2231,11 @@ def enable_grad_sync(): output_tensors = [] forward_data_store = [] + logging.debug(f"[Rank {dist.get_rank()} ][forward_backward_pipelining_without_interleaving] [warmup] num_warmup_microbatches {num_warmup_microbatches} num_microbatches_remaining {num_microbatches_remaining}") # Run warmup forward passes. for i in range(num_warmup_microbatches): # Decide to checkpoint all layers' activations of the current micro-batch + logging.debug(f"[Rank {dist.get_rank()} ][forward_backward_pipelining_without_interleaving] [warmup] currennt microbatch index i {i} ]") if max_outstanding_backprops is not None: checkpoint_activations_microbatch = ( i % max_outstanding_backprops @@ -2128,10 +2243,11 @@ def enable_grad_sync(): ) else: checkpoint_activations_microbatch = None - + logging.debug(f'[Rank {dist.get_rank()} ][forward_backward_pipelining_without_interleaving] [warmup] [recv_forward]') input_tensor = p2p_communicator.recv_forward( - recv_tensor_shapes, is_pp_first_stage(p2p_communicator.pp_group) + recv_tensor_shapes, p2p_communicator.is_pp_first_stage ) + logging.debug(f"[Rank {dist.get_rank()} ][forward_backward_pipelining_without_interleaving] [warmup] [forward_step]") output_tensor, num_tokens = forward_step( forward_step_func, data_iterator, @@ -2140,31 +2256,34 @@ def enable_grad_sync(): input_tensor, forward_data_store, config, - cp_group_size=pg_collection.cp.size(), + cp_group_size=llm_cp_size, collect_non_loss_data=collect_non_loss_data, checkpoint_activations_microbatch=checkpoint_activations_microbatch, is_first_microbatch=check_first_val_step(first_val_step, forward_only, i == 0), current_microbatch=i, - is_last_stage=is_pp_last_stage(p2p_communicator.pp_group), + is_last_stage=p2p_communicator.is_pp_last_stage, ) - p2p_communicator.send_forward(output_tensor, is_pp_last_stage(p2p_communicator.pp_group)) + logging.debug(f'[Rank {dist.get_rank()} ][forward_backward_pipelining_without_interleaving] [warmup] [send_forward]') + p2p_communicator.send_forward(output_tensor, p2p_communicator.is_pp_last_stage) total_num_tokens += num_tokens if not forward_only: input_tensors.append(input_tensor) output_tensors.append(output_tensor) - deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs) + deallocate_output_tensor_container(output_tensor, config.deallocate_pipeline_outputs) # Before running 1F1B, need to receive first forward tensor. # If all microbatches are run in warmup / cooldown phase, then no need to # receive this tensor here. if num_microbatches_remaining > 0: + logging.debug(f'[Rank {dist.get_rank()} ][forward_backward_pipelining_without_interleaving] [after warmup] [recv_forward]') input_tensor = p2p_communicator.recv_forward( - recv_tensor_shapes, is_pp_first_stage(p2p_communicator.pp_group) + recv_tensor_shapes, p2p_communicator.is_pp_first_stage ) # Run 1F1B in steady state. for i in range(num_microbatches_remaining): + logging.debug(f"[Rank {dist.get_rank()} ][forward_backward_pipelining_without_interleaving] [steady state] current microbatch index i {i} num_microbatches_remaining {num_microbatches_remaining} num_warmup_microbatches {num_warmup_microbatches}") last_iteration = i == (num_microbatches_remaining - 1) # Decide to checkpoint all layers' activations of the current micro-batch @@ -2174,7 +2293,7 @@ def enable_grad_sync(): ) >= config.num_microbatches_with_partial_activation_checkpoints else: checkpoint_activations_microbatch = None - + logging.debug(f"[Rank {dist.get_rank()} ][forward_backward_pipelining_without_interleaving] [steady state] [forward_step]") output_tensor, num_tokens = forward_step( forward_step_func, data_iterator, @@ -2183,34 +2302,34 @@ def enable_grad_sync(): input_tensor, forward_data_store, config, - cp_group_size=pg_collection.cp.size(), + cp_group_size=llm_cp_size, collect_non_loss_data=collect_non_loss_data, checkpoint_activations_microbatch=checkpoint_activations_microbatch, is_first_microbatch=check_first_val_step( first_val_step, forward_only, (i == 0) and (num_warmup_microbatches == 0) ), current_microbatch=i + num_warmup_microbatches, - is_last_stage=is_pp_last_stage(p2p_communicator.pp_group), + is_last_stage=p2p_communicator.is_pp_last_stage, ) total_num_tokens += num_tokens if forward_only: - p2p_communicator.send_forward( - output_tensor, is_pp_last_stage(p2p_communicator.pp_group) - ) + logging.debug(f'[Rank {dist.get_rank()} ][forward_backward_pipelining_without_interleaving] [steady state] [send_forward]') + p2p_communicator.send_forward(output_tensor, p2p_communicator.is_pp_last_stage) if not last_iteration: input_tensor = p2p_communicator.recv_forward( - recv_tensor_shapes, is_pp_first_stage(p2p_communicator.pp_group) + recv_tensor_shapes, p2p_communicator.is_pp_first_stage ) else: + logging.debug(f'[Rank {dist.get_rank()} ][forward_backward_pipelining_without_interleaving] [steady state] [send_forward_recv_backward]') output_tensor_grad = p2p_communicator.send_forward_recv_backward( - output_tensor, send_tensor_shapes, is_pp_last_stage(p2p_communicator.pp_group) + output_tensor, send_tensor_shapes, p2p_communicator.is_pp_last_stage ) # Add input_tensor and output_tensor to end of list. input_tensors.append(input_tensor) output_tensors.append(output_tensor) - deallocate_output_tensor(output_tensor[0], config.deallocate_pipeline_outputs) + deallocate_output_tensor_container(output_tensor, config.deallocate_pipeline_outputs) # Pop input_tensor and output_tensor from the start of the list for # the backward pass. @@ -2220,28 +2339,29 @@ def enable_grad_sync(): # Enable grad sync for the last microbatch in the batch if the full # backward pass completes in the 1F1B stage. if num_warmup_microbatches == 0 and last_iteration: - if config.grad_sync_func is None or rank == 0: + if config.grad_sync_func is None or p2p_communicator.is_pp_first_stage: enable_grad_sync() - + logging.debug(f'[Rank {dist.get_rank()} ][forward_backward_pipelining_without_interleaving] [steady state] [backward_step] num_warmup_microbatches {num_warmup_microbatches} last_iteration {last_iteration} no_sync_context {no_sync_context}') input_tensor_grad = backward_step( input_tensor, output_tensor, output_tensor_grad, model_type, config ) if last_iteration: input_tensor = None + logging.debug(f'[Rank {dist.get_rank()} ][forward_backward_pipelining_without_interleaving] [steady state] [send_backward]') p2p_communicator.send_backward( - input_tensor_grad, is_pp_first_stage(p2p_communicator.pp_group) + input_tensor_grad, p2p_communicator.is_pp_first_stage ) else: + logging.debug(f'[Rank {dist.get_rank()} ][forward_backward_pipelining_without_interleaving] [steady state] [send_backward_recv_forward]') input_tensor = p2p_communicator.send_backward_recv_forward( - input_tensor_grad, - recv_tensor_shapes, - is_pp_first_stage(p2p_communicator.pp_group), + input_tensor_grad, recv_tensor_shapes, p2p_communicator.is_pp_first_stage ) # Run cooldown backward passes. if not forward_only: for i in range(num_warmup_microbatches): + logging.debug(f"[Rank {dist.get_rank()} ][forward_backward_pipelining_without_interleaving] [cooldown] current microbatch index i {i} num_warmup_microbatches {num_warmup_microbatches}") # Enable async grad reduction in the last backward pass # Note: If grad sync function is provided, only enable @@ -2249,23 +2369,21 @@ def enable_grad_sync(): # pipeline stages do grad reduction during pipeline # bubble. if i == num_warmup_microbatches - 1: - if config.grad_sync_func is None or rank == 0: + if config.grad_sync_func is None or p2p_communicator.is_pp_first_stage: enable_grad_sync() input_tensor = input_tensors.pop(0) output_tensor = output_tensors.pop(0) - + logging.debug(f'[Rank {dist.get_rank()} ][forward_backward_pipelining_without_interleaving] [cooldown] [recv_backward]') output_tensor_grad = p2p_communicator.recv_backward( - send_tensor_shapes, is_pp_last_stage(p2p_communicator.pp_group) + send_tensor_shapes, p2p_communicator.is_pp_last_stage ) - + logging.debug(f'[Rank {dist.get_rank()} ][forward_backward_pipelining_without_interleaving] [cooldown] [backward_step]') input_tensor_grad = backward_step( input_tensor, output_tensor, output_tensor_grad, model_type, config ) - - p2p_communicator.send_backward( - input_tensor_grad, is_pp_first_stage(p2p_communicator.pp_group) - ) + logging.debug(f'[Rank {dist.get_rank()} ][forward_backward_pipelining_without_interleaving] [cooldown] [send_backward]') + p2p_communicator.send_backward(input_tensor_grad, p2p_communicator.is_pp_first_stage) # Launch any remaining grad reductions. if no_sync_context is not None: @@ -2278,7 +2396,7 @@ def enable_grad_sync(): # If defer_embedding_wgrad_compute is enabled we need to do the # weight gradient GEMM's here. finish_embedding_wgrad_compute( - config, embedding_module, is_pp_last_stage(p2p_communicator.pp_group), tp_group + config, embedding_module, p2p_communicator.is_pp_last_stage, tp_group ) # Finalize model grads (perform full grad all-reduce / reduce-scatter for diff --git a/megatron/core/process_groups_config.py b/megatron/core/process_groups_config.py index 07c922ea685..d55b705a68a 100644 --- a/megatron/core/process_groups_config.py +++ b/megatron/core/process_groups_config.py @@ -140,6 +140,18 @@ def __init__(self, **kwargs): else: raise ValueError(f"Unknown attribute: {key}") + def __repr__(self): + """Return a concise representation showing which process groups exist and their sizes.""" + active_pgs = [] + for field_info in fields(self): + try: + pg = getattr(self, field_info.name, None) + if pg is not None and hasattr(pg, 'size'): + active_pgs.append(f"{field_info.name}({pg.size()})") + except AttributeError: + continue + return f"ProcessGroupCollection({', '.join(active_pgs)})" if active_pgs else "ProcessGroupCollection(empty)" + @classmethod def use_mpu_process_groups(cls, required_pgs: Optional[List[str]] = None): """ diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index aead6133f22..9a6baaff4d4 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -555,7 +555,7 @@ def __call__(self, *args, **kwargs): def forward( self, hidden_states: Union[Tensor, WrappedTensor], - attention_mask: Optional[Tensor], + attention_mask: Optional[Tensor] = None, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, rotary_pos_emb: Optional[Tensor] = None, diff --git a/run.sh b/run.sh new file mode 100644 index 00000000000..333c3606c5a --- /dev/null +++ b/run.sh @@ -0,0 +1,3 @@ +export PYTHONPATH=${PWD}:${PYTHONPATH} +mkdir -p ../logs +torchrun --nproc_per_node=3 tests/unit_tests/models/heterogenous_parallel/train.py 2>&1 | tee ../logs/train.log diff --git a/tests/unit_tests/models/heterogenous_parallel/__init__.py b/tests/unit_tests/models/heterogenous_parallel/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/unit_tests/models/heterogenous_parallel/data.py b/tests/unit_tests/models/heterogenous_parallel/data.py new file mode 100644 index 00000000000..df0714de549 --- /dev/null +++ b/tests/unit_tests/models/heterogenous_parallel/data.py @@ -0,0 +1,88 @@ +from typing import Any, Dict, Iterator, List +import torch +from examples.mimo.data.mock import MockVLMDataset +from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage +from tests.unit_tests.models.heterogenous_parallel.parallel_utils import is_current_rank_in_grid +from torch.utils.data import DataLoader + +def _collate_fn(batch: List[Dict], image_seq_length: int = 1024, hidden_size: int = 1024) -> Dict[str, torch.Tensor]: + """ + Collate function for the DataLoader. + + Args: + batch: List of dictionaries from the dataset + image_seq_length: Sequence length for image tokens + hidden_size: Hidden size for the vision encoder output + + Returns: + Dictionary of batched tensors + """ + input_ids = torch.stack([item["input_ids"] for item in batch]) + labels = torch.stack([item["labels"] for item in batch]) + loss_mask = torch.stack([item["loss_mask"] for item in batch]) + position_ids = torch.stack([item["position_ids"] for item in batch]) + + bsz = input_ids.shape[0] + + return { + "input_ids": input_ids, + "labels": labels, + "loss_mask": loss_mask, + "position_ids": position_ids, + "modality_inputs": { + "images": { + "clip_encoder": {'hidden_states': torch.randn(image_seq_length, bsz, hidden_size, dtype=torch.bfloat16), 'attention_mask': None}, + } + }, + } + +def move_to_device(data, device): + """Recursively move tensors in nested dicts to device.""" + if isinstance(data, torch.Tensor): + return data.to(device) + elif isinstance(data, dict): + return {k: move_to_device(v, device) for k, v in data.items()} + return data + +def get_data_iterator(encoder_grid, llm_grid, image_seq_length, seq_length, image_special_token_id, batch_size, vocab_size, vision_hidden_size): + data_iterator = None + + # we initialize iterator on first pp stage of encoders and LLM + + encoder_1_condition = is_current_rank_in_grid(encoder_grid) and is_pp_first_stage( + encoder_grid.get_pg("pp") + ) + + + llm_condition = is_current_rank_in_grid(llm_grid) and (is_pp_first_stage( + llm_grid.get_pg("pp") + ) or is_pp_last_stage(llm_grid.get_pg("pp"))) + + if encoder_1_condition or llm_condition: + dataset = MockVLMDataset( + size=256, + image_size=224, + seq_len=seq_length, + image_seq_length=image_seq_length, + pad_token_id=0, + image_token_id=image_special_token_id + ) + dataloader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + num_workers=0, + collate_fn=lambda batch: _collate_fn(batch, image_seq_length=image_seq_length, hidden_size=vision_hidden_size), + ) + data_iterator = iter(dataloader) + return data_iterator + +def get_batch(data_iterator: Iterator[Dict[str, Any]]): + if data_iterator is not None: + input_tensor = next(data_iterator) + if input_tensor is not None: + input_tensor = move_to_device(input_tensor, torch.device("cuda")) + else: + input_tensor = None + + return input_tensor \ No newline at end of file diff --git a/tests/unit_tests/models/heterogenous_parallel/model_specs.py b/tests/unit_tests/models/heterogenous_parallel/model_specs.py new file mode 100644 index 00000000000..97573d7fa6b --- /dev/null +++ b/tests/unit_tests/models/heterogenous_parallel/model_specs.py @@ -0,0 +1,167 @@ +import torch.distributed as dist +import torch +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.models.gpt.gpt_model import GPTModel +from megatron.core.models.mimo.config.base_configs import MimoModelConfig +from megatron.core.models.mimo.model.base import MimoModel +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_block import TransformerBlock +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.models.vision.multimodal_projector import MultimodalProjector +from examples.mimo.configs.llava_vlm import get_llava_projection_layer_spec, get_llava_projection_config +from megatron.core.models.mimo.submodules.vision import VisionModalitySubmodules +from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig +from tests.unit_tests.pipeline_parallel.test_multimodule_schedules import create_hypercomm_grid, _get_pg_collection_with_embedding_groups +from parallel_utils import _create_pg_collection + +def get_language_model_spec(num_layers, num_moe_experts, hidden_size, vocab_size, seq_len, pg_collection): + """Get the language model spec.""" + # Determine pre_process and post_process based on PP rank + pp_rank = dist.get_rank(pg_collection.pp) + pp_size = dist.get_world_size(pg_collection.pp) + pre_process = (pp_rank == 0) + post_process = (pp_rank == pp_size - 1) + + print(f"[get_language_model_spec] Rank {dist.get_rank()}: PP rank={pp_rank}/{pp_size}, " + f"pre_process={pre_process}, post_process={post_process}") + + tp_size = pg_collection.tp.size() if pg_collection.tp is not None else 1 + pp_size = pg_collection.pp.size() if pg_collection.pp is not None else 1 + + lm_config = TransformerConfig( + num_layers=num_layers, num_moe_experts=num_moe_experts, hidden_size=hidden_size, num_attention_heads=4, use_cpu_initialization=True, variable_seq_lengths=True, moe_token_dispatcher_type= 'alltoall', tensor_model_parallel_size=tp_size, pipeline_model_parallel_size=pp_size, pipeline_dtype=torch.bfloat16, bf16=True, + cross_entropy_loss_fusion=True, + cross_entropy_fusion_impl='native', + ) + language_layer_spec = get_gpt_layer_with_transformer_engine_spec() + language_model_spec = ModuleSpec( + module=GPTModel, + params={ + "config": lm_config, + "transformer_layer_spec": language_layer_spec, + "vocab_size": vocab_size, + "max_sequence_length": seq_len, + "pre_process": pre_process, + "post_process": post_process, + "pg_collection": pg_collection, + }, + ) + return language_model_spec + + +def get_vision_submodules_spec(num_layers, num_moe_experts, hidden_size, language_hidden_size, pg_collection): + """Get the submodule spec for the vision modality. + + Args: + num_layers: Number of transformer layers in vision encoder + hidden_size: Hidden size of vision encoder + language_hidden_size: Hidden size of language model (for projection output) + pg_collection: Process group collection + """ + vision_layer_spec = get_gpt_layer_with_transformer_engine_spec() + + tp_size = pg_collection.tp.size() if pg_collection.tp is not None else 1 + pp_size = pg_collection.pp.size() if pg_collection.pp is not None else 1 + + vision_config = TransformerConfig( + num_layers=num_layers, num_moe_experts=num_moe_experts, hidden_size=hidden_size, num_attention_heads=4, use_cpu_initialization=True, variable_seq_lengths=True, moe_token_dispatcher_type= 'alltoall', tensor_model_parallel_size=tp_size, pipeline_model_parallel_size=pp_size, pipeline_dtype=torch.bfloat16, bf16=True, + ) + vision_encoder_spec = ModuleSpec( + module=TransformerBlock, + params={ + "config": vision_config, + "spec": vision_layer_spec, + "pg_collection": pg_collection, + "pre_process": True, + "post_process": True + }, + ) + + # Create vision projection spec - projects from vision hidden size to language hidden size + vision_projection_spec = ModuleSpec( + module=MultimodalProjector, + params={ + "config": get_llava_projection_config( + hidden_size=language_hidden_size # Output size should match language model + ), + "submodules": get_llava_projection_layer_spec().submodules, + "projector_type": "mlp", + "input_size": vision_config.hidden_size, # Input size from vision encoder + "tp_group": pg_collection.tp, + }, + ) + + # Create vision modality spec + vision_submodule_spec = ModuleSpec( + module=VisionModalitySubmodules, + submodules={ + "encoders": {"clip_encoder": vision_encoder_spec}, + "input_projections": [vision_projection_spec], + }, + ) + + return vision_submodule_spec + + +def get_vlm_mimo_model( + vision_num_layers, vision_num_moe_experts, vision_hidden_size, language_num_layers, language_num_moe_experts, language_hidden_size, + vocab_size, seq_len, special_token_ids, + vision_tp, vision_pp, vision_dp, vision_cp, vision_ep, + language_tp, language_pp, language_dp, language_cp, language_ep, +): + global_world_size = dist.get_world_size() + vision_grid_size = vision_tp * vision_pp * vision_dp * vision_cp + language_grid_size = language_tp * language_pp * language_dp * language_cp + assert global_world_size == vision_grid_size + language_grid_size, \ + f"global_world_size ({global_world_size}) should be equal to vision_grid_size ({vision_grid_size}) + language_grid_size ({language_grid_size})" + + num_distributed_optimizer_instances=1 + rank_offset=0 + world_size=vision_grid_size + vision_pg_collection, vision_grid, vision_expert_grid = _create_pg_collection(vision_tp, vision_pp, vision_cp, vision_ep, num_distributed_optimizer_instances, rank_offset, world_size) + + num_distributed_optimizer_instances=1 + rank_offset=vision_grid_size + world_size=language_grid_size + language_pg_collection, language_grid, language_expert_grid = _create_pg_collection(language_tp, language_pp, language_cp, language_ep, num_distributed_optimizer_instances, rank_offset, world_size) + + vision_submodule_spec = get_vision_submodules_spec(vision_num_layers, vision_num_moe_experts, vision_hidden_size, language_hidden_size, vision_pg_collection) + language_model_spec = get_language_model_spec(language_num_layers, language_num_moe_experts, language_hidden_size, vocab_size, seq_len, language_pg_collection) + + mimo_config = MimoModelConfig( + language_model_spec=language_model_spec, + modality_submodules_spec={"images": vision_submodule_spec,}, + special_token_ids=special_token_ids, + ) + # Create MIMO model + mimo_model = MimoModel(mimo_config) + # TODO(shifangx): need map from model name to more than one grid, or to one ProcessGroupCollection + module_to_grid_map = {'images': vision_grid, 'language_module': language_grid} + topology = { + 'images': ['language_module'], # images sends forward results to language_module + 'language_module': [], # language_module is the last stage here + } + + + mimo_model.to(torch.device("cuda")).to(torch.bfloat16) + + ddp_config = DistributedDataParallelConfig(overlap_grad_reduce=True, bucket_size=10000) + if mimo_model.language_model is not None: + mimo_model.language_model = DistributedDataParallel( + config=mimo_model.language_model.config, + ddp_config=ddp_config, + module=mimo_model.language_model, + pg_collection=language_pg_collection + ) + submodule = mimo_model.modality_submodules['images'] + + if submodule is not None: + submodule = DistributedDataParallel( + config=submodule.encoders['clip_encoder'].config, + ddp_config=ddp_config, + module=submodule, + pg_collection=vision_pg_collection + ) + mimo_model.modality_submodules['images'] = submodule + + return mimo_model, module_to_grid_map, topology, vision_pg_collection, language_pg_collection diff --git a/tests/unit_tests/models/heterogenous_parallel/parallel_utils.py b/tests/unit_tests/models/heterogenous_parallel/parallel_utils.py new file mode 100644 index 00000000000..46369038e94 --- /dev/null +++ b/tests/unit_tests/models/heterogenous_parallel/parallel_utils.py @@ -0,0 +1,192 @@ +from tests.unit_tests.pipeline_parallel.test_multimodule_schedules import create_hypercomm_grid, _get_pg_collection_with_embedding_groups +import torch.distributed as dist +from contextlib import contextmanager +from megatron.core.distributed.finalize_model_grads import finalize_model_grads as _finalize_model_grads +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.hyper_comm_grid import HyperCommGrid +from typing import Optional +import torch + +def is_current_rank_in_grid(grid) -> bool: + """Check if the current rank is in the grid.""" + return grid.rank_offset <= dist.get_rank() < (grid.rank_offset + grid.size) + + +def get_module_to_grid_tuple(mimo_model, vision_module_grid, language_module_grid): + return_tuple = [(mimo_model.modality_submodules['images'], vision_module_grid), (mimo_model.language_model, language_module_grid)] + return return_tuple + + + +@contextmanager +def multimodule_no_sync(module_to_grid_tuple): + contexts = [] + for module, grid in module_to_grid_tuple: + if module is not None and is_current_rank_in_grid(grid): + contexts.append(module.no_sync()) + + # Enter all contexts + for ctx in contexts: + ctx.__enter__() + + try: + yield + finally: + # Exit all contexts in reverse order + for ctx in reversed(contexts): + ctx.__exit__(None, None, None) + + +def get_pg_collections_for_rank(module_to_grid_map): + """Get pg_collections for modules that should be initialized on the current rank.""" + pg_collections = [] + for _ , grid_name in module_to_grid_map.items(): + if is_current_rank_in_grid(grid_name): + pg_collections.append(_get_pg_collection_with_embedding_groups(grid_name)) + return pg_collections + +def finalize_model_grads(model, num_tokens=None, pg_collection=None, *, module_to_grid_tuple): + """Wrapper to call finalize_model_grads for each module in its respective grid. + + Args: + model: Model list (passed by scheduler, but not used - we use module_to_grid_tuple instead) + num_tokens: Number of tokens for gradient scaling + pg_collection: Process group collection + module_to_grid_tuple: Tuple of (module, grid) pairs to finalize grads for each module in its grid + """ + for module, grid in module_to_grid_tuple: + if module is not None and is_current_rank_in_grid(grid): + _finalize_model_grads([module], num_tokens=num_tokens, pg_collection=pg_collection) + + +def zero_grad_buffer_for_multimodule(module_to_grid_tuple): + """Reset gradient buffers for all DDP-wrapped modules in their respective grids. + + Args: + module_to_grid_tuple: Tuple of (module, grid) pairs to reset grads for each module + """ + for module, grid in module_to_grid_tuple: + if module is not None and is_current_rank_in_grid(grid): + module.zero_grad_buffer() + + +def _create_pg_collection( + tp_size: int, pp_size: int, cp_size: int, ep_size: int, num_distributed_optimizer_instances: int, rank_offset: int, world_size: int +) -> ProcessGroupCollection: + """Create all process groups via HyperCommGrid and return a ProcessGroupCollection.""" + # world_size = torch.distributed.get_world_size() + model_size = tp_size * pp_size * cp_size + if world_size % model_size != 0: + raise RuntimeError(f"world_size ({world_size}) is not divisible by {model_size}") + dp_size = world_size // model_size + + grid = HyperCommGrid( + shape=[tp_size, cp_size, dp_size, pp_size], + dim_names=["tp", "cp", "dp", "pp"], + rank_offset=rank_offset, + backend="nccl", + ) + # Core groups + tp_pg = grid.create_pg(["tp"]) + cp_pg = grid.create_pg(["cp"]) + pp_pg = grid.create_pg(["pp"]) + dp_pg = grid.create_pg(["dp"]) + mp_pg = grid.create_pg(["tp", "pp"]) + tp_cp_pg = grid.create_pg(["tp", "cp"]) + tp_dp_cp_pg = grid.create_pg(["tp", "dp", "cp"]) + dp_cp_pg = grid.create_pg(["dp", "cp"]) + + # Expert/MoE related groups (refer to original parallel_state.initialize_model_parallel) + expert_tp_size = 1 # TODO: add expert_tp_size as input argument + # Expert data-parallel size folds CP into DP (as in original expert rank generator) + expt_model_block = expert_tp_size * ep_size * pp_size + if world_size % expt_model_block != 0: + raise RuntimeError( + f"world_size ({world_size}) is not divisible by expert_tensor_model_pipeline size ({expt_model_block})" + ) + expt_dp_size = world_size // expt_model_block + use_optimizer_instance_groups = num_distributed_optimizer_instances > 1 + inner_dp_dim: Optional[str] = None + outer_dp_dim: Optional[str] = None + if use_optimizer_instance_groups: + assert expt_dp_size % num_distributed_optimizer_instances == 0, ( + "Expert DP size must be divisible by the number of optimizer instances." + ) + inner_expt_dp_size = expt_dp_size // num_distributed_optimizer_instances + expert_grid = HyperCommGrid( + shape=[expert_tp_size, ep_size, inner_expt_dp_size, num_distributed_optimizer_instances, pp_size], + dim_names=["tp", "ep", "inner_dp", "outer_dp", "pp"], + rank_offset=rank_offset, + backend="nccl", + ) + dp_group_dims: list[str] = ["inner_dp", "outer_dp"] + inner_dp_dim = "inner_dp" + outer_dp_dim = "outer_dp" + else: + expert_grid = HyperCommGrid( + shape=[expert_tp_size, ep_size, expt_dp_size, pp_size], + dim_names=["tp", "ep", "dp", "pp"], + rank_offset=rank_offset, + backend="nccl", + ) + dp_group_dims = ["dp"] + ep_pg = expert_grid.create_pg(["ep"]) + expt_tp_pg = expert_grid.create_pg(["tp"]) + tp_ep_pg = expert_grid.create_pg(["tp", "ep"]) + tp_ep_pp_pg = expert_grid.create_pg(["tp", "ep", "pp"]) + expt_dp_pg = expert_grid.create_pg(dp_group_dims) + + # Embedding and position-embedding groups + embd_pg = None + pos_embd_pg = None + # Enumerate ranks per PP group + pp_rank_lists = grid._gen_rank_enum(["pp"]) + # Determine embedding ranks for each pp group + embedding_rank_lists: list[list[int]] = [] + pos_embedding_rank_lists: list[list[int]] = [] + for ranks in pp_rank_lists: + if not ranks: + continue + # embedding_ranks: first and last pp stage (or only one if pp_size==1) + embedding_rank_lists.append([ranks[0]] if len(ranks) == 1 else [ranks[0], ranks[-1]]) + pos_embedding_rank_lists.append([ranks[0]]) + if embedding_rank_lists: + embd_pg, _ = torch.distributed.new_subgroups_by_enumeration(embedding_rank_lists, backend="nccl") + if pos_embedding_rank_lists: + pos_embd_pg, _ = torch.distributed.new_subgroups_by_enumeration(pos_embedding_rank_lists, backend="nccl") + + # Build Partial-Distributed-Optimizer groups for Expert DP when multiple instances are used. + intra_expt_dp_pg = None + inter_dist_opt_pg = None + intra_dist_opt_pg = None + if inner_dp_dim is not None and outer_dp_dim is not None: + intra_expt_dp_pg = expert_grid.create_pg([inner_dp_dim]) + inter_dist_opt_pg = expert_grid.create_pg([outer_dp_dim]) + # Match distributed optimizer instance grouping from parallel_state: + # combine tp-ep-pp ranks across the intra-partial DP slice. + intra_dist_opt_pg = expert_grid.create_pg(["tp", "ep", inner_dp_dim, "pp"]) + + # Build ProcessGroupCollection with available groups. + pg_collection = ProcessGroupCollection( + tp=tp_pg, + pp=pp_pg, + mp=mp_pg, + embd=embd_pg, + pos_embd=pos_embd_pg, + cp=cp_pg, + tp_cp=tp_cp_pg, + hcp=None, + ep=ep_pg, + expt_tp=expt_tp_pg, + tp_ep=tp_ep_pg, + tp_ep_pp=tp_ep_pp_pg, + tp_dp_cp=tp_dp_cp_pg, + dp=dp_pg, + dp_cp=dp_cp_pg, + expt_dp=expt_dp_pg, + intra_dp_cp=dp_cp_pg, + intra_expt_dp=intra_expt_dp_pg if intra_expt_dp_pg is not None else expt_dp_pg, + inter_dist_opt=inter_dist_opt_pg, + intra_dist_opt=intra_dist_opt_pg, + ) + return pg_collection, grid, expert_grid diff --git a/tests/unit_tests/models/heterogenous_parallel/train.py b/tests/unit_tests/models/heterogenous_parallel/train.py new file mode 100644 index 00000000000..2a6c21e62bf --- /dev/null +++ b/tests/unit_tests/models/heterogenous_parallel/train.py @@ -0,0 +1,312 @@ +import torch +import torch.distributed as dist +from functools import partial +import logging + + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', force=True) + +from tests.unit_tests.test_utilities import Utils +from tests.unit_tests.models.heterogenous_parallel.model_specs import get_vlm_mimo_model +from tests.unit_tests.models.heterogenous_parallel.parallel_utils import ( + get_module_to_grid_tuple, + multimodule_no_sync, + finalize_model_grads, + get_pg_collections_for_rank, + zero_grad_buffer_for_multimodule +) +from tests.unit_tests.models.heterogenous_parallel.data import get_data_iterator, get_batch +from megatron.core.pipeline_parallel.multimodule_communicator import MultiModulePipelineCommunicator +import megatron.core.pipeline_parallel.schedules as schedule + + +def loss_func(loss_mask, output_tensor): + """Simple loss function for MIMO model training. + + Args: + loss_mask: mask indicating which tokens contribute to the loss + output_tensor: model output tensor + + Returns: + tuple: (loss, num_tokens, metrics_dict) + """ + losses = output_tensor.float() + + loss_mask = loss_mask.contiguous().view(-1).float() + + total_tokens = loss_mask.sum().clone().detach().to(torch.int) + total_loss = torch.sum(losses.view(-1) * loss_mask) + reporting_loss = torch.cat([total_loss.clone().detach().view(1), total_tokens.view(1)]) + + return (total_loss, total_tokens, {'lm loss': (reporting_loss)}) + + +def forward_step(data_iterator, model): + """Forward step for MIMO model training. + + Args: + data_iterator: iterator over the dataset + model: MIMO model instance + + Returns: + tuple: (output_tensor, loss_function) + """ + data_batch = get_batch(data_iterator) + if data_batch is None: + data_batch = {'input_ids': None} + output_tensor, loss_mask = model(**data_batch) + # Return output and loss function + return output_tensor, partial(loss_func, loss_mask) + + +def test_1f_1b_schedule_vlm_mimo_model_custom_pgs( + vision_num_layers, vision_num_moe_experts, vision_hidden_size, + language_num_layers, language_num_moe_experts, language_hidden_size, + vocab_size, image_seq_length, seq_length, + special_token_ids, + vision_tp, vision_pp, vision_dp, vision_cp, vision_ep, + language_tp, language_pp, language_dp, language_cp, language_ep, + batch_size, num_microbatches, + num_iterations=1, profile_start_step=None, profile_end_step=None, enable_profiling=False +): + """Test 1F1B schedule with VLM MIMO model using custom process groups. + + Args: + vision_num_layers: Number of layers in vision encoder + vision_hidden_size: Hidden size for vision encoder + language_num_layers: Number of layers in language model + language_hidden_size: Hidden size for language model + vocab_size: Vocabulary size + image_seq_length: Sequence length for images + seq_length: Total sequence length (text tokens = seq_length - image_seq_length) + special_token_ids: Dictionary of special token IDs + vision_tp, vision_pp, vision_dp, vision_cp, vision_ep: Vision model parallelism configs (TP, PP, DP, CP, EP) + language_tp, language_pp, language_dp, language_cp, language_ep: Language model parallelism configs (TP, PP, DP, CP, EP) + batch_size: Batch size for training + num_microbatches: Number of microbatches for pipeline parallelism + """ + logging.info("Creating VLM MIMO model...") + mimo_model, module_to_grid_map, topology, vision_pg_collection, language_pg_collection = get_vlm_mimo_model( + vision_num_layers=vision_num_layers, + vision_num_moe_experts=vision_num_moe_experts, + vision_hidden_size=vision_hidden_size, + language_num_layers=language_num_layers, + language_num_moe_experts=language_num_moe_experts, + language_hidden_size=language_hidden_size, + vocab_size=vocab_size, + seq_len=seq_length, + special_token_ids=special_token_ids, + vision_tp=vision_tp, + vision_pp=vision_pp, + vision_dp=vision_dp, + vision_cp=vision_cp, + vision_ep=vision_ep, + language_tp=language_tp, + language_pp=language_pp, + language_dp=language_dp, + language_cp=language_cp, + language_ep=language_ep, + ) + + logging.info(f"Rank {dist.get_rank()}: Model created successfully") + + # Set up module to grid tuple for no_sync and finalize_model_grads + module_to_grid_tuple = get_module_to_grid_tuple( + mimo_model, + module_to_grid_map['images'], + module_to_grid_map['language_module'] + ) + + # Configure no_sync and finalize_model_grads functions + mimo_model.config.no_sync_func = partial(multimodule_no_sync, module_to_grid_tuple=module_to_grid_tuple) + mimo_model.config.finalize_model_grads_func = partial(finalize_model_grads, module_to_grid_tuple=module_to_grid_tuple) + + # Create multimodule communicator + multimodule_communicator = MultiModulePipelineCommunicator( + module_to_grid_map, topology, mimo_model.config, dim_mapping={'b': 0, 's': 1, 'h': 2} + ) + + logging.info(f"Rank {dist.get_rank()}: Creating data iterator...") + + # Get data iterator + data_iterator = get_data_iterator( + encoder_grid=module_to_grid_map['images'], + llm_grid=module_to_grid_map['language_module'], + image_seq_length=image_seq_length, + seq_length=seq_length, + image_special_token_id=special_token_ids['images'], + batch_size=batch_size, + vocab_size=vocab_size, + vision_hidden_size=vision_hidden_size + ) + + # Set model type for unit test + mimo_model.model_type = 'unit-test' + + # Prepare common arguments for schedule + common_args = { + 'forward_step_func': forward_step, + 'data_iterator': data_iterator, + 'model': [mimo_model], + 'num_microbatches': num_microbatches, + 'seq_length': seq_length, + 'micro_batch_size': batch_size, + 'forward_only': False, + } + + # Get pg collections for modules that should be initialized on this rank + current_rank = dist.get_rank() + if current_rank < vision_tp*vision_pp*vision_dp: + pg_collection = vision_pg_collection + else: + pg_collection = language_pg_collection + print(f"for debug: Rank {dist.get_rank()}: pg_collection: {pg_collection}") + all_losses = [] + + from torch.utils.tensorboard import SummaryWriter + writer = SummaryWriter() + + from megatron.core.optimizer.optimizer_config import OptimizerConfig + from megatron.core.optimizer import get_megatron_optimizer + # Create optimizer config + optimizer_config = OptimizerConfig( + optimizer='adam', + lr=0.001, + weight_decay=0.01, + adam_beta1=0.9, + adam_beta2=0.999, + adam_eps=1e-8, + ) + model_chunks = [] + if mimo_model.modality_submodules is not None: + for submodule in mimo_model.modality_submodules.values(): + if submodule is not None: + model_chunks.append(submodule) + if mimo_model.language_model is not None: + if mimo_model.language_model is not None: + model_chunks.append(mimo_model.language_model) + # print(f"for debug: Rank {dist.get_rank()}, model_chunks used to create optimizer: {model_chunks}") + optimizer = get_megatron_optimizer( + config=optimizer_config, + model_chunks=model_chunks, + use_gloo_process_groups=False, # Required when using custom process groups + pg_collection=pg_collection, + ) + + for iteration in range(num_iterations): + # Start profiling if enabled + if enable_profiling and profile_start_step is not None and iteration == profile_start_step: + logging.info(f"Rank {dist.get_rank()}: Starting profiler at iteration {iteration}") + torch.cuda.cudart().cudaProfilerStart() + + logging.info(f"Rank {dist.get_rank()}: Iteration {iteration} - Starting 1F1B schedule...") + + # Run 1F1B schedule + losses_reduced = schedule.forward_backward_pipelining_without_interleaving( + p2p_communicator=multimodule_communicator, + pg_collection=pg_collection, + **common_args + ) + + all_losses.append(losses_reduced) + for idx, loss in enumerate(losses_reduced): + writer.add_scalar('training loss', loss['lm loss'][0], iteration) + writer.add_scalar('num tokens', loss['lm loss'][1], iteration) + logging.info(f"Rank {dist.get_rank()}: Iteration {iteration} - Losses: {losses_reduced}") + + # Update parameters. + update_successful, grad_norm, num_zeros_in_grad = optimizer.step() + print(f"for debug: Rank {dist.get_rank()}, at iteration {iteration}, update_successful: {update_successful}, grad_norm: {grad_norm}, num_zeros_in_grad: {num_zeros_in_grad}") + + zero_grad_buffer_for_multimodule(module_to_grid_tuple) + + # Stop profiling if enabled + if enable_profiling and profile_end_step is not None and iteration == profile_end_step: + logging.info(f"Rank {dist.get_rank()}: Stopping profiler at iteration {iteration}") + torch.cuda.cudart().cudaProfilerStop() + + writer.flush() + logging.info(f"Rank {dist.get_rank()}: Training completed. All losses: {all_losses}") + + return all_losses + + +if __name__ == "__main__": + # Initialize distributed training + Utils.initialize_distributed() + + # Profiling configuration + enable_profiling = True + num_iterations = 6 + profile_start_step = 3 + profile_end_step = 5 + + # Model parameters + vision_num_layers = 16 + vision_num_moe_experts = 8 + vision_hidden_size = 1024 + language_num_layers = 16 + language_num_moe_experts = 8 + language_hidden_size = 2048 + + # Data parameters + vocab_size = 48000 + image_seq_length = 1024 + seq_length = 4096 # Total sequence length (text tokens = seq_length - image_seq_length) + special_token_ids = {"images": 32000} + + # Model parallelisms (CP and EP are hardcoded to 1 in model_specs.py) + vision_tp, vision_pp, vision_dp, vision_cp, vision_ep = 1, 1, 1, 1, 1 + language_tp, language_pp, language_dp, language_cp, language_ep = 1, 1, 2, 1, 2 + assert vision_cp == 1, \ + f"Do not support vision module with CP > 1 currently" + assert language_cp == 1, \ + f"Do not support language module with CP > 1 currently" + + # Training parameters + rank = dist.get_rank() + global_batch_size = 32 + num_microbatches = 16 + if rank < vision_tp*vision_pp*vision_dp: + assert global_batch_size%(num_microbatches * vision_dp)==0, \ + f"global_batch_size ({global_batch_size}) should be divisible by (num_microbatches ({num_microbatches}) * vision_dp ({vision_dp}))" + batch_size = global_batch_size//(num_microbatches * vision_dp) + print(f"for debug: Rank {rank}, is in vision module, batch_size: {batch_size}") + else: + assert global_batch_size%(num_microbatches*language_dp)==0, \ + f"global_batch_size ({global_batch_size}) should be divisible by (num_microbatches ({num_microbatches}) * language_dp ({language_dp}))" + batch_size = global_batch_size// (num_microbatches*language_dp) + print(f"for debug: Rank {rank}, is in language module, batch_size: {batch_size}") + + + losses = test_1f_1b_schedule_vlm_mimo_model_custom_pgs( + vision_num_layers=vision_num_layers, + vision_num_moe_experts=vision_num_moe_experts, + vision_hidden_size=vision_hidden_size, + language_num_layers=language_num_layers, + language_num_moe_experts=language_num_moe_experts, + language_hidden_size=language_hidden_size, + vocab_size=vocab_size, + image_seq_length=image_seq_length, + seq_length=seq_length, + special_token_ids=special_token_ids, + vision_tp=vision_tp, + vision_pp=vision_pp, + vision_dp=vision_dp, + vision_cp=vision_cp, + vision_ep=vision_ep, + language_tp=language_tp, + language_pp=language_pp, + language_dp=language_dp, + language_cp=language_cp, + language_ep=language_ep, + batch_size=batch_size, + num_microbatches=num_microbatches, + num_iterations=num_iterations, + profile_start_step=profile_start_step, + profile_end_step=profile_end_step, + enable_profiling=enable_profiling, + ) + logging.info(f"Final losses: {losses}") + + dist.destroy_process_group() diff --git a/tests/unit_tests/pipeline_parallel/test_bridge_communicator.py b/tests/unit_tests/pipeline_parallel/test_bridge_communicator.py index 4b426b718eb..0598068986b 100644 --- a/tests/unit_tests/pipeline_parallel/test_bridge_communicator.py +++ b/tests/unit_tests/pipeline_parallel/test_bridge_communicator.py @@ -11,11 +11,7 @@ from megatron.core.hyper_comm_grid import HyperCommGrid from megatron.core.model_parallel_config import ModelParallelConfig from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec -from megatron.core.parallel_state import ( - get_context_parallel_group, - get_expert_model_parallel_rank, - get_tensor_model_parallel_rank, -) +from megatron.core.parallel_state import get_context_parallel_group, get_tensor_model_parallel_rank from megatron.core.pipeline_parallel.bridge_communicator import BridgeCommunicator from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear @@ -24,21 +20,19 @@ from megatron.core.transformer.transformer_config import TransformerConfig from tests.unit_tests.test_utilities import Utils +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(message)s", + stream=sys.stdout, + force=True, +) + def _create_transformer_block( dtype=torch.bfloat16, hidden_size=4096, pg_collection=None ) -> TransformerBlock: torch.manual_seed(12345) - model_parallel_cuda_manual_seed( - 123, - tp_rank=( - pg_collection.tp.rank() - if pg_collection is not None - else get_tensor_model_parallel_rank() - ), - ep_rank=torch.distributed.get_rank(), - etp_rank=torch.distributed.get_rank(), - ) + model_parallel_cuda_manual_seed(123) if pg_collection is not None: cp_size = pg_collection.cp.size() else: @@ -111,10 +105,6 @@ def _shard_and_copy_( def create_hypercomm_grid(offset=0, tp=1, cp=1, pp=1, dp=1): """Create a HyperCommGrid with tensor parallelism=2, context parallelism=2, and data parallelism=2.""" # Set up environment for world size 8 if not already set - if not dist.is_initialized(): - raise RuntimeError("Distributed process group is not initialized") - - # tests below assume a world size of 8 if "WORLD_SIZE" not in os.environ: os.environ["WORLD_SIZE"] = "8" @@ -181,6 +171,7 @@ def get_transformer_block_and_grid( class TestBridgeCommunicator: + """Test suite for BridgeCommunicator usage.""" @classmethod def setup_class(cls): @@ -205,8 +196,8 @@ def test_bridge_communicator_init(self): grid1 = create_hypercomm_grid(offset=0, tp=2, cp=1, pp=1, dp=2) grid2 = create_hypercomm_grid(offset=4, tp=2, cp=1, pp=1, dp=2) bridge_communicator = BridgeCommunicator(grid1, grid2) - assert bridge_communicator.src_grid is grid1 - assert bridge_communicator.dest_grid is grid2 + assert bridge_communicator.src_grid == grid1 + assert bridge_communicator.dest_grid == grid2 assert bridge_communicator.current_rank == dist.get_rank() assert bridge_communicator.comm_map is not None @@ -305,9 +296,6 @@ def test_bridge_communicator_with_transformer_blocks( (sequence_length, micro_batch_size, hidden_size), device="cuda" ).to(dtype) current_rank = dist.get_rank() - - # we compare output with transformer block with global parallel state - # so need to initialize model parallel state Utils.initialize_model_parallel( tensor_model_parallel_size=parallel_state_tp, create_gloo_process_groups=False ) @@ -404,6 +392,78 @@ def test_bridge_communicator_with_transformer_blocks( Utils.destroy_model_parallel() + @pytest.mark.skipif( + version.parse(torch.__version__) < version.parse('2.3.0'), + reason="Feature requires PyTorch 2.3 or later", + ) + def test_tranformer_block_with_different_parallelisms(self): + os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "0" + + torch.manual_seed(12345) + + # Initialize model parallel state (required for model_parallel_cuda_manual_seed) + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, create_gloo_process_groups=False + ) + + hidden_size = 2048 + dtype = torch.float32 + + sequence_length = 8192 + micro_batch_size = 2 + hidden_states = torch.randn( + (sequence_length, micro_batch_size, hidden_size), device="cuda" + ).to(dtype) + + ref_grid = create_hypercomm_grid(offset=0, tp=1, cp=1, pp=1, dp=8) + ref_pg_collection = _get_pg_collection_from_grid(ref_grid) + ref_block = _create_transformer_block( + dtype=dtype, hidden_size=hidden_size, pg_collection=ref_pg_collection + ) + _avg_params(ref_block, ref_grid.get_pg("dp")) + + # tp8 dp 1 grid + block1, grid_1 = get_transformer_block_and_grid( + tp_size=8, + cp_size=1, + pp_size=1, + dp_size=1, + ref_block=ref_block, + dtype=dtype, + hidden_size=hidden_size, + ) + + # tp4 dp 2 grid + block2, grid_2 = get_transformer_block_and_grid( + tp_size=4, + cp_size=1, + pp_size=1, + dp_size=2, + ref_block=ref_block, + hidden_size=hidden_size, + dtype=dtype, + ) + + dist.barrier() + + output_grid_1 = block1(hidden_states=hidden_states, attention_mask=None) + + output_grid_2 = block2(hidden_states=hidden_states, attention_mask=None) + + logging.debug( + f"Rank {dist.get_rank()}: shapes - grid 1 {output_grid_1.shape} grid 2 {output_grid_2.shape}" + ) + logging.debug( + f"Rank {dist.get_rank()}: sum - grid 1 {output_grid_1.sum()} grid 2 {output_grid_2.sum()}" + ) + + torch.testing.assert_close(output_grid_1, output_grid_2, rtol=1e-3, atol=1e-3) + + # Clean up model parallel state + Utils.destroy_model_parallel() + @pytest.mark.parametrize( "tp, cp, pp, dp, expected_src_ranks, expected_dest_ranks", [ diff --git a/tests/unit_tests/pipeline_parallel/test_multimodule_communicator.py b/tests/unit_tests/pipeline_parallel/test_multimodule_communicator.py index 73739859f42..acd7355c32c 100644 --- a/tests/unit_tests/pipeline_parallel/test_multimodule_communicator.py +++ b/tests/unit_tests/pipeline_parallel/test_multimodule_communicator.py @@ -1,27 +1,164 @@ -import logging import os -import sys +from unittest.mock import MagicMock, patch import pytest import torch import torch.distributed as dist -from packaging import version -from megatron.core import parallel_state from megatron.core.hyper_comm_grid import HyperCommGrid from megatron.core.model_parallel_config import ModelParallelConfig -from megatron.core.pipeline_parallel.multimodule_communicator import MultiModulePipelineCommunicator -from tests.unit_tests.pipeline_parallel.test_bridge_communicator import ( - _avg_params, - _create_transformer_block, - _get_pg_collection_from_grid, - create_hypercomm_grid, - get_transformer_block_and_grid, +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.parallel_state import get_context_parallel_group, get_tensor_model_parallel_rank +from megatron.core.pipeline_parallel.bridge_communicator import BridgeCommunicator +from megatron.core.pipeline_parallel.multimodule_communicator import ( + MultiModulePipelineCommunicator, ) +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer.transformer_block import TransformerBlock +from megatron.core.transformer.transformer_config import TransformerConfig from tests.unit_tests.test_utilities import Utils +def _create_transformer_block( + dtype=torch.bfloat16, hidden_size=4096, pg_collection=None +) -> TransformerBlock: + torch.manual_seed(12345) + model_parallel_cuda_manual_seed(123) + if pg_collection is not None: + cp_size = pg_collection.cp.size() + else: + cp_size = get_context_parallel_group().size() + transformer_config = TransformerConfig( + num_layers=1, + hidden_size=hidden_size, + num_attention_heads=8, + use_cpu_initialization=True, + attention_dropout=0.0, + hidden_dropout=0.0, + bf16=dtype == torch.bfloat16, + context_parallel_size=cp_size, + ) + + block = ( + TransformerBlock( + transformer_config, + get_gpt_layer_with_transformer_engine_spec(), + pg_collection=pg_collection, + ) + .cuda() + .to(dtype) + ) + with torch.no_grad(): + for mod in block.modules(): + if hasattr(mod, "bias") and mod.bias is not None: + mod.bias.zero_() + return block + + +def _shard_and_copy_( + ref_block: TransformerBlock, tgt_block: TransformerBlock, tp_size: int, tp_rank: int +) -> None: + """Copy weights from *ref_block* into a tensor-parallel *tgt_block*.""" + + ref_sd = ref_block.state_dict() + tgt_sd = tgt_block.state_dict() + + for name, tgt_param in tgt_sd.items(): + full_param = ref_sd[name] + + # Exact match – just copy. + if full_param.shape == tgt_param.shape: + tgt_param.copy_(full_param) + continue + + # ColumnParallel: shard along dim-0. + if tgt_param.shape[0] * tp_size == full_param.shape[0]: + slice_ = torch.chunk(full_param, tp_size, dim=0)[tp_rank] + tgt_param.copy_(slice_) + continue + + # RowParallel: shard along dim-1. + if tgt_param.shape[1] * tp_size == full_param.shape[1]: + slice_ = torch.chunk(full_param, tp_size, dim=1)[tp_rank] + tgt_param.copy_(slice_) + continue + + raise RuntimeError( + f"Unhandled TP sharding for {name}: ref {full_param.shape} tgt {tgt_param.shape}" + ) + + +def create_hypercomm_grid(offset=0, tp=1, cp=1, pp=1, dp=1): + """Create a HyperCommGrid with tensor parallelism=2, context parallelism=2, and data parallelism=2.""" + # Set up environment for world size 8 if not already set + if "WORLD_SIZE" not in os.environ: + os.environ["WORLD_SIZE"] = "8" + + grid = HyperCommGrid( + shape=[tp, cp, pp, dp], + dim_names=["tp", "cp", "pp", "dp"], + rank_offset=offset, + backend="nccl", + ) + _ = grid.create_pg(["tp"]) + _ = grid.create_pg(["cp"]) + _ = grid.create_pg(["pp"]) + _ = grid.create_pg(["dp"]) + return grid + + +def _get_pg_collection_from_grid(grid): + pg_collection = ProcessGroupCollection() + pg_collection.tp = grid.get_pg("tp") + pg_collection.cp = grid.get_pg("cp") + pg_collection.pp = grid.get_pg("pp") + return pg_collection + + +def _avg_params(module: torch.nn.Module, group: dist.ProcessGroup = None) -> None: + world = dist.get_world_size(group=group or dist.group.WORLD) + for p in module.parameters(): + dist.all_reduce(p.data, op=dist.ReduceOp.SUM, group=group or dist.group.WORLD) + p.data.div_(world) + + +def get_transformer_block_and_grid( + ref_block, + tp_size=1, + cp_size=1, + pp_size=1, + dp_size=1, + grid_offset: int = 0, + use_global_parallel_state: bool = False, + hidden_size: int = 4096, + dtype: torch.dtype = torch.bfloat16, +): + """Utility to build a ``TransformerBlock`` for tests.""" + + current_rank = dist.get_rank() + if use_global_parallel_state: + block = _create_transformer_block(dtype=dtype, hidden_size=hidden_size) + _shard_and_copy_(ref_block, block, tp_size, get_tensor_model_parallel_rank()) + grid = None + else: + grid = create_hypercomm_grid( + offset=grid_offset, tp=tp_size, cp=cp_size, pp=pp_size, dp=dp_size + ) + if grid.rank_offset <= current_rank < grid.rank_offset + grid.size: + pg_collection = _get_pg_collection_from_grid(grid) + block = _create_transformer_block( + dtype=dtype, hidden_size=hidden_size, pg_collection=pg_collection + ) + _shard_and_copy_(ref_block, block, tp_size, pg_collection.tp.rank()) + else: + block = None + + return block, grid + + class TestMultiModulePipelineCommunicator: + """Test suite for MultiModulePipelineCommunicator usage.""" @classmethod def setup_class(cls): @@ -31,17 +168,12 @@ def setup_class(cls): if torch.cuda.is_available(): torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) - world_size = dist.get_world_size() - if world_size != 8: - pytest.skip( - f"These tests require 8 GPUs, but only {world_size} are available.", - allow_module_level=True, - ) - + @classmethod def teardown_class(cls): - Utils.destroy_model_parallel() + if dist.is_initialized(): + dist.destroy_process_group() - def test_multimodule_communicator_init(self): + def test_mllm_communicator_init(self): """Test MultiModulePipelineCommunicator initialization.""" # Create process group grids for each module @@ -73,7 +205,7 @@ def test_multimodule_communicator_init(self): assert mllm_comm.config == config assert mllm_comm.current_rank == dist.get_rank() - def test_compute_total_pipeline_stages(self): + def test_compute_total_pipeline_stages_overall_and_till_rank(self): """Test compute_total_pipeline_stages for overall chain and until specific ranks.""" # Create process group grids for each module @@ -109,8 +241,6 @@ def test_compute_total_pipeline_stages(self): def test_send_forward_recv_forward(self): """Test send_forward and recv_forward operations.""" - if not dist.is_initialized(): - pytest.skip("Distributed not initialized") # Create process group grids for each module image_encoder_grid = create_hypercomm_grid(offset=0, tp=1, cp=1, pp=1, dp=1) @@ -161,63 +291,8 @@ def test_send_forward_recv_forward(self): input_dict = mllm_comm.recv_forward() assert input_dict['llm'].shape == (1, 32, 128) - def test_send_forward_recv_forward_with_different_pp_size(self): - """Test for the case when pp(image_encoder) != pp(audio_encoder).""" - if not dist.is_initialized(): - pytest.skip("Distributed not initialized") - - # Create process group grids for each module - image_encoder_grid = create_hypercomm_grid(offset=0, tp=1, cp=1, pp=2, dp=1) - audio_encoder_grid = create_hypercomm_grid(offset=2, tp=2, cp=1, pp=1, dp=1) - llm_grid = create_hypercomm_grid(offset=4, tp=1, cp=1, pp=4, dp=1) - - # Set up module-grid mapping and topology - module_to_grid_map = { - 'image_encoder': image_encoder_grid, - 'audio_encoder': audio_encoder_grid, - 'llm': llm_grid, - } - topology = {'image_encoder': ['llm'], 'audio_encoder': ['llm'], 'llm': []} - config = ModelParallelConfig(pipeline_dtype=torch.float) - mllm_comm = MultiModulePipelineCommunicator(module_to_grid_map, topology, config) - - # Simulate forward communication for each module - if mllm_comm.is_current_rank_in_grid(image_encoder_grid): - output_dict = {'image_encoder': torch.randn(2, 8, 128).cuda()} - if dist.get_rank() == 0: - # Image encoder sends output forward - mllm_comm.send_forward(output_dict) - else: - # Image stage receives image outputs - input_dict = mllm_comm.recv_forward(tensor_shape=(2, 8, 128)) - assert input_dict['image_encoder'].shape == (2, 8, 128) - mllm_comm.send_forward(output_dict) - if mllm_comm.is_current_rank_in_grid(audio_encoder_grid): - # Audio encoder sends output forward - output_dict = {'audio_encoder': torch.randn(2, 16, 128).cuda()} - mllm_comm.send_forward(output_dict) - if mllm_comm.is_current_rank_in_grid(llm_grid): - output_dict = {'llm': torch.randn(2, 32, 128).cuda()} - if dist.get_rank() == 4: - # LLM stage receives both image and audio outputs - input_dict = mllm_comm.recv_forward() - assert input_dict['image_encoder'].shape == (2, 8, 128) - assert input_dict['audio_encoder'].shape == (2, 16, 128) - mllm_comm.send_forward(output_dict) - elif dist.get_rank() == 5 or dist.get_rank() == 6: - # LLM stage receives concatenated LLM outputs - input_dict = mllm_comm.recv_forward(tensor_shape=(2, 32, 128)) - assert input_dict['llm'].shape == (2, 32, 128) - mllm_comm.send_forward(output_dict) - elif dist.get_rank() == 7: - # LLM stage receives concatenated LLM outputs - input_dict = mllm_comm.recv_forward(tensor_shape=(2, 32, 128)) - assert input_dict['llm'].shape == (2, 32, 128) - def test_send_backward_recv_backward(self): """Test send_backward and recv_backward operations.""" - if not dist.is_initialized(): - pytest.skip("Distributed not initialized") # Create process group grids for each module image_encoder_grid = create_hypercomm_grid(offset=0, tp=1, cp=1, pp=1, dp=1) @@ -271,14 +346,8 @@ def test_send_backward_recv_backward(self): received_grad = mllm_comm.recv_backward() assert received_grad['audio_encoder'].shape == (2, 16, 128) - @pytest.mark.skipif( - version.parse(torch.__version__) < version.parse('2.3.0'), - reason="Feature requires PyTorch 2.3 or later", - ) def test_send_forward_recv_backward_send_backward_recv_forward(self): """Test send_forward_recv_backward and send_backward_recv_forward operations.""" - if not dist.is_initialized(): - pytest.skip("Distributed not initialized") # Create process group grids for each module image_encoder_grid = create_hypercomm_grid(offset=0, tp=1, cp=1, pp=1, dp=1) @@ -304,55 +373,54 @@ def test_send_forward_recv_backward_send_backward_recv_forward(self): # Simulate bidirectional send/recv for forward and backward in pipeline - # Encoder stages send forward to the first stage of LLM, and receive backward from the first stage of LLM + # Encoder stages: send forward tensor, receive backward gradient if mllm_comm.is_current_rank_in_grid(image_encoder_grid): - output_dict = {'image_encoder': torch.randn(2, 8, 128).cuda()} + output_dict = {'image_encoder': torch.randn(16, 256, 512).cuda()} received_grad = mllm_comm.send_forward_recv_backward(output_dict) - assert received_grad['image_encoder'].shape == (2, 8, 128) + assert received_grad['image_encoder'].shape == (16, 256, 512) if mllm_comm.is_current_rank_in_grid(audio_encoder_grid): - output_dict = {'audio_encoder': torch.randn(2, 16, 128).cuda()} + output_dict = {'audio_encoder': torch.randn(16, 128, 512).cuda()} received_grad = mllm_comm.send_forward_recv_backward(output_dict) - assert received_grad['audio_encoder'].shape == (2, 16, 128) + assert received_grad['audio_encoder'].shape == (16, 128, 512) + + # LLM: receives backward (from generator) then immediately receives forward (from encoders) if mllm_comm.is_current_rank_in_grid(llm_grid): if dist.get_rank() == 2 or dist.get_rank() == 3: grad_dict = { - 'image_encoder': torch.randn(2, 8, 128).cuda(), - 'audio_encoder': torch.randn(2, 16, 128).cuda(), + 'image_encoder': torch.randn(16, 256, 512).cuda(), + 'audio_encoder': torch.randn(16, 128, 512).cuda(), } input_dict = mllm_comm.send_backward_recv_forward(grad_dict) - assert input_dict['image_encoder'].shape == (2, 8, 128) - assert input_dict['audio_encoder'].shape == (2, 16, 128) + assert input_dict['image_encoder'].shape == (16, 256, 512) + assert input_dict['audio_encoder'].shape == (16, 128, 512) - # First stage of LLM sends forward to the second stage of LLM, and receive backward from the second stage of LLM + # LLM: send forward (as LLM) and receive backward if mllm_comm.is_current_rank_in_grid(llm_grid): if dist.get_rank() == 2 or dist.get_rank() == 3: - output_dict = {'llm': torch.randn(2, 32, 128).cuda()} + output_dict = {'llm': torch.randn(16, 128, 512).cuda()} received_grad = mllm_comm.send_forward_recv_backward( - output_dict, tensor_shape=(2, 32, 128) + output_dict, tensor_shape=(16, 128, 512) ) - assert received_grad['llm'].shape == (2, 32, 128) + assert received_grad['llm'].shape == (16, 128, 512) if dist.get_rank() == 4 or dist.get_rank() == 5: - grad_dict = {'llm': torch.randn(2, 32, 128).cuda()} + grad_dict = {'llm': torch.randn(16, 128, 512).cuda()} input_dict = mllm_comm.send_backward_recv_forward( - grad_dict, tensor_shape=(2, 32, 128) + grad_dict, tensor_shape=(16, 128, 512) ) - assert input_dict['llm'].shape == (2, 32, 128) + assert input_dict['llm'].shape == (16, 128, 512) - # Second stage of LLM sends forward to generator, and receive backward from generator + # LLM: send forward and get gradient (2nd stage) if mllm_comm.is_current_rank_in_grid(llm_grid): if dist.get_rank() == 4 or dist.get_rank() == 5: - output_dict = {'llm': torch.randn(2, 32, 128).cuda()} + output_dict = {'llm': torch.randn(16, 128, 512).cuda()} received_grad = mllm_comm.send_forward_recv_backward(output_dict) - assert received_grad['llm'].shape == (2, 32, 128) + assert received_grad['llm'].shape == (16, 128, 512) + # Generator: send backward gradient, receive forward activation if mllm_comm.is_current_rank_in_grid(generator_grid): - grad_dict = {'llm': torch.randn(1, 32, 128).cuda()} + grad_dict = {'llm': torch.randn(8, 128, 512).cuda()} input_dict = mllm_comm.send_backward_recv_forward(grad_dict) - assert input_dict['llm'].shape == (1, 32, 128) + assert input_dict['llm'].shape == (8, 128, 512) - @pytest.mark.skipif( - version.parse(torch.__version__) < version.parse('2.3.0'), - reason="Feature requires PyTorch 2.3 or later", - ) def test_send_forward_recv_forward_with_transformer_blocks(self): """Test send_forward and recv_forward operations.""" @@ -510,14 +578,14 @@ def test_send_forward_recv_forward_with_transformer_blocks(self): hidden_size=hidden_size, dtype=dtype, ) - global_llm_block_pp_rank_0, _ = get_transformer_block_and_grid( + global_llm_block_pp_stage_0, _ = get_transformer_block_and_grid( ref_block, tp_size=parallel_state_tp, use_global_parallel_state=True, hidden_size=hidden_size, dtype=dtype, ) - global_llm_block_pp_rank_1, _ = get_transformer_block_and_grid( + global_llm_block_pp_stage_1, _ = get_transformer_block_and_grid( ref_block, tp_size=parallel_state_tp, use_global_parallel_state=True, @@ -553,24 +621,24 @@ def test_send_forward_recv_forward_with_transformer_blocks(self): global_llm_input = torch.cat( [global_image_encoder_output, global_audio_encoder_output], dim=seq_dim ) - global_llm_pp_rank_0_output = global_llm_block_pp_rank_0( + global_llm_pp_stage_0_output = global_llm_block_pp_stage_0( hidden_states=global_llm_input, attention_mask=None ) if current_rank == 2 or current_rank == 3: torch.testing.assert_close( - global_llm_pp_rank_0_output, llm_output, rtol=1e-3, atol=1e-3 + global_llm_pp_stage_0_output, llm_output, rtol=1e-3, atol=1e-3 ) - global_llm_pp_rank_1_output = global_llm_block_pp_rank_1( - hidden_states=global_llm_pp_rank_0_output, attention_mask=None + global_llm_pp_stage_1_output = global_llm_block_pp_stage_1( + hidden_states=global_llm_pp_stage_0_output, attention_mask=None ) if current_rank == 4 or current_rank == 5: torch.testing.assert_close( - global_llm_pp_rank_1_output, llm_output, rtol=1e-3, atol=1e-3 + global_llm_pp_stage_1_output, llm_output, rtol=1e-3, atol=1e-3 ) # Generator output and comparison to distributed output (for each DP chunk) global_generator_block_output = global_generator_block( - hidden_states=global_llm_pp_rank_1_output, attention_mask=None + hidden_states=global_llm_pp_stage_1_output, attention_mask=None ) global_generator_block_chunks = torch.split( global_generator_block_output, global_generator_block_output.shape[1] // 2, dim=1 @@ -584,10 +652,9 @@ def test_send_forward_recv_forward_with_transformer_blocks(self): global_generator_block_chunks[1], generator_output, rtol=1e-3, atol=1e-3 ) - @pytest.mark.skipif( - version.parse(torch.__version__) < version.parse('2.3.0'), - reason="Feature requires PyTorch 2.3 or later", - ) + # ========== Clean up model-parallel state ========== + Utils.destroy_model_parallel() + @pytest.mark.parametrize( "grid1_tp, grid1_pp, grid1_dp, grid2_tp, grid2_pp, grid2_dp, parallel_state_tp", [ @@ -605,8 +672,8 @@ def test_send_forward_recv_forward_with_transformer_blocks_and_different_paralle ): """Test bridge communicator with two transformer blocks having different process group configurations.""" # Model and input configuration - hidden_size = 16 - sequence_length = 2 + hidden_size = 1024 + sequence_length = 16 micro_batch_size = 8 torch.manual_seed(12345) dtype = torch.float32 @@ -669,7 +736,7 @@ def test_send_forward_recv_forward_with_transformer_blocks_and_different_paralle # If current rank is in the first grid, run first block and send output if grid_1 is not None and mllm_comm.is_current_rank_in_grid(grid_1): rank_module_info = mllm_comm.rank_module_map['image_encoder'] - if rank_module_info.pp_rank == 0: + if rank_module_info.pp_stage == 0: hidden_states = block_grid_1(hidden_states=hidden_states, attention_mask=None) mllm_comm.send_forward({'image_encoder': hidden_states}) else: @@ -683,15 +750,15 @@ def test_send_forward_recv_forward_with_transformer_blocks_and_different_paralle # If current rank is in second grid, receive and run the second block if grid_2 is not None and mllm_comm.is_current_rank_in_grid(grid_2): rank_module_info = mllm_comm.rank_module_map['llm'] - if rank_module_info.pp_rank == 0: + if rank_module_info.pp_stage == 0: input_dict = mllm_comm.recv_forward() hidden_states = input_dict['image_encoder'] hidden_states = block_grid_2(hidden_states=hidden_states, attention_mask=None) - if rank_module_info.pp_rank == rank_module_info.pp_size - 1: + if rank_module_info.pp_stage == rank_module_info.pp_size - 1: output_grid_2 = hidden_states else: mllm_comm.send_forward({'llm': hidden_states}) - elif rank_module_info.pp_rank < rank_module_info.pp_size - 1: + elif rank_module_info.pp_stage < rank_module_info.pp_size - 1: input_dict = mllm_comm.recv_forward( tensor_shape=( sequence_length, @@ -754,27 +821,31 @@ def test_send_forward_recv_forward_with_transformer_blocks_and_different_paralle if ( grid_2 is not None and mllm_comm.is_current_rank_in_grid(grid_2) - and rank_module_info.pp_rank == rank_module_info.pp_size - 1 + and rank_module_info.pp_stage == rank_module_info.pp_size - 1 ): if grid1_dp == grid2_dp: # DP size matches: all outputs directly compared torch.testing.assert_close(hidden_states_ref, output_grid_2, rtol=1e-3, atol=1e-3) - elif grid1_dp < grid2_dp: - # If grid2 expands DP: each output_grid_2 chunk corresponds to a split of the reference output - grid2_dp_ranks = grid_2._gen_rank_enum([x for x in grid_2.dim_names if x != "dp"]) - global_block_2_chunks = torch.split( - hidden_states_ref, hidden_states_ref.shape[1] // (grid2_dp // grid1_dp), dim=1 - ) - relevant_chunk = None - for i, dp_ranks in enumerate(grid2_dp_ranks): - if current_rank in dp_ranks: - relevant_chunk = global_block_2_chunks[i % len(global_block_2_chunks)] - torch.testing.assert_close(relevant_chunk, output_grid_2, rtol=1e-3, atol=1e-3) - else: - # If DP shrinks (grid1_dp > grid2_dp): just compare the relevant first chunk - output_grid_2_first_chunk = torch.chunk(output_grid_2, grid1_dp // grid2_dp, dim=1)[ - 0 - ] - torch.testing.assert_close( - hidden_states_ref, output_grid_2_first_chunk, rtol=1e-3, atol=1e-3 - ) + # elif grid1_dp < grid2_dp: + # # If grid2 expands DP: each output_grid_2 chunk corresponds to a split of the reference output + # grid2_dp_ranks = grid_2._gen_rank_enum([x for x in grid_2.dim_names if x != "dp"]) + # global_block_2_chunks = torch.split( + # hidden_states_ref, + # hidden_states_ref.shape[1] // (grid2_dp // grid1_dp), + # dim=1, + # ) + # relevant_chunk = None + # for i, dp_ranks in enumerate(grid2_dp_ranks): + # if current_rank in dp_ranks: + # relevant_chunk = global_block_2_chunks[i % len(global_block_2_chunks)] + # torch.testing.assert_close(relevant_chunk, output_grid_2, rtol=1e-3, atol=1e-3) + # else: + # # If DP shrinks (grid1_dp > grid2_dp): just compare the relevant first chunk + # output_grid_2_first_chunk = torch.chunk(output_grid_2, grid1_dp // grid2_dp, dim=1)[ + # 0 + # ] + # torch.testing.assert_close( + # hidden_states_ref, output_grid_2_first_chunk, rtol=1e-3, atol=1e-3 + # ) + + Utils.destroy_model_parallel() # Clean up parallel context diff --git a/tests/unit_tests/pipeline_parallel/test_multimodule_schedules.py b/tests/unit_tests/pipeline_parallel/test_multimodule_schedules.py new file mode 100644 index 00000000000..6c4216cec5f --- /dev/null +++ b/tests/unit_tests/pipeline_parallel/test_multimodule_schedules.py @@ -0,0 +1,720 @@ +import logging +import os +from typing import Dict, List +from contextlib import contextmanager +import pytest +import torch +import torch.distributed as dist +from packaging import version +from pytest_mock import mocker + +import megatron.core.pipeline_parallel.schedules as schedule +from megatron.core import ModelParallelConfig +from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig +from megatron.core.distributed.finalize_model_grads import finalize_model_grads +from megatron.core.hyper_comm_grid import HyperCommGrid +from megatron.core.model_parallel_config import ModelParallelConfig +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.parallel_state import get_context_parallel_group, get_tensor_model_parallel_rank +from megatron.core.pipeline_parallel.multimodule_communicator import ( + MultiModulePipelineCommunicator, +) +from megatron.core.pipeline_parallel.p2p_communication import P2PCommunicator +from megatron.core.pipeline_parallel.utils import is_pp_first_stage, is_pp_last_stage +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.transformer_block import TransformerBlock +from megatron.core.transformer.transformer_config import TransformerConfig +from tests.unit_tests.test_utilities import Utils + +rank = Utils.rank +logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') + +class DataIterator: + + def __init__(self, hidden_size: int, seq_length: int, micro_batch_size: int): + self.hidden_size = hidden_size + self.seq_length = seq_length + self.micro_batch_size = micro_batch_size + + def __iter__(self): + return self + + def __next__(self): + return torch.randn( + self.seq_length, + self.micro_batch_size, + self.hidden_size, + device='cuda', + dtype=torch.bfloat16, + ) + + + +class SingleEncoderModel(torch.nn.Module): + def __init__( + self, + hidden_size, + encoder_tp, + encoder_pp, + encoder_dp, + llm_tp, + llm_pp, + llm_dp, + llm_grid_offset, + ): + + super().__init__() + + self.encoder, self.encoder_grid = get_transformer_block_and_grid( + tp_size=encoder_tp, + cp_size=1, + pp_size=encoder_pp, + dp_size=encoder_dp, + hidden_size=hidden_size, + ) + + self.llm, self.llm_grid = get_transformer_block_and_grid( + tp_size=llm_tp, + cp_size=1, + pp_size=llm_pp, + dp_size=llm_dp, + grid_offset=llm_grid_offset, + hidden_size=hidden_size, + ) + + # Simple list for iteration + self.modules_and_grids = [ + (self.encoder, self.encoder_grid), + (self.llm, self.llm_grid) + ] + + self.current_rank = dist.get_rank() + self.encoder_input_tensor = None + self.llm_input_tensor = None + + + def finish_grad_sync(self): + """Finish gradient synchronization for all active modules on this rank.""" + for module, grid in self.modules_and_grids: + if module is not None and self.is_current_rank_in_grid(grid): + module.finish_grad_sync() + @contextmanager + def no_sync(self): + contexts = [] + if self.is_current_rank_in_grid(self.encoder_grid): + contexts.append(self.encoder.no_sync()) + if self.is_current_rank_in_grid(self.llm_grid): + contexts.append(self.llm.no_sync()) + + # Enter all contexts + for ctx in contexts: + ctx.__enter__() + + try: + yield + finally: + # Exit all contexts in reverse order + for ctx in reversed(contexts): + ctx.__exit__(None, None, None) + + @property + def ddp_config(self): + # Try to get ddp_config from the first available module on this rank + for module, grid in self.modules_and_grids: + if module is not None and self.is_current_rank_in_grid(grid): + return module.ddp_config + raise AttributeError( + f"No active modules with ddp_config found on rank {self.current_rank}" + ) + + def scale_gradients(self, scaling_factor: float): + """Scale gradients for all active modules on this rank.""" + for module, grid in self.modules_and_grids: + if module is not None and self.is_current_rank_in_grid(grid): + module.scale_gradients(scaling_factor) + + def is_current_rank_in_grid(self, grid: HyperCommGrid) -> bool: + """Check if the current rank is in the grid.""" + return grid.rank_offset <= self.current_rank < (grid.rank_offset + grid.size) + + def finalize_model_grads(self, module=None, num_tokens=None, pg_collection=None): + for module, grid in self.modules_and_grids: + if module is not None and self.is_current_rank_in_grid(grid): + finalize_model_grads([module], num_tokens=None, pg_collection=_get_pg_collection_with_embedding_groups(grid)) + + @contextmanager + def no_sync(self): + contexts = [] + for module, grid in self.modules_and_grids: + if module is not None and self.is_current_rank_in_grid(grid): + contexts.append(module.no_sync()) + + # Enter all contexts + for ctx in contexts: + ctx.__enter__() + + try: + yield + finally: + # Exit all contexts in reverse order + for ctx in reversed(contexts): + ctx.__exit__(None, None, None) + + def set_input_tensor(self, input_tensor: List[Dict[str, torch.Tensor]]): + if self.is_current_rank_in_grid(self.encoder_grid) and 'encoder' in input_tensor[0]: + if isinstance(input_tensor[0]["encoder"], list): + encoder_input_tensor = input_tensor[0]["encoder"][0] + else: + encoder_input_tensor = input_tensor[0]["encoder"] + logging.debug( + f"[Rank {dist.get_rank()} ][SingleEncoderModel] [set_input_tensor] [encoder] input tensor shape: {input_tensor[0]['encoder'][0].shape}" + ) + self.encoder_input_tensor = encoder_input_tensor + elif self.is_current_rank_in_grid(self.llm_grid): + if 'llm' in input_tensor[0]: + if isinstance(input_tensor[0]["llm"], list): + llm_input_tensor = input_tensor[0]["llm"][0] + else: + llm_input_tensor = input_tensor[0]["llm"] + logging.debug( + f"[Rank {dist.get_rank()} ][SingleEncoderModel] [set_input_tensor] [llm] input tensor shape: {llm_input_tensor.shape}" + ) + self.llm_input_tensor = llm_input_tensor + elif 'encoder' in input_tensor[0]: + logging.debug( + f"[Rank {dist.get_rank()} ][SingleEncoderModel] [set_input_tensor] [encoder] input tensor shape: {input_tensor[0]['encoder'].shape}" + ) + self.llm_input_tensor = input_tensor[0]["encoder"] + else: + raise ValueError(f"Rank {dist.get_rank()} is not valid") + + def forward(self, hidden_states): + + current_rank = dist.get_rank() + output_dict = {} + if self.is_current_rank_in_grid(self.encoder_grid): + # if pp rank > 0 in encoder pp group then we use self.encoder_input_tensor as input else we use hidden_states + if is_pp_first_stage(self.encoder_grid.get_pg("pp")): + input_tensor = hidden_states + else: + assert ( + self.encoder_input_tensor is not None + ), "Encoder input tensor is not provided for pp rank > 0" + input_tensor = self.encoder_input_tensor + logging.debug( + f"[Rank {dist.get_rank()} ][SingleEncoderModel] [forward] [encoder] input tensor shape: {input_tensor.shape}" + ) + output_dict["encoder"] = self.encoder(input_tensor, attention_mask=None) + elif self.is_current_rank_in_grid(self.llm_grid): + assert ( + self.llm_input_tensor is not None + ), "LLM input tensor is not provided for pp rank > 0" + input_tensor = self.llm_input_tensor + logging.debug( + f"[Rank {dist.get_rank()} ][SingleEncoderModel] [forward] [llm] input tensor shape: {input_tensor.shape}" + ) + output_dict["llm"] = self.llm(input_tensor, attention_mask=None) + else: + raise ValueError(f"Rank {current_rank} is not valid") + + return output_dict + +class DualEncoderModel(SingleEncoderModel): + def __init__(self, hidden_size, encoder_tp, encoder_pp, encoder_dp, llm_tp, llm_pp, llm_dp, llm_grid_offset): + super().__init__(hidden_size, encoder_tp, encoder_pp, encoder_dp, llm_tp, llm_pp, llm_dp, llm_grid_offset) + + self.encoder_1, self.encoder_1_grid = get_transformer_block_and_grid( + tp_size=encoder_tp, + cp_size=1, + pp_size=encoder_pp, + dp_size=encoder_dp, + hidden_size=hidden_size, + ) + + self.encoder_2, self.encoder_2_grid = get_transformer_block_and_grid( + tp_size=encoder_tp, + cp_size=1, + pp_size=encoder_pp, + dp_size=encoder_dp, + hidden_size=hidden_size, + ) + + self.llm, self.llm_grid = get_transformer_block_and_grid( + tp_size=llm_tp, + cp_size=1, + pp_size=llm_pp, + dp_size=llm_dp, + grid_offset=llm_grid_offset, + hidden_size=hidden_size, + ) + + self.modules_and_grids = [ + (self.encoder_1, self.encoder_1_grid), + (self.encoder_2, self.encoder_2_grid), + (self.llm, self.llm_grid) + ] + + self.current_rank = dist.get_rank() + self.encoder_1_input_tensor = None + self.encoder_2_input_tensor = None + self.llm_input_tensor = None + + self.pre_process = False + self.post_process = False + self.share_embeddings_and_output_weights = False + + + def set_input_tensor(self, input_tensor: List[Dict[str, torch.Tensor]]): + logging.debug(f" In DualEncoderModel set_input_tensor rank {dist.get_rank()} input_tensor keys: {input_tensor[0].keys()}") + if self.is_current_rank_in_grid(self.encoder_1_grid) and 'encoder_1' in input_tensor[0]: + if isinstance(input_tensor[0]["encoder_1"], list): + self.encoder_1_input_tensor = input_tensor[0]["encoder_1"][0] + else: + self.encoder_1_input_tensor = input_tensor[0]["encoder_1"] + if self.is_current_rank_in_grid(self.encoder_2_grid) and 'encoder_2' in input_tensor[0]: + if isinstance(input_tensor[0]["encoder_2"], list): + self.encoder_2_input_tensor = input_tensor[0]["encoder_2"][0] + else: + self.encoder_2_input_tensor = input_tensor[0]["encoder_2"] + if self.is_current_rank_in_grid(self.llm_grid): + if 'llm' in input_tensor[0]: + if isinstance(input_tensor[0]["llm"], list): + self.llm_input_tensor = input_tensor[0]["llm"][0] + else: + self.llm_input_tensor = input_tensor[0]["llm"] + elif 'encoder_1' in input_tensor[0] and 'encoder_2' in input_tensor[0]: + # concat across sequence dimension (s, b, h) + logging.debug(f'In DualEncoderModel LLM set_input_tensor rank {dist.get_rank()} encoder_1 shape: {input_tensor[0]["encoder_1"].shape} encoder_2 shape: {input_tensor[0]["encoder_2"].shape}') + self.llm_input_tensor = torch.concat([input_tensor[0]["encoder_1"], input_tensor[0]["encoder_2"]], dim=0) + logging.debug(f" In DualEncoderModel LLM set_input_tensor rank {dist.get_rank()} llm_input_tensor shape: {self.llm_input_tensor.shape}") + else: + raise ValueError(f"Rank {dist.get_rank()} is not valid") + + def forward(self, hidden_states): + current_rank = dist.get_rank() + output_dict = {} + logging.debug(f" In DualEncoderModel forward rank {dist.get_rank()}") + if self.is_current_rank_in_grid(self.encoder_1_grid): + if is_pp_first_stage(self.encoder_1_grid.get_pg("pp")): + input_tensor = hidden_states + else: + assert ( + self.encoder_1_input_tensor is not None + ), "Encoder input tensor is not provided for pp rank > 0" + input_tensor = self.encoder_1_input_tensor + output_dict["encoder_1"] = self.encoder_1(input_tensor, attention_mask=None) + if self.is_current_rank_in_grid(self.encoder_2_grid): + if is_pp_first_stage(self.encoder_2_grid.get_pg("pp")): + input_tensor = hidden_states + else: + assert ( + self.encoder_2_input_tensor is not None + ), "Encoder input tensor is not provided for pp rank > 0" + input_tensor = self.encoder_2_input_tensor + output_dict["encoder_2"] = self.encoder_2(input_tensor, attention_mask=None) + if self.is_current_rank_in_grid(self.llm_grid): + assert ( + self.llm_input_tensor is not None + ), "LLM input tensor is not provided for pp rank > 0" + input_tensor = self.llm_input_tensor + output_dict["llm"] = self.llm(input_tensor, attention_mask=None) + logging.debug(f"[DualEncoderModel] model fwd pass in rank {dist.get_rank()} output_dict keys: {output_dict.keys()}") + return output_dict + +def _create_transformer_block( + dtype=torch.bfloat16, hidden_size=4096, pg_collection=None +) -> TransformerBlock: + torch.manual_seed(12345) + model_parallel_cuda_manual_seed( + 123, + tp_rank=pg_collection.tp.rank(), + ep_rank=pg_collection.ep.rank(), + etp_rank=torch.distributed.get_rank(), + ) + if pg_collection is not None: + cp_size = pg_collection.cp.size() + else: + cp_size = get_context_parallel_group().size() + transformer_config = TransformerConfig( + num_layers=1, + hidden_size=hidden_size, + num_attention_heads=8, + use_cpu_initialization=True, + attention_dropout=0.0, + hidden_dropout=0.0, + bf16=dtype == torch.bfloat16, + context_parallel_size=cp_size, + ) + + block = ( + TransformerBlock( + transformer_config, + get_gpt_layer_with_transformer_engine_spec(), + pg_collection=pg_collection, + ) + .cuda() + .to(dtype) + ) + with torch.no_grad(): + for mod in block.modules(): + if hasattr(mod, "bias") and mod.bias is not None: + mod.bias.zero_() + return block + + +def create_hypercomm_grid(offset=0, tp=1, cp=1, pp=1, dp=1, ep=1, etp=1): + """Create a HyperCommGrid with tensor parallelism=2, context parallelism=2, and data parallelism=2.""" + # Set up environment for world size 8 if not already set + if "WORLD_SIZE" not in os.environ: + os.environ["WORLD_SIZE"] = "8" + + grid = HyperCommGrid( + shape=[tp, cp, pp, dp, ep, etp], # 需要加上etp吗 + dim_names=["tp", "cp", "pp", "dp", "ep", "etp"], + rank_offset=offset, + backend="nccl", + ) + _ = grid.create_pg(["tp"]) + _ = grid.create_pg(["cp"]) + _ = grid.create_pg(["pp"]) + _ = grid.create_pg(["dp"]) + _ = grid.create_pg(["ep"]) + # _ = grid.create_pg(["etp"]) + # _ = grid.create_pg(["edp"]) + _ = grid.create_pg(["tp", "pp"]) + _ = grid.create_pg(["dp", "cp"]) + _ = grid.create_pg(["tp", "cp"]) + _ = grid.create_pg(["tp", "dp", "cp"]) + _ = grid.create_pg(["tp", "ep", "pp"]) + return grid + + +def _get_pg_collection_from_grid(grid): + pg_collection = ProcessGroupCollection() + pg_collection.tp = grid.get_pg("tp") + pg_collection.cp = grid.get_pg("cp") + pg_collection.pp = grid.get_pg("pp") + pg_collection.ep = grid.get_pg("ep") + dp_group = grid.get_pg("dp") + dp_cp_group = grid.get_pg(["dp", "cp"]) + pg_collection.dp = dp_group + pg_collection.dp_cp = dp_cp_group + pg_collection.mp = grid.get_pg(["tp", "pp"]) + pg_collection.dp_cp = grid.get_pg(["dp", "cp"]) + pg_collection.tp_cp = grid.get_pg(["tp", "cp"]) + pg_collection.tp_dp_cp = grid.get_pg(["tp", "dp", "cp"]) + pg_collection.tp_ep_pp = grid.get_pg(["tp", "ep", "pp"]) + pg_collection.expt_tp = None + pg_collection.expt_dp = None + return pg_collection + + +def get_transformer_block_and_grid( + tp_size=1, + cp_size=1, + pp_size=1, + dp_size=1, + grid_offset: int = 0, + hidden_size: int = 4096, + dtype: torch.dtype = torch.bfloat16, +): + """Utility to build a ``TransformerBlock`` for tests.""" + + current_rank = dist.get_rank() + grid = create_hypercomm_grid(offset=grid_offset, tp=tp_size, cp=cp_size, pp=pp_size, dp=dp_size) + if grid.rank_offset <= current_rank < grid.rank_offset + grid.size: + pg_collection = _get_pg_collection_from_grid(grid) + block = _create_transformer_block( + dtype=dtype, hidden_size=hidden_size, pg_collection=pg_collection + ) + ddp_config = DistributedDataParallelConfig(overlap_grad_reduce=True, bucket_size=10000) + block = DistributedDataParallel( + config=block.config, ddp_config=ddp_config, module=block, pg_collection=pg_collection + ) + block.pre_process = False + block.post_process = False + block.share_embeddings_and_output_weights = False + + + else: + block = None + + return block, grid + + +def _populate_embedding_and_position_groups(pp_group): + """Create *new* embedding-related process groups from *pp_group* ranks.""" + + pp_ranks = sorted(dist.get_process_group_ranks(pp_group)) + + pos_embd_ranks = [pp_ranks[0]] + embd_ranks = [pp_ranks[0]] + if pp_ranks[-1] != pp_ranks[0]: + embd_ranks.append(pp_ranks[-1]) + + pos_embd_pg = dist.new_group(ranks=pos_embd_ranks) + embd_pg = dist.new_group(ranks=embd_ranks) + + return pos_embd_pg, embd_pg + + +def _get_pg_collection_with_embedding_groups(grid): + pg_collection = _get_pg_collection_from_grid(grid) + if pg_collection.pp: + pos_embd_pg, embd_pg = _populate_embedding_and_position_groups(pg_collection.pp) + pos_embd_pg = pos_embd_pg if is_pp_first_stage(pg_collection.pp) else None + embd_pg = ( + embd_pg + if (is_pp_last_stage(pg_collection.pp) or is_pp_first_stage(pg_collection.pp)) + else None + ) + pg_collection.pos_embd = pos_embd_pg + pg_collection.embd = embd_pg + + return pg_collection + + +@pytest.mark.skipif( + version.parse(torch.__version__) < version.parse('2.3.0'), + reason="Device mesh feature requires PyTorch 2.3 or later", +) +@pytest.mark.parametrize( + "encoder_tp,encoder_pp,encoder_dp,llm_tp,llm_pp,llm_dp,llm_grid_offset", [(2, 2, 1, 2, 2, 1, 4)] +) +def test_forward_backward_pipelining_without_interleaving_multi_module_single_encoder( + mocker, encoder_tp, encoder_pp, encoder_dp, llm_tp, llm_pp, llm_dp, llm_grid_offset +): + + Utils.initialize_distributed() + + def step_func(data_iterator, model): + + def loss_func(output_tensor_dict: Dict[str, torch.Tensor]): + assert ( + 'llm' in output_tensor_dict + ), f'llm is not in output_tensor_dict: {output_tensor_dict}' + loss = output_tensor_dict['llm'].sum() + return loss, {'loss_reduced': loss} + + if data_iterator is not None: + input_tensor = next(data_iterator) + else: + input_tensor = None + + model_output = model(input_tensor) + + return model_output, loss_func + + sequence_length = 512 + micro_batch_size = 1 + hidden_size = 1024 + + # Create model + model = SingleEncoderModel( + hidden_size=hidden_size, + encoder_tp=encoder_tp, + encoder_pp=encoder_pp, + encoder_dp=encoder_dp, + llm_tp=llm_tp, + llm_pp=llm_pp, + llm_dp=llm_dp, + llm_grid_offset=llm_grid_offset, + ) + model.model_type = 'unit-test' + + module_to_grid_map = {'encoder': model.encoder_grid, 'llm': model.llm_grid} + topology = { + 'encoder': ['llm'], # image_encoder sends forward results to llm + 'llm': [], # llm is the last stage here + } + config = ModelParallelConfig(pipeline_dtype=torch.bfloat16) + config.calculate_per_token_loss = False + config.qk_layernorm = False + config.sequence_parallel = False + config.moe_router_enable_expert_bias = False + config.moe_router_load_balancing_type = "aux_loss" + config.variable_seq_lengths = True + config.no_sync_func = model.no_sync + config.finalize_model_grads_func = model.finalize_model_grads + + + # Add grad scale function to convert float losses to tensors + def grad_scale_func(loss): + if isinstance(loss, (int, float)): + return torch.tensor(loss, dtype=torch.float32, device='cuda', requires_grad=True) + else: + return loss # Already a tensor + + config.grad_scale_func = grad_scale_func + model.config = config + config.hidden_size = hidden_size + + multimodule_communicator = MultiModulePipelineCommunicator( + module_to_grid_map, topology, config, dim_mapping={'s': 0, 'h': 2, 'b': 1} + ) + + data_iterator = None + if model.is_current_rank_in_grid(model.encoder_grid) and is_pp_first_stage( + model.encoder_grid.get_pg("pp") + ): + data_iterator = DataIterator( + hidden_size=hidden_size, seq_length=sequence_length, micro_batch_size=micro_batch_size + ) + + common_args = { + 'forward_step_func': step_func, + 'data_iterator': data_iterator, + 'model': [model], + 'num_microbatches': 16, + 'seq_length': sequence_length, + 'micro_batch_size': micro_batch_size, + 'forward_only': False, + } + + if 0 <= dist.get_rank() < 4: + pg_collection = _get_pg_collection_with_embedding_groups(model.encoder_grid) + elif 4 <= dist.get_rank() < 8: + pg_collection = _get_pg_collection_with_embedding_groups(model.llm_grid) + else: + raise ValueError(f"Rank {dist.get_rank()} is not valid") + + losses_reduced_explicit = schedule.forward_backward_pipelining_without_interleaving( + p2p_communicator=multimodule_communicator, pg_collection=pg_collection, **common_args + ) + logging.info(f"Losses reduced explicit: {losses_reduced_explicit}") + + +@pytest.mark.skipif( + version.parse(torch.__version__) < version.parse('2.3.0'), + reason="Device mesh feature requires PyTorch 2.3 or later", +) +@pytest.mark.parametrize( + "encoder_tp,encoder_pp,encoder_dp,llm_tp,llm_pp,llm_dp,llm_grid_offset", [(2, 2, 1, 2, 2, 1, 4)] +) +def test_forward_backward_pipelining_without_interleaving_multi_module_dual_encoder( + mocker, encoder_tp, encoder_pp, encoder_dp, llm_tp, llm_pp, llm_dp, llm_grid_offset +): + Utils.initialize_distributed() + + def step_func(data_iterator, model): + + def loss_func(output_tensor_dict: Dict[str, torch.Tensor]): + assert ( + 'llm' in output_tensor_dict + ), f'llm is not in output_tensor_dict: {output_tensor_dict}' + loss = output_tensor_dict['llm'].sum() + return loss, {'loss_reduced': loss} + + if data_iterator is not None: + input_tensor = next(data_iterator) + else: + input_tensor = None + + model_output = model(input_tensor) + + return model_output, loss_func + + sequence_length = 512 + micro_batch_size = 1 + hidden_size = 1024 + + # Create model + model = DualEncoderModel( + hidden_size=hidden_size, + encoder_tp=encoder_tp, + encoder_pp=encoder_pp, + encoder_dp=encoder_dp, + llm_tp=llm_tp, + llm_pp=llm_pp, + llm_dp=llm_dp, + llm_grid_offset=llm_grid_offset, + ) + model.model_type = 'unit-test' + + module_to_grid_map = {'encoder_1': model.encoder_1_grid, 'encoder_2': model.encoder_2_grid, 'llm': model.llm_grid} + topology = { + 'encoder_1': ['llm'], # encoder_1 sends forward results to llm + 'encoder_2': ['llm'], # encoder_2 sends forward results to llm + 'llm': [], # llm is the last stage here + } + config = ModelParallelConfig(pipeline_dtype=torch.bfloat16) + config.finalize_model_grads_func = model.finalize_model_grads + config.calculate_per_token_loss = False + config.qk_layernorm = False + config.sequence_parallel = False + config.moe_router_enable_expert_bias = False + config.moe_router_load_balancing_type = "aux_loss" + config.variable_seq_lengths = True + config.no_sync_func = model.no_sync + + # Add grad scale function to convert float losses to tensors + def grad_scale_func(loss): + if isinstance(loss, (int, float)): + return torch.tensor(loss, dtype=torch.float32, device='cuda', requires_grad=True) + else: + return loss # Already a tensor + + config.grad_scale_func = grad_scale_func + model.config = config + config.hidden_size = hidden_size + + multimodule_communicator = MultiModulePipelineCommunicator( + module_to_grid_map, topology, config, dim_mapping={'s': 0, 'h': 2, 'b': 1} + ) + + data_iterator = None + if model.is_current_rank_in_grid(model.encoder_1_grid) and is_pp_first_stage( + model.encoder_1_grid.get_pg("pp") + ): + data_iterator = DataIterator( + hidden_size=hidden_size, seq_length=sequence_length, micro_batch_size=micro_batch_size + ) + + common_args = { + 'forward_step_func': step_func, + 'data_iterator': data_iterator, + 'model': [model], + 'num_microbatches': 16, + 'seq_length': sequence_length, + 'micro_batch_size': micro_batch_size, + 'forward_only': False, + } + + if 0 <= dist.get_rank() < 4: + pg_collection_encoder_1 = _get_pg_collection_with_embedding_groups(model.encoder_1_grid) + pg_collection_encoder_2 = _get_pg_collection_with_embedding_groups(model.encoder_2_grid) + pg_collection = [pg_collection_encoder_1, pg_collection_encoder_2] + elif 4 <= dist.get_rank() < 8: + pg_collection_llm = _get_pg_collection_with_embedding_groups(model.llm_grid) + pg_collection = [pg_collection_llm] + else: + raise ValueError(f"Rank {dist.get_rank()} is not valid") + + losses_reduced_explicit = schedule.forward_backward_pipelining_without_interleaving( + p2p_communicator=multimodule_communicator, pg_collection=pg_collection, **common_args + ) + logging.info(f"Losses reduced explicit: {losses_reduced_explicit}") + + +if __name__ == "__main__": + from unittest.mock import Mock + + # Set logging level to DEBUG + logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') + + # Create a mock object that mimics pytest-mock's mocker + mock_mocker = Mock() + + # Use the same parameters as defined in the pytest.mark.parametrize decorator + test_forward_backward_pipelining_without_interleaving_multi_module_single_encoder( + mock_mocker, + encoder_tp=2, + encoder_pp=2, + encoder_dp=1, + llm_tp=2, + llm_pp=2, + llm_dp=1, + llm_grid_offset=4 + )