-
Notifications
You must be signed in to change notification settings - Fork 24
Expand file tree
/
Copy pathsp_fsdp_dense.py
More file actions
95 lines (78 loc) · 2.94 KB
/
sp_fsdp_dense.py
File metadata and controls
95 lines (78 loc) · 2.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import numpy as np
from functools import partial
from peft import LoraConfig
import twinkle
from twinkle import DeviceGroup, DeviceMesh, Platform, get_logger
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.model import TransformersModel
from twinkle.preprocessor import SelfCognitionProcessor
logger = get_logger()
MODEL_ID = 'ms://Qwen/Qwen3.5-4B'
DATASETS = 'ms://swift/self-cognition'
device_group = [DeviceGroup(
name='default',
ranks=[0, 1, 2, 3],
device_type=Platform.get_platform().device_prefix(),
)]
# FSDP + sequence-parallel validation over 4 GPUs: dp=2, fsdp=2.
# In Transformers route, ulysses_size is the total sequence-parallel degree.
device_mesh = DeviceMesh(
device_type=Platform.get_platform().device_prefix(),
mesh=np.arange(4).reshape(2, 2),
mesh_dim_names=('dp', 'fsdp'),
ulysses_size=2,
)
twinkle.initialize(
mode='local',
nproc_per_node=4,
global_device_mesh=device_mesh,
lazy_collect=False,
)
def eval(model):
dataloader = DataLoader(
dataset=partial(create_dataset, data_slice=range(100)),
batch_size=4,
device_mesh=device_mesh,
)
for _, batch in enumerate(dataloader):
model.forward_only(inputs=batch, adapter_name='default')
model.calculate_loss(adapter_name='default')
return model.calculate_metric(is_training=False, adapter_name='default')
def create_dataset(data_slice=None):
dataset = Dataset(dataset_meta=DatasetMeta(DATASETS, data_slice=range(500)))
dataset.set_template('Qwen3_5Template', model_id=MODEL_ID)
dataset.map(SelfCognitionProcessor('twinkle模型', 'twinkle团队'))
dataset.encode(batched=True)
return dataset
def train():
dataloader = DataLoader(
dataset=partial(create_dataset, data_slice=None),
batch_size=8,
device_mesh=device_mesh,
)
model = TransformersModel(
model_id=MODEL_ID,
device_mesh=device_mesh,
strategy='native_fsdp',
)
lora_config = LoraConfig(target_modules='all-linear')
model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=1)
model.set_optimizer('AdamW', lr=1e-4, adapter_name='default')
model.set_lr_scheduler(
scheduler_cls='CosineWarmupScheduler',
num_warmup_steps=5,
num_training_steps=len(dataloader),
adapter_name='default',
)
logger.info(model.get_train_configs(adapter_name='default'))
logger.info(f'Total steps: {len(dataloader)}')
for step, batch in enumerate(dataloader):
model.forward_backward(inputs=batch, adapter_name='default')
model.clip_grad_and_step(adapter_name='default')
if step % 20 == 0:
metric = model.calculate_metric(is_training=True, adapter_name='default')
logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
model.save('last-checkpoint', interval=1)
if __name__ == '__main__':
train()