Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions pta_patch/_fsdp_collectives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import List, Tuple
import os
import torch


lib = torch.library.Library("fsdp", "FRAGMENT")


@torch.library.impl(lib, "chunk_cat", "PrivateUse1")
def chunk_cat(
tensors: List[torch.Tensor],
dim: int,
num_chunks: int,
out: torch.Tensor,
) -> None:
tensors = [tensor.contiguous() for tensor in tensors]
out = out.contiguous()
torch._chunk_cat(tensors, dim, num_chunks, out=out)


@torch.library.impl(lib, "all_gather_copy_in", "PrivateUse1")
def all_gather_copy_in_npu(
all_gather_inputs: List[torch.Tensor],
inp_split_sizes: List[int],
all_gather_input_numel: int,
world_size: int,
rank: int,
dtype: torch.dtype,
device: torch.device,
) -> Tuple[torch.Tensor, torch.Tensor]:
all_gather_output = torch.empty(
(all_gather_input_numel * world_size,), dtype=dtype, device=device
)
all_gather_input = all_gather_output.narrow(
0, all_gather_input_numel * rank, all_gather_input_numel
)
foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes)
with torch.no_grad():
if foreach_copy_dsts[0].device == all_gather_inputs[0].device:
torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs, non_blocking=True)
else:
torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs)
return all_gather_input, all_gather_output

@torch.library.impl(lib, "split_with_sizes_copy", "PrivateUse1")
def split_with_sizes_copy(
all_gather_output: torch.Tensor,
all_gather_input_split_sizes: List[int],
dim: int,
out: List[torch.Tensor],
num_expert: int = 128,
hidden_size: int = 4096,
moe_intermediate_size: int = 1536,
) -> None:
# 当且仅当满足如下条件,才启用gmm_nz优化
# 1. 打开GROUPMM_NZ_TRANSPOSE开关
# 2. all_gather_input_split_sizes长度大于1
# 3. 切分后的最后一个权重用于GMM down_proj
enable_gmm_nz = int(os.getenv("GROUPMM_NZ_TRANSPOSE","0")) \
and len(all_gather_input_split_sizes) > 1 \
and out[-1].shape[0] * out[-1].shape[1] == num_expert * hidden_size * moe_intermediate_size

if enable_gmm_nz:
from special_op import npu_special_slice
num_rank = out[0].shape[0]
total_size = sum(all_gather_input_split_sizes)

# 切分后最后两个权重用于GMM up_proj和down_proj
up_size = out[-1].shape[1]
down_size = out[-2].shape[1]

up_start = total_size - up_size
down_start = up_start - down_size

out[-1].resize_(num_expert,moe_intermediate_size,hidden_size)
out[-2].resize_(num_expert,hidden_size,moe_intermediate_size*2)

# GMM权重切分和转NZ使用融合算子
npu_special_slice(all_gather_output, dim, up_start, total_size, out[-1])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里命名是否应该更具有人类可读性?

npu_special_slice(all_gather_output, dim, down_start, up_start, out[-2])

other_tensors = all_gather_output[:, :down_start].view(num_rank, -1)
torch.split_with_sizes_copy(
other_tensors, all_gather_input_split_sizes[:-2], dim=dim, out=out[:-2]
)

return

torch.split_with_sizes_copy(
all_gather_output, all_gather_input_split_sizes, dim=dim, out=out
)
1 change: 1 addition & 0 deletions special_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from torch_npu import npu_special_slice

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件是否必要?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

必要的,需要对PTA内部copy_out实现做补丁时调用npu_special_slice的接口,这个接口本身也适配在PTA里。由于不能自调用,所以在外部做了一个import动作

37 changes: 37 additions & 0 deletions test_qwen3_235b_npu.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
set -x

CANN_DIR=/usr/local/Ascend/ascend-toolkit # 默认CANN安装地址
CANN_OPS_DIR=/tmp/cann-ops # GMM_NZ使能补丁的CANN-ops安装地址
PTA_FSDP_DIR=/usr/local/lib/python3.11/site-packages/torch_npu/distributed/fsdp # PTA的FSDP补丁位置,方便后续替换slice算子的实现

