From 168d9a3718076a7d1598988c013c4a435228aab5 Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Tue, 12 Nov 2024 23:38:59 -0800 Subject: [PATCH 1/2] fix allreduce latency and mem usage when tp is in use --- llm_analysis/analysis.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/llm_analysis/analysis.py b/llm_analysis/analysis.py index 20100c5..32b1cb6 100644 --- a/llm_analysis/analysis.py +++ b/llm_analysis/analysis.py @@ -475,9 +475,9 @@ def get_memory_optimizer_state_and_gradient_per_layer( memory_optimizer_state_others_per_layer = op_bytes_per_params * ( (self.get_num_params_per_layer_attn() + - +self.get_num_params_per_layer_router() + - self.get_num_params_per_layer_layernorm()) - ) / self.parallelism_config.tp_size / sharded_dp_size + +self.get_num_params_per_layer_router()) / + self.parallelism_config.tp_size + + self.get_num_params_per_layer_layernorm()) / sharded_dp_size memory_optimizer_state_per_layer = memory_optimizer_state_mlp_per_layer + memory_optimizer_state_others_per_layer @@ -1218,9 +1218,9 @@ def get_latency_fwd_per_tp_comm(self, batch_size: int, seq_len: int, elems_per_all_reduce = (2 * batch_size * seq_len * self.model_config.hidden_dim * (tp_size - 1) / tp_size) - latency_per_all_reduce = ( - elems_per_all_reduce * dtype_bytes / - (self.gpu_config.intra_node_bandwidth_in_GB_per_sec * 10**9)) + # assuming tp_size <= number of GPUs per node, thus using intra-node bandwidth + latency_per_all_reduce = (elems_per_all_reduce * dtype_bytes / + (self.get_intra_node_bandwidth() * 10**9)) return max( latency_per_all_reduce, From c3bb7eb990f21df704a781da1a95c9180b6eac55 Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Wed, 13 Nov 2024 00:19:29 -0800 Subject: [PATCH 2/2] update latency calcuation in allgather --- llm_analysis/analysis.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/llm_analysis/analysis.py b/llm_analysis/analysis.py index 32b1cb6..4740aea 100644 --- a/llm_analysis/analysis.py +++ b/llm_analysis/analysis.py @@ -1230,6 +1230,7 @@ def get_latency_fwd_per_tp_comm(self, batch_size: int, seq_len: int, def get_latency_fwd_per_layer_shared_dp_comm(self) -> float: dp_size = self.parallelism_config.dp_size ep_size = self.parallelism_config.ep_size + tp_size = self.parallelism_config.tp_size def time_allgather(S, n, B): # https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md#allgather @@ -1243,15 +1244,17 @@ def time_allgather(S, n, B): self.get_num_params_per_layer_layernorm() ) * self.dtype_config.weight_bits / BITS_PER_BYTE - latency_allgather_params_mlp = time_allgather( - params_bytes_mlp, dp_size / ep_size, - (self.get_intra_node_bandwidth() - if dp_size <= 8 else self.get_inter_node_bandwidth()) * 10**9) + # assuming tp and dp are preferred when sharding intra node, pp is only applied across nodes + # when (dp_size * tp_size) <= 8, the data parallel processes are within a node + bandwidth = self.get_intra_node_bandwidth() if ( + dp_size * tp_size) <= 8 else self.get_inter_node_bandwidth() + + latency_allgather_params_mlp = time_allgather(params_bytes_mlp, + dp_size / ep_size, + bandwidth * 10**9) latency_allgather_params_non_mlp = time_allgather( - params_bytes_non_mlp, dp_size, - (self.get_intra_node_bandwidth() - if dp_size <= 8 else self.get_inter_node_bandwidth()) * 10**9) + params_bytes_non_mlp, dp_size, bandwidth * 10**9) latency_fwd_per_layer_shared_dp_comm = latency_allgather_params_mlp + latency_allgather_params_non_mlp