diff --git a/deepspeed/moe/experts.py b/deepspeed/moe/experts.py index 8cadb0c387fa..e3f7a794d308 100644 --- a/deepspeed/moe/experts.py +++ b/deepspeed/moe/experts.py @@ -33,3 +33,7 @@ def forward(self, inputs): expert_output = torch.cat(expert_outputs, dim=1) return expert_output + + def forward_single(self, idx, inputs): + expert_output = self.deepspeed_experts[idx](inputs) + return expert_output \ No newline at end of file diff --git a/deepspeed/moe/layer.py b/deepspeed/moe/layer.py index 89fe2bb46c3c..7d6bdc6963e9 100644 --- a/deepspeed/moe/layer.py +++ b/deepspeed/moe/layer.py @@ -12,6 +12,77 @@ from .experts import Experts import typing +from fmoe import FMoE +import fmoe +import copy +class VitFMoE(FMoE): + def __init__( + self, expert, d_model=1, top_k=1, num_local_experts=1, expert_group_name=None, moe_group=None, gate_kwargs={} + ): + assert expert_group_name is not None, 'expert_group_name should be provided' + world_size = torch.distributed.get_world_size(moe_group) + super().__init__( + num_expert=num_local_experts, + d_model=d_model, + world_size=world_size, + mp_group=None, + top_k=top_k, + moe_group=moe_group, + gate = fmoe.gates.GShardGate, + gate_kwargs = gate_kwargs + ) + + self.experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)]) + self.experts_fused = False + self.num_experts = num_local_experts + self.num_local_experts = num_local_experts + + for expert in self.experts: + for name, param in expert.named_parameters(): + param.allreduce = False + param.group_name = expert_group_name + + def expert_fn(self, inp, fwd_expert_count): + r""" + The default expert function which either calls the experts as a whole + or as separate experts. + """ + if self.experts_fused: + return self.experts(inp, fwd_expert_count) + if isinstance(fwd_expert_count, torch.Tensor): + fwd_expert_count_cpu = fwd_expert_count.cpu().numpy() + outputs = [] + base_idx = 0 + for i in range(self.num_experts): + batch_size = fwd_expert_count_cpu[i] + inp_slice = inp[base_idx : base_idx + batch_size] + # outputs.append(self.experts[i](inp_slice, torch.tensor([fwd_expert_count[i]]))) + outputs.append(self.experts[i](inp_slice)) + # outputs.append(self.experts.forward_single(i, inp_slice)) + base_idx += batch_size + return torch.cat(outputs, dim=0) + + def expert_fn_single(self, inp, fwd_expert_count, idx): + r""" + forward single expert for smart scheduling. + """ + assert not self.experts_fused, "should not use fused experts" + output = self.experts[idx](inp) + # output = self.experts.forward_single(idx, inp) + return output + + def _set_ep_group(self, group): + self.moe_group = group + + def forward(self, inp: torch.Tensor): + r""" + This module wraps up the FMoE module with reshape, residual and layer + normalization. + """ + original_shape = inp.shape + inp = inp.reshape(-1, self.d_model) + output = super().forward(inp) + return output.reshape(original_shape) class MoE(torch.nn.Module): """Initialize an MoE layer. @@ -47,7 +118,8 @@ def __init__(self, drop_tokens: bool = True, use_rts=True, use_tutel: bool = False, - enable_expert_tensor_parallelism: bool = False): + enable_expert_tensor_parallelism: bool = False, + use_fmoe: bool = False): super(MoE, self).__init__() @@ -58,6 +130,7 @@ def __init__(self, self.expert_group_name = f"ep_size_{self.ep_size}" self.num_experts = num_experts self.num_local_experts = num_experts // self.ep_size + self.use_fmoe = use_fmoe log_dist( f'Creating MoE layer with num_experts: {num_experts} | num_local_experts: {self.num_local_experts} | expert_parallel_size: {self.ep_size}', @@ -67,13 +140,23 @@ def __init__(self, 'Unsupported noisy_gate_policy: ' + noisy_gate_policy experts = Experts(expert, self.num_local_experts, self.expert_group_name) - self.deepspeed_moe = MOELayer(TopKGate(hidden_size, num_experts, k, capacity_factor, eval_capacity_factor, - min_capacity, noisy_gate_policy, drop_tokens, use_rts), - experts, - self.expert_group_name, - self.ep_size, - self.num_local_experts, - use_tutel=use_tutel) + if not use_fmoe: + self.deepspeed_moe = MOELayer(TopKGate(hidden_size, num_experts, k, capacity_factor, eval_capacity_factor, + min_capacity, noisy_gate_policy, drop_tokens, use_rts), + experts, + self.expert_group_name, + self.ep_size, + self.num_local_experts, + use_tutel=use_tutel) + else: + # Note: need to setup groups for moe before creating MOE layers + ep_group = groups._get_expert_parallel_group(self.expert_group_name) + self.deepspeed_moe = VitFMoE(experts, d_model=hidden_size, top_k=k, + num_local_experts=self.num_local_experts, + moe_group = ep_group, + expert_group_name = self.expert_group_name, + gate_kwargs = {'capacity':(capacity_factor, eval_capacity_factor)}) + if self.use_residual: self.mlp = expert # coefficient is used for weighted sum of the output of expert and mlp @@ -112,7 +195,12 @@ def forward(self, hidden_states, used_token=None): * exp_counts (int): expert count """ - output = self.deepspeed_moe(hidden_states, used_token) + if self.use_fmoe: + output = self.deepspeed_moe(hidden_states) + else: + # import pdb;pdb.set_trace() + output = self.deepspeed_moe(hidden_states, used_token) + if self.use_residual: # Residual MoE output_mlp = self.mlp(hidden_states) @@ -121,4 +209,7 @@ def forward(self, hidden_states, used_token=None): coef = self.coefficient(hidden_states) coef = torch.nn.functional.softmax(coef, dim=-1) output = output * coef[..., 0:1] + output_mlp * coef[..., 1:] - return output, self.deepspeed_moe.l_aux, self.deepspeed_moe.exp_counts + if self.use_fmoe: + return output#,None,None # todo: fix this + else: + return output#, self.deepspeed_moe.l_aux, self.deepspeed_moe.exp_counts diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index e953938c06a4..82d7c2c01ed1 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -96,6 +96,8 @@ from deepspeed.runtime.config import DtypeEnum +from fmoe import FMoE + # Set to torch's distributed package or deepspeed.comm based inside DeepSpeedEngine init dist = None @@ -113,6 +115,7 @@ # Fail silently so we don't spam logs unnecessarily if user isn't using amp APEX_INSTALLED = False +MOE_PARAM_PREFIX = '.deepspeed_moe.experts.' def split_half_float_double_sparse(tensors): device_type = get_accelerator().device_name() @@ -223,8 +226,8 @@ def __init__( self.gas_boundary_ctr = 0 self.dist_backend = get_accelerator().communication_backend_name() self.has_moe_layers = False - self.num_experts = [] - self.gate_modules = [] + self.num_experts = [] # for load and save checkpoint + self.gate_modules = [] # for time profile and print self.moe_layers = [] self._step_applied = False self._global_grad_norm = None @@ -1052,7 +1055,7 @@ def _configure_distributed_model(self, model): # MoE related initialization for _, module in self.module.named_modules(): - if isinstance(module, MoE): + if isinstance(module, MoE):# or isinstance(module, FMoE): self.has_moe_layers = True self.num_experts.append(module.num_experts) @@ -2384,7 +2387,7 @@ def load_moe_state_dict(checkpoint_path, map_location=torch.device('cpu')) # Updating global -> local expert ids - moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.' + moe_str_prefix = MOE_PARAM_PREFIX #'.deepspeed_moe.experts.deepspeed_experts.' for key in list(expert_state_dict.keys()): local_key = key.replace(f'{moe_str_prefix}{global_expert_id}', f'{moe_str_prefix}{local_expert_id}') @@ -2406,7 +2409,7 @@ def load_moe_state_dict(checkpoint_path, map_location=torch.device('cpu')) # print(expert_state_dict.keys()) # Updating global -> local expert ids - moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.' + moe_str_prefix = MOE_PARAM_PREFIX #'.deepspeed_moe.experts.deepspeed_experts.' for key in list(expert_state_dict.keys()): local_key = key.replace(f'{moe_str_prefix}{global_expert_id}', f'{moe_str_prefix}{local_expert_id}') @@ -2868,7 +2871,7 @@ def _get_non_moe_state_dict(self, full_state_dict): Get the state dict of the non-moe layers """ for key in list(full_state_dict.keys()): - if 'expert' in key and 'moe.gate.wg.weight' not in key: + if 'expert' in key and 'moe.gate' not in key: full_state_dict.pop(key) return full_state_dict @@ -2895,9 +2898,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}): # get all moe parameters moe_state_dict = {} for n, p in module.state_dict().items(): - if 'expert' in n and 'moe.gate.wg.weight' not in n: + if 'expert' in n and 'moe.gate' not in n: moe_state_dict[n_module + '.' + n] = p - moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.' + moe_str_prefix = MOE_PARAM_PREFIX #'.deepspeed_moe.experts.deepspeed_experts.' # print(moe_state_dict.keys()) # until now, everything is fine. So the bug happens at next few lines # Reorder the moe name rank, so that each checkpoint only has one expert experts_state_dict = defaultdict(dict) @@ -2940,6 +2943,7 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}): return # Save optimizer states. They are different across each exp parallel rank. + # for zero, this is None optimizer_state = { 'optimizer': self.optimizer.state_dict() if self.optimizer and not self.zero_optimization() else None }