Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions llm_analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,9 +475,9 @@ def get_memory_optimizer_state_and_gradient_per_layer(

memory_optimizer_state_others_per_layer = op_bytes_per_params * (
(self.get_num_params_per_layer_attn() +
+self.get_num_params_per_layer_router() +
self.get_num_params_per_layer_layernorm())
) / self.parallelism_config.tp_size / sharded_dp_size
+self.get_num_params_per_layer_router()) /
self.parallelism_config.tp_size +
self.get_num_params_per_layer_layernorm()) / sharded_dp_size

memory_optimizer_state_per_layer = memory_optimizer_state_mlp_per_layer + memory_optimizer_state_others_per_layer

Expand Down Expand Up @@ -1218,9 +1218,9 @@ def get_latency_fwd_per_tp_comm(self, batch_size: int, seq_len: int,
elems_per_all_reduce = (2 * batch_size * seq_len *
self.model_config.hidden_dim * (tp_size - 1) /
tp_size)
latency_per_all_reduce = (
elems_per_all_reduce * dtype_bytes /
(self.gpu_config.intra_node_bandwidth_in_GB_per_sec * 10**9))
# assuming tp_size <= number of GPUs per node, thus using intra-node bandwidth
latency_per_all_reduce = (elems_per_all_reduce * dtype_bytes /
(self.get_intra_node_bandwidth() * 10**9))

return max(
latency_per_all_reduce,
Expand All @@ -1230,6 +1230,7 @@ def get_latency_fwd_per_tp_comm(self, batch_size: int, seq_len: int,
def get_latency_fwd_per_layer_shared_dp_comm(self) -> float:
dp_size = self.parallelism_config.dp_size
ep_size = self.parallelism_config.ep_size
tp_size = self.parallelism_config.tp_size

def time_allgather(S, n, B):
# https://github.com/NVIDIA/nccl-tests/blob/master/doc/PERFORMANCE.md#allgather
Expand All @@ -1243,15 +1244,17 @@ def time_allgather(S, n, B):
self.get_num_params_per_layer_layernorm()
) * self.dtype_config.weight_bits / BITS_PER_BYTE

latency_allgather_params_mlp = time_allgather(
params_bytes_mlp, dp_size / ep_size,
(self.get_intra_node_bandwidth()
if dp_size <= 8 else self.get_inter_node_bandwidth()) * 10**9)
# assuming tp and dp are preferred when sharding intra node, pp is only applied across nodes
# when (dp_size * tp_size) <= 8, the data parallel processes are within a node
bandwidth = self.get_intra_node_bandwidth() if (
dp_size * tp_size) <= 8 else self.get_inter_node_bandwidth()

latency_allgather_params_mlp = time_allgather(params_bytes_mlp,
dp_size / ep_size,
bandwidth * 10**9)

latency_allgather_params_non_mlp = time_allgather(
params_bytes_non_mlp, dp_size,
(self.get_intra_node_bandwidth()
if dp_size <= 8 else self.get_inter_node_bandwidth()) * 10**9)
params_bytes_non_mlp, dp_size, bandwidth * 10**9)

latency_fwd_per_layer_shared_dp_comm = latency_allgather_params_mlp + latency_allgather_params_non_mlp

Expand Down
Loading