mkdir ${CANN_OPS_DIR}
./pta_patch/CANN-custom_ops--linux.aarch64.run --install-path=${CANN_OPS_DIR}
source ${CANN_OPS_DIR}/vendors/customize/bin/set_env.bash
source ${CANN_DIR}/set_env.sh

# 安装PTA 2.6.0版本GMM 切K轴补丁
pip install /path/to/torch_npu-custom.whl --force-reinstall

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q3的torch_npu已经支持切K轴了,建议改一下install路径,或者删掉

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该pta包是适配自定义SliceNz融合算子的包,由于后续PTA主线可能不接受该算子,此处表明需要用户在训练环境中自行安装custom包来调用该接口

cp ./pta_patch/_fsdp_collectives.py ${PTA_FSDP_DIR}

# 使能GMM NZ开关
export GROUPMM_NZ_TRANSPOSE=1

export QWEN3_MOE_PATH=/path/to/qwen3_moe_weights
export ALPACA_PATH=/path/to/alpaca_dataset

export XTUNER_USE_FA3="1"
export HCCL_RDMA_TC=132


# 自定义, 1是开启,0是关闭
export LINEAR_ONLY_SHARD=1

mkdir ${LOGS_DIR}

torchrun --nproc-per-node 16 \
--master_addr=$MASTER_ADDR \
--master_port=$MASTER_PORT \
--nnodes=$WORLD_SIZE \
--node_rank=$RANK \
ci/scripts/test_sft_trainer_235B.py \
${PROF_DIR} | tee ${LOGS_DIR}/rank_${RANK}.log
28 changes: 24 additions & 4 deletions xtuner/v1/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from xtuner.v1.utils.loader import HFCheckpointLoader

from .utils import ModelForwardExtraLogInfo

import os

logger = get_logger()

Expand Down Expand Up @@ -229,19 +229,34 @@ def safetensors_to_params(
start: int | None,
end: int | None,
dim: int | None,
flag: str | None,
num_expert: int = 128,
hidden_size: int = 4096,
moe_intermediate_size: int = 1536,
):
if len(safetensors) > 1:
assert dim is not None, "Internal Error dim must not be None when len(safetensors) > 1"
loaded_tensor = torch.cat(safetensors, dim=dim)
else:
loaded_tensor = safetensors[0]

if int(os.getenv("GROUPMM_NZ_TRANSPOSE","0")) == 1 and flag == 'fused':
out_feature, in_feature = safetensors[0].shape
out_feature = out_feature * 2 if out_feature == moe_intermediate_size else out_feature
loaded_tensor = loaded_tensor.view(-1, out_feature, in_feature).transpose(1,2).contiguous().view(-1, out_feature).contiguous()

if start is not None and end is not None:
assert self.fsdp_config is not None, (
"Internal Error. fsdp_config must not be None when start and end is not None"
)
start = min(start, loaded_tensor.shape[self.FSDP_SHARD_DIM])
end = min(end, loaded_tensor.shape[self.FSDP_SHARD_DIM])

if int(os.getenv("GROUPMM_NZ_TRANSPOSE","0")) == 1 and torch.distributed.get_world_size() >= num_expert and loaded_tensor.shape[self.FSDP_SHARD_DIM] == hidden_size:
if torch.distributed.get_rank() % 4 >= 2:
start += hidden_size // 2
end += hidden_size // 2

