From 0258c8a8bc9cec8c86f4ad52b39d6840988da2c8 Mon Sep 17 00:00:00 2001 From: Anshu Raina Date: Mon, 2 Mar 2026 13:39:47 -0800 Subject: [PATCH] docs: expand projection.md with memory projection and performance details --- docs/projection.md | 661 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 609 insertions(+), 52 deletions(-) diff --git a/docs/projection.md b/docs/projection.md index 5c5212026..9af860731 100644 --- a/docs/projection.md +++ b/docs/projection.md @@ -1,39 +1,64 @@ -# Performance Projection +# Projection -Primus includes a performance projection tool that benchmarks transformer layers on a single node and projects training iteration times to multi-node configurations. +Primus includes projection tools that estimate **memory requirements** and **training performance** for large-scale distributed training without requiring the full target cluster. Two projection modes are available: -- **User-facing entry**: `primus-cli … -- projection performance [options]` -- **Implementation entrypoint**: `primus/cli/subcommands/projection.py` -- **Core logic**: `primus/core/projection/performance_projection/projection.py` - -## Overview - -The performance projection tool: - -1. **Benchmarks** transformer layers on a single node to measure forward/backward pass times -2. **Simulates** pipeline parallelism scheduling (including zero-bubble optimization) -3. **Projects** performance to multi-node configurations by modeling: - - Data Parallelism (DP) scaling - - Gradient AllReduce communication overhead - - Expert Parallelism (EP) All-to-All communication overhead - - Inter-node vs intra-node communication differences +| Mode | Command | What it does | +|------|---------|--------------| +| **Memory** | `projection memory` | Estimates per-GPU memory (parameters, optimizer, activations) using analytical formulas | +| **Performance** | `projection performance` | Benchmarks layers on 1 node, then projects training time to multi-node clusters | -This allows you to estimate training performance on larger clusters without actually running on them. +- **User-facing entry**: `primus-cli … -- projection {memory,performance} [options]` +- **Implementation entrypoint**: `primus/cli/subcommands/projection.py` +- **Core logic**: + - Memory: `primus/core/projection/memory_projection/projection.py` + - Performance: `primus/core/projection/performance_projection/projection.py` + +--- + +## Table of Contents + +1. [Quick Start](#quick-start) +2. [Command Syntax](#command-syntax) +3. [Memory Projection](#memory-projection) + - [Overview](#memory-overview) + - [Architecture](#memory-architecture) + - [Parameter Estimation](#parameter-estimation) + - [Param + Optimizer Memory](#param--optimizer-memory) + - [Activation Memory](#activation-memory) + - [Pipeline Schedule Memory Scaling](#pipeline-schedule-memory-scaling) + - [Recomputation Support](#recomputation-support) + - [Memory Formulas Reference](#memory-formulas-reference) +4. [Performance Projection](#performance-projection) + - [Overview](#performance-overview) + - [How It Works](#how-it-works) + - [Scaling Mechanisms](#scaling-mechanisms) + - [Communication Modeling](#communication-modeling) + - [Pipeline Schedule Simulator](#pipeline-schedule-simulator) + - [Overall Performance Prediction Flow](#overall-performance-prediction-flow) +5. [Example Output](#example-output) +6. [Assumptions and Limitations](#assumptions-and-limitations) +7. [Tips](#tips) + +--- ## Quick Start -Run a basic performance projection for the minimum required nodes: +### Memory Projection + +Estimate per-GPU memory for a model configuration (no GPU needed for estimation, but the CLI currently requires torch distributed init): ```bash export NNODES=1 export HSA_NO_SCRATCH_RECLAIM=1 bash runner/primus-cli direct --script primus/cli/main.py -- \ - projection performance \ + projection memory \ --config examples/megatron/configs/MI300X/deepseek_v2_lite-BF16-pretrain.yaml ``` -Project performance to a specific number of nodes: +### Performance Projection + +Benchmark on 1 node and project to a target cluster: ```bash export NNODES=1 @@ -45,26 +70,32 @@ bash runner/primus-cli direct --script primus/cli/main.py -- \ --target-nodes 4 ``` +--- + ## Command Syntax ```bash -primus-cli [global-options] [mode-args] -- projection performance [options] +primus-cli [global-options] [mode-args] -- projection {memory,performance} [options] ``` -### Options +### Common Options | Option | Type | Description | |--------|------|-------------| | `--config` | string | Path to the Primus YAML configuration file (required) | + +### Performance-Only Options + +| Option | Type | Description | +|--------|------|-------------| | `--target-nodes` | int | Target number of nodes for projection. Defaults to minimum required by parallelism config | | `--hardware-config` | string | Path to YAML file with custom hardware parameters for communication modeling | ### Parallelism Overrides -You can override parallelism settings from the config file using environment variables or CLI overrides: +You can override parallelism settings from the config file: ```bash -# Using environment variables export PRIMUS_TP=1 export PRIMUS_PP=3 export PRIMUS_EP=8 @@ -75,30 +106,390 @@ bash runner/primus-cli direct --script primus/cli/main.py -- \ --target-nodes 6 ``` -## How It Works +--- + +## Memory Projection + + +### Overview + +The memory projection estimates **per-GPU memory** usage by analytically computing: + +1. **Parameter memory** — model weights stored on this GPU +2. **Optimizer state memory** — optimizer first/second moments, sharded across DP ranks +3. **Activation memory** — intermediate tensors stored for the backward pass + +It uses a hierarchical profiler system that mirrors the model's module structure, computing each component's contribution bottom-up. + + +### Architecture + +``` +LanguageModelProfiler +├── EmbeddingProfiler — vocab embeddings (stage 0 only) +├── DenseTransformerLayerProfiler — for non-MoE layers +│ ├── LayerNormProfiler (×3) — pre-attn, pre-MLP, post-MLP +│ ├── AttentionProfiler — QKV projections + attention +│ ├── ResidualAddProfiler (×2) — skip connections +│ └── DenseMLPProfiler — standard SwiGLU/FFN +├── MoETransformerLayerProfiler — for MoE layers +│ ├── LayerNormProfiler (×3) +│ ├── AttentionProfiler +│ ├── ResidualAddProfiler (×2) +│ ├── RouterProfiler — expert routing logits +│ └── MoEMLPProfiler — routed experts + shared expert +├── LayerNormProfiler — final layer norm (last stage only) +├── OutputLayerProfiler — language model head (last stage only) +└── LossProfiler — cross-entropy loss (last stage only) +``` + +Each profiler implements two key methods: +- `estimated_num_params(rank)` — parameter count (total if `rank=None`, per-GPU if rank given) +- `estimated_activation_memory(batch_size, seq_len)` — activation bytes for one microbatch + + +### Parameter Estimation + +Parameters are computed per component and summed across all layers assigned to this GPU's pipeline stage. + +#### Layer Assignment + +Layers are distributed across `PP × VPP` virtual stages. Each physical PP rank hosts `VPP` virtual stages in an interleaved pattern: + +``` +PP rank 0 → virtual stages 0, PP, 2×PP, ... +PP rank 1 → virtual stages 1, PP+1, 2×PP+1, ... +``` + +#### Per-Component Parameter Formulas + +| Component | Formula | Notes | +|-----------|---------|-------| +| **Embedding** | `V × H` | `V` = padded vocab size, `H` = hidden size | +| **LayerNorm** | `2 × H` | gamma + beta per LayerNorm | +| **Attention (standard)** | `2 × H² × (1 + G/A) × P` | `A` = num heads, `G` = num KV groups, `P` = proj ratio | +| **Attention (MLA)** | `q_term + kv_term + pos + out` | DeepSeek-style multi-latent attention | +| **Dense MLP (SwiGLU)** | `3 × H × FFN` | gate, up, down projections | +| **Dense MLP (standard)** | `2 × H × FFN` | up, down projections | +| **Router** | `H × NE` | `NE` = number of experts | +| **MoE MLP** | `NE/EP × n_proj × H × FFN_e + shared` | Expert params sharded by EP | +| **Output Layer** | `V × H` | May share weights with embedding | + +Where: +- `H` = `hidden_size` +- `V` = `padded_vocab_size` +- `FFN` = `ffn_hidden_size` (dense MLP intermediate dimension) +- `FFN_e` = `moe_ffn_hidden_size` (per-expert intermediate dimension) +- `NE` = `num_experts` +- `EP` = `expert_model_parallel_size` +- `n_proj` = 3 for SwiGLU, 2 for standard FFN + +### Param + Optimizer Memory + +The total bytes per parameter depends on the training precision and optimizer sharding: + +``` +bytes_per_param = weight_bytes + gradient_bytes + optimizer_bytes + +Where: + weight_bytes = 2 (BF16 weights) + gradient_bytes = 2 (BF16 gradients) + optimizer_bytes = 10/DP (FP32 master weights + Adam m + Adam v, sharded across DP) + = (2 + 4 + 4) / DP +``` + +**DP calculation:** + +``` +DP = world_size / (EP × PP) +``` + +Note: TP and CP are not divided out because all TP/CP ranks within a DP group share the same optimizer partition. + +**Total param + optimizer memory per GPU:** + +``` +param_optimizer_memory = params_on_this_gpu × bytes_per_param +``` + +### Activation Memory -### 1. Configuration Reduction +Activation memory is the memory needed to store intermediate tensors for the backward pass. Each component estimates what it stores for backward. + +#### Base Tensor (sbh) + +The fundamental building block is the hidden state tensor: + +``` +sbh = MBS × (S / TP / CP) × H × 2 bytes (BF16) +``` + +Where `MBS` = micro batch size, `S` = sequence length. + +#### Per-Component Activation Formulas + +##### LayerNorm +Stores its input for backward: +``` +act = sbh = MBS × S/(TP×CP) × H × 2 +``` + +##### Residual Add +Stores the residual for backward: +``` +act = sbh +``` + +##### Router +Stores hidden states for routing weight gradients: +``` +act = sbh +``` + +##### Attention (standard, Flash Attention) + +Stores Q, K, V, output, and logsumexp for Flash Attention backward: + +```python +tokens_per_rank = MBS × S / (TP × CP) + +# activation width = Q + K + V + output + softmax stats +Q_width = kv_channels × num_heads # e.g. 128 × 64 = 8192 +KV_width = kv_channels × num_kv_groups # e.g. 128 × 1 = 128 (MQA) +output_width = hidden_size # 8192 +softmax_width = Q_width (with Flash Attention) # 8192 + +total_width = Q_width + 2×KV_width + output_width + softmax_width +act = tokens_per_rank × total_width × 2 (BF16) +``` + +For MQA with 64 heads and 1 KV group: `Q(256MB) + K(4MB) + V(4MB) + O(256MB) + LSE(4MB) ≈ 0.51 GB` + +##### Dense MLP (SwiGLU) + +For the SwiGLU computation `output = down_proj(silu(gate_proj(x)) ⊙ up_proj(x))`, stores: + +```python +tokens = MBS × S / (TP × CP) + +# SwiGLU stores gate, up, and hidden (silu×up) for backward +intermediate = 2 × tokens × FFN × 2 # gate_proj + up_proj outputs (BF16) +activation = tokens × FFN × 2 # silu(gate) ⊙ up (input to down_proj) +output = tokens × H × 2 # down_proj output + +act = intermediate + activation + output + = tokens × (3×FFN + H) × 2 +``` + +##### MoE MLP + +For MoE, each token is routed to `topk` experts, duplicating the activation: + +```python +tokens = MBS × S / (TP × CP) +topk_tokens = tokens × topk # total token-expert pairs + +# Routed experts: same SwiGLU structure per token-expert pair +routed_act = topk_tokens × (3×FFN_e + H) × 2 + +# Shared expert (if configured): processes ALL tokens +shared_act = tokens × (3×FFN_e + H) × 2 # same SwiGLU, one copy + +act = routed_act + shared_act + = tokens × (topk + N_shared) × (H + 3×FFN_e) × 2 +``` + +Where: +- `topk` = `moe_router_topk` (experts activated per token) +- `FFN_e` = `moe_ffn_hidden_size` (per-expert FFN intermediate dimension) +- `N_shared` = 1 if `moe_shared_expert_intermediate_size` is set, else 0 + +**Example (MoE 4.5T, MBS=4, S=16384, CP=4, H=8192, FFN_e=2048, topk=36):** +``` +tokens = 4 × 16384/4 = 16,384 +MoE MLP = 16,384 × (36+1) × (8192 + 3×2048) × 2 = 16.19 GB +``` + +##### Full Transformer Layer (without recompute) + +For a MoE layer, the total is the sum of all components: + +``` +layer_act = 3×LayerNorm + Attention + 2×ResidualAdd + Router + MoE_MLP + = 3×sbh + attn_act + 2×sbh + sbh + moe_mlp_act + = 6×sbh + attn_act + moe_mlp_act +``` + +For a dense layer: same but with Dense MLP instead of Router + MoE MLP. + +##### Full Layer Activation Summary + +| Component | Formula | Typical Size (MoE 4.5T) | +|-----------|---------|------------------------| +| LayerNorm (×3) | `3 × sbh` | 0.75 GB | +| Residual Add (×2) | `2 × sbh` | 0.50 GB | +| Router | `sbh` | 0.25 GB | +| Attention (Flash, MQA) | `tokens × (Q+2KV+O+softmax) × 2` | 0.51 GB | +| MoE MLP (SwiGLU) | `tokens × (topk+1) × (H+3×FFN_e) × 2` | 16.19 GB | +| **Full MoE layer** | **sum** | **18.20 GB** | +| With full recompute | `sbh` (checkpoint only) | 0.25 GB | + + +### Pipeline Schedule Memory Scaling + +With pipeline parallelism, multiple microbatches are in-flight simultaneously, each requiring stored activations. + +#### 1F1B Schedule + +In the 1F1B (one-forward-one-backward) schedule, the first pipeline stage (rank 0) accumulates `PP` microbatches during the warmup phase before starting any backward passes. This means peak activation memory requires storing activations for `PP` microbatches. + +``` +base_activation = sum of per-layer activations across all layers on this rank +peak_activation = base_activation × PP +``` + +#### VPP (Virtual Pipeline Parallelism) Overhead + +With interleaved scheduling (VPP > 1), there is a small memory overhead because more microbatches can be partially in-flight: + +``` +interleaved_penalty = 1 + (PP - 1) / (PP × VPP) +``` + +For VPP=1: penalty = 1 + (PP-1)/PP (significant overhead) +For VPP=20: penalty = 1 + (PP-1)/(PP×20) ≈ 1.04 (nearly negligible) + +#### Gradient Accumulation Saving + +When gradient accumulation (GA) steps are fewer than PP stages, the pipeline isn't fully filled: + +```python +GA = GBS / (MBS × DP) +ga_saving = 1.0 if GA >= PP else GA / PP +``` + +#### Final Activation Memory Formula + +``` +total_activation = base_activation × PP × interleaved_penalty × ga_saving +``` + + +### Recomputation Support + +Activation recomputation trades compute for memory by discarding intermediate activations during forward and recomputing them during backward. + +#### Full Recompute (`recompute_granularity="full"`) + +When a layer is fully recomputed, only the **input tensor** is stored as a checkpoint: + +``` +recomputed_layer_act = sbh = MBS × S/(TP×CP) × H × 2 bytes +``` + +This is dramatically smaller than the full activation. For example, an MoE layer drops from ~18 GB to 0.25 GB. + +#### Partial Recompute + +The `recompute_num_layers` setting controls how many layers per VPP stage are recomputed: + +```python +for each layer on this rank: + local_idx = layer_index % layers_per_vpp_stage + if recompute_granularity == "full" and local_idx < recompute_num_layers: + act += input_act_per_layer # just sbh (checkpoint) + else: + act += full_layer_act # all intermediates +``` + +#### With Recompute: Total Memory + +``` +total_with_recompute = (L/PP × sbh) × PP × interleaved_penalty × ga_saving + + recompute_working_memory (1 layer's full activation, temporary) + + embedding_act (stage 0 only) +``` + +The recompute working memory is transient — only one layer's full intermediates exist at a time during backward. + +### Memory Formulas Reference + +Summary of all memory components for one GPU: + +``` +┌─────────────────────────────────────────────────────────────────────┐ +│ Total GPU Memory │ +├─────────────────────────────────────────────────────────────────────┤ +│ │ +│ 1. Parameters (BF16) │ +│ = params_on_rank × 2 bytes │ +│ │ +│ 2. Gradients (BF16) │ +│ = params_on_rank × 2 bytes │ +│ │ +│ 3. Optimizer States (FP32, sharded across DP) │ +│ = params_on_rank × 10 / DP bytes │ +│ (master weights: 2B + Adam m: 4B + Adam v: 4B) │ +│ │ +│ 4. Activations │ +│ = Σ(per-layer act) × PP × VPP_penalty × GA_saving │ +│ + embedding/output activations (stage-dependent) │ +│ │ +│ 5. Transient buffers (not in projection) │ +│ - A2A dispatch/combine buffers │ +│ - Communication scratch space │ +│ - PyTorch allocator fragmentation │ +│ │ +│ Total = (1) + (2) + (3) + (4) │ +│ Reported as: Param+Optimizer + Activation = Projected Total │ +│ │ +└─────────────────────────────────────────────────────────────────────┘ +``` + +--- + +## Performance Projection + + +### Overview + +The performance projection tool: + +1. **Benchmarks** transformer layers on a single node to measure forward/backward pass times +2. **Simulates** pipeline parallelism scheduling (including zero-bubble optimization) +3. **Projects** performance to multi-node configurations by modeling: + - Data Parallelism (DP) scaling + - Gradient AllReduce communication overhead + - Expert Parallelism (EP) All-to-All communication overhead + - Inter-node vs intra-node communication differences + +This allows you to estimate training performance on larger clusters without actually running on them. + +### How It Works + +#### 1. Configuration Reduction If the parallelism configuration requires multiple nodes (e.g., PP=3 needs 3 nodes), the tool automatically reduces the config for single-node benchmarking: - **Pipeline Parallelism (PP)**: Reduced to fit on 1 node, PP overhead estimated analytically - **Expert Parallelism (EP)**: Reduced if needed, All-to-All overhead added back -### 2. Layer Benchmarking +#### 2. Layer Benchmarking The tool benchmarks each transformer layer type: - Dense attention layers - MoE (Mixture of Experts) layers - Measures forward and backward pass times separately +- Also benchmarks embedding and output layers -### 3. Pipeline Simulation +#### 3. Pipeline Simulation For PP > 1, the tool simulates the pipeline schedule to account for: - Pipeline bubble overhead - Microbatch interleaving - Zero-bubble scheduling (if enabled) -### 4. Data Parallel Scaling +#### 4. Data Parallel Scaling The projection models how performance scales with additional nodes: @@ -106,11 +497,9 @@ The projection models how performance scales with additional nodes: Projected Time = (Base Time / DP_scaling_factor) + Communication Overheads ``` -## Scaling Mechanisms +### Scaling Mechanisms -The tool models the following parallelism dimensions and their communication patterns: - -### Tensor Parallelism (TP) +#### Tensor Parallelism (TP) **What it does**: Splits individual layer weights across GPUs within a node. @@ -118,7 +507,7 @@ The tool models the following parallelism dimensions and their communication pat **Communication**: AllReduce within TP group (typically intra-node, fast). -### Pipeline Parallelism (PP) +#### Pipeline Parallelism (PP) **What it does**: Distributes layers across pipeline stages. Each stage processes microbatches in sequence. @@ -128,7 +517,7 @@ The tool models the following parallelism dimensions and their communication pat - Simulates the actual 1F1B or zero-bubble schedule with proper send/receive synchronization - Accounts for pipeline bubble overhead and microbatch interleaving -### Expert Parallelism (EP) +#### Expert Parallelism (EP) **What it does**: Distributes MoE experts across GPUs. Each GPU holds a subset of experts. @@ -143,7 +532,7 @@ The tool models the following parallelism dimensions and their communication pat All-to-All Message Size = tokens × hidden_size × top_k × 2 (BF16) ``` -### Data Parallelism (DP) +#### Data Parallelism (DP) **What it does**: Replicates the model across DP groups. Each group processes different data batches. @@ -159,13 +548,13 @@ Gradient AllReduce Size = num_params × 4 (FP32 gradients) **Optimization**: If `overlap_grad_reduce=True` (default), gradient AllReduce is overlapped with backward computation and not on the critical path. -### Context Parallelism (CP) +#### Context Parallelism (CP) **What it does**: Splits sequence length across GPUs for long-context training. **How it's modeled**: CP affects the GPU topology for communication routing. Currently included in minimum GPU requirements calculation. -## Communication Modeling +### Communication Modeling The tool uses analytical models to estimate collective communication times: @@ -176,33 +565,177 @@ The tool uses analytical models to estimate collective communication times: | P2P Send/Recv | PP (activations) | Point-to-point latency + bandwidth | Communication times differ significantly based on: -- **Intra-node**: Fast (e.g., NVLink, xGMI) -- **Inter-node**: Slower (e.g., InfiniBand, RoCE) +- **Intra-node**: Fast (e.g., NVLink, UALink, xGMI) +- **Inter-node**: Slower (e.g., InfiniBand, RoCE) + +Custom hardware parameters can be provided via `--hardware-config `. -### Key Concepts +#### Key Concepts -#### Minimum Nodes Required +##### Minimum Nodes Required -The minimum nodes required is determined by: ``` Min Nodes = ceil(TP × PP × EP × CP / GPUs_per_node) ``` -#### Scaling Behavior +##### Scaling Behavior - **DP scaling**: Linear speedup. Doubling DP halves iteration time (minus communication overhead). - **PP scaling**: Happens in multiples of pipeline replicas. With PP=3, you need 3, 6, 9... nodes to increase scaling. -- **EP scaling**: Divides the experts on EP nodes. +- **EP scaling**: Divides the experts across EP nodes. + +### Pipeline Schedule Simulator + +The pipeline simulator (`simulator.py`) simulates the execution of pipeline parallelism schedules to calculate step time and bubble ratio. + +#### Schedule Algorithms + +| Algorithm | Description | Use Case | +|-----------|-------------|----------| +| **1F1B** | Standard one-forward-one-backward | Default pipeline schedule | +| **Interleaved 1F1B** | Multiple chunks per rank (VPP > 1) | Reduced bubble ratio | +| **Zero-Bubble** | Splits backward into B + W | Minimal bubble overhead | + +#### Zero-Bubble Scheduling + +Zero-bubble minimizes pipeline bubbles by separating the backward pass: +- **B (Input Gradient):** Compute gradients w.r.t. input activations +- **W (Weight Gradient):** Compute gradients w.r.t. weights + +This allows more flexible scheduling because W doesn't depend on receiving gradients from the next stage. By default, backward time is split 50/50 between B and W. + +Two implementations are available: +1. **Simple Zero-Bubble Simulator** — basic F-B-W pattern with warmup/steady/cooldown phases +2. **Megatron ILP-Based Scheduler** — graph-based schedule optimization with memory-aware scheduling using Megatron's actual zero-bubble scheduler + +#### P2P Communication in Pipeline Simulation + +The pipeline simulator uses a **fixed small constant** (~0.1 ms) for P2P communication, NOT the analytical `sendrecv` model. This is because: +1. P2P communication is typically **overlapped with compute** in optimized schedules +2. The simulator focuses on **schedule ordering and bubble calculation** +3. P2P time is **small relative to F/B/W times** for large models + +However, when the benchmark PP differs from the target PP (e.g., benchmark PP=1, target PP=6), the **analytical `sendrecv` model** is used to estimate the PP communication overhead that was not captured in the benchmark: + +``` +PP overhead = additional_stages × 2 (fwd+bwd) × microbatches × sendrecv(activation_size) +``` + +P2P communication becomes significant when PP stages span nodes (inter-node P2P has 10-100× higher latency than intra-node). + +### Overall Performance Prediction Flow + +``` +1. Load Configuration + └── Parse YAML config, extract parallelism settings + +2. Single-Node Benchmarking + ├── If config requires multiple nodes: + │ └── Reduce PP to 1, possibly rescale EP to fit on 1 node + ├── Limit layers (1 dense + 1 MoE for efficiency) + └── Benchmark forward + backward times + +3. Extrapolate to Full Model + └── Multiply per-layer times by total layer count + +4. Pipeline Schedule Simulation (if PP > 1) + ├── Build chunk time matrix (per-rank, per-chunk) + ├── Select scheduler (1F1B, Interleaved, Zero-Bubble) + └── Simulate to get step_time_ms and bubble_ratio + +5. Add Communication Overhead (if config was reduced) + ├── PP overhead: P2P communication between stages + └── EP overhead: Additional All-to-All for larger EP + +6. Multinode Scaling Projection + ├── Calculate DP scaling factor + ├── Scale compute time: projected = base × (min_dp / target_dp) + ├── Add gradient AllReduce (if not overlapped) + └── Report projected iteration time and tokens/s +``` + +#### Performance Formula + +``` +Projected_Time = Base_Time × (Min_DP / Target_DP) + Communication_Overhead + +Where: + Base_Time = Pipeline simulation time (includes PP bubbles) + Min_DP = DP at minimum node configuration + Target_DP = DP at target node configuration + Communication_Overhead = Gradient AllReduce (if not overlapped) + + MoE All-to-All overhead (if EP spans nodes) +``` + +#### Example Calculation + +**Configuration:** DeepSeek V2 Lite — TP=1, PP=3, EP=8, CP=1 — GBS=640, MBS=4, Seq=4096 + +``` +Step 1: Minimum Nodes + GPUs required = 1 × 3 × 8 × 1 = 24 GPUs = 3 nodes + Min DP = 24 / (1 × 3 × 1) = 8 + +Step 2: Target Configuration (6 nodes) + Total GPUs = 48 + Target DP = 48 / (1 × 3 × 1) = 16 + DP Scaling = 16 / 8 = 2× + +Step 3: Projected Time + Base_Time (from pipeline simulation) = 5500 ms + Projected_Time = 5500 × (8 / 16) = 2750 ms + Tokens/s = (640 × 4096) / 2.750 = 953,018 tokens/s +``` + +--- ## Example Output +### Memory Projection + ``` ==================================================================================================== -[Primus:Performance Projection] Configuration Summary: - Benchmark Config: PP=1, EP=8, TP=1, CP=1, DP=1 (1 node) - Target Config: PP=1, EP=8, TP=1, CP=1, DP=4 (4 nodes) - Benchmark Microbatches: 160 (global_batch=640, micro_batch=4, benchmark_dp=1) +[Primus:Projection] Component-wise Profiling Results (Rank 0): +==================================================================================================== + + Total Number of Parameters: 15.654321 Billion (15,654,321,024) + + [embedding] + Params: 0.819200 Billion (819,200,000) + Activation Memory: 0.2500 GB + + [dense_transformer_layer] + Params: 0.302000 Billion (302,000,000) + Activation Memory: 2.1250 GB + + [layer_norm] + Params: 0.000016 Billion (16,384) + Activation Memory: 0.2500 GB + + [self_attention] + Params: 0.134218 Billion (134,217,728) + Activation Memory: 0.5100 GB + + [mlp] + Params: 0.167772 Billion (167,772,160) + Activation Memory: 0.8650 GB + + [moe_transformer_layer] + Params: 1.001400 Billion (1,001,400,000) + Activation Memory: 18.2000 GB + +==================================================================================================== +[Primus:Projection] Memory Projection Summary on Rank 0: + Params: 20.850000 Billion (20,850,000,000) + Param+Optimizer Memory: 83.7400 GB + Activation Memory (per batch size 4, seq len 16384): 36.7500 GB + Projected Total Memory: 120.4900 GB +==================================================================================================== +``` +### Performance Projection + +``` ==================================================================================================== Multinode Scaling Projection Results ==================================================================================================== @@ -226,12 +759,36 @@ Multinode Scaling Projection Results ==================================================================================================== ``` +--- + +## Assumptions and Limitations + +### Assumptions + +1. **Linear DP Scaling** — Compute time scales linearly with 1/DP (ideal weak scaling) +2. **Communication Model** — Bandwidth efficiency is constant (default 91%); inter-node communication uses switch topology; all NICs are used in parallel for inter-node traffic +3. **Pipeline Bubbles** — B/W split is 50/50 for zero-bubble scheduling; P2P communication time is small relative to compute +4. **Gradient AllReduce** — By default overlapped with compute (`overlap_grad_reduce=True`); if disabled, added to critical path +5. **MoE All-to-All** — All-to-All time scales with EP size; per-peer latency overhead is constant + +### Limitations + +1. **Single-Node Benchmark Accuracy** — Benchmarking with reduced PP/EP may not capture all behaviors; layer timing variance is assumed uniform +2. **Communication Contention** — Model doesn't account for network contention; assumes dedicated bandwidth per collective +3. **Memory Pressure** — Memory impact on performance not fully modeled; activation recomputation overhead not considered in performance +4. **Hardware Heterogeneity** — Assumes homogeneous nodes; GPU frequency variations not modeled + +--- + ## Tips +- **Start with memory projection**: Run `projection memory` first to verify your model fits in GPU memory before benchmarking. - **Start with 1 node**: Always benchmark on 1 node first to establish baseline performance. -- **Understand scaling limits**: DP scaling is limited by global_batch_size / micro_batch_size. If you run out of microbatches, adding more nodes won't help. -- **Check minimum nodes**: If your config requires multiple nodes (e.g., PP=4 with 8 GPUs/node), projection will automatically reduce PP for benchmarking. +- **Understand scaling limits**: DP scaling is limited by `global_batch_size / micro_batch_size`. If you run out of microbatches, adding more nodes won't help. +- **Check minimum nodes**: If your config requires multiple nodes (e.g., PP=4 with 8 GPUs/node), the performance projection will automatically reduce PP for benchmarking. - **Pipeline scaling**: With PP > 1, you can only scale in multiples of the pipeline replica size. +- **Recomputation trade-off**: Full recompute dramatically reduces activation memory (e.g., 18 GB → 0.25 GB per MoE layer) at the cost of ~33% more compute. +- **MoE activation dominance**: For MoE models, the MoE MLP activation (scaled by `topk`) typically dominates the per-layer activation budget. Consider recomputation if memory is tight. ## Related Documentation