|
| 1 | +# Sequence Parallel(SP)实现原理 |
| 2 | + |
| 3 | +本节描述 Twinkle 的序列并行(Sequence Parallel, SP)在 Transformers 路径中的实现机制与关键设计点。 |
| 4 | + |
| 5 | +## 实现原理概览 |
| 6 | +- 核心思想:把 **序列维度** 切分到多个 SP rank,上游/下游保持张量语义一致;必要时通过 **all-to-all** 在注意力前后重排,保证每个 rank 都能看到完整序列上下文。 |
| 7 | +- 关键约束:SP 不改变 TP/PP/EP 等非数据并行维度,SP 只在数据并行维度内“分组”。 |
| 8 | +- 通信形态:主要是 **all_to_all_single**(注意力前后),以及 **all_gather / all_reduce**(position_ids、loss、logits)。 |
| 9 | +- 通用性来源:SP 不要求模型结构特化,而是 **在注意力实现层进行补丁(patch)**,统一拦截 Q/K/V 的布局与 mask/position_ids 的构造方式,从而让不同模型共享同一条 SP 数据路径。 |
| 10 | + |
| 11 | +## 实现原理细节 |
| 12 | +### 1) SP 进程组构造 |
| 13 | +- 若 device mesh 明确包含 `sp` 维度,直接用该维度构建进程组。 |
| 14 | +- 否则在 **数据并行维度(dp/fsdp)内部**按 `sp_size` 切分分组,确保 SP 不跨 TP/PP/EP。 |
| 15 | +- 该逻辑在 `SequenceParallel._get_sp_group_from_device_mesh()` 中实现,保证每个 SP 组大小为 `sp_size`。 |
| 16 | + |
| 17 | +### 2) 序列切分与对齐 |
| 18 | +- 训练/推理前先将序列长度 pad 到可被 `sp_size` 整除,再按序列维切分给各 SP rank。 |
| 19 | +- `real_position_ids` 用于保留原始长度信息,后续在 gather 或 mask 中还原真实序列。 |
| 20 | + |
| 21 | +### 3) 注意力的 all-to-all 交换 |
| 22 | +- 在注意力计算前,把 Q/K/V 从 “按序列切分” 重排为 “按 head 维/局部序列” 的布局; |
| 23 | +- 使用 `all_to_all_single` 在 SP 组内交换,使每个 rank 看到完整序列上下文; |
| 24 | +- 注意力计算完成后,再做一次逆向 all-to-all,将输出切回本地序列分片。 |
| 25 | +- 这部分由 `_SeqAllToAll`(autograd 包装)和 `DistributedAttention` 完成,保证前后向一致。 |
| 26 | +#### 为什么说 SP 是“通用”的 |
| 27 | +- SP **不直接改模型结构**,而是 **patch attention 路径**: |
| 28 | + - 对 FlashAttention2/SDPA 的入口做封装,接管 Q/K/V 的重排与 all-to-all; |
| 29 | + - 对 mask / cache_position / position_ids 的构造进行重建; |
| 30 | + - 这样即使上层模型结构不同,只要最终走到标准 attention 实现(FA2/SDPA),就能被 SP 统一接入。 |
| 31 | +- 这使得 SP 能跨多种模型复用,代价是注意力层必须遵循被 patch 的调用路径。 |
| 32 | + |
| 33 | +#### Attention 内部的序列聚合与头切分 |
| 34 | +下面以常见张量布局示意(忽略 batch 维): |
| 35 | +- 输入进入注意力前,每个 SP rank 只拥有局部序列 `Ls = Lp / S`,Q/K/V 形状为 `[Ls, num_heads, head_dim]`。 |
| 36 | +- 目标是让每个 rank 在注意力计算时看到 **完整序列长度 Lp**,但每个 rank 只负责 **部分 heads**,以保持计算量与通信量平衡。 |
| 37 | + |
| 38 | +步骤拆解: |
| 39 | +1. 重排为 all-to-all 友好的布局 |
| 40 | + 把 Q/K/V reshape/permute 成 `[S, Ls, num_heads / S, head_dim]`(概念上是把 head 维切成 S 份,把序列维切成 S 份),再做 `all_to_all_single`。 |
| 41 | +1. all-to-all 后的布局 |
| 42 | + 每个 rank 得到 `[Lp, num_heads / S, head_dim]`,即 **序列全量聚合** + **头维切分**。 |
| 43 | +1. 本地注意力计算 |
| 44 | + 在完整序列上做 attention,得到 `context`:`[Lp, num_heads / S, head_dim]`。 |
| 45 | +1. 逆向 all-to-all |
| 46 | + 将 `context` 通过反向 all-to-all 交换,恢复为本 rank 的局部序列输出 `[Ls, num_heads, head_dim]`。 |
| 47 | + |
| 48 | +这一流程保证: |
| 49 | +- 序列维在注意力计算时是全量的,避免跨 rank 缺失上下文。 |
| 50 | +- 头维被均匀切分,保持每个 rank 的计算和显存负载可控。 |
| 51 | + |
| 52 | +### 4) Mask/Position Id 适配 |
| 53 | +- FlashAttention2/SDPA 的 mask 与 cache_position 在 SP 下需要重建: |
| 54 | + - FA2:SP 场景下避免全 mask,必要时根据 `real_position_ids` 重建。 |
| 55 | + - SDPA:通过 `real_position_ids` 重建 cache_position(仅非 packed 场景)。 |
| 56 | +- packed 场景下必须使用 FA2 varlen 语义,否则会跨子序列错误注意力。 |
| 57 | + |
| 58 | +### 5) Loss/Metric 的“全局一致,梯度本地” |
| 59 | +- 指标需要全局一致:`sum/mean` 通过 all-reduce 汇总。 |
| 60 | +- 反向只来自本 rank:用 `out_metric = out.detach() / sp_size` + `out - out.detach()` 组合,避免重复梯度。 |
| 61 | +- MoE router loss 需要全序列 logits 时,先 gather 再裁剪。 |
| 62 | + |
| 63 | +## 目标与适用场景 |
| 64 | +- 目标:在不改变全局 world size 的前提下,沿 **序列维度**切分计算,降低单卡序列长度压力。 |
| 65 | +- 适用:**长序列训练**、**显存瓶颈在激活**的场景。 |
| 66 | +- 不适用/收益有限:短序列、小模型或通信占比高时,可能 **变慢** 且 **显存收益不明显**。 |
| 67 | + |
| 68 | +## 关键概念 |
| 69 | +- **SP 组**:在现有 device mesh 上构造的序列并行进程组;不新增 world size。 |
| 70 | +- **real_position_ids**:原始 position_ids(未 pad),用于还原真实序列长度。 |
| 71 | +- **packed**:PackingDataset 产生的多样本拼接序列(position_ids 多次归零)。 |
| 72 | + |
| 73 | +## 核心流程 |
| 74 | +### 1) 进程组构建 |
| 75 | +- 若 device_mesh 含 `sp` 维度,直接使用。 |
| 76 | +- 否则按 `dp/fsdp` 维度分块构建 SP 组(不跨 TP/PP/EP 等非数据维)。 |
| 77 | + |
| 78 | +### 2) 输入预处理(pad & split) |
| 79 | +- 先 **pad** 到 `seq_len % sp_size == 0`; |
| 80 | +- 再按序列维 **split** 到各 SP rank; |
| 81 | +- `real_position_ids` 保留原始长度,用于后续裁剪。 |
| 82 | + |
| 83 | +### 3) Attention 适配 |
| 84 | +为支持 SP 与 FlashAttention2/SDPA: |
| 85 | +- 对 FA2/SDPA 的 mask、cache_position 做适配; |
| 86 | +- **packed** 情况下,要求使用 **FlashAttention2**(SDPA 不支持 packed)。 |
| 87 | + |
| 88 | +### 4) 输出处理(可选) |
| 89 | +当 `gather_logits=True`: |
| 90 | +- 将各 rank 的 logits 按序列维 **gather** 回完整序列; |
| 91 | +- 使用 `real_position_ids` 裁剪回原始长度。 |
| 92 | + |
| 93 | +### 5) Loss 处理 |
| 94 | +SP 的 loss 需要“前向全局一致、反向本地化”: |
| 95 | +- `sum`:本地 `loss_sum` all-reduce 得到全局 sum; |
| 96 | +- `mean`:本地 `loss_sum` 和 `num_valid_tokens` 各自 all-reduce; |
| 97 | +- 使用 **stop-grad 拼接** 保证前向数值全局一致,反向梯度仅来自本 rank: |
| 98 | + - `out_metric = out.detach() / sp_size` |
| 99 | + - `return out_metric + (out - out.detach())` |
| 100 | + |
| 101 | +> 说明:该策略修正指标重复计数,同时不影响训练梯度。 |
| 102 | +
|
| 103 | +### 6) Packed 训练的边界处理 |
| 104 | +PackingDataset 会把多条样本拼成一条长序列,`position_ids` 在边界处重置为 0。 |
| 105 | +为了避免“跨样本预测”错误监督: |
| 106 | +- 在边界前一个 token 位置将 `labels` 置为 `-100`; |
| 107 | +- 最后一个 token 也置为 `-100`,避免 wrap-around。 |
| 108 | + |
| 109 | +### 7) MoE 辅助损失 |
| 110 | +MoE 的 `router_logits` 需要看到全序列: |
| 111 | +- 用 `GatherLoss` 将 SP 切分的 logits 拼回; |
| 112 | +- 再裁剪到原始长度以计算 aux loss。 |
| 113 | + |
| 114 | +## 数据流与 shape 变化(开启 SP 后) |
| 115 | +下面用符号说明(单卡视角 + SP 组视角): |
| 116 | +- `B`:batch size |
| 117 | +- `L`:原始序列长度 |
| 118 | +- `S`:`sp_size` |
| 119 | +- `Lp`:pad 后长度,`Lp = ceil(L / S) * S` |
| 120 | +- `Ls`:本 rank 序列长度,`Ls = Lp / S` |
| 121 | +- `H`:hidden size |
| 122 | +- `V`:vocab size |
| 123 | + |
| 124 | +```text |
| 125 | +输入侧(全局/逻辑视角) |
| 126 | +input_ids : [B, L] -> pad -> [B, Lp] |
| 127 | +position_ids : [B, L] -> pad -> [B, Lp] (real_position_ids 记录未 pad 长度) |
| 128 | +attention_mask : [B, L] -> pad -> [B, Lp] |
| 129 | +labels : [B, L] -> pad -> [B, Lp] (packed 场景会额外把边界位置置 -100) |
| 130 | +
|
| 131 | +按序列维 split 到各 SP rank |
| 132 | +input_ids : [B, Lp] --split--> [B, Ls] |
| 133 | +position_ids : [B, Lp] --split--> [B, Ls] |
| 134 | +attention_mask : [B, Lp] --split--> [B, Ls] |
| 135 | +labels : [B, Lp] --split--> [B, Ls] |
| 136 | +
|
| 137 | +本 rank 计算路径 |
| 138 | +Embedding(input_ids) -> hidden_states: [B, Ls, H] |
| 139 | +
|
| 140 | +Attention 前重排(all-to-all 1) |
| 141 | +Q/K/V: [B, Ls, num_heads, head_dim] |
| 142 | + -> all_to_all_single (SP 组内交换) |
| 143 | + -> [B, Lp, num_heads / S, head_dim] (每 rank 拥有完整序列但更少 heads) |
| 144 | +
|
| 145 | +Attention 计算 |
| 146 | +context: [B, Lp, num_heads / S, head_dim] |
| 147 | +
|
| 148 | +Attention 后重排(all-to-all 2) |
| 149 | +context -> all_to_all_single |
| 150 | +output: [B, Ls, H] |
| 151 | +
|
| 152 | +FFN/残差等后续 |
| 153 | +hidden_states: [B, Ls, H] |
| 154 | +
|
| 155 | +LM Head / logits |
| 156 | +logits: [B, Ls, V] |
| 157 | +
|
| 158 | +可选:gather_logits=True |
| 159 | +logits: [B, Ls, V] --gather--> [B, Lp, V] --crop--> [B, L, V] |
| 160 | +
|
| 161 | +Loss/metrics |
| 162 | +labels: [B, Ls] 对齐本地 logits 计算 |
| 163 | +loss_sum / token_count: 本地 -> all-reduce -> 全局 |
| 164 | +mean/sum 通过 stop-grad 拼接保证前向一致、反向本地 |
| 165 | +``` |
| 166 | + |
| 167 | +说明: |
| 168 | +- SP 只切分 **序列维**,不会改变 batch 维或 hidden 维的语义。 |
| 169 | +- 注意力阶段需要 all-to-all 是为了让每个 rank 在计算 attention 时看到完整序列上下文。 |
| 170 | +- packed 场景下,`position_ids` 中存在多段 0..n 重置;SP 对此要求使用 FA2 的 varlen 语义。 |
| 171 | + |
| 172 | +## 关键配置 |
| 173 | +- `ulysses_size`:SP 大小。 |
| 174 | +- `gather_logits`:是否在输出阶段聚合 logits(推理/评估常用)。 |
| 175 | +- `loss_reduction`:`sum` / `mean`,必须与上游 loss 语义一致。 |
| 176 | + |
| 177 | +## 局限与注意事项 |
| 178 | +- **packed + SDPA 不支持**:packed 批次需使用 FlashAttention2。 |
| 179 | +- **性能**:SP 增加通信,短序列可能更慢。 |
| 180 | +- **显存**:主要减少激活峰值,对参数显存影响不大。 |
| 181 | + |
| 182 | +## 与 FSDP/DP 关系 |
| 183 | +- SP 不新增 world size,不参与 FSDP/DP 的梯度规约。 |
| 184 | +- FSDP/DP 负责参数梯度规约;SP 仅影响序列维度的切分与聚合。 |
| 185 | + |
| 186 | +## 与 FSDP 的结合与“特效” |
| 187 | +### 1) 组合原则 |
| 188 | +- SP 只切分 **序列维**,FSDP 切分 **参数/梯度**;两者是正交的。 |
| 189 | +- SP 的进程组在 **dp/fsdp 维度内部**构建,因此不会跨 FSDP shard 边界。 |
| 190 | + |
| 191 | +### 2) 主要收益 |
| 192 | +- **激活显存更省**:SP 减小每卡序列长度,FSDP 又把参数/梯度切碎,组合后显存峰值更低。 |
| 193 | +- **更稳的长序列训练**:FSDP 把参数压力降下来,SP 把激活压力降下来,长序列更易跑通。 |
| 194 | + |
| 195 | +### 3) 通信/性能侧影响 |
| 196 | +- SP 引入 all-to-all(注意力前后)以及 loss/logits 的 all-reduce/gather; |
| 197 | +- FSDP 引入参数分片通信(all-gather/reshard); |
| 198 | +- 组合时通信叠加,**短序列场景可能更慢**,长序列场景通常收益更明显。 |
| 199 | + |
| 200 | +### 4) 需要注意的点 |
| 201 | +- SP 组按 dp/fsdp 分组,所以 **sp_size 必须整除 data_world_size(dp * fsdp)**。 |
| 202 | +- 如果使用 packed + FA2,仍需遵守 SP 对 packed 的限制(SDPA 不支持 packed)。 |
0 commit comments