Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepspeed/moe/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,7 @@ def forward(self, inputs):

expert_output = torch.cat(expert_outputs, dim=1)
return expert_output

def forward_single(self, idx, inputs):
expert_output = self.deepspeed_experts[idx](inputs)
return expert_output
111 changes: 101 additions & 10 deletions deepspeed/moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,77 @@
from .experts import Experts
import typing

from fmoe import FMoE
import fmoe
import copy
class VitFMoE(FMoE):
def __init__(
self, expert, d_model=1, top_k=1, num_local_experts=1, expert_group_name=None, moe_group=None, gate_kwargs={}
):
assert expert_group_name is not None, 'expert_group_name should be provided'
world_size = torch.distributed.get_world_size(moe_group)
super().__init__(
num_expert=num_local_experts,
d_model=d_model,
world_size=world_size,
mp_group=None,
top_k=top_k,
moe_group=moe_group,
gate = fmoe.gates.GShardGate,
gate_kwargs = gate_kwargs
)

self.experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)])
self.experts_fused = False
self.num_experts = num_local_experts
self.num_local_experts = num_local_experts

for expert in self.experts:
for name, param in expert.named_parameters():
param.allreduce = False
param.group_name = expert_group_name

def expert_fn(self, inp, fwd_expert_count):
r"""
The default expert function which either calls the experts as a whole
or as separate experts.
"""
if self.experts_fused:
return self.experts(inp, fwd_expert_count)
if isinstance(fwd_expert_count, torch.Tensor):
fwd_expert_count_cpu = fwd_expert_count.cpu().numpy()
outputs = []
base_idx = 0
for i in range(self.num_experts):
batch_size = fwd_expert_count_cpu[i]
inp_slice = inp[base_idx : base_idx + batch_size]
# outputs.append(self.experts[i](inp_slice, torch.tensor([fwd_expert_count[i]])))
outputs.append(self.experts[i](inp_slice))
# outputs.append(self.experts.forward_single(i, inp_slice))
base_idx += batch_size
return torch.cat(outputs, dim=0)

def expert_fn_single(self, inp, fwd_expert_count, idx):
r"""
forward single expert for smart scheduling.
"""
assert not self.experts_fused, "should not use fused experts"
output = self.experts[idx](inp)
# output = self.experts.forward_single(idx, inp)
return output

def _set_ep_group(self, group):
self.moe_group = group

def forward(self, inp: torch.Tensor):
r"""
This module wraps up the FMoE module with reshape, residual and layer
normalization.
"""
original_shape = inp.shape
inp = inp.reshape(-1, self.d_model)
output = super().forward(inp)
return output.reshape(original_shape)

class MoE(torch.nn.Module):
"""Initialize an MoE layer.
Expand Down Expand Up @@ -47,7 +118,8 @@ def __init__(self,
drop_tokens: bool = True,
use_rts=True,
use_tutel: bool = False,
enable_expert_tensor_parallelism: bool = False):
enable_expert_tensor_parallelism: bool = False,
use_fmoe: bool = False):

super(MoE, self).__init__()

Expand All @@ -58,6 +130,7 @@ def __init__(self,
self.expert_group_name = f"ep_size_{self.ep_size}"
self.num_experts = num_experts
self.num_local_experts = num_experts // self.ep_size
self.use_fmoe = use_fmoe

log_dist(
f'Creating MoE layer with num_experts: {num_experts} | num_local_experts: {self.num_local_experts} | expert_parallel_size: {self.ep_size}',
Expand All @@ -67,13 +140,23 @@ def __init__(self,
'Unsupported noisy_gate_policy: ' + noisy_gate_policy

experts = Experts(expert, self.num_local_experts, self.expert_group_name)
self.deepspeed_moe = MOELayer(TopKGate(hidden_size, num_experts, k, capacity_factor, eval_capacity_factor,
min_capacity, noisy_gate_policy, drop_tokens, use_rts),
experts,
self.expert_group_name,
self.ep_size,
self.num_local_experts,
use_tutel=use_tutel)
if not use_fmoe:
self.deepspeed_moe = MOELayer(TopKGate(hidden_size, num_experts, k, capacity_factor, eval_capacity_factor,
min_capacity, noisy_gate_policy, drop_tokens, use_rts),
experts,
self.expert_group_name,
self.ep_size,
self.num_local_experts,
use_tutel=use_tutel)
else:
# Note: need to setup groups for moe before creating MOE layers
ep_group = groups._get_expert_parallel_group(self.expert_group_name)
self.deepspeed_moe = VitFMoE(experts, d_model=hidden_size, top_k=k,
num_local_experts=self.num_local_experts,
moe_group = ep_group,
expert_group_name = self.expert_group_name,
gate_kwargs = {'capacity':(capacity_factor, eval_capacity_factor)})

if self.use_residual:
self.mlp = expert
# coefficient is used for weighted sum of the output of expert and mlp
Expand Down Expand Up @@ -112,7 +195,12 @@ def forward(self, hidden_states, used_token=None):

* exp_counts (int): expert count
"""
output = self.deepspeed_moe(hidden_states, used_token)
if self.use_fmoe:
output = self.deepspeed_moe(hidden_states)
else:
# import pdb;pdb.set_trace()
output = self.deepspeed_moe(hidden_states, used_token)

