Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cuda/fastermoe/smart_schedule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ torch::Tensor _smart_sch_backward(
auto global_grad_in = grad_out.new_zeros({global_batch_size, d_model});
auto grad_in = grad_out.new_zeros({buf_batch_size, d_model});

AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_out.scalar_type(),
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,grad_out.scalar_type(),
"fmoe_cuda_smartsch_backward", ([&] {
fmoe_cuda_fused_backward_impl(
backward_fn,
Expand Down
6 changes: 3 additions & 3 deletions fmoe/gates/gshard_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

class GShardGate(NaiveGate):
def __init__(self, d_model, num_expert, world_size,
topk=2, capacity=(1.2, 2.4), random_routing=True):
assert topk == 2, 'topk should be 2 in gshard'
super().__init__(d_model, num_expert, world_size, top_k=2)
topk=2, capacity=(1, 1), random_routing=False):
# assert topk == 2, 'topk should be 2 in gshard'
super().__init__(d_model, num_expert, world_size, top_k=topk)
self.capacity = capacity
self.random_routing = random_routing

Expand Down
3 changes: 2 additions & 1 deletion fmoe/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __init__(
gate_hook=None,
mask=None,
mask_dict=None,
gate_kwargs={}
):
super().__init__()
self.num_expert = num_expert
Expand Down Expand Up @@ -145,7 +146,7 @@ def __init__(
else:
self.experts_fused = True

self.gate = gate(d_model, num_expert, world_size, top_k)
self.gate = gate(d_model, num_expert, world_size, top_k, **gate_kwargs)
self.gate_hook = gate_hook
self.mask = mask
self.mask_dict = mask_dict
Expand Down
20 changes: 13 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
ext_libs = []

authors = [
'Jiaao He',
'Jiezhong Qiu',
'Aohan Zeng',
'Tiago Antunes',
'Jinjun Peng',
'Jiaao He',
'Jiezhong Qiu',
'Aohan Zeng',
'Tiago Antunes',
'Jinjun Peng',
'Qin Li',
'Mingshu Zhai'
]
Expand All @@ -37,6 +37,11 @@
else:
define_macros=[]

include_dirs = []
if os.environ.get("NCCL_PATH"):
include_dirs.append(os.environ.get("NCCL_PATH")+'/include')
nccl_lib_path = os.environ.get("NCCL_PATH")+'/lib'
os.environ['LIBRARY_PATH'] = nccl_lib_path+':'+os.environ.get('LIBRARY_PATH','')

if __name__ == '__main__':
setuptools.setup(
Expand All @@ -50,7 +55,7 @@
packages=['fmoe', 'fmoe.megatron', 'fmoe.gates', 'fmoe.fastermoe'],
ext_modules=[
CUDAExtension(
name='fmoe_cuda',
name='fmoe_cuda',
sources=[
'cuda/stream_manager.cpp',
'cuda/local_exchange.cu',
Expand All @@ -65,7 +70,8 @@
'cxx': cxx_flags,
'nvcc': cxx_flags
},
libraries=ext_libs
libraries=ext_libs,
include_dirs=include_dirs
)
],
cmdclass={
Expand Down