Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/deepspeed/moe_e/run-distributed.sh
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ train() {
SaveDir="${OUT_DIR?}/checkpoints/${ARCH}-${RUN_NAME}"
mkdir -p $SaveDir

python -m torch.distributed.launch \
torchrun \
--nproc_per_node=${NUM_GPUS} \
--node_rank=${NODE_RANK:-0} \
--nnodes=${NODE_COUNT:-1} \
Expand Down
14 changes: 12 additions & 2 deletions examples/deepspeed/moe_e/user/modules/transformer_moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@

DEFAULT_MIN_PARAMS_TO_WRAP = int(1e8)

def get_device():
import deepspeed

def _get_version(_version):
return tuple(map(int, _version.split('.')))
if _get_version(deepspeed.__version__) >= _get_version("0.15.0"):
return "cuda"
return "cpu"

class TransformerEncoderLayer_MOE(nn.Module):
"""Encoder layer block.
Expand Down Expand Up @@ -93,7 +101,8 @@ def __init__(self, args, index=-1):
self.experts = None
if (index + 1) % 2 == 0:
from deepspeed.moe.layer import MoE
self.expert_counts = torch.zeros(1, args.num_experts, dtype=torch.int64).to('cpu')
dev = get_device()
self.expert_counts = torch.zeros(1, args.num_experts, dtype=torch.int64).to(dev)
self.fc1 = self.build_fc1(
self.embed_dim,
args.encoder_ffn_embed_dim,
Expand Down Expand Up @@ -340,7 +349,8 @@ def __init__(
self.experts = None
if (index + 1) % 2 == 0:
from deepspeed.moe.layer import MoE
self.expert_counts = torch.zeros(1, args.num_experts, dtype=torch.int64).to('cpu')
dev = get_device()
self.expert_counts = torch.zeros(1, args.num_experts, dtype=torch.int64).to(dev)
self.fc1 = self.build_fc1(
self.embed_dim,
args.encoder_ffn_embed_dim,
Expand Down
4 changes: 2 additions & 2 deletions examples/deepspeed/moe_e/user/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def main(cfg: FairseqConfig) -> None:
# ckpt_path = f"{cfg.checkpoint.save_dir}/deepspeed_moe"
trainer.ds_module = load_deepspeed_state_(
cfg=cfg,
model=trainer.model.module.module,
model=trainer.model,
weights_path=None)
if cfg.model.ep_world_size > cfg.model.num_experts:
raise ValueError(
Expand Down Expand Up @@ -438,7 +438,7 @@ def validate_and_save(
from user.ds_utils import save_deepspeed_state_
save_deepspeed_state_(
cfg,
model=trainer.model.module.module,
model=trainer.model,
trainer=trainer,
ds_module=trainer.ds_module,
ckpt_tag=None,
Expand Down
4 changes: 2 additions & 2 deletions fairseq/data/indexed_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def write_longs(f, a):
3: np.int16,
4: np.int32,
5: np.int64,
6: np.float,
6: float,
7: np.double,
8: np.uint16,
9: np.uint32,
Expand Down Expand Up @@ -325,7 +325,7 @@ class IndexedDatasetBuilder:
np.int16: 2,
np.int32: 4,
np.int64: 8,
np.float: 4,
float: 4,
np.double: 8,
}

Expand Down
27 changes: 14 additions & 13 deletions fairseq/dataclass/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import sys
from dataclasses import _MISSING_TYPE, dataclass, field
from typing import Any, List, Optional
Expand Down Expand Up @@ -271,9 +272,9 @@ class DistributedTrainingConfig(FairseqDataclass):
},
)
device_id: int = field(
default=0,
default=os.getenv("LOCAL_RANK", 0),
metadata={
"help": "which GPU to use (usually configured automatically)",
"help": "which GPU to use (by default looks for $LOCAL_RANK, usually configured automatically)",
"argparse_alias": "--local_rank",
},
)
Expand Down Expand Up @@ -1023,16 +1024,16 @@ class EMAConfig(FairseqDataclass):

@dataclass
class FairseqConfig(FairseqDataclass):
common: CommonConfig = CommonConfig()
common_eval: CommonEvalConfig = CommonEvalConfig()
distributed_training: DistributedTrainingConfig = DistributedTrainingConfig()
dataset: DatasetConfig = DatasetConfig()
optimization: OptimizationConfig = OptimizationConfig()
checkpoint: CheckpointConfig = CheckpointConfig()
bmuf: FairseqBMUFConfig = FairseqBMUFConfig()
generation: GenerationConfig = GenerationConfig()
eval_lm: EvalLMConfig = EvalLMConfig()
interactive: InteractiveConfig = InteractiveConfig()
common: CommonConfig = field(default=CommonConfig)
common_eval: CommonEvalConfig = field(default=CommonEvalConfig)
distributed_training: DistributedTrainingConfig = field(default=DistributedTrainingConfig)
dataset: DatasetConfig = field(default=DatasetConfig)
optimization: OptimizationConfig = field(default=OptimizationConfig)
checkpoint: CheckpointConfig = field(default=CheckpointConfig)
bmuf: FairseqBMUFConfig = field(default=FairseqBMUFConfig)
generation: GenerationConfig = field(default=GenerationConfig)
eval_lm: EvalLMConfig = field(default=EvalLMConfig)
interactive: InteractiveConfig = field(default=InteractiveConfig)
model: Any = MISSING
task: Any = None
criterion: Any = None
Expand All @@ -1041,4 +1042,4 @@ class FairseqConfig(FairseqDataclass):
scoring: Any = None
bpe: Any = None
tokenizer: Any = None
ema: EMAConfig = EMAConfig()
ema: EMAConfig = field(default=EMAConfig)
6 changes: 3 additions & 3 deletions fairseq/dataclass/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,13 +364,13 @@ def override_module_args(args: Namespace) -> Tuple[List[str], List[str]]:

class omegaconf_no_object_check:
def __init__(self):
self.old_is_primitive = _utils.is_primitive_type
self.old_is_primitive = _utils.is_primitive_type_annotation

def __enter__(self):
_utils.is_primitive_type = lambda _: True
_utils.is_primitive_type_annotation = lambda _: True

def __exit__(self, type, value, traceback):
_utils.is_primitive_type = self.old_is_primitive
_utils.is_primitive_type_annotation = self.old_is_primitive


def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig:
Expand Down
6 changes: 3 additions & 3 deletions fairseq/models/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,13 @@ class TransformerConfig(FairseqDataclass):
},
)
adaptive_input: bool = False
encoder: EncDecBaseConfig = EncDecBaseConfig()
encoder: EncDecBaseConfig = field(default=EncDecBaseConfig)
# TODO should really be in the encoder config
max_source_positions: int = field(
default=DEFAULT_MAX_SOURCE_POSITIONS,
metadata={"help": "Maximum input length supported by the encoder"},
)
decoder: DecoderConfig = DecoderConfig()
decoder: DecoderConfig = field(default=DecoderConfig)
# TODO should really be in the decoder config
max_target_positions: int = field(
default=DEFAULT_MAX_TARGET_POSITIONS,
Expand Down Expand Up @@ -182,7 +182,7 @@ class TransformerConfig(FairseqDataclass):
default=False, metadata={"help": "perform cross+self-attention"}
)
# args for Training with Quantization Noise for Extreme Model Compression ({Fan*, Stock*} et al., 2020)
quant_noise: QuantNoiseConfig = field(default=QuantNoiseConfig())
quant_noise: QuantNoiseConfig = field(default=QuantNoiseConfig)
min_params_to_wrap: int = field(
default=DEFAULT_MIN_PARAMS_TO_WRAP,
metadata={
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ def do_setup(package_data):
"cffi",
"cython",
'dataclasses; python_version<"3.7"',
"hydra-core>=1.0.7,<1.1",
"omegaconf<2.1",
"hydra-core>=1.3.2",
"omegaconf",
'numpy<1.20.0; python_version<"3.7"',
'numpy; python_version>="3.7"',
"regex",
Expand Down