Skip to content

Commit e117893

Browse files
committed
doc
1 parent 6783a9a commit e117893

File tree

2 files changed

+399
-0
lines changed

2 files changed

+399
-0
lines changed
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
# Sequence Parallel(SP)实现原理
2+
3+
本节描述 Twinkle 的序列并行(Sequence Parallel, SP)在 Transformers 路径中的实现机制与关键设计点。
4+
5+
## 实现原理概览
6+
- 核心思想:把 **序列维度** 切分到多个 SP rank,上游/下游保持张量语义一致;必要时通过 **all-to-all** 在注意力前后重排,保证每个 rank 都能看到完整序列上下文。
7+
- 关键约束:SP 不改变 TP/PP/EP 等非数据并行维度,SP 只在数据并行维度内“分组”。
8+
- 通信形态:主要是 **all_to_all_single**(注意力前后),以及 **all_gather / all_reduce**(position_ids、loss、logits)。
9+
- 通用性来源:SP 不要求模型结构特化,而是 **在注意力实现层进行补丁(patch)**,统一拦截 Q/K/V 的布局与 mask/position_ids 的构造方式,从而让不同模型共享同一条 SP 数据路径。
10+
11+
## 实现原理细节
12+
### 1) SP 进程组构造
13+
- 若 device mesh 明确包含 `sp` 维度,直接用该维度构建进程组。
14+
- 否则在 **数据并行维度(dp/fsdp)内部**`sp_size` 切分分组,确保 SP 不跨 TP/PP/EP。
15+
- 该逻辑在 `SequenceParallel._get_sp_group_from_device_mesh()` 中实现,保证每个 SP 组大小为 `sp_size`
16+
17+
### 2) 序列切分与对齐
18+
- 训练/推理前先将序列长度 pad 到可被 `sp_size` 整除,再按序列维切分给各 SP rank。
19+
- `real_position_ids` 用于保留原始长度信息,后续在 gather 或 mask 中还原真实序列。
20+
21+
### 3) 注意力的 all-to-all 交换
22+
- 在注意力计算前,把 Q/K/V 从 “按序列切分” 重排为 “按 head 维/局部序列” 的布局;
23+
- 使用 `all_to_all_single` 在 SP 组内交换,使每个 rank 看到完整序列上下文;
24+
- 注意力计算完成后,再做一次逆向 all-to-all,将输出切回本地序列分片。
25+
- 这部分由 `_SeqAllToAll`(autograd 包装)和 `DistributedAttention` 完成,保证前后向一致。
26+
#### 为什么说 SP 是“通用”的
27+
- SP **不直接改模型结构**,而是 **patch attention 路径**
28+
- 对 FlashAttention2/SDPA 的入口做封装,接管 Q/K/V 的重排与 all-to-all;
29+
- 对 mask / cache_position / position_ids 的构造进行重建;
30+
- 这样即使上层模型结构不同,只要最终走到标准 attention 实现(FA2/SDPA),就能被 SP 统一接入。
31+
- 这使得 SP 能跨多种模型复用,代价是注意力层必须遵循被 patch 的调用路径。
32+
33+
#### Attention 内部的序列聚合与头切分
34+
下面以常见张量布局示意(忽略 batch 维):
35+
- 输入进入注意力前,每个 SP rank 只拥有局部序列 `Ls = Lp / S`,Q/K/V 形状为 `[Ls, num_heads, head_dim]`
36+
- 目标是让每个 rank 在注意力计算时看到 **完整序列长度 Lp**,但每个 rank 只负责 **部分 heads**,以保持计算量与通信量平衡。
37+
38+
步骤拆解:
39+
1. 重排为 all-to-all 友好的布局
40+
把 Q/K/V reshape/permute 成 `[S, Ls, num_heads / S, head_dim]`(概念上是把 head 维切成 S 份,把序列维切成 S 份),再做 `all_to_all_single`
41+
1. all-to-all 后的布局
42+
每个 rank 得到 `[Lp, num_heads / S, head_dim]`,即 **序列全量聚合** + **头维切分**
43+
1. 本地注意力计算
44+
在完整序列上做 attention,得到 `context``[Lp, num_heads / S, head_dim]`
45+
1. 逆向 all-to-all
46+
`context` 通过反向 all-to-all 交换,恢复为本 rank 的局部序列输出 `[Ls, num_heads, head_dim]`
47+
48+
这一流程保证:
49+
- 序列维在注意力计算时是全量的,避免跨 rank 缺失上下文。
50+
- 头维被均匀切分,保持每个 rank 的计算和显存负载可控。
51+
52+
### 4) Mask/Position Id 适配
53+
- FlashAttention2/SDPA 的 mask 与 cache_position 在 SP 下需要重建:
54+
- FA2:SP 场景下避免全 mask,必要时根据 `real_position_ids` 重建。
55+
- SDPA:通过 `real_position_ids` 重建 cache_position(仅非 packed 场景)。
56+
- packed 场景下必须使用 FA2 varlen 语义,否则会跨子序列错误注意力。
57+
58+
### 5) Loss/Metric 的“全局一致,梯度本地”
59+
- 指标需要全局一致:`sum/mean` 通过 all-reduce 汇总。
60+
- 反向只来自本 rank:用 `out_metric = out.detach() / sp_size` + `out - out.detach()` 组合,避免重复梯度。
61+
- MoE router loss 需要全序列 logits 时,先 gather 再裁剪。
62+
63+
## 目标与适用场景
64+
- 目标:在不改变全局 world size 的前提下,沿 **序列维度**切分计算,降低单卡序列长度压力。
65+
- 适用:**长序列训练****显存瓶颈在激活**的场景。
66+
- 不适用/收益有限:短序列、小模型或通信占比高时,可能 **变慢****显存收益不明显**
67+
68+
## 关键概念
69+
- **SP 组**:在现有 device mesh 上构造的序列并行进程组;不新增 world size。
70+
- **real_position_ids**:原始 position_ids(未 pad),用于还原真实序列长度。
71+
- **packed**:PackingDataset 产生的多样本拼接序列(position_ids 多次归零)。
72+
73+
## 核心流程
74+
### 1) 进程组构建
75+
- 若 device_mesh 含 `sp` 维度,直接使用。
76+
- 否则按 `dp/fsdp` 维度分块构建 SP 组(不跨 TP/PP/EP 等非数据维)。
77+
78+
### 2) 输入预处理(pad & split)
79+
-**pad**`seq_len % sp_size == 0`
80+
- 再按序列维 **split** 到各 SP rank;
81+
- `real_position_ids` 保留原始长度,用于后续裁剪。
82+
83+
### 3) Attention 适配
84+
为支持 SP 与 FlashAttention2/SDPA:
85+
- 对 FA2/SDPA 的 mask、cache_position 做适配;
86+
- **packed** 情况下,要求使用 **FlashAttention2**(SDPA 不支持 packed)。
87+
88+
### 4) 输出处理(可选)
89+
`gather_logits=True`
90+
- 将各 rank 的 logits 按序列维 **gather** 回完整序列;
91+
- 使用 `real_position_ids` 裁剪回原始长度。
92+
93+
### 5) Loss 处理
94+
SP 的 loss 需要“前向全局一致、反向本地化”:
95+
- `sum`:本地 `loss_sum` all-reduce 得到全局 sum;
96+
- `mean`:本地 `loss_sum``num_valid_tokens` 各自 all-reduce;
97+
- 使用 **stop-grad 拼接** 保证前向数值全局一致,反向梯度仅来自本 rank:
98+
- `out_metric = out.detach() / sp_size`
99+
- `return out_metric + (out - out.detach())`
100+
101+
> 说明:该策略修正指标重复计数,同时不影响训练梯度。
102+
103+
### 6) Packed 训练的边界处理
104+
PackingDataset 会把多条样本拼成一条长序列,`position_ids` 在边界处重置为 0。
105+
为了避免“跨样本预测”错误监督:
106+
- 在边界前一个 token 位置将 `labels` 置为 `-100`
107+
- 最后一个 token 也置为 `-100`,避免 wrap-around。
108+
109+
### 7) MoE 辅助损失
110+
MoE 的 `router_logits` 需要看到全序列:
111+
-`GatherLoss` 将 SP 切分的 logits 拼回;
112+
- 再裁剪到原始长度以计算 aux loss。
113+
114+
## 数据流与 shape 变化(开启 SP 后)
115+
下面用符号说明(单卡视角 + SP 组视角):
116+
- `B`:batch size
117+
- `L`:原始序列长度
118+
- `S``sp_size`
119+
- `Lp`:pad 后长度,`Lp = ceil(L / S) * S`
120+
- `Ls`:本 rank 序列长度,`Ls = Lp / S`
121+
- `H`:hidden size
122+
- `V`:vocab size
123+
124+
```text
125+
输入侧(全局/逻辑视角)
126+
input_ids : [B, L] -> pad -> [B, Lp]
127+
position_ids : [B, L] -> pad -> [B, Lp] (real_position_ids 记录未 pad 长度)
128+
attention_mask : [B, L] -> pad -> [B, Lp]
129+
labels : [B, L] -> pad -> [B, Lp] (packed 场景会额外把边界位置置 -100)
130+
131+
按序列维 split 到各 SP rank
132+
input_ids : [B, Lp] --split--> [B, Ls]
133+
position_ids : [B, Lp] --split--> [B, Ls]
134+
attention_mask : [B, Lp] --split--> [B, Ls]
135+
labels : [B, Lp] --split--> [B, Ls]
136+
137+
本 rank 计算路径
138+
Embedding(input_ids) -> hidden_states: [B, Ls, H]
139+
140+
Attention 前重排(all-to-all 1)
141+
Q/K/V: [B, Ls, num_heads, head_dim]
142+
-> all_to_all_single (SP 组内交换)
143+
-> [B, Lp, num_heads / S, head_dim] (每 rank 拥有完整序列但更少 heads)
144+
145+
Attention 计算
146+
context: [B, Lp, num_heads / S, head_dim]
147+
148+
Attention 后重排(all-to-all 2)
149+
context -> all_to_all_single
150+
output: [B, Ls, H]
151+
152+
FFN/残差等后续
153+
hidden_states: [B, Ls, H]
154+
155+
LM Head / logits
156+
logits: [B, Ls, V]
157+
158+
可选:gather_logits=True
159+
logits: [B, Ls, V] --gather--> [B, Lp, V] --crop--> [B, L, V]
160+
161+
Loss/metrics
162+
labels: [B, Ls] 对齐本地 logits 计算
163+
loss_sum / token_count: 本地 -> all-reduce -> 全局
164+
mean/sum 通过 stop-grad 拼接保证前向一致、反向本地
165+
```
166+
167+
说明:
168+
- SP 只切分 **序列维**,不会改变 batch 维或 hidden 维的语义。
169+
- 注意力阶段需要 all-to-all 是为了让每个 rank 在计算 attention 时看到完整序列上下文。
170+
- packed 场景下,`position_ids` 中存在多段 0..n 重置;SP 对此要求使用 FA2 的 varlen 语义。
171+
172+
## 关键配置
173+
- `ulysses_size`:SP 大小。
174+
- `gather_logits`:是否在输出阶段聚合 logits(推理/评估常用)。
175+
- `loss_reduction``sum` / `mean`,必须与上游 loss 语义一致。
176+
177+
## 局限与注意事项
178+
- **packed + SDPA 不支持**:packed 批次需使用 FlashAttention2。
179+
- **性能**:SP 增加通信,短序列可能更慢。
180+
- **显存**:主要减少激活峰值,对参数显存影响不大。
181+
182+
## 与 FSDP/DP 关系
183+
- SP 不新增 world size,不参与 FSDP/DP 的梯度规约。
184+
- FSDP/DP 负责参数梯度规约;SP 仅影响序列维度的切分与聚合。
185+
186+
## 与 FSDP 的结合与“特效”
187+
### 1) 组合原则
188+
- SP 只切分 **序列维**,FSDP 切分 **参数/梯度**;两者是正交的。
189+
- SP 的进程组在 **dp/fsdp 维度内部**构建,因此不会跨 FSDP shard 边界。
190+
191+
### 2) 主要收益
192+
- **激活显存更省**:SP 减小每卡序列长度,FSDP 又把参数/梯度切碎,组合后显存峰值更低。
193+
- **更稳的长序列训练**:FSDP 把参数压力降下来,SP 把激活压力降下来,长序列更易跑通。
194+
195+
### 3) 通信/性能侧影响
196+
- SP 引入 all-to-all(注意力前后)以及 loss/logits 的 all-reduce/gather;
197+
- FSDP 引入参数分片通信(all-gather/reshard);
198+
- 组合时通信叠加,**短序列场景可能更慢**,长序列场景通常收益更明显。
199+
200+
### 4) 需要注意的点
201+
- SP 组按 dp/fsdp 分组,所以 **sp_size 必须整除 data_world_size(dp * fsdp)**
202+
- 如果使用 packed + FA2,仍需遵守 SP 对 packed 的限制(SDPA 不支持 packed)。

0 commit comments

Comments
 (0)