-
Notifications
You must be signed in to change notification settings - Fork 0
训练适配GMM NZ特性 #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
训练适配GMM NZ特性 #1
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,91 @@ | ||
| from typing import List, Tuple | ||
| import os | ||
| import torch | ||
|
|
||
|
|
||
| lib = torch.library.Library("fsdp", "FRAGMENT") | ||
|
|
||
|
|
||
| @torch.library.impl(lib, "chunk_cat", "PrivateUse1") | ||
| def chunk_cat( | ||
| tensors: List[torch.Tensor], | ||
| dim: int, | ||
| num_chunks: int, | ||
| out: torch.Tensor, | ||
| ) -> None: | ||
| tensors = [tensor.contiguous() for tensor in tensors] | ||
| out = out.contiguous() | ||
| torch._chunk_cat(tensors, dim, num_chunks, out=out) | ||
|
|
||
|
|
||
| @torch.library.impl(lib, "all_gather_copy_in", "PrivateUse1") | ||
| def all_gather_copy_in_npu( | ||
| all_gather_inputs: List[torch.Tensor], | ||
| inp_split_sizes: List[int], | ||
| all_gather_input_numel: int, | ||
| world_size: int, | ||
| rank: int, | ||
| dtype: torch.dtype, | ||
| device: torch.device, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| all_gather_output = torch.empty( | ||
| (all_gather_input_numel * world_size,), dtype=dtype, device=device | ||
| ) | ||
| all_gather_input = all_gather_output.narrow( | ||
| 0, all_gather_input_numel * rank, all_gather_input_numel | ||
| ) | ||
| foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) | ||
| with torch.no_grad(): | ||
| if foreach_copy_dsts[0].device == all_gather_inputs[0].device: | ||
| torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs, non_blocking=True) | ||
| else: | ||
| torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs) | ||
| return all_gather_input, all_gather_output | ||
|
|
||
| @torch.library.impl(lib, "split_with_sizes_copy", "PrivateUse1") | ||
| def split_with_sizes_copy( | ||
| all_gather_output: torch.Tensor, | ||
| all_gather_input_split_sizes: List[int], | ||
| dim: int, | ||
| out: List[torch.Tensor], | ||
| num_expert: int = 128, | ||
| hidden_size: int = 4096, | ||
| moe_intermediate_size: int = 1536, | ||
| ) -> None: | ||
| # 当且仅当满足如下条件,才启用gmm_nz优化 | ||
| # 1. 打开GROUPMM_NZ_TRANSPOSE开关 | ||
| # 2. all_gather_input_split_sizes长度大于1 | ||
| # 3. 切分后的最后一个权重用于GMM down_proj | ||
| enable_gmm_nz = int(os.getenv("GROUPMM_NZ_TRANSPOSE","0")) \ | ||
| and len(all_gather_input_split_sizes) > 1 \ | ||
| and out[-1].shape[0] * out[-1].shape[1] == num_expert * hidden_size * moe_intermediate_size | ||
|
|
||
| if enable_gmm_nz: | ||
| from special_op import npu_special_slice | ||
| num_rank = out[0].shape[0] | ||
| total_size = sum(all_gather_input_split_sizes) | ||
|
|
||
| # 切分后最后两个权重用于GMM up_proj和down_proj | ||
| up_size = out[-1].shape[1] | ||
| down_size = out[-2].shape[1] | ||
|
|
||
| up_start = total_size - up_size | ||
| down_start = up_start - down_size | ||
|
|
||
| out[-1].resize_(num_expert,moe_intermediate_size,hidden_size) | ||
| out[-2].resize_(num_expert,hidden_size,moe_intermediate_size*2) | ||
|
|
||
| # GMM权重切分和转NZ使用融合算子 | ||
| npu_special_slice(all_gather_output, dim, up_start, total_size, out[-1]) | ||
| npu_special_slice(all_gather_output, dim, down_start, up_start, out[-2]) | ||
|
|
||
| other_tensors = all_gather_output[:, :down_start].view(num_rank, -1) | ||
| torch.split_with_sizes_copy( | ||
| other_tensors, all_gather_input_split_sizes[:-2], dim=dim, out=out[:-2] | ||
| ) | ||
|
|
||
| return | ||
|
|
||
| torch.split_with_sizes_copy( | ||
| all_gather_output, all_gather_input_split_sizes, dim=dim, out=out | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from torch_npu import npu_special_slice | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个文件是否必要?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 必要的,需要对PTA内部copy_out实现做补丁时调用npu_special_slice的接口,这个接口本身也适配在PTA里。由于不能自调用,所以在外部做了一个import动作 |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| set -x | ||
|
|
||
| CANN_DIR=/usr/local/Ascend/ascend-toolkit # 默认CANN安装地址 | ||
| CANN_OPS_DIR=/tmp/cann-ops # GMM_NZ使能补丁的CANN-ops安装地址 | ||
| PTA_FSDP_DIR=/usr/local/lib/python3.11/site-packages/torch_npu/distributed/fsdp # PTA的FSDP补丁位置,方便后续替换slice算子的实现 | ||
|
|
||
| mkdir ${CANN_OPS_DIR} | ||
| ./pta_patch/CANN-custom_ops--linux.aarch64.run --install-path=${CANN_OPS_DIR} | ||
| source ${CANN_OPS_DIR}/vendors/customize/bin/set_env.bash | ||
| source ${CANN_DIR}/set_env.sh | ||
|
|
||
| # 安装PTA 2.6.0版本GMM 切K轴补丁 | ||
| pip install /path/to/torch_npu-custom.whl --force-reinstall | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Q3的torch_npu已经支持切K轴了,建议改一下install路径,或者删掉
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 该pta包是适配自定义SliceNz融合算子的包,由于后续PTA主线可能不接受该算子,此处表明需要用户在训练环境中自行安装custom包来调用该接口 |
||
| cp ./pta_patch/_fsdp_collectives.py ${PTA_FSDP_DIR} | ||
|
|
||
| # 使能GMM NZ开关 | ||
| export GROUPMM_NZ_TRANSPOSE=1 | ||
|
|
||
| export QWEN3_MOE_PATH=/path/to/qwen3_moe_weights | ||
| export ALPACA_PATH=/path/to/alpaca_dataset | ||
|
|
||
| export XTUNER_USE_FA3="1" | ||
| export HCCL_RDMA_TC=132 | ||
|
|
||
|
|
||
| # 自定义, 1是开启,0是关闭 | ||
| export LINEAR_ONLY_SHARD=1 | ||
|
|
||
| mkdir ${LOGS_DIR} | ||
|
|
||
| torchrun --nproc-per-node 16 \ | ||
| --master_addr=$MASTER_ADDR \ | ||
| --master_port=$MASTER_PORT \ | ||
| --nnodes=$WORLD_SIZE \ | ||
| --node_rank=$RANK \ | ||
| ci/scripts/test_sft_trainer_235B.py \ | ||
| ${PROF_DIR} | tee ${LOGS_DIR}/rank_${RANK}.log | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里命名是否应该更具有人类可读性?