Skip to content

Commit d69a864

Browse files
authored
[feat]support ep_fsdp (#71)
1 parent 85e4f7d commit d69a864

10 files changed

Lines changed: 1427 additions & 606 deletions

File tree

cookbook/transformers/ep_fsdp_qwen3_moe.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
2-
import numpy as np
32
import os
43
from transformers import AutoConfig
54

@@ -17,15 +16,18 @@
1716
TEMPLATE_ID = os.environ.get('TEMPLATE_ID', 'Template')
1817
_num_layers_env = os.environ.get('NUM_LAYERS')
1918
NUM_LAYERS = int(_num_layers_env) if _num_layers_env is not None else None
19+
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '4'))
20+
GRAD_ACCUM_STEPS = int(os.environ.get('GRAD_ACCUM_STEPS', '4'))
21+
LR = float(os.environ.get('LR', '1e-5'))
22+
MAX_GRAD_NORM = float(os.environ.get('MAX_GRAD_NORM', '1.0'))
23+
KEEP_ROUTER_LOGITS = os.environ.get('KEEP_ROUTER_LOGITS', '0') == '1'
2024

21-
# 4 gpus, dp=2, ep=2
22-
dp_size = 2
23-
ep_size = 2
24-
25-
device_mesh = DeviceMesh(
25+
# 8 gpus, dp=1, fsdp=8 (data parallel), ep_size=8 (expert parallel)
26+
device_mesh = DeviceMesh.from_sizes(
27+
fsdp_size=8,
28+
dp_size=1,
29+
ep_size=8,
2630
device_type=Platform.get_platform().device_prefix(),
27-
mesh=np.arange(dp_size * ep_size).reshape(dp_size, ep_size),
28-
mesh_dim_names=('dp', 'ep'),
2931
)
3032

3133
twinkle.initialize(
@@ -41,7 +43,7 @@ def train():
4143
if hasattr(config, 'use_cache'):
4244
config.use_cache = False
4345

44-
dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000)))
46+
dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(1000)))
4547
try:
4648
dataset.set_template(TEMPLATE_ID, model_id=MODEL_ID)
4749
except ValueError:
@@ -51,11 +53,10 @@ def train():
5153
dataset.encode(batched=True)
5254
dataloader = DataLoader(
5355
dataset=dataset,
54-
batch_size=4,
56+
batch_size=BATCH_SIZE,
5557
device_mesh=device_mesh,
5658
)
5759

58-
grad_accum_steps = 4
5960
model = TransformersModel(
6061
model_id=MODEL_ID,
6162
config=config,
@@ -64,29 +65,43 @@ def train():
6465
'expert_parallel': {
6566
'enabled': True,
6667
'router_dtype': 'fp32',
67-
'all_to_all': 'torch',
68-
'keep_router_logits': False,
68+
'keep_router_logits': KEEP_ROUTER_LOGITS,
6969
}
7070
},
7171
)
7272
# Disable foreach to avoid DTensor mixed-type errors in EP runs.
73-
model.set_optimizer('AdamW', foreach=False)
73+
model.set_optimizer('AdamW', lr=LR, foreach=False)
74+
model.set_lr_scheduler(
75+
scheduler_cls='CosineWarmupScheduler',
76+
num_warmup_steps=5,
77+
num_training_steps=len(dataloader),
78+
)
7479

7580
logger.info(get_device_placement())
7681
logger.info(model.get_train_configs())
82+
logger.info(
83+
f'Total steps: {len(dataloader)}, batch_size={BATCH_SIZE}, grad_accum={GRAD_ACCUM_STEPS}, '
84+
f'lr={LR:.2e}, max_grad_norm={MAX_GRAD_NORM}, '
85+
f'keep_router_logits={KEEP_ROUTER_LOGITS}')
7786

7887
for step, batch in enumerate(dataloader):
7988
if callable(batch):
8089
batch = batch()
81-
model.forward_backward(inputs=batch, gradient_accumulation_steps=grad_accum_steps)
82-
model.clip_grad_and_step(gradient_accumulation_steps=grad_accum_steps)
83-
if step % grad_accum_steps == 0:
90+
model.forward_backward(inputs=batch, gradient_accumulation_steps=GRAD_ACCUM_STEPS)
91+
model.clip_grad_and_step(
92+
max_grad_norm=MAX_GRAD_NORM,
93+
gradient_accumulation_steps=GRAD_ACCUM_STEPS,
94+
)
95+
96+
is_sync_step = ((step + 1) % GRAD_ACCUM_STEPS == 0)
97+
if is_sync_step:
98+
optimizer_step = (step + 1) // GRAD_ACCUM_STEPS
8499
metric = model.calculate_metric(is_training=True)
85100
if callable(metric):
86101
metric = metric()
87-
logger.info(f'Current is step {step // grad_accum_steps}, metric: {metric}')
88-
if step > 0 and step % 50 == 0:
89-
model.save('./output')
102+
logger.info(f'Current optimizer_step {optimizer_step}, metric: {metric}')
103+
if optimizer_step > 0 and optimizer_step % 50 == 0:
104+
model.save(name=f'checkpoint-step-{optimizer_step}', output_dir='./output')
90105

91106

92107
if __name__ == '__main__':
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
2-
from .expert_parallel import apply_expert_parallel
2+
from .expert_parallel import ExpertShardingSpec, apply_expert_parallel
33

4-
__all__ = ['apply_expert_parallel']
4+
__all__ = ['ExpertShardingSpec', 'apply_expert_parallel']

0 commit comments

Comments
 (0)