diff --git a/pta_patch/_fsdp_collectives.py b/pta_patch/_fsdp_collectives.py new file mode 100755 index 000000000..49d27ec84 --- /dev/null +++ b/pta_patch/_fsdp_collectives.py @@ -0,0 +1,91 @@ +from typing import List, Tuple +import os +import torch + + +lib = torch.library.Library("fsdp", "FRAGMENT") + + +@torch.library.impl(lib, "chunk_cat", "PrivateUse1") +def chunk_cat( + tensors: List[torch.Tensor], + dim: int, + num_chunks: int, + out: torch.Tensor, +) -> None: + tensors = [tensor.contiguous() for tensor in tensors] + out = out.contiguous() + torch._chunk_cat(tensors, dim, num_chunks, out=out) + + +@torch.library.impl(lib, "all_gather_copy_in", "PrivateUse1") +def all_gather_copy_in_npu( + all_gather_inputs: List[torch.Tensor], + inp_split_sizes: List[int], + all_gather_input_numel: int, + world_size: int, + rank: int, + dtype: torch.dtype, + device: torch.device, +) -> Tuple[torch.Tensor, torch.Tensor]: + all_gather_output = torch.empty( + (all_gather_input_numel * world_size,), dtype=dtype, device=device + ) + all_gather_input = all_gather_output.narrow( + 0, all_gather_input_numel * rank, all_gather_input_numel + ) + foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) + with torch.no_grad(): + if foreach_copy_dsts[0].device == all_gather_inputs[0].device: + torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs, non_blocking=True) + else: + torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs) + return all_gather_input, all_gather_output + +@torch.library.impl(lib, "split_with_sizes_copy", "PrivateUse1") +def split_with_sizes_copy( + all_gather_output: torch.Tensor, + all_gather_input_split_sizes: List[int], + dim: int, + out: List[torch.Tensor], + num_expert: int = 128, + hidden_size: int = 4096, + moe_intermediate_size: int = 1536, +) -> None: + # 当且仅当满足如下条件,才启用gmm_nz优化 + # 1. 打开GROUPMM_NZ_TRANSPOSE开关 + # 2. all_gather_input_split_sizes长度大于1 + # 3. 切分后的最后一个权重用于GMM down_proj + enable_gmm_nz = int(os.getenv("GROUPMM_NZ_TRANSPOSE","0")) \ + and len(all_gather_input_split_sizes) > 1 \ + and out[-1].shape[0] * out[-1].shape[1] == num_expert * hidden_size * moe_intermediate_size + + if enable_gmm_nz: + from special_op import npu_special_slice + num_rank = out[0].shape[0] + total_size = sum(all_gather_input_split_sizes) + + # 切分后最后两个权重用于GMM up_proj和down_proj + up_size = out[-1].shape[1] + down_size = out[-2].shape[1] + + up_start = total_size - up_size + down_start = up_start - down_size + + out[-1].resize_(num_expert,moe_intermediate_size,hidden_size) + out[-2].resize_(num_expert,hidden_size,moe_intermediate_size*2) + + # GMM权重切分和转NZ使用融合算子 + npu_special_slice(all_gather_output, dim, up_start, total_size, out[-1]) + npu_special_slice(all_gather_output, dim, down_start, up_start, out[-2]) + + other_tensors = all_gather_output[:, :down_start].view(num_rank, -1) + torch.split_with_sizes_copy( + other_tensors, all_gather_input_split_sizes[:-2], dim=dim, out=out[:-2] + ) + + return + + torch.split_with_sizes_copy( + all_gather_output, all_gather_input_split_sizes, dim=dim, out=out + ) diff --git a/special_op.py b/special_op.py new file mode 100644 index 000000000..2a33ee775 --- /dev/null +++ b/special_op.py @@ -0,0 +1 @@ +from torch_npu import npu_special_slice diff --git a/test_qwen3_235b_npu.sh b/test_qwen3_235b_npu.sh new file mode 100644 index 000000000..8c4bbf53b --- /dev/null +++ b/test_qwen3_235b_npu.sh @@ -0,0 +1,37 @@ +set -x + +CANN_DIR=/usr/local/Ascend/ascend-toolkit # 默认CANN安装地址 +CANN_OPS_DIR=/tmp/cann-ops # GMM_NZ使能补丁的CANN-ops安装地址 +PTA_FSDP_DIR=/usr/local/lib/python3.11/site-packages/torch_npu/distributed/fsdp # PTA的FSDP补丁位置,方便后续替换slice算子的实现 + +mkdir ${CANN_OPS_DIR} +./pta_patch/CANN-custom_ops--linux.aarch64.run --install-path=${CANN_OPS_DIR} +source ${CANN_OPS_DIR}/vendors/customize/bin/set_env.bash +source ${CANN_DIR}/set_env.sh + +# 安装PTA 2.6.0版本GMM 切K轴补丁 +pip install /path/to/torch_npu-custom.whl --force-reinstall +cp ./pta_patch/_fsdp_collectives.py ${PTA_FSDP_DIR} + +# 使能GMM NZ开关 +export GROUPMM_NZ_TRANSPOSE=1 + +export QWEN3_MOE_PATH=/path/to/qwen3_moe_weights +export ALPACA_PATH=/path/to/alpaca_dataset + +export XTUNER_USE_FA3="1" +export HCCL_RDMA_TC=132 + + +# 自定义, 1是开启,0是关闭 +export LINEAR_ONLY_SHARD=1 + +mkdir ${LOGS_DIR} + +torchrun --nproc-per-node 16 \ + --master_addr=$MASTER_ADDR \ + --master_port=$MASTER_PORT \ + --nnodes=$WORLD_SIZE \ + --node_rank=$RANK \ + ci/scripts/test_sft_trainer_235B.py \ + ${PROF_DIR} | tee ${LOGS_DIR}/rank_${RANK}.log diff --git a/xtuner/v1/model/base.py b/xtuner/v1/model/base.py index d558ce151..5b37a1ee1 100644 --- a/xtuner/v1/model/base.py +++ b/xtuner/v1/model/base.py @@ -39,7 +39,7 @@ from xtuner.v1.utils.loader import HFCheckpointLoader from .utils import ModelForwardExtraLogInfo - +import os logger = get_logger() @@ -229,6 +229,10 @@ def safetensors_to_params( start: int | None, end: int | None, dim: int | None, + flag: str | None, + num_expert: int = 128, + hidden_size: int = 4096, + moe_intermediate_size: int = 1536, ): if len(safetensors) > 1: assert dim is not None, "Internal Error dim must not be None when len(safetensors) > 1" @@ -236,12 +240,23 @@ def safetensors_to_params( else: loaded_tensor = safetensors[0] + if int(os.getenv("GROUPMM_NZ_TRANSPOSE","0")) == 1 and flag == 'fused': + out_feature, in_feature = safetensors[0].shape + out_feature = out_feature * 2 if out_feature == moe_intermediate_size else out_feature + loaded_tensor = loaded_tensor.view(-1, out_feature, in_feature).transpose(1,2).contiguous().view(-1, out_feature).contiguous() + if start is not None and end is not None: assert self.fsdp_config is not None, ( "Internal Error. fsdp_config must not be None when start and end is not None" ) start = min(start, loaded_tensor.shape[self.FSDP_SHARD_DIM]) end = min(end, loaded_tensor.shape[self.FSDP_SHARD_DIM]) + + if int(os.getenv("GROUPMM_NZ_TRANSPOSE","0")) == 1 and torch.distributed.get_world_size() >= num_expert and loaded_tensor.shape[self.FSDP_SHARD_DIM] == hidden_size: + if torch.distributed.get_rank() % 4 >= 2: + start += hidden_size // 2 + end += hidden_size // 2 + loaded_tensor_slice = loaded_tensor.index_select( dim=self.FSDP_SHARD_DIM, index=torch.arange(start, end, dtype=torch.int64, device=loaded_tensor.device) ) @@ -956,12 +971,12 @@ def _load_same_hf_param( end = None self.safetensors_to_params( - [loaded_tensor], local_tensor, param_name=load_spec.name, start=start, end=end, dim=load_spec.dim + [loaded_tensor], local_tensor, param_name=load_spec.name, start=start, end=end, dim=load_spec.dim, flag='same', ) return [] def _load_fused_hf_param( - self, param: torch.Tensor, load_spec: LoadSpec, checkpoint_loader: HFCheckpointLoader + self, param: torch.Tensor, load_spec: LoadSpec, checkpoint_loader: HFCheckpointLoader, num_expert: int = 128 ) -> list[str]: # For expert parallel # NOTE: @@ -1004,6 +1019,10 @@ def _load_fused_hf_param( hf_keys_start = int(fsdp_start / hf_key_size) hf_keys_end = math.ceil(fsdp_end / hf_key_size) + if int(os.getenv("GROUPMM_NZ_TRANSPOSE","0")) == 1 and len(hf_keys) == num_expert *2 and torch.distributed.get_world_size() >= num_expert: # gate & up的情况,down的情况需要排除 + hf_keys_start = int(torch.distributed.get_rank() // 4) * 2 + hf_keys_end = hf_keys_start + 2 + # Empty pad by fsdp if hf_keys_start == hf_keys_end: return [] @@ -1038,7 +1057,7 @@ def _load_fused_hf_param( return missing_keys self.safetensors_to_params( - _loaded_tensor, local_tensor, param_name=load_spec.name, start=start, end=end, dim=load_spec.dim + _loaded_tensor, local_tensor, param_name=load_spec.name, start=start, end=end, dim=load_spec.dim, flag='fused', ) return missing_keys @@ -1084,6 +1103,7 @@ def _load_shard_hf_param( start=start, end=end, dim=load_spec.dim, + flag='shard', ) return [] diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index 41b35bc3d..449e51b14 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -47,7 +47,12 @@ ) from xtuner.v1.utils.activation_offload import async_save_on_cpu from xtuner.v1.utils.compile import maybe_compile +import torch_npu +import os +if int(os.getenv("LINEAR_ONLY_SHARD", "0")) == 1: + from xtuner.v1.patch.fsdp_partial_shard import apply_fsdp_partial_shard_patch + apply_fsdp_partial_shard_patch() DEVICE = get_device() logger = get_logger() @@ -167,7 +172,7 @@ def __init__(self, config: MoEConfig): else: self.z_loss = None - self.offload_stream = torch.cuda.Stream() + self.offload_stream = torch_npu.npu.Stream() #torch.cuda.Stream() def _select_non_pad_router_logits( self, @@ -688,6 +693,16 @@ def fully_shard( ) num_recompute_layers = int(self.config.num_hidden_layers * self.fsdp_config.recompute_ratio) + def layer_to_fully_shard(layer): + return [ + layer.self_attn.q_proj, + layer.self_attn.k_proj, + layer.self_attn.v_proj, + layer.self_attn.o_proj, + layer.experts.fused_w1w3, + layer.experts.fused_w2, + ] + for layer_idx, layer in tqdm(self.layers.items(), desc="[FSDP Sharding]"): layer_idx = int(layer_idx) if layer_idx < num_recompute_layers - 1: @@ -698,19 +713,26 @@ def fully_shard( reshard_after_forward = False else: reshard_after_forward = self.fsdp_config.reshard_after_forward + is_linear_only_shard = int(os.getenv("LINEAR_ONLY_SHARD", "0")) == 1 fully_shard( - layer, + layer_to_fully_shard(layer) if is_linear_only_shard else layer, mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, mp_policy=mp_policy, reshard_after_forward=reshard_after_forward, offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, + **({'hook_module': layer} if is_linear_only_shard else {}), ) for layer_cur, layer_next in zip( list(self.layers.values())[:-1], list(self.layers.values())[1:], ): - layer_cur.set_modules_to_forward_prefetch([layer_next]) # type: ignore + if not is_linear_only_shard: + layer_cur.set_modules_to_forward_prefetch([layer_next]) # type: ignore + else: + layer_cur_fully_shard_modules = layer_to_fully_shard(layer_cur) + layer_next_fully_shard_modules = layer_to_fully_shard(layer_next) + layer_cur_fully_shard_modules[0].set_modules_to_forward_prefetch([layer_next_fully_shard_modules[0]]) fully_shard( self.embed_tokens, diff --git a/xtuner/v1/module/grouped_linear/moe_group_linear.py b/xtuner/v1/module/grouped_linear/moe_group_linear.py index d23331d94..93901683e 100644 --- a/xtuner/v1/module/grouped_linear/moe_group_linear.py +++ b/xtuner/v1/module/grouped_linear/moe_group_linear.py @@ -7,6 +7,9 @@ from xtuner.v1.float8.float8_gmm_tile_wise import TileWiseFloat8GroupedLinear from xtuner.v1.ops import group_gemm +from torch.autograd import Function +import torch_npu +import os class GroupedLinear(nn.Module): # TODO:Missng example docs @@ -22,7 +25,11 @@ def __init__( self.in_features = in_features self.out_features = out_features self.num_routed_experts = num_routed_experts - weight = torch.empty(num_routed_experts * out_features, in_features) + # 添加训练NZ开关,注意该开关只在npu上有性能增益 + if int(os.getenv("GROUPMM_NZ_TRANSPOSE","0")): + weight = torch.empty(num_routed_experts * in_features, out_features) + else: + weight = torch.empty(num_routed_experts * out_features, in_features) self.ep_mesh = ep_mesh if self.ep_mesh is not None and self.ep_mesh.size() > 1: @@ -40,8 +47,13 @@ def __init__( def forward(self, x: torch.Tensor, tokens_per_expert: torch.Tensor, decoding: bool = False): weight = self.weight.to_local() if isinstance(self.weight, DTensor) else self.weight - weight = weight.view(-1, self.out_features, self.in_features) - out = group_gemm(x, weight, tokens_per_expert) + + if int(os.getenv("GROUPMM_NZ_TRANSPOSE","0")): + weight = weight.view(-1, self.in_features, self.out_features) + out = NpuGMMOp.apply(weight, x, tokens_per_expert) + else: + weight = weight.view(-1, self.out_features, self.in_features) + out = group_gemm(x, weight, tokens_per_expert) if self.moe_bias: bias = self.bias.to_local() if isinstance(self.bias, DTensor) else self.bias @@ -49,6 +61,33 @@ def forward(self, x: torch.Tensor, tokens_per_expert: torch.Tensor, decoding: bo return out +class NpuGMMOp(Function): + @staticmethod + def forward(ctx, weight, x, tokens_per_expert): + if not int(os.getenv("GROUPMM_NZ_TRANSPOSE","0")): + weight = torch.transpose(weight, 1, 2) + ctx.save_for_backward(weight, x, tokens_per_expert) + outs = torch_npu.npu_grouped_matmul([x], [weight], group_list = tokens_per_expert, group_type = 0, group_list_type = 1, split_item = 2) + return outs[0] + + + @staticmethod + def backward(ctx, grad_output): + tensors = ctx.saved_tensors + weight = tensors[0] + input_tensor = tensors[1] + tokens_per_expert = tensors[2] + weight = torch.transpose(weight, 1, 2) + grad_input = torch_npu.npu_grouped_matmul([grad_output], [weight], group_list = tokens_per_expert, + group_type = 0, group_list_type = 1, split_item=2)[0] + grad_weight = torch_npu.npu_grouped_matmul([input_tensor.T], [grad_output], bias=None, group_list = tokens_per_expert, + split_item=3, group_type=2, group_list_type=1)[0] + if not int(os.getenv("GROUPMM_NZ_TRANSPOSE","0")): + grad_weight = torch.transpose(grad_weight, 1, 2) + return grad_weight, grad_input, None + + + def build_grouped_linear( in_features: int, out_features: int, diff --git a/xtuner/v1/patch/fsdp_partial_shard.py b/xtuner/v1/patch/fsdp_partial_shard.py new file mode 100755 index 000000000..a462c4656 --- /dev/null +++ b/xtuner/v1/patch/fsdp_partial_shard.py @@ -0,0 +1,148 @@ +from typing import ( + Callable, + List, + Optional, + Union, + Tuple +) + +import torch +import torch.nn as nn + +from torch.distributed._composable import contract +from torch.distributed._composable_state import _insert_module_state + +from torch.distributed.tensor import DeviceMesh, Shard +from torch.distributed.utils import _get_root_modules +from torch.distributed.device_mesh import _get_device_handle + +from torch.distributed.fsdp._fully_shard._fsdp_api import MixedPrecisionPolicy, OffloadPolicy +from torch.distributed.fsdp._fully_shard._fsdp_common import FSDPMeshInfo, HSDPMeshInfo +from torch.distributed.fsdp._fully_shard._fsdp_init import ( + _get_device_from_mesh, + _get_managed_modules, + _get_managed_states, + _get_post_forward_mesh_info, + _init_default_fully_shard_mesh, + _move_states_to_device, +) +from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup +from torch.distributed.fsdp._fully_shard._fsdp_state import FSDPState +from torch.distributed.fsdp._fully_shard._fully_shard import ( + cls_to_fsdp_cls, + _unimplemented_deepcopy, + FSDPModule, +) + +@contract(state_cls=FSDPState) # type: ignore[operator] +def fully_shard( + module: Union[nn.Module, List[nn.Module]], + *, + hook_module: nn.Module = None, + mesh: Optional[DeviceMesh] = None, + reshard_after_forward: Union[bool, int] = True, + shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]] = None, + mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(), + offload_policy: OffloadPolicy = OffloadPolicy(), +): + if isinstance(module, (nn.ModuleList, nn.ModuleDict)): + raise ValueError( + f"fully_shard does not support containers that do not implement forward: {module}" + ) + mesh = mesh or _init_default_fully_shard_mesh() + if mesh.ndim not in (1, 2): + raise ValueError(f"fully_shard expects a 1D or 2D DeviceMesh but got {mesh}") + elif mesh.ndim == 1: + mesh_info = FSDPMeshInfo(mesh, shard_mesh_dim=0) + else: + if mesh.mesh_dim_names is None: + raise AssertionError( + "Please init the 2D mesh for HSDP with mesh_dim_names specified" + ) + mesh_info = HSDPMeshInfo(mesh, shard_mesh_dim=1, replicate_mesh_dim=0) + device = _get_device_from_mesh(mesh) + post_forward_mesh_info = _get_post_forward_mesh_info( + reshard_after_forward, mesh_info + ) + + arg_module = module + modules = ( + (module,) if isinstance(module, nn.Module) else tuple(_get_root_modules(module)) + ) + state = fully_shard.state(modules[0]) + state.init(modules, device, mp_policy, hook_module=hook_module) + + managed_modules = _get_managed_modules(modules) + params, buffers = _get_managed_states(managed_modules) + _move_states_to_device(params, buffers, device) + if params: + state._fsdp_param_group = FSDPParamGroup( + params, + modules, + mesh_info, + post_forward_mesh_info, + device, + shard_placement_fn, + mp_policy, + offload_policy, + ) + + # For Dynamo + for managed_module in managed_modules: + managed_module._is_fsdp_managed_module = True # type: ignore[assignment] + managed_module._fsdp_use_orig_params = True # type: ignore[assignment] + + # Place FSDP leftmost for highest priority in the method resolution order + for module in modules: + cls = module.__class__ + new_cls = cls_to_fsdp_cls.get(cls, None) + if not new_cls: + dct = {"__deepcopy__": _unimplemented_deepcopy} + new_cls = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct) + cls_to_fsdp_cls[cls] = new_cls + module.__class__ = new_cls + return arg_module + + +def fsdp_state_init( + self, + modules: Tuple[nn.Module, ...], + device: torch.device, + mp_policy: MixedPrecisionPolicy, + hook_module: nn.Module = None, +) -> None: + for module in modules: + _insert_module_state(module, self) + self._modules = modules + self._device = device + self._device_handle = _get_device_handle(device.type) + self._mp_policy = mp_policy + + + if len(modules) == 1: + self._pre_forward_hook_handle = modules[0].register_forward_pre_hook( + self._pre_forward, prepend=True, with_kwargs=True + ) + self._post_forward_hook_handle = modules[0].register_forward_hook( + self._post_forward, prepend=False + ) + elif hook_module is not None: + self._pre_forward_hook_handle = hook_module.register_forward_pre_hook( + self._pre_forward, prepend=True, with_kwargs=True + ) + self._post_forward_hook_handle = hook_module.register_forward_hook( + self._post_forward, prepend=False + ) + else: + hook_handle = _register_group_forward_hooks( + modules, + self._pre_forward, + self._post_forward, + self._modules_to_run_forward, + ) + self._pre_forward_hook_handle = hook_handle + self._post_forward_hook_handle = hook_handle + +def apply_fsdp_partial_shard_patch(): + torch.distributed.fsdp.fully_shard = fully_shard + torch.distributed.fsdp._fully_shard._fsdp_state.FSDPState.init = fsdp_state_init \ No newline at end of file