Skip to content

打开Smart schedule运行examples/transformer-xl/scripts/run_enwik8_base_moe.sh 报错 #207

@WhatBrain

Description

@WhatBrain

Describe the bug
A clear and concise description of what the bug is.
When I use export FMOE_FASTER_SHADOW_ENABLE=1 and export FMOE_FASTER_SCHEDULE_ENABLE=1 to turn on Smart schedule,and bash examples/transformer-xl/scripts/run_enwik8_base_moe.sh train, it reports an error:
Original Traceback (most recent call last): File "/data/anaconda3/envs/fast/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker output = module(*input, **kwargs) File "/data/anaconda3/envs/fast/lib/python3.9/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/data/jiang/fastmoe-master/fastmoe-master/examples/transformer-xl/mem_transformer.py", line 801, in forward hidden, new_mems = self._forward(data, mems=mems) File "/data/jiang/fastmoe-master/fastmoe-master/examples/transformer-xl/mem_transformer.py", line 732, in _forward core_out = layer(core_out, pos_emb, self.r_w_bias, File "/data/anaconda3/envs/fast/lib/python3.9/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/data/jiang/fastmoe-master/fastmoe-master/examples/transformer-xl/mem_transformer.py", line 481, in forward output = self.pos_ff(output) File "/data/anaconda3/envs/fast/lib/python3.9/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, **kwargs) File "/data/jiang/fastmoe-master/fastmoe-master/examples/transformer-xl/mem_transformer.py", line 403, in forward core_out = super().forward(inp) File "/data/anaconda3/envs/fast/lib/python3.9/site-packages/fastmoe-1.1.0-py3.9-linux-x86_64.egg/fmoe/transformer.py", line 65, in forward output = super().forward(inp) File "/data/anaconda3/envs/fast/lib/python3.9/site-packages/fastmoe-1.1.0-py3.9-linux-x86_64.egg/fmoe/layers.py", line 251, in forward fwd = _fmoe_general_global_forward( File "/data/anaconda3/envs/fast/lib/python3.9/site-packages/fastmoe-1.1.0-py3.9-linux-x86_64.egg/fmoe/fastermoe/schedule.py", line 136, in _fmoe_general_global_forward stored_models = policy_fn(local_expert_count, global_expert_count, File "/data/anaconda3/envs/fast/lib/python3.9/site-packages/fastmoe-1.1.0-py3.9-linux-x86_64.egg/fmoe/fastermoe/shadow_policy.py", line 27, in global_policy dist.all_gather(agecs, local_expert_count, group=moe_group) File "/data/anaconda3/envs/fast/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 1862, in all_gather default_pg = _get_default_group() File "/data/anaconda3/envs/fast/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py", line 347, in _get_default_group raise RuntimeError("Default process group has not been initialized, " RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.

Do I need to add init_process_group to the code in the examples
To Reproduce

Expected behavior

Logs

Platform

  • Device: [NVIDIA A100]
  • CUDA version: [ 11.1]
  • NCCL version: [2.7.8]
  • PyTorch version: [1.8.0]

Additional context
Add any other context about the problem here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions