-
Notifications
You must be signed in to change notification settings - Fork 23
Expand file tree
/
Copy pathtp.py
More file actions
82 lines (72 loc) · 3.1 KB
/
tp.py
File metadata and controls
82 lines (72 loc) · 3.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
from peft import LoraConfig
from tqdm import tqdm
import twinkle
from twinkle import DeviceMesh, Platform, get_device_placement, get_logger
from twinkle.dataloader import DataLoader
from twinkle.dataset import Dataset, DatasetMeta
from twinkle.model import MegatronModel
from twinkle.preprocessor import SelfCognitionProcessor
# Construct a device_mesh, tp=pp=dp=2
device_mesh = DeviceMesh.from_sizes(dp_size=2, tp_size=2, pp_size=2)
# use torchrun mode
twinkle.initialize(mode='local', global_device_mesh=device_mesh)
logger = get_logger()
def eval(model):
# 100 Samples
dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100)))
dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B')
dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
dataset.encode()
dataloader = DataLoader(dataset=dataset, batch_size=16)
for step, batch in tqdm(enumerate(dataloader)):
model.forward_only(inputs=batch)
metrics = model.calculate_metric(is_training=False)
return metrics
def train():
# 1000 samples
dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000)))
# Set template to prepare encoding
dataset.set_template('Qwen3_5Template', model_id='ms://Qwen/Qwen3.5-4B')
# Preprocess the dataset to standard format
dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
# Encode dataset
dataset.encode()
# Global batch size = 1, dp_size = 1
dataloader = DataLoader(dataset=dataset, batch_size=16)
# Use a MegatronModel
model = MegatronModel(model_id='ms://Qwen/Qwen3.5-4B')
lora_config = LoraConfig(r=8, lora_alpha=32, target_modules='all-linear')
# Add a lora to model, with name `default`
# Comment this to use full-parameter training
model.add_adapter_to_model('default', lora_config)
# Add Optimizer for lora `default`
model.set_optimizer(optimizer_cls='default', lr=1e-4)
# Add LRScheduler for lora `default`
model.set_lr_scheduler(scheduler_cls='default', lr_warmup_steps=5, lr_decay_steps=len(dataloader))
logger.info(get_device_placement())
# Print the training config
logger.info(model.get_train_configs())
logger.info(f'Total steps: {len(dataloader)}')
loss_metric = 99.0
# lora: 10G * 8
# full: 40G * 8
for step, batch in enumerate(dataloader):
# Do forward and backward
model.forward_backward(inputs=batch)
# Step
model.clip_grad_and_step()
if step % 5 == 0:
# Print metric
metric = model.calculate_metric(is_training=True)
logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
if step > 0 and step % 20 == 0:
metrics = eval(model)
logger.info(f'Eval metric: {metrics}')
metrics['step'] = step
if loss_metric > float(metrics['loss']):
model.save(f'checkpoint-{step}')
loss_metric = float(metrics['loss'])
model.save(f'last-checkpoint')
if __name__ == '__main__':
train()