if self.use_residual:
# Residual MoE
output_mlp = self.mlp(hidden_states)
Expand All @@ -121,4 +209,7 @@ def forward(self, hidden_states, used_token=None):
coef = self.coefficient(hidden_states)
coef = torch.nn.functional.softmax(coef, dim=-1)
output = output * coef[..., 0:1] + output_mlp * coef[..., 1:]
return output, self.deepspeed_moe.l_aux, self.deepspeed_moe.exp_counts
if self.use_fmoe:
return output#,None,None # todo: fix this
else:
return output#, self.deepspeed_moe.l_aux, self.deepspeed_moe.exp_counts
20 changes: 12 additions & 8 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@

from deepspeed.runtime.config import DtypeEnum

from fmoe import FMoE

# Set to torch's distributed package or deepspeed.comm based inside DeepSpeedEngine init
dist = None

Expand All @@ -113,6 +115,7 @@
# Fail silently so we don't spam logs unnecessarily if user isn't using amp
APEX_INSTALLED = False

MOE_PARAM_PREFIX = '.deepspeed_moe.experts.'

def split_half_float_double_sparse(tensors):
device_type = get_accelerator().device_name()
Expand Down Expand Up @@ -223,8 +226,8 @@ def __init__(
self.gas_boundary_ctr = 0
self.dist_backend = get_accelerator().communication_backend_name()
self.has_moe_layers = False
self.num_experts = []
self.gate_modules = []
self.num_experts = [] # for load and save checkpoint
self.gate_modules = [] # for time profile and print
self.moe_layers = []
self._step_applied = False
self._global_grad_norm = None
Expand Down Expand Up @@ -1052,7 +1055,7 @@ def _configure_distributed_model(self, model):

# MoE related initialization
for _, module in self.module.named_modules():
if isinstance(module, MoE):
if isinstance(module, MoE):# or isinstance(module, FMoE):
self.has_moe_layers = True
self.num_experts.append(module.num_experts)

Expand Down Expand Up @@ -2384,7 +2387,7 @@ def load_moe_state_dict(checkpoint_path,
map_location=torch.device('cpu'))

# Updating global -> local expert ids
moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.'
moe_str_prefix = MOE_PARAM_PREFIX #'.deepspeed_moe.experts.deepspeed_experts.'
for key in list(expert_state_dict.keys()):
local_key = key.replace(f'{moe_str_prefix}{global_expert_id}',
f'{moe_str_prefix}{local_expert_id}')
Expand All @@ -2406,7 +2409,7 @@ def load_moe_state_dict(checkpoint_path,
map_location=torch.device('cpu'))
# print(expert_state_dict.keys())
# Updating global -> local expert ids
moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.'
moe_str_prefix = MOE_PARAM_PREFIX #'.deepspeed_moe.experts.deepspeed_experts.'
for key in list(expert_state_dict.keys()):
local_key = key.replace(f'{moe_str_prefix}{global_expert_id}',
f'{moe_str_prefix}{local_expert_id}')
Expand Down Expand Up @@ -2868,7 +2871,7 @@ def _get_non_moe_state_dict(self, full_state_dict):
Get the state dict of the non-moe layers
"""
for key in list(full_state_dict.keys()):
if 'expert' in key and 'moe.gate.wg.weight' not in key:
if 'expert' in key and 'moe.gate' not in key:
full_state_dict.pop(key)

return full_state_dict
Expand All @@ -2895,9 +2898,9 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}):
# get all moe parameters
moe_state_dict = {}
for n, p in module.state_dict().items():
if 'expert' in n and 'moe.gate.wg.weight' not in n:
if 'expert' in n and 'moe.gate' not in n:
moe_state_dict[n_module + '.' + n] = p
moe_str_prefix = '.deepspeed_moe.experts.deepspeed_experts.'
moe_str_prefix = MOE_PARAM_PREFIX #'.deepspeed_moe.experts.deepspeed_experts.'
# print(moe_state_dict.keys()) # until now, everything is fine. So the bug happens at next few lines
# Reorder the moe name rank, so that each checkpoint only has one expert
experts_state_dict = defaultdict(dict)
Expand Down Expand Up @@ -2940,6 +2943,7 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}):
return

# Save optimizer states. They are different across each exp parallel rank.
# for zero, this is None
optimizer_state = {
'optimizer': self.optimizer.state_dict() if self.optimizer and not self.zero_optimization() else None
}
Expand Down