loaded_tensor_slice = loaded_tensor.index_select(
dim=self.FSDP_SHARD_DIM, index=torch.arange(start, end, dtype=torch.int64, device=loaded_tensor.device)
)
Expand Down Expand Up @@ -956,12 +971,12 @@ def _load_same_hf_param(
end = None

self.safetensors_to_params(
[loaded_tensor], local_tensor, param_name=load_spec.name, start=start, end=end, dim=load_spec.dim
[loaded_tensor], local_tensor, param_name=load_spec.name, start=start, end=end, dim=load_spec.dim, flag='same',
)
return []

def _load_fused_hf_param(
self, param: torch.Tensor, load_spec: LoadSpec, checkpoint_loader: HFCheckpointLoader
self, param: torch.Tensor, load_spec: LoadSpec, checkpoint_loader: HFCheckpointLoader, num_expert: int = 128
) -> list[str]:
# For expert parallel
# NOTE:
Expand Down Expand Up @@ -1004,6 +1019,10 @@ def _load_fused_hf_param(
hf_keys_start = int(fsdp_start / hf_key_size)
hf_keys_end = math.ceil(fsdp_end / hf_key_size)

if int(os.getenv("GROUPMM_NZ_TRANSPOSE","0")) == 1 and len(hf_keys) == num_expert *2 and torch.distributed.get_world_size() >= num_expert: # gate & up的情况,down的情况需要排除
hf_keys_start = int(torch.distributed.get_rank() // 4) * 2
hf_keys_end = hf_keys_start + 2

# Empty pad by fsdp
if hf_keys_start == hf_keys_end:
return []
Expand Down Expand Up @@ -1038,7 +1057,7 @@ def _load_fused_hf_param(
return missing_keys

self.safetensors_to_params(
_loaded_tensor, local_tensor, param_name=load_spec.name, start=start, end=end, dim=load_spec.dim
_loaded_tensor, local_tensor, param_name=load_spec.name, start=start, end=end, dim=load_spec.dim, flag='fused',
)
return missing_keys

Expand Down Expand Up @@ -1084,6 +1103,7 @@ def _load_shard_hf_param(
start=start,
end=end,
dim=load_spec.dim,
flag='shard',
)
return []

Expand Down
28 changes: 25 additions & 3 deletions xtuner/v1/model/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,12 @@
)
from xtuner.v1.utils.activation_offload import async_save_on_cpu
from xtuner.v1.utils.compile import maybe_compile
import torch_npu
import os

if int(os.getenv("LINEAR_ONLY_SHARD", "0")) == 1:
from xtuner.v1.patch.fsdp_partial_shard import apply_fsdp_partial_shard_patch
apply_fsdp_partial_shard_patch()

DEVICE = get_device()
logger = get_logger()
Expand Down Expand Up @@ -167,7 +172,7 @@ def __init__(self, config: MoEConfig):
else:
self.z_loss = None

self.offload_stream = torch.cuda.Stream()
self.offload_stream = torch_npu.npu.Stream() #torch.cuda.Stream()

def _select_non_pad_router_logits(
self,
Expand Down Expand Up @@ -688,6 +693,16 @@ def fully_shard(
)
num_recompute_layers = int(self.config.num_hidden_layers * self.fsdp_config.recompute_ratio)

def layer_to_fully_shard(layer):
return [
layer.self_attn.q_proj,
layer.self_attn.k_proj,
layer.self_attn.v_proj,
layer.self_attn.o_proj,
layer.experts.fused_w1w3,
layer.experts.fused_w2,
]

for layer_idx, layer in tqdm(self.layers.items(), desc="[FSDP Sharding]"):
layer_idx = int(layer_idx)
if layer_idx < num_recompute_layers - 1:
Expand All @@ -698,19 +713,26 @@ def fully_shard(
reshard_after_forward = False
else:
reshard_after_forward = self.fsdp_config.reshard_after_forward
is_linear_only_shard = int(os.getenv("LINEAR_ONLY_SHARD", "0")) == 1
fully_shard(
layer,
layer_to_fully_shard(layer) if is_linear_only_shard else layer,
mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh,
mp_policy=mp_policy,
reshard_after_forward=reshard_after_forward,
offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None,
**({'hook_module': layer} if is_linear_only_shard else {}),
)

for layer_cur, layer_next in zip(
list(self.layers.values())[:-1],
list(self.layers.values())[1:],
):
layer_cur.set_modules_to_forward_prefetch([layer_next]) # type: ignore
if not is_linear_only_shard:
layer_cur.set_modules_to_forward_prefetch([layer_next]) # type: ignore
else:
layer_cur_fully_shard_modules = layer_to_fully_shard(layer_cur)
layer_next_fully_shard_modules = layer_to_fully_shard(layer_next)
layer_cur_fully_shard_modules[0].set_modules_to_forward_prefetch([layer_next_fully_shard_modules[0]])

fully_shard(
self.embed_tokens,
Expand Down
45 changes: 42 additions & 3 deletions xtuner/v1/module/grouped_linear/moe_group_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from xtuner.v1.float8.float8_gmm_tile_wise import TileWiseFloat8GroupedLinear
from xtuner.v1.ops import group_gemm

from torch.autograd import Function
import torch_npu
import os

class GroupedLinear(nn.Module):
# TODO:Missng example docs
Expand All @@ -22,7 +25,11 @@ def __init__(
self.in_features = in_features
self.out_features = out_features
self.num_routed_experts = num_routed_experts
weight = torch.empty(num_routed_experts * out_features, in_features)
# 添加训练NZ开关,注意该开关只在npu上有性能增益
if int(os.getenv("GROUPMM_NZ_TRANSPOSE","0")):
weight = torch.empty(num_routed_experts * in_features, out_features)
else:
weight = torch.empty(num_routed_experts * out_features, in_features)

self.ep_mesh = ep_mesh
if self.ep_mesh is not None and self.ep_mesh.size() > 1:
Expand All @@ -40,15 +47,47 @@ def __init__(

def forward(self, x: torch.Tensor, tokens_per_expert: torch.Tensor, decoding: bool = False):
weight = self.weight.to_local() if isinstance(self.weight, DTensor) else self.weight
weight = weight.view(-1, self.out_features, self.in_features)
out = group_gemm(x, weight, tokens_per_expert)

if int(os.getenv("GROUPMM_NZ_TRANSPOSE","0")):
weight = weight.view(-1, self.in_features, self.out_features)
out = NpuGMMOp.apply(weight, x, tokens_per_expert)
else:
weight = weight.view(-1, self.out_features, self.in_features)
out = group_gemm(x, weight, tokens_per_expert)

if self.moe_bias:
bias = self.bias.to_local() if isinstance(self.bias, DTensor) else self.bias
out = out + bias.repeat_interleave(tokens_per_expert, dim=0) # TODO: 无法 compile
return out


class NpuGMMOp(Function):
@staticmethod
def forward(ctx, weight, x, tokens_per_expert):
if not int(os.getenv("GROUPMM_NZ_TRANSPOSE","0")):
weight = torch.transpose(weight, 1, 2)
ctx.save_for_backward(weight, x, tokens_per_expert)
outs = torch_npu.npu_grouped_matmul([x], [weight], group_list = tokens_per_expert, group_type = 0, group_list_type = 1, split_item = 2)
return outs[0]


@staticmethod
def backward(ctx, grad_output):
tensors = ctx.saved_tensors
weight = tensors[0]
input_tensor = tensors[1]
tokens_per_expert = tensors[2]
weight = torch.transpose(weight, 1, 2)
grad_input = torch_npu.npu_grouped_matmul([grad_output], [weight], group_list = tokens_per_expert,
group_type = 0, group_list_type = 1, split_item=2)[0]
grad_weight = torch_npu.npu_grouped_matmul([input_tensor.T], [grad_output], bias=None, group_list = tokens_per_expert,
split_item=3, group_type=2, group_list_type=1)[0]
if not int(os.getenv("GROUPMM_NZ_TRANSPOSE","0")):
grad_weight = torch.transpose(grad_weight, 1, 2)
return grad_weight, grad_input, None



def build_grouped_linear(
in_features: int,
out_features: int,
Expand Down
Loading