-
Notifications
You must be signed in to change notification settings - Fork 0
训练适配GMM NZ特性 #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
训练适配GMM NZ特性 #1
Conversation
|
run包不应该直接放到代码仓 |
|
应该按feature级分PR合入,另外PR 标题要有关键信息 |
3fb0844 to
46da32a
Compare
最新commit已经移除run包 |
dcp.save优化特性已经拆分至另一pr:#2 |
| source ${CANN_DIR}/set_env.sh | ||
|
|
||
| # 安装PTA 2.6.0版本GMM 切K轴补丁 | ||
| pip install /path/to/torch_npu-custom.whl --force-reinstall |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Q3的torch_npu已经支持切K轴了,建议改一下install路径,或者删掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
该pta包是适配自定义SliceNz融合算子的包,由于后续PTA主线可能不接受该算子,此处表明需要用户在训练环境中自行安装custom包来调用该接口
pta_patch/_fsdp_collectives.py
Outdated
| dim: int, | ||
| out: List[torch.Tensor], | ||
| ) -> None: | ||
| if len(all_gather_input_split_sizes) > 1 and out[-1].shape[0] * out[-1].shape[1] >= 128 * 4096*1536: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里写死是因为fsdp对moe分布的切分轴的问题吗?这部分代码是可以被其他模型复用的,请尽量用人类可读的方式写明。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
最新commit中已经将判断条件表示为了专家数*hidden_size*moe_intermediate_size的方式。用于判断当前copy_out的权重后续是否用于GMM运算
| @@ -0,0 +1 @@ | |||
| from torch_npu import npu_special_slice | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个文件是否必要?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
必要的,需要对PTA内部copy_out实现做补丁时调用npu_special_slice的接口,这个接口本身也适配在PTA里。由于不能自调用,所以在外部做了一个import动作
xtuner/v1/model/base.py
Outdated
| hf_keys_start = int(fsdp_start / hf_key_size) | ||
| hf_keys_end = math.ceil(fsdp_end / hf_key_size) | ||
|
|
||
| if int(os.getenv("GROUPMM_NZ_TRANSPOSE","0")) == 1 and len(hf_keys) == 128*2 and torch.distributed.get_world_size() == 512: # gate & up的情况,down的情况需要排除 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是不是可以直接用hf_keys的内容判断?这种写法容易导致误判。另外对于world size大于512的情况怎么处理?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hf_keys的内容为模块名的列表,包括每层专家的gate、up或者down映射权重的名称。最新commit已经修改为判断hf_keys长度是否为2*专家数(即gate&up的情况),以及world_size>=专家数时,才需要走该分支
pta_patch/_fsdp_collectives.py
Outdated
| out[-1].resize_(num_exp,out_dim1,in_dim1) | ||
| out[-2].resize_(num_exp,out_dim2,in_dim2) | ||
|
|
||
| npu_special_slice(all_gather_output, dim, weight_1_start, total_size, out[-1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是在这里面做了nz吗?special slice这种命名太宽泛了
| out[-2].resize_(num_expert,hidden_size,moe_intermediate_size*2) | ||
|
|
||
| # GMM权重切分和转NZ使用融合算子 | ||
| npu_special_slice(all_gather_output, dim, up_start, total_size, out[-1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里命名是否应该更具有人类可读性?
特性描述
xtuner框架使能GMM NZ的训练优化特性,FSDP2 中融合copy_out和转NZ操作,采用NPU亲和的NZ格式做GMM运算以取得性能收益
具体改进
用户使能
参考test_qwen3_235b_npu.sh脚本,打开GROUPED_MATMUL_NZ_TRANSPOSE开关和512对齐开关LINEAR_ONLY_SHARD
测试
验证结果