From 0d941fd7aca0940bf6b90ef7fdb38ef1fce860d9 Mon Sep 17 00:00:00 2001 From: Jiayi Yan <1195343015@qq.com> Date: Sun, 2 Feb 2025 16:45:50 +0800 Subject: [PATCH 1/3] update parallel state --- README.md | 14 +- aicb.py | 8 +- log_analyzer/log.py | 40 +-- run_suites.py | 8 +- scripts/megatron_gpt.sh | 23 +- scripts/megatron_workload_with_aiob.sh | 15 +- training/tutorial.md | 10 +- utils/utils.py | 138 ++++++---- workload_applyer.py | 256 +++++++++++++----- .../AIOB_simAI_workload_generator.py | 16 +- .../generate_collective_test.py | 4 +- .../generate_deepspeed_stage1_2_workload.py | 2 +- .../generate_deepspeed_stage3_workload.py | 2 +- .../generate_megatron_workload.py | 42 +-- .../mocked_model/AiobMegatron.py | 2 +- .../mocked_model/MockedMegatron.py | 4 + workload_generator/workload_generator.py | 2 +- 17 files changed, 383 insertions(+), 203 deletions(-) diff --git a/README.md b/README.md index d6a09d6..2588d74 100755 --- a/README.md +++ b/README.md @@ -140,7 +140,7 @@ For the `Megatron parallel framework`, you can quickly start using the scripts/m ```bash sh scripts/megatron_gpt.sh \ --nnodes 1 --node_rank 0 --nproc_per_node 8 --master_addr localhost --master_port 29500 \ --m 7 --world_size 8 --tensor_model_parallel_size 2 --pipeline_model_parallel 1 \ +-m 7 --world_size 8 --tensor_model_parallel_size 2 --pipeline_model_parallel_size 1 \ --frame Megatron --global_batch 16 \ --micro_batch 1 --seq_length 2048 --swiglu --use_flash_attn --aiob_enable ``` @@ -150,7 +150,7 @@ For `Moe` , you can quickly start it using the [scripts/megatron_gpt.sh](scripts ```bash sh scripts/megatron_gpt.sh \ --nnodes 1 --node_rank 0 --nproc_per_node 8 --master_addr localhost --master_port 29500 \ --m moe --world_size 8 --tensor_model_parallel_size 4 --pipeline_model_parallel 1 \ +-m moe --world_size 8 --tensor_model_parallel_size 4 --pipeline_model_parallel_size 1 \ --moe_enable --expert_model_parallel_size 1 \ --frame Megatron --global_batch 16 \ --num_experts 4 --moe_router_topk 2 \ @@ -177,7 +177,7 @@ Note that the computation times are obtained through the execution of computatio The following commands does not generate the computation descrition file, but also run the workload in the real GPU cluster. ```bash sh scripts/megatron_gpt.sh \ --m 7 --world_size 8 --tensor_model_parallel_size 2 --pipeline_model_parallel 1 \ +-m 7 --world_size 8 --tensor_model_parallel_size 2 --pipeline_model_parallel_size 1 \ --frame Megatron --global_batch 16 \ --micro_batch 1 --seq_length 2048 \ --swiglu --use_flash_attn --aiob_enable @@ -187,7 +187,7 @@ Users can defined their own computation times or directly use the files we provi By specifying the computation description file with the `--comp_filepath` option, you can embed computation times before running the workload on a physical machine. ```bash sh scripts/megatron_gpt.sh \ --m 7 --world_size 8 --tensor_model_parallel_size 2 --pipeline_model_parallel 1 \ +-m 7 --world_size 8 --tensor_model_parallel_size 2 --pipeline_model_parallel_size 1 \ --frame Megatron --global_batch 16 --micro_batch 1 \ --seq_length 2048 --swiglu --use_flash_attn \ --aiob_enable \ @@ -206,7 +206,7 @@ Here, you can use the script [scripts/megatron_workload.sh](scripts/megatron_wor ```bash sh ./scripts/megatron_workload_with_aiob.sh \ -m 7 --world_size 4096 \ ---tensor_model_parallel_size 2 --pipeline_model_parallel 1 \ +--tensor_model_parallel_size 2 --pipeline_model_parallel_size 1 \ --frame Megatron --global_batch 8192 \ --micro_batch 1 --seq_length 4096 \ --swiglu --use_flash_attn --aiob_enable @@ -215,7 +215,7 @@ sh ./scripts/megatron_workload_with_aiob.sh \ ```bash sh ./scripts/megatron_workload_with_aiob.sh -m 7 \ ---world_size 4096 --tensor_model_parallel_size 2 --pipeline_model_parallel 1 \ +--world_size 4096 --tensor_model_parallel_size 2 --pipeline_model_parallel_size 1 \ --frame Megatron --global_batch 8192 \ --micro_batch 1 --seq_length 4096 --swiglu \ --use_flash_attn --aiob_enable \ @@ -225,7 +225,7 @@ sh ./scripts/megatron_workload_with_aiob.sh -m 7 \ For the Moe, you can also use [scripts/megatron_workload_with_aiob.sh](scripts/workload_megatron.sh) to generate the corresponding model's workload file. ```bash sh scripts/megatron_workload_with_aiob.sh \ --m moe --world_size 512 --tensor_model_parallel_size 2 --pipeline_model_parallel 1 --sp --ep 16 \ +-m moe --world_size 512 --tensor_model_parallel_size 2 --pipeline_model_parallel_size 1 --sp --ep 16 \ --num_experts 64 --moe_router_topk 2 --moe_grouped_gemm --moe_enable \ --frame Megatron --global_batch 1024 \ --micro_batch 1 --seq_length 4096 --swiglu \ diff --git a/aicb.py b/aicb.py index 6b0e3a7..6236219 100755 --- a/aicb.py +++ b/aicb.py @@ -57,15 +57,15 @@ else: filepath = get_aiob_path(args) torch.distributed.barrier() - compute_cache = extract_averages(filepath,args) + compute_cache = extract_averages(filepath, args) else: print("comp_filepath:", args.comp_filepath) - compute_cache = extract_averages(args.comp_filepath,args) + compute_cache = extract_averages(args.comp_filepath, args) workload = Comp_with_aiob(workload, compute_cache) if torch.distributed.get_rank() == 0: filename = f"{workload_generator.name}_{args.model_name}_sp_{args.enable_sequence_parallel}_iteration_{args.epoch_num}_computationEnable_{args.computation_enable}_{args.world_size}n.csv" workload.dump(filename) - if not args.workload_only : + if not args.workload_only: applyer = WorkloadApplyer(workload=workload, args=args) cpu_time = applyer.apply_workload() if torch.distributed.get_rank() == 0: @@ -76,7 +76,7 @@ if args.enable_visual: try: from visualize.generate import visualize_output - visualize_output(csv_filename,False) + visualize_output(csv_filename, False) except ImportError: print("visualize_output is not available because required library is not found") diff --git a/log_analyzer/log.py b/log_analyzer/log.py index 92ee39e..9cfffee 100755 --- a/log_analyzer/log.py +++ b/log_analyzer/log.py @@ -227,25 +227,26 @@ def _get_elapsed_time(self): return self.epoch_times def analyze_time(self, print_fn=print): - self.epoch_times.pop(0) - max_val = max(self.epoch_times) - min_val = min(self.epoch_times) - mean_val = sum(self.epoch_times) / len(self.epoch_times) - - variance = sum((x - mean_val) ** 2 for x in self.epoch_times) / len( - self.epoch_times - ) - variance = math.sqrt(variance) + if self.epoch_times: + self.epoch_times.pop(0) + max_val = max(self.epoch_times) + min_val = min(self.epoch_times) + mean_val = sum(self.epoch_times) / len(self.epoch_times) + + variance = sum((x - mean_val) ** 2 for x in self.epoch_times) / len( + self.epoch_times + ) + variance = math.sqrt(variance) - sorted_list = sorted(self.epoch_times) - p90_val = sorted_list[int(len(sorted_list) * 0.9)] - p99_val = sorted_list[int(len(sorted_list) * 0.99)] - header = f"{'Init time':<18} {'Max iteration time':<20} {'Min iteration time':<20} {'Avg iteration time':<20} {'P90 iteration time ':<20} {'Iteration time Std ':<20}\n" - separator = "-" * len(header) + "\n" - log_str = separator + header + separator - iteration_result = f"{self.epoch_times[0]:<18.2f} {max_val:<20.2f} {min_val:<20.2f} {mean_val:<20.2f} {p90_val:<20.2f} {variance:<20.2f}\n" - log_str += iteration_result - print_fn(f"\n\tDetailed info for AICB iteration time\n{log_str}") + sorted_list = sorted(self.epoch_times) + p90_val = sorted_list[int(len(sorted_list) * 0.9)] + p99_val = sorted_list[int(len(sorted_list) * 0.99)] + header = f"{'Init time':<18} {'Max iteration time':<20} {'Min iteration time':<20} {'Avg iteration time':<20} {'P90 iteration time ':<20} {'Iteration time Std ':<20}\n" + separator = "-" * len(header) + "\n" + log_str = separator + header + separator + iteration_result = f"{self.epoch_times[0]:<18.2f} {max_val:<20.2f} {min_val:<20.2f} {mean_val:<20.2f} {p90_val:<20.2f} {variance:<20.2f}\n" + log_str += iteration_result + print_fn(f"\n\tDetailed info for AICB iteration time\n{log_str}") class Workload: @@ -254,7 +255,8 @@ def __init__(self) -> None: def append(self, log_item: Union[LogItem, Dict]): if isinstance(log_item, LogItem): - self.workload.append(log_item) + if log_item.comm_group_size != 1: + self.workload.append(log_item) return if "stage" not in log_item: log_item["stage"] = log_item["operation"] if "operation" in log_item else "" diff --git a/run_suites.py b/run_suites.py index 941d5aa..c016cc4 100755 --- a/run_suites.py +++ b/run_suites.py @@ -76,11 +76,11 @@ def read_config(config): ) if int(megatron_conf["gpt_175B"]): running_command["megatron_gpt175B"] = ( - f"bash scripts/megatron_gpt.sh -m 175 --tensor_model_parallel_size 8 --epoch_num 10 --pipeline_model_parallel 2 --sp" + f"bash scripts/megatron_gpt.sh -m 175 --tensor_model_parallel_size 8 --epoch_num 10 --pipeline_model_parallel_size 2 --sp" ) if int(megatron_conf["gpt_175B_tp"]): running_command["megatron_gpt175B_tp"] = ( - f"bash scripts/megatron_gpt.sh -m 175 --tensor_model_parallel_size 8 --epoch_num 10 --pipeline_model_parallel 2" + f"bash scripts/megatron_gpt.sh -m 175 --tensor_model_parallel_size 8 --epoch_num 10 --pipeline_model_parallel_size 2" ) if int(megatron_conf["gpt_22B"]): running_command["megatron_gpt_22B"] = ( @@ -104,11 +104,11 @@ def read_config(config): ) if int(aiob_conf["gpt_175B_aiob"]): running_command["megatron_gpt175B_aiob"] = ( - f"bash scripts/megatron_gpt.sh -m 175 --tensor_model_parallel_size 8 --epoch_num 10 --aiob_enable --pipeline_model_parallel 2 --sp" + f"bash scripts/megatron_gpt.sh -m 175 --tensor_model_parallel_size 8 --epoch_num 10 --aiob_enable --pipeline_model_parallel_size 2 --sp" ) if int(aiob_conf["gpt_175B_tp_aiob"]): running_command["megatron_gpt175B_tp_aiob"] = ( - f"bash scripts/megatron_gpt.sh -m 175 --tensor_model_parallel_size 8 --epoch_num 10 --aiob_enable --pipeline_model_parallel 2 " + f"bash scripts/megatron_gpt.sh -m 175 --tensor_model_parallel_size 8 --epoch_num 10 --aiob_enable --pipeline_model_parallel_size 2 " ) if int(aiob_conf["gpt_22B_aiob"]): running_command["megatron_gpt_22B_aiob"] = ( diff --git a/scripts/megatron_gpt.sh b/scripts/megatron_gpt.sh index e59080a..283f199 100755 --- a/scripts/megatron_gpt.sh +++ b/scripts/megatron_gpt.sh @@ -14,7 +14,8 @@ seq_length=2048 micro_batch=1 epoch_num=1 tensor_model_parallel_size=8 -pipeline_model_parallel=1 +pipeline_model_parallel_size=1 +context_parallel_size=1 vocab_size=50257 model_name=gpt_13b ga_num=2 @@ -32,7 +33,8 @@ usage() { --frame Communication framework: $frame --world_size World size (number of nodes): $WORLD_SIZE --tensor_model_parallel_size Tensor parallelism size: $tensor_model_parallel_size - --pipeline_model_parallel Pipeline parallelism size: $pipeline_model_parallel + --pipeline_model_parallel_size Pipeline parallelism size: $pipeline_model_parallel_size + --context_parallel_size Context parallelism size: $context_parallel_size --global_batch Global batch size: $global_batch --micro_batch Micro batch size: $micro_batch --num_layers Number of layers: $num_layers @@ -72,8 +74,10 @@ echo "Processing argument: $1" world_size=$2; shift;; --tensor_model_parallel_size|tp_num) tensor_model_parallel_size=$2; shift;; - --pipeline_model_parallel|pp_num) - pipeline_model_parallel=$2; shift;; + --pipeline_model_parallel_size|pp_num) + pipeline_model_parallel_size=$2; shift;; + --context_parallel_size|pp_num) + context_parallel_size=$2; shift;; --global_batch) global_batch=$2; shift;; --micro_batch) @@ -169,7 +173,7 @@ case $model_size in ffn_hidden_size=53248 num_attention_heads=128 tensor_model_parallel_size=8 - pipeline_model_parallel=16 + pipeline_model_parallel_size=16 ;; 65) model_name=llama_65B @@ -178,7 +182,7 @@ case $model_size in ffn_hidden_size=28672 num_attention_heads=64 tensor_model_parallel_size=8 - pipeline_model_parallel=2 + pipeline_model_parallel_size=2 ;; moe) model_name=Mixtral_8*7B @@ -199,8 +203,8 @@ case $model_size in ;; esac -dp_num=$((world_size/tensor_model_parallel_size/pipeline_model_parallel)) -global_batch=$((ga_num*dp_num*micro_batch)) +data_parallel_size=$((world_size/tensor_model_parallel_size/pipeline_model_parallel_size)) +global_batch=$((ga_num*data_parallel_size*micro_batch)) if [ $workload_only ]; then script="python -m workload_generator.generate_megatron_workload" else @@ -220,7 +224,8 @@ cmd="$script \ --num_attention_heads=$num_attention_heads \ --seq_length=$seq_length \ --vocab_size=$vocab_size \ - --pipeline_model_parallel=$pipeline_model_parallel \ + --pipeline_model_parallel_size=$pipeline_model_parallel_size \ + --context-parallel-size=$context_parallel_size \ --use-distributed-optimizer \ --max_position_embeddings=$max_position_embeddings \ ${aiob_enable} \ diff --git a/scripts/megatron_workload_with_aiob.sh b/scripts/megatron_workload_with_aiob.sh index 2c3ed94..97c9ec2 100755 --- a/scripts/megatron_workload_with_aiob.sh +++ b/scripts/megatron_workload_with_aiob.sh @@ -4,7 +4,8 @@ frame=Megatron world_size=32 tensor_model_parallel_size=8 -pipeline_model_parallel=1 +pipeline_model_parallel_size=1 +context_parallel_size=1 global_batch=1024 micro_batch=1 num_layers=40 @@ -31,7 +32,8 @@ usage() { --frame communication framework, defaults to $frame --world_size world size, defaults to $world_size --tensor_model_parallel_size tensor parallelism size, defaults to $tensor_model_parallel_size - --pipeline_model_parallel pipeline parallelism size, defaults to $pipeline_model_parallel + --pipeline_model_parallel_size pipeline parallelism size, defaults to $pipeline_model_parallel_size + --context_parallel_size context parallelism size, defaults to $context_parallel_size --global_batch global batch size, defaults to $global_batch --micro_batch micro batch size, defaults to $micro_batch --num_layers number of layers, defaults to $num_layers @@ -68,8 +70,10 @@ do world_size=$2; shift;; --tensor_model_parallel_size|--tp) tensor_model_parallel_size=$2; shift;; - --pipeline_model_parallel|--pp) - pipeline_model_parallel=$2; shift;; + --pipeline_model_parallel_size|--pp) + pipeline_model_parallel_size=$2; shift;; + --context_parallel_size|--cp) + context_parallel_size=$2; shift;; --global_batch) global_batch=$2; shift;; --micro_batch) @@ -181,7 +185,8 @@ cmd="python -m workload_generator.AIOB_simAI_workload_generator \ --frame=$frame \ --world_size=$world_size \ --tensor_model_parallel_size=$tensor_model_parallel_size \ - --pipeline_model_parallel=$pipeline_model_parallel \ + --pipeline_model_parallel_size=$pipeline_model_parallel_size \ + --context-parallel-size=$context_parallel_size \ --global_batch=$global_batch \ --micro_batch=$micro_batch \ --num_layers=$num_layers \ diff --git a/training/tutorial.md b/training/tutorial.md index cd5fc50..3e6e1ed 100755 --- a/training/tutorial.md +++ b/training/tutorial.md @@ -45,7 +45,7 @@ export WORLD_SIZE=1 export RANK=0 sh ./scripts/megatron_gpt.sh \ --m 13 --world_size 8 --tensor_model_parallel_size 8 --pipeline_model_parallel 1 \ +-m 13 --world_size 8 --tensor_model_parallel_size 8 --pipeline_model_parallel_size 1 \ --frame Megatron --global_batch 2 \ --micro_batch 1 --seq_length 4096 \ --swiglu --use_flash_attn --aiob_enable @@ -107,7 +107,7 @@ We provide four pre-existing models (7/13/22/175)B to quickly generate the corre Below is an example of generating a Workload with a model size of 7B, tp 4, pp 1, a total GPU count of 4096, gbs 8192, mbs 1, sequence length of 4096, with flash_attn, swiglu, and aiob enabled, and reading Example.txt as the computation time. ```bash sh ./scripts/megatron_workload_with_aiob.sh -m 7 \ ---world_size 4096 --tensor_model_parallel_size 4 --pipeline_model_parallel 1 \ +--world_size 4096 --tensor_model_parallel_size 4 --pipeline_model_parallel_size 1 \ --frame Megatron --global_batch 8192 \ --micro_batch 1 --seq_length 4096 --swiglu \ --use_flash_attn --aiob_enable \ @@ -140,7 +140,7 @@ The main parameters for AICB are as follows: | | max_position_embeddings | Maximum number of position embeddings to use. | | | ffn_hidden_size | Transformer Feed-Forward Network hidden size. | | Megatron parallel parameters | tensor_model_parallel_size | Degree of tensor model parallelism. | -| | pipeline_model_parallel | Degree of pipeline model parallelism. | +| | pipeline_model_parallel_size | Degree of pipeline model parallelism. | | | enable_sequence_parallel | Enable sequence parallel optimization. | | Megatron optimization parameters | use_flash_attn | Use FlashAttention implementation of attention. | | | swiglu | Use gated linear units and SiLU activation instead of default gelu | @@ -210,7 +210,7 @@ Here is an example: ```bash python -m workload_generator.AIOB_simAI_workload_generator \ --model_name GPT-13B --frame=Megatron \ - --world_size=16 --tensor_model_parallel_size=2 --pipeline_model_parallel=1 --global_batch=16 \ + --world_size=16 --tensor_model_parallel_size=2 --pipeline_model_parallel_size=1 --global_batch=16 \ --micro_batch=1 --num_layers=40 --seq_length=2048 \ --hidden_size=5120 --epoch_num=1 \ --use-distributed-optimizer --num_attention_heads=40 \ @@ -288,7 +288,7 @@ Here is a brief example of training process and workload item: ```python trainer.init() for _ in range(epoch_num): - if pipeline_model_parallel > 1: + if pipeline_model_parallel_size > 1: trainer.with_pipeline_forward_backward() else: for _ in range(num_microbatches): diff --git a/utils/utils.py b/utils/utils.py index 537c94e..e278eed 100755 --- a/utils/utils.py +++ b/utils/utils.py @@ -137,14 +137,16 @@ def decompose(index, shape, stride=None): ) ranks.append(rank) return ranks + class RankGenerator(object): - def __init__(self, tp: int, ep: int, dp: int, pp: int, cp: int, order: str) -> None: + def __init__(self, tp: int, ep: int, dp: int, pp: int, cp: int, order: str, rank_offset: int = 0) -> None: self.tp = tp self.ep = ep self.dp = dp self.pp = pp self.cp = cp - self.world_size = tp * dp * pp * cp + self.rank_offset = rank_offset + self.world_size = tp * dp * pp * cp * ep self.name_to_size = { "tp": self.tp, @@ -163,59 +165,49 @@ def __init__(self, tp: int, ep: int, dp: int, pp: int, cp: int, order: str) -> N for name in self.name_to_size.keys(): if name not in order and self.name_to_size[name] != 1: raise RuntimeError( - f"The size of ({name}) is ({self.name_to_size[name]}), but you haven't specified the order ({self.order})." + f"The size of ({name}) is ({self.name_to_size[name]}), but you haven't" + f"specified the order ({self.order})." ) elif name not in order: order = order + '-' + name - self.order_w_ep = order - self.order_wo_ep = '-'.join([token for token in order.split('-') if token != 'ep']) - self.ordered_size_wo_ep = [] - self.ordered_size_w_ep = [] + self.order = order + self.ordered_size = [] for token in order.split('-'): - if token == 'dp': - self.ordered_size_w_ep.append(self.dp // self.ep) - self.ordered_size_wo_ep.append(self.dp) - elif token == 'ep': - self.ordered_size_w_ep.append(self.ep) - else: - self.ordered_size_w_ep.append(self.name_to_size[token]) - self.ordered_size_wo_ep.append(self.name_to_size[token]) + self.ordered_size.append(self.name_to_size[token]) def get_mask(self, order: str, token: str): + """Create a mask for the specified tokens based on the given order. + + Args: + order (str): The order of parallelism types (e.g., 'tp-dp-pp'). + token (str): The specific parallelism types to include in the mask, + separated by hyphens (e.g., 'tp-dp'). + """ ordered_token = order.split('-') - token = token.split('-') + token_list = token.split('-') mask = [False] * len(ordered_token) - for t in token: + for t in token_list: mask[ordered_token.index(t)] = True return mask - def get_ranks(self, token, independent_ep=False): - '''Get rank group by input token. + def get_ranks(self, token): + """Get rank group by input token. - Arguments: + Args: token (str): Specify the ranks type that want to get. If we want to obtain multiple parallel types, we can use a hyphen '-' to separate them. For example, if we want to obtain the TP_DP group, the token should be 'tp-dp'. - - independent_ep (bool: True): - This flag controls whether we treat EP and DP independently. - EP shares ranks with DP, if we want to get ranks related to - EP, we should set the flag. For example, get_ranks('dp', True) - will get DP modulo EP group, and get_ranks('dp', False) will - get full DP group. - ''' - if independent_ep: - parallel_size = self.ordered_size_w_ep - order = self.order_w_ep - else: - parallel_size = self.ordered_size_wo_ep - order = self.order_wo_ep - mask = self.get_mask(order, token) - ranks = generate_masked_orthogonal_rank_groups(self.world_size, parallel_size, mask) + """ + mask = self.get_mask(self.order, token) + ranks = generate_masked_orthogonal_rank_groups(self.world_size, self.ordered_size, mask) + if self.rank_offset > 0: + for rank_group in ranks: + for i in range(len(rank_group)): + rank_group[i] += self.rank_offset return ranks def gelu_impl(x): @@ -468,15 +460,20 @@ def get_comm_type(cls, value): class CommGroup(str, Enum): """Enum class for possible comm groups""" - dp_group = "dp_group" - pp_group = "pp_group" + dp_cp_group = "dp_cp_group" + cp_group = "cp_group" + tp_pp_group = "tp_pp_group" tp_group = "tp_group" + pp_group = "pp_group" + tp_dp_cp_group = "tp_dp_cp_group" + tp_dp_group = "tp_dp_group" + tp_cp_group = "tp_cp_group" ep_group = "ep_group" - ep_dp_group = "ep_dp_group" - ep_tp_group = "ep_tp_group" - embedding_group = "embedding_group" - all = "all_nodes" + exp_tp_group = "exp_tp_group" + tp_ep_group = "tp_ep_group" + tp_ep_pp_group = "tp_ep_pp_group" + exp_dp_group = "exp_dp_group" class WorkloadWriter: @@ -508,10 +505,15 @@ def get_params(): help="Number of GPUs") parser.add_argument("--tensor_model_parallel_size", type=int, default=1, help='Degree of tensor model parallelism.') - parser.add_argument("--pipeline_model_parallel", type=int, default=1, + parser.add_argument("--pipeline_model_parallel_size", type=int, default=1, help='Degree of pipeline model parallelism.') + parser.add_argument("--encoder_tensor_model_parallel_size", type=int, default=0, + help='Degree of encoder tensor model parallelism.') + parser.add_argument("--encoder_pipeline_model_parallel_size", type=int, default=0, + help='Degree of encoder pipeline model parallelism.') parser.add_argument('--context-parallel-size', type=int, default=1, help='Degree of context parallelism.') + parser.add_argument("--pp_rank", type=int, default=-1, help='Rank where encoder and decoder should be split.') parser.add_argument("--global_batch", type=int, default=4, @@ -553,16 +555,13 @@ def get_params(): args = parser.parse_args() assert ( - args.world_size % (args.tensor_model_parallel_size * args.pipeline_model_parallel) == 0 - ), f"world size: {args.world_size}, tp: {args.tensor_model_parallel_size}, pp: {args.pipeline_model_parallel}" + args.world_size % (args.tensor_model_parallel_size * args.pipeline_model_parallel_size) == 0 + ), f"world size: {args.world_size}, tp: {args.tensor_model_parallel_size}, pp: {args.pipeline_model_parallel_size}" if args.moe_enable: assert ( args.moe_enable and args.enable_sequence_parallel ), f"moe must be enabled with sequence parallel" - args.dp_num = args.world_size // (args.tensor_model_parallel_size * args.pipeline_model_parallel) - # assert args.global_batch % (args.dp_num * args.micro_batch) == 0, \ - # f"global_batch: {args.global_batch}, dp: {args.dp_num}, micro_batch: {args.micro_batch}" - args.num_microbatches = args.global_batch // (args.dp_num * args.micro_batch) + if args.aiob_enable and not args.computation_enable: args.computation_enable = True @@ -594,8 +593,43 @@ def get_params(): "Expert parallelism is not supported with fp16 training." if args.moe_grouped_gemm: assert args.dtype == "bfloat16", 'Currently GroupedGEMM for MoE only supports bf16 dtype.' - if args.pipeline_model_parallel > 1 : - args.num_layers = int(args.num_layers//args.pipeline_model_parallel) + if args.pipeline_model_parallel_size > 1 : + args.num_layers = int(args.num_layers // args.pipeline_model_parallel_size) + + if args.encoder_pipeline_model_parallel_size == 0 and args.num_experts == 0: + assert args.encoder_tensor_model_parallel_size == args.tensor_model_parallel_size, "If non-MOE encoder shares first decoder pipeline rank it must have the same TP as the decoder." + + if args.encoder_tensor_model_parallel_size > 0: + assert args.num_attention_heads % args.encoder_tensor_model_parallel_size == 0 + assert args.encoder_tensor_model_parallel_size <= args.tensor_model_parallel_size, "We do not support encoders with more TP than the decoder." + + if args.encoder_pipeline_model_parallel_size > 0 and args.encoder_tensor_model_parallel_size == 0: + args.encoder_tensor_model_parallel_size = args.tensor_model_parallel_size + + encoder_model_size = args.encoder_tensor_model_parallel_size * args.encoder_pipeline_model_parallel_size * args.context_parallel_size + decoder_model_size = args.tensor_model_parallel_size * args.pipeline_model_parallel_size * args.context_parallel_size + total_model_size = encoder_model_size + decoder_model_size + + # Total model size. + assert args.world_size % total_model_size == 0, ( + f"world size ({args.world_size}) is not divisible by total_model_size ({encoder_model_size=} + {decoder_model_size=})" + ) + + args.data_parallel_size = args.world_size // total_model_size + + # assert args.global_batch % (args.data_parallel_size * args.micro_batch) == 0, \ + # f"global_batch: {args.global_batch}, dp: {args.data_parallel_size}, micro_batch: {args.micro_batch}" + args.num_microbatches = args.global_batch // (args.data_parallel_size * args.micro_batch) + + if args.expert_tensor_parallel_size is None: + args.expert_tensor_parallel_size = args.tensor_model_parallel_size + + + if args.seq_length is not None and args.context_parallel_size > 1: + assert args.seq_length % (args.context_parallel_size * 2) == 0, \ + 'seq-length should be a multiple of 2 * context-parallel-size ' \ + 'if context-parallel-size > 1.' + return args @@ -710,6 +744,8 @@ def get_simAI_workload_params(parser: argparse.ArgumentParser): def get_moe_params(parser: argparse.ArgumentParser): parser.add_argument('--moe_enable', action="store_true") parser.add_argument('--expert_model_parallel_size', type=int, default=1, help='Degree of expert model parallelism.') + parser.add_argument('--expert-tensor-parallel-size', type=int, default=None, + help='Degree of expert model parallelism. Default is None, which will be set to the value of --tensor-model-paralle-size.') parser.add_argument('--num_experts', type=int, default=1, help='Number of Experts in MoE (None means no MoE)') parser.add_argument('--moe_router_topk', type=int, default=1, help='Number of experts to route to for each token. The default is 2.') parser.add_argument('--moe_grouped_gemm', action='store_true', diff --git a/workload_applyer.py b/workload_applyer.py index dff4dbd..d9f5c46 100755 --- a/workload_applyer.py +++ b/workload_applyer.py @@ -17,6 +17,7 @@ from utils.utils import WorkloadWriter, CommGroup, CommType, ReduceOp from utils.benchmark_logger import bench_logger import utils.utils as utils +from itertools import cycle class WorkloadApplyer: @@ -42,7 +43,7 @@ def __init__(self, workload=None, args=None, filename=None) -> None: torch.cuda.set_device(self.device) self.device = torch.cuda.current_device() self.comm_group_info, self.pp_global_rank_info = ( - self._generate_dp_tp_pp_ep_groups() + self._generate_tp_cp_ep_dp_pp_groups() ) self.workload = workload self.comm_type_function = { @@ -59,7 +60,6 @@ def __init__(self, workload=None, args=None, filename=None) -> None: CommType.computation: self._apply_computation, CommType.all_to_all: self._apply_all_to_all, CommType.epoch_end: bench_logger.end_epoch, - } cal_tuple_num = lambda t: math.prod(t[0]) + math.prod(t[1]) @@ -84,104 +84,232 @@ def __init__(self, workload=None, args=None, filename=None) -> None: self.buffer = torch.empty( (max_msg_size,), dtype=torch.bfloat16, device=self.device ) - def _generate_dp_tp_pp_ep_groups(self): + + + + def _generate_tp_cp_ep_dp_pp_groups(self): """Borrow from Megatron-LM""" - all_data_parallel_group_ranks = [] world_size = self.args.world_size rank = torch.distributed.get_rank() self.rank = rank - tensor_model_parallel_size, pipeline_model_parallel_size, data_parallel_size,expert_model_parallel_size = ( + tensor_model_parallel_size, pipeline_model_parallel_size, data_parallel_size, expert_model_parallel_size, context_parallel_size = ( self.args.tensor_model_parallel_size, - self.args.pipeline_model_parallel, - self.args.dp_num, + self.args.pipeline_model_parallel_size, + self.args.data_parallel_size, self.args.expert_model_parallel_size, + self.args.context_parallel_size, + ) + order = 'tp-cp-ep-dp-pp' + expert_tensor_parallel_size = self.args.expert_tensor_parallel_size + encoder_tensor_model_parallel_size = self.args.encoder_tensor_model_parallel_size + encoder_pipeline_model_parallel_size = self.args.encoder_pipeline_model_parallel_size + world_size: int = torch.distributed.get_world_size() + + if encoder_tensor_model_parallel_size > 0: + assert ( + encoder_tensor_model_parallel_size <= tensor_model_parallel_size + ), "We do not support encoders with more TP than the decoder." + + encoder_model_size = ( + encoder_tensor_model_parallel_size + * encoder_pipeline_model_parallel_size + * context_parallel_size + ) + decoder_model_size = ( + tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size + ) + total_model_size = encoder_model_size + decoder_model_size + + if world_size % total_model_size != 0: + raise RuntimeError(f"world_size ({world_size}) is not divisible by {total_model_size}") + + data_parallel_size: int = world_size // total_model_size + + encoder_world_size = encoder_model_size * data_parallel_size + decoder_world_size = decoder_model_size * data_parallel_size + if encoder_world_size > 0: + encoder_rank_generator = utils.RankGenerator( + tp=encoder_tensor_model_parallel_size, + ep=1, + dp=data_parallel_size, + pp=encoder_pipeline_model_parallel_size, + cp=context_parallel_size, + order=order, + rank_offset=0, + ) + else: + encoder_rank_generator = None + + decoder_rank_generator = utils.RankGenerator( + tp=tensor_model_parallel_size, + ep=1, + dp=data_parallel_size, + pp=pipeline_model_parallel_size, + cp=context_parallel_size, + order=order, + rank_offset=encoder_world_size, ) - rank_generator = utils.RankGenerator( - tp=tensor_model_parallel_size, - ep=expert_model_parallel_size, - dp=data_parallel_size, - pp=pipeline_model_parallel_size, - cp=self.args.context_parallel_size, - order='tp-cp-ep-dp-pp', - ) - for ranks in rank_generator.get_ranks('ep', independent_ep=True): + + # Build expert rank generator + if expert_tensor_parallel_size is None: + expert_tensor_parallel_size = tensor_model_parallel_size + expert_tensor_model_pipeline_parallel_size = ( + expert_tensor_parallel_size * expert_model_parallel_size * pipeline_model_parallel_size + ) + expert_data_parallel_size = decoder_world_size // expert_tensor_model_pipeline_parallel_size + if decoder_world_size % expert_tensor_model_pipeline_parallel_size != 0: + raise RuntimeError( + f"decoder world_size ({decoder_world_size}) is not divisible by expert_tensor_model_pipeline_parallel size ({expert_tensor_model_pipeline_parallel_size})" + ) + + # TODO: support expert specific ordering + expert_decoder_rank_generator = utils.RankGenerator( + tp=expert_tensor_parallel_size, + ep=expert_model_parallel_size, + dp=expert_data_parallel_size, + pp=pipeline_model_parallel_size, + cp=1, + order=order, + rank_offset=encoder_world_size, + ) + + assert decoder_rank_generator.get_ranks("pp") == expert_decoder_rank_generator.get_ranks( + "pp" + ), f"Pipeline parallel groups are expected to be the same for Non-Expert and Expert part, \ + but got {decoder_rank_generator.get_ranks('pp')} and {expert_decoder_rank_generator.get_ranks('pp')}" + def generator_wrapper(group_type, is_expert=False, **kwargs): + """The `RankGenerator` class produces a hyper-rectangle for a given set of + tensor, pipeline, data, expert, and context parallelism. If we have an encoder, + in addition to the default decoder, we essentially instantiate two `RankGenerator` + classes to construct the parallelism for each module separately, and we then have + to stitch them together for the right groups. For now, this means pp and tp-pp.""" + if is_expert: + d_ranks = expert_decoder_rank_generator.get_ranks(group_type, **kwargs) + else: + d_ranks = decoder_rank_generator.get_ranks(group_type, **kwargs) + + if encoder_rank_generator is None: + for x in d_ranks: + yield x + return + e_ranks = encoder_rank_generator.get_ranks(group_type, **kwargs) + if group_type == 'pp': + # Map 1 encoder tp rank to several decoder tp ranks, because + # these won't be the same size. + for x, y in zip(cycle(e_ranks), d_ranks): + yield x + y + elif group_type == 'tp-pp': + # For this group, we can just return the concatenated + # groups together, because their sizes are the same. + assert len(e_ranks) == len(d_ranks) + for x, y in zip(e_ranks, d_ranks): + yield x + y + else: + for x in e_ranks: + yield x + for x in d_ranks: + yield x + + for ranks in generator_wrapper('dp'): group = torch.distributed.new_group( ranks ) if rank in ranks: - ep_group = group - for ranks in rank_generator.get_ranks('tp'): + dp_group = group + for ranks in generator_wrapper('dp-cp'): + group = torch.distributed.new_group( + ranks + ) + if rank in ranks: + dp_cp_group = group + for ranks in generator_wrapper('cp'): + group = torch.distributed.new_group( + ranks + ) + if rank in ranks: + cp_group = group + for ranks in generator_wrapper('tp-pp'): + group = torch.distributed.new_group( + ranks + ) + if rank in ranks: + tp_pp_group = group + for ranks in generator_wrapper('tp'): group = torch.distributed.new_group( ranks ) if rank in ranks: tp_group = group - for ranks in rank_generator.get_ranks('pp'): + for ranks in generator_wrapper('pp'): group = torch.distributed.new_group( ranks ) if rank in ranks: pp_group = group pp_global_rank = ranks - # Setup embedding group (to exchange gradients between - # first and last stages). - # if len(ranks) > 1: - # embedding_ranks = [ranks[0], ranks[-1]] - # position_embedding_ranks = [ranks[0]] - # if self.args.pipeline_model_parallel_split_rank is not None: - # if ranks[self.args.pipeline_model_parallel_split_rank] not in embedding_ranks: - # embedding_ranks = [ - # ranks[0], - # ranks[self.args.pipeline_model_parallel_split_rank], - # ranks[-1], - # ] - # if ranks[self.args.pipeline_model_parallel_split_rank] not in position_embedding_ranks: - # position_embedding_ranks = [ranks[0], ranks[self.args.pipeline_model_parallel_split_rank]] - # else: - # embedding_ranks = ranks - # position_embedding_ranks = ranks - - # group = torch.distributed.new_group( - # embedding_ranks - # ) - # if rank in embedding_ranks: - # _EMBEDDING_GROUP = group - # if rank in ranks: - # _EMBEDDING_GLOBAL_RANKS = embedding_ranks - - # group = torch.distributed.new_group( - # position_embedding_ranks, - - # ) - # if rank in position_embedding_ranks: - # _POSITION_EMBEDDING_GROUP = group - # if rank in ranks: - # _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks - for ranks in rank_generator.get_ranks('dp'): + for ranks in generator_wrapper('tp-dp-cp'): group = torch.distributed.new_group( ranks ) if rank in ranks: - dp_group = group - for ranks in rank_generator.get_ranks('tp-ep', independent_ep=True): + tp_dp_cp_group = group + for ranks in generator_wrapper('tp-dp'): + group = torch.distributed.new_group( + ranks + ) + if rank in ranks: + tp_dp_group = group + for ranks in generator_wrapper('tp-cp'): + group = torch.distributed.new_group( + ranks + ) + if rank in ranks: + tp_cp_group = group + for ranks in generator_wrapper('ep', is_expert=True): + group = torch.distributed.new_group( + ranks + ) + if rank in ranks: + ep_group = group + for ranks in generator_wrapper('tp', is_expert=True): + group = torch.distributed.new_group( + ranks + ) + if rank in ranks: + exp_tp_group = group + for ranks in generator_wrapper('tp-ep', is_expert=True): + group = torch.distributed.new_group( + ranks + ) + if rank in ranks: + tp_ep_group = group + for ranks in generator_wrapper('tp-ep-pp', is_expert=True): group = torch.distributed.new_group( ranks ) if rank in ranks: - ep_tp_group = group - for ranks in rank_generator.get_ranks('dp', independent_ep=True): + tp_ep_pp_group = group + for ranks in generator_wrapper('dp', is_expert=True): group = torch.distributed.new_group( ranks ) if rank in ranks: - ep_dp_group = group + exp_dp_group = group return { - CommGroup.tp_group: tp_group, CommGroup.dp_group: dp_group, + CommGroup.dp_cp_group: dp_cp_group, + CommGroup.cp_group: cp_group, + CommGroup.tp_pp_group: tp_pp_group, + CommGroup.tp_group: tp_group, CommGroup.pp_group: pp_group, + CommGroup.tp_dp_cp_group: tp_dp_cp_group, + CommGroup.tp_dp_group: tp_dp_group, + CommGroup.tp_cp_group: tp_cp_group, CommGroup.ep_group: ep_group, - CommGroup.ep_tp_group: ep_tp_group, - CommGroup.ep_dp_group: ep_dp_group, + CommGroup.exp_tp_group: exp_tp_group, + CommGroup.tp_ep_group: tp_ep_group, + CommGroup.tp_ep_pp_group: tp_ep_pp_group, + CommGroup.exp_dp_group: exp_dp_group, }, pp_global_rank def _get_pipeline_parallel_size(self): @@ -217,7 +345,7 @@ def _apply_p2pcommunication(self, item): else: pass if item.additional == "send_next": - if self._get_pipeline_parallel_rank() != self.args.pipeline_model_parallel - 1: + if self._get_pipeline_parallel_rank() != self.args.pipeline_model_parallel_size - 1: send_next_op = torch.distributed.P2POp( torch.distributed.isend, tensor, self._get_pipeline_next_rank() ) @@ -238,7 +366,7 @@ def _apply_p2pcommunication(self, item): else: pass if item.additional == "recv_next": - if self._get_pipeline_parallel_rank() != self.args.pipeline_model_parallel - 1: + if self._get_pipeline_parallel_rank() != self.args.pipeline_model_parallel_size - 1: tensor_recv_next = torch.empty( item.msg_size // 2, dtype=torch.bfloat16, device=self.device ) diff --git a/workload_generator/AIOB_simAI_workload_generator.py b/workload_generator/AIOB_simAI_workload_generator.py index 73528c2..da207c9 100755 --- a/workload_generator/AIOB_simAI_workload_generator.py +++ b/workload_generator/AIOB_simAI_workload_generator.py @@ -146,7 +146,7 @@ def _get_total_params(self): def workload_generate_aiob(self): # args.world_size --> total gpus number - self.ga_num = self.args.global_batch // (self.args.micro_batch * self.args.dp_num) + self.ga_num = self.args.global_batch // (self.args.micro_batch * self.args.data_parallel_size) if self.ga_num < 1: print( "[WARN]: ga num < 1, please confirm global_batch num and micro_batch num" @@ -245,7 +245,7 @@ def workload_generate_aiob(self): dp_comm_size=0, ) ) - if self.args.expert_model_parallel_size != self.args.dp_num: + if self.args.expert_model_parallel_size != self.args.data_parallel_size: self.workload.append(Work_Item(name="moe_grad_norm1", forward_compute_time=default_compute_time, forward_comm = "NONE", forward_comm_size= 0, backward_compute_time=default_compute_time, backward_comm="NONE", backward_comm_size=0, @@ -543,7 +543,7 @@ def workload_generate_aiob(self): def workload_generate(self): # args.world_size --> total gpus number - self.ga_num = self.args.global_batch // (self.args.micro_batch * self.args.dp_num) + self.ga_num = self.args.global_batch // (self.args.micro_batch * self.args.data_parallel_size) if self.ga_num < 1: print( "[WARN]: ga num < 1, please confirm global_batch num and micro_batch num" @@ -592,7 +592,7 @@ def workload_generate(self): dp_comm_size=0, ) ) - if args.expert_model_parallel_size != args.dp_num: + if args.expert_model_parallel_size != args.data_parallel_size: self.workload.append(Work_Item(name="moe_grad_norm1", forward_compute_time=default_compute_time, forward_comm = "NONE", forward_comm_size= 0, backward_compute_time=default_compute_time, backward_comm="NONE", backward_comm_size=0, @@ -784,14 +784,14 @@ def dump_file(self, filename): pp_comm = ( f"pp_comm: {pp_comm_value}" - if self.args.pipeline_model_parallel != 1 + if self.args.pipeline_model_parallel_size != 1 else "pp_comm: 0" ) with open(filename, "w") as f: f.write(( f"HYBRID_TRANSFORMER_FWD_IN_BCKWD model_parallel_NPU_group: {self.args.tensor_model_parallel_size} " f"ep: {self.args.expert_model_parallel_size} " - f"pp: {self.args.pipeline_model_parallel} " + f"pp: {self.args.pipeline_model_parallel_size} " f"vpp: {self.args.num_layers} " f"ga: {self.ga_num} all_gpus: {self.args.world_size} " f"checkpoints: 0 checkpoint_initiates: 0 " @@ -857,7 +857,7 @@ def dump_file(self, filename): else: f.write( f"HYBRID_TRANSFORMER_FWD_IN_BCKWD model_parallel_NPU_group: {self.args.tensor_model_parallel_size} \ - expert_parallel_npu_group: {self.args.expert_model_parallel_size} pp: {self.args.pipeline_model_parallel} \ + expert_parallel_npu_group: {self.args.expert_model_parallel_size} pp: {self.args.pipeline_model_parallel_size} \ ga: {self.ga_num} all_gpus: {self.args.world_size} checkpoints: 0 checkpoint_initiates: 0" + "\n" ) @@ -876,7 +876,7 @@ def dump_file(self, filename): result_dir = "results/workload/" if not os.path.isdir(result_dir): os.makedirs(result_dir) - filename = f"{args.gpu_type}-{args.model_name}-world_size{args.world_size}-tp{args.tensor_model_parallel_size}-pp{args.pipeline_model_parallel}-ep{args.expert_model_parallel_size}-gbs{args.global_batch}-mbs{args.micro_batch}-seq{args.seq_length}-MOE-{args.moe_enable}-GEMM-{args.moe_grouped_gemm}-flash_attn-{args.use_flash_attn}" + filename = f"{args.gpu_type}-{args.model_name}-world_size{args.world_size}-tp{args.tensor_model_parallel_size}-pp{args.pipeline_model_parallel_size}-ep{args.expert_model_parallel_size}-gbs{args.global_batch}-mbs{args.micro_batch}-seq{args.seq_length}-MOE-{args.moe_enable}-GEMM-{args.moe_grouped_gemm}-flash_attn-{args.use_flash_attn}" filepath = os.path.join(result_dir, filename) params = model.parameters() # work = SIMAI_workload(model, args, GPU_Tensor_core.A100, "gpt13B") diff --git a/workload_generator/generate_collective_test.py b/workload_generator/generate_collective_test.py index f91de3c..9b62c23 100644 --- a/workload_generator/generate_collective_test.py +++ b/workload_generator/generate_collective_test.py @@ -30,7 +30,7 @@ def init(self): LogItem( comm_type=CommType.get_comm_type(self.args.test_comm), comm_group=CommGroup.dp_group, - comm_group_size=self.args.dp_num, + comm_group_size=self.args.data_parallel_size, msg_size=self.args.begin_size, stage="warmup", ) @@ -52,7 +52,7 @@ def step(self): LogItem( comm_type=test_comm, comm_group=CommGroup.dp_group, - comm_group_size=self.args.dp_num, + comm_group_size=self.args.data_parallel_size, msg_size=curr_size, stage="test_step", ) diff --git a/workload_generator/generate_deepspeed_stage1_2_workload.py b/workload_generator/generate_deepspeed_stage1_2_workload.py index eca30d9..8a12530 100644 --- a/workload_generator/generate_deepspeed_stage1_2_workload.py +++ b/workload_generator/generate_deepspeed_stage1_2_workload.py @@ -49,7 +49,7 @@ def __init__(self, args, model) -> None: ) self.allgather_bucket_size = args.allgather_bucket_size self.amp_enabled = args.amp_enabled - self.dp_world_size = args.dp_num + self.dp_world_size = args.data_parallel_size self.elem_size = 2 self.all_params = list(self.model.parameters()) diff --git a/workload_generator/generate_deepspeed_stage3_workload.py b/workload_generator/generate_deepspeed_stage3_workload.py index 2d114ff..4d8e7d3 100644 --- a/workload_generator/generate_deepspeed_stage3_workload.py +++ b/workload_generator/generate_deepspeed_stage3_workload.py @@ -35,7 +35,7 @@ def __init__(self, args, model) -> None: super().__init__(args, model) self.name = "deepspeed_stage3" self.amp_enabled = args.amp_enabled - self.dp_world_size = args.dp_num + self.dp_world_size = args.data_parallel_size self.batch_size = args.micro_batch self.seq_len = args.seq_length self.compute_enable = args.computation_enable diff --git a/workload_generator/generate_megatron_workload.py b/workload_generator/generate_megatron_workload.py index 13daab4..b135ee0 100644 --- a/workload_generator/generate_megatron_workload.py +++ b/workload_generator/generate_megatron_workload.py @@ -14,7 +14,7 @@ #!/bin/python """example of running megatron on gpt-7B python -m workload_generator.megatron_workload \ - --frame=Megatron --world_size=16 --tensor_model_parallel_size=8 --pipeline_model_parallel=1 --global_batch=64 --micro_batch=2 \ + --frame=Megatron --world_size=16 --tensor_model_parallel_size=8 --pipeline_model_parallel_size=1 --global_batch=64 --micro_batch=2 \ --num_layers=32 --seq_length=2048 --hidden_size=4096 --epoch_num=2 --use-distributed-optimizer --enable_sequence_parallel """ from utils.utils import CommGroup, CommType, get_params, WorkloadWriter @@ -50,7 +50,7 @@ def init(self): LogItem( comm_type=CommType.all_reduce, comm_group=CommGroup.dp_group, - comm_group_size=self.args.dp_num, + comm_group_size=self.args.data_parallel_size, msg_size=1 * 8, stage="init.model_setup", ) @@ -60,17 +60,17 @@ def init(self): LogItem( comm_type=CommType.all_reduce, comm_group=CommGroup.dp_group, - comm_group_size=self.args.dp_num, + comm_group_size=self.args.data_parallel_size, msg_size=1 * 8, stage="init.model_setup", ) ) - if args.pipeline_model_parallel > 1: + if args.pipeline_model_parallel_size > 1: self.workload.append( LogItem( comm_type=CommType.all_reduce, comm_group=CommGroup.pp_group, - comm_group_size=self.args.pipeline_model_parallel, + comm_group_size=self.args.pipeline_model_parallel_size, msg_size=1 * 8, stage="init.model_setup", ) @@ -80,7 +80,7 @@ def init(self): LogItem( comm_type=CommType.all_gather, comm_group=CommGroup.dp_group, - comm_group_size=self.args.dp_num, + comm_group_size=self.args.data_parallel_size, msg_size=4 * 8, stage="init.model_setup", ) @@ -97,7 +97,7 @@ def init(self): ) ) - if args.pp_rank == args.pipeline_model_parallel - 1 and args.pipeline_model_parallel > 1: + if args.pp_rank == args.pipeline_model_parallel_size - 1 and args.pipeline_model_parallel_size > 1: for p in self.model.embedding.parameters(): self.workload.append( LogItem( @@ -113,7 +113,7 @@ def init(self): LogItem( comm_type=CommType.all_gather, comm_group=CommGroup.dp_group, - comm_group_size=self.args.dp_num, + comm_group_size=self.args.data_parallel_size, msg_size=8 * 8, stage="init.model_setup", ) @@ -132,9 +132,9 @@ def with_pipeline_forward_backward(self): import torch rank = torch.distributed.get_rank() world_size = args.world_size - pp_rank = self.get_pp_rank(rank, world_size, args.pipeline_model_parallel) + pp_rank = self.get_pp_rank(rank, world_size, args.pipeline_model_parallel_size) pp_num_warmup_microbatches = min( - args.pipeline_model_parallel - pp_rank - 1, args.num_microbatches + args.pipeline_model_parallel_size - pp_rank - 1, args.num_microbatches ) num_microbatches_remaining = args.num_microbatches - pp_num_warmup_microbatches temp = self.model.forward() @@ -178,7 +178,7 @@ def with_pipeline_forward_backward(self): # for item in forward_comm: self.workload.extend(self.model.forward()) - if pp_rank != args.pipeline_model_parallel - 1: + if pp_rank != args.pipeline_model_parallel_size - 1: # send_next self.workload.append( LogItem( @@ -229,7 +229,7 @@ def with_pipeline_forward_backward(self): ) self.workload.extend(self.model.forward()) - if pp_rank != args.pipeline_model_parallel - 1: + if pp_rank != args.pipeline_model_parallel_size - 1: # recv next self.workload.append( LogItem( @@ -298,7 +298,7 @@ def with_pipeline_forward_backward(self): for _ in range(pp_num_warmup_microbatches): # recv next - if pp_rank != args.pipeline_model_parallel - 1: + if pp_rank != args.pipeline_model_parallel_size - 1: self.workload.append( LogItem( comm_type=CommType.irecv, @@ -367,7 +367,7 @@ def forward(self): LogItem( comm_type=CommType.all_reduce, comm_group=CommGroup.dp_group, - comm_group_size=self.args.dp_num, + comm_group_size=self.args.data_parallel_size, msg_size=1 * 4, stage="forward_step.average_losses_across_data_parallel_group", ) @@ -384,8 +384,8 @@ def step(self): LogItem( comm_type=CommType.reduce_scatter, comm_group=CommGroup.dp_group, - comm_group_size=self.args.dp_num, - msg_size=4 * self._get_total_params() // (args.pipeline_model_parallel), + comm_group_size=self.args.data_parallel_size, + msg_size=4 * self._get_total_params() // (args.pipeline_model_parallel_size), stage="step", ) ) @@ -393,8 +393,8 @@ def step(self): LogItem( comm_type=CommType.all_gather, comm_group=CommGroup.dp_group, - comm_group_size=self.args.dp_num, - msg_size=2 * self._get_total_params() // (args.pipeline_model_parallel), + comm_group_size=self.args.data_parallel_size, + msg_size=2 * self._get_total_params() // (args.pipeline_model_parallel_size), stage="step", ) ) @@ -404,8 +404,8 @@ def step(self): LogItem( comm_type=CommType.all_reduce, comm_group=CommGroup.dp_group, - comm_group_size=self.args.dp_num, - msg_size=4 * self._get_total_params() // (args.pipeline_model_parallel), + comm_group_size=self.args.data_parallel_size, + msg_size=4 * self._get_total_params() // (args.pipeline_model_parallel_size), stage="step.finish_grad_sync", ) ) @@ -415,7 +415,7 @@ def step(self): comm_type=CommType.all_reduce, comm_group=CommGroup.tp_group, comm_group_size=self.args.tensor_model_parallel_size, - msg_size=2 * self._get_layernorm_params() // (args.pipeline_model_parallel), + msg_size=2 * self._get_layernorm_params() // (args.pipeline_model_parallel_size), stage="step._allreduce_layernorm_grads", ) ) diff --git a/workload_generator/mocked_model/AiobMegatron.py b/workload_generator/mocked_model/AiobMegatron.py index 60e10d3..063c070 100755 --- a/workload_generator/mocked_model/AiobMegatron.py +++ b/workload_generator/mocked_model/AiobMegatron.py @@ -881,7 +881,7 @@ class Grad_param: def __init__(self, args=None): tp = args.tensor_model_parallel_size param = args.model_param - self.dp = args.dp_num + self.dp = args.data_parallel_size device = torch.cuda.current_device() dtype = torch.float32 diff --git a/workload_generator/mocked_model/MockedMegatron.py b/workload_generator/mocked_model/MockedMegatron.py index 38e0cc1..2d49376 100755 --- a/workload_generator/mocked_model/MockedMegatron.py +++ b/workload_generator/mocked_model/MockedMegatron.py @@ -264,6 +264,7 @@ def __init__( num_attention_heads, hidden_size, tp, + cp, seq_len, batch_size, layer_id, @@ -523,6 +524,7 @@ def __init__( hidden_size, ffn_hidden_size, tp, + cp, seq_len, batch_size, num_attention_heads, @@ -540,6 +542,7 @@ def __init__( num_attention_heads, hidden_size, tp, + cp, seq_len, batch_size, layer_id, @@ -645,6 +648,7 @@ def __init__(self, config): config.hidden_size, config.ffn_hidden_size, config.tensor_model_parallel_size, + config.context_parallel_size, config.seq_length, config.micro_batch, config.num_attention_heads, diff --git a/workload_generator/workload_generator.py b/workload_generator/workload_generator.py index 35b6d5b..a9e22ae 100644 --- a/workload_generator/workload_generator.py +++ b/workload_generator/workload_generator.py @@ -31,7 +31,7 @@ def __call__(self): self.init() self.workload.append(LogItem(comm_type=CommType.epoch_end)) for i in range(args.epoch_num): - if args.pipeline_model_parallel > 1 and args.frame != "collective_test": + if args.pipeline_model_parallel_size > 1 and args.frame != "collective_test": self.with_pipeline_forward_backward() self.step() else: From eac4527ed2ff46d45da67e185cefd0bb5046d5a1 Mon Sep 17 00:00:00 2001 From: Jiayi Yan <66017932+1195343015@users.noreply.github.com> Date: Fri, 14 Feb 2025 14:29:32 +0800 Subject: [PATCH 2/3] fix cp_num --- scripts/megatron_gpt.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/megatron_gpt.sh b/scripts/megatron_gpt.sh index 283f199..4351219 100755 --- a/scripts/megatron_gpt.sh +++ b/scripts/megatron_gpt.sh @@ -76,7 +76,7 @@ echo "Processing argument: $1" tensor_model_parallel_size=$2; shift;; --pipeline_model_parallel_size|pp_num) pipeline_model_parallel_size=$2; shift;; - --context_parallel_size|pp_num) + --context_parallel_size|cp_num) context_parallel_size=$2; shift;; --global_batch) global_batch=$2; shift;; From c450682ed828045eea59c3fa0412107b868b996a Mon Sep 17 00:00:00 2001 From: 1195343015 <1195343015@qq.com> Date: Fri, 30 May 2025 21:00:15 +0800 Subject: [PATCH 3/3] fix pp_comm log --- scripts/megatron_gpt.sh | 2 +- .../generate_megatron_workload.py | 20 +++++++++---------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/scripts/megatron_gpt.sh b/scripts/megatron_gpt.sh index 4351219..35cf4f7 100755 --- a/scripts/megatron_gpt.sh +++ b/scripts/megatron_gpt.sh @@ -214,7 +214,7 @@ fi cmd="$script \ --frame=$frame \ --model_name=$model_name \ - --world_size=$(($WORLD_SIZE * $NUM_GPUS)) \ + --world_size=$world_size \ --tensor_model_parallel_size=$tensor_model_parallel_size \ --micro_batch=$micro_batch \ --global_batch=$global_batch \ diff --git a/workload_generator/generate_megatron_workload.py b/workload_generator/generate_megatron_workload.py index b135ee0..9890649 100644 --- a/workload_generator/generate_megatron_workload.py +++ b/workload_generator/generate_megatron_workload.py @@ -147,7 +147,7 @@ def with_pipeline_forward_backward(self): LogItem( comm_type=CommType.irecv, comm_group=CommGroup.pp_group, - comm_group_size=1, + comm_group_size=self.args.pipeline_model_parallel_size, msg_size=2 * (args.hidden_size * args.seq_length * args.micro_batch), stage="forward_step", @@ -184,7 +184,7 @@ def with_pipeline_forward_backward(self): LogItem( comm_type=CommType.isend, comm_group=CommGroup.pp_group, - comm_group_size=1, + comm_group_size=self.args.pipeline_model_parallel_size, msg_size=2 * (args.hidden_size * args.seq_length * args.micro_batch), stage="forward_step", @@ -197,7 +197,7 @@ def with_pipeline_forward_backward(self): LogItem( comm_type=CommType.irecv, comm_group=CommGroup.pp_group, - comm_group_size=1, + comm_group_size=self.args.pipeline_model_parallel_size, msg_size=2 * (args.hidden_size * args.seq_length * args.micro_batch), stage="forward_step", @@ -235,7 +235,7 @@ def with_pipeline_forward_backward(self): LogItem( comm_type=CommType.irecv, comm_group=CommGroup.pp_group, - comm_group_size=1, + comm_group_size=self.args.pipeline_model_parallel_size, msg_size=2 * (args.hidden_size * args.seq_length * args.micro_batch), stage="forward_step", @@ -247,7 +247,7 @@ def with_pipeline_forward_backward(self): LogItem( comm_type=CommType.isend, comm_group=CommGroup.pp_group, - comm_group_size=1, + comm_group_size=self.args.pipeline_model_parallel_size, msg_size=2 * (args.hidden_size * args.seq_length * args.micro_batch), stage="forward_step", @@ -264,7 +264,7 @@ def with_pipeline_forward_backward(self): LogItem( comm_type=CommType.isend, comm_group=CommGroup.pp_group, - comm_group_size=1, + comm_group_size=self.args.pipeline_model_parallel_size, msg_size=2 * (args.hidden_size * args.seq_length * args.micro_batch), stage="backward_step", @@ -277,7 +277,7 @@ def with_pipeline_forward_backward(self): LogItem( comm_type=CommType.isend, comm_group=CommGroup.pp_group, - comm_group_size=1, + comm_group_size=self.args.pipeline_model_parallel_size, msg_size=2 * (args.hidden_size * args.seq_length * args.micro_batch), stage="backward_step", @@ -288,7 +288,7 @@ def with_pipeline_forward_backward(self): LogItem( comm_type=CommType.irecv, comm_group=CommGroup.pp_group, - comm_group_size=1, + comm_group_size=self.args.pipeline_model_parallel_size, msg_size=2 * (args.hidden_size * args.seq_length * args.micro_batch), stage="backward_step", @@ -303,7 +303,7 @@ def with_pipeline_forward_backward(self): LogItem( comm_type=CommType.irecv, comm_group=CommGroup.pp_group, - comm_group_size=1, + comm_group_size=self.args.pipeline_model_parallel_size, msg_size=2 * (args.hidden_size * args.seq_length * args.micro_batch), stage="backward_step", @@ -319,7 +319,7 @@ def with_pipeline_forward_backward(self): LogItem( comm_type=CommType.isend, comm_group=CommGroup.pp_group, - comm_group_size=1, + comm_group_size=self.args.pipeline_model_parallel_size, msg_size=2 * (args.hidden_size * args.seq_length * args.micro_batch), stage="backward_step",