Skip to content

Commit 82a766b

Browse files
committed
wip
1 parent 8a8d02d commit 82a766b

File tree

2 files changed

+281
-0
lines changed

2 files changed

+281
-0
lines changed

cookbook/sft/multi_lora.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import os
2+
from peft import LoraConfig
3+
4+
import twinkle
5+
from twinkle import DeviceMesh, get_device_placement, get_logger
6+
from twinkle.dataloader import DataLoader
7+
from twinkle.dataset import Dataset, DatasetMeta
8+
from twinkle.model import MultiLoraTransformersModel
9+
from twinkle.preprocessor import SelfCognitionProcessor
10+
11+
logger = get_logger()
12+
13+
MODEL_ID = os.getenv('MODEL_ID', 'ms://Qwen/Qwen2.5-7B-Instruct')
14+
DATASET_ID = os.getenv('DATASET_ID', 'ms://swift/self-cognition')
15+
OUTPUT_DIR = os.getenv('OUTPUT_DIR', 'output/multi_lora')
16+
17+
TRAIN_SAMPLES = int(os.getenv('TRAIN_SAMPLES', '1000'))
18+
BATCH_SIZE = int(os.getenv('BATCH_SIZE', '8'))
19+
EPOCHS = int(os.getenv('EPOCHS', '1'))
20+
GRAD_ACC_STEPS = int(os.getenv('GRAD_ACC_STEPS', '2'))
21+
MAX_LENGTH = int(os.getenv('MAX_LENGTH', '1024'))
22+
MAX_LORAS = int(os.getenv('MAX_LORAS', '4'))
23+
MAX_R = int(os.getenv('MAX_R', '32'))
24+
25+
LOG_INTERVAL = int(os.getenv('LOG_INTERVAL', '20'))
26+
SAVE_EVERY_EPOCH = os.getenv('SAVE_EVERY_EPOCH', '1') == '1'
27+
28+
29+
def build_device_mesh():
30+
world_size = int(os.getenv('WORLD_SIZE', '1'))
31+
if world_size <= 1:
32+
return None
33+
# MultiLora + FSDP path: fsdp_world_size > 1 will force native_fsdp in model.
34+
return DeviceMesh.from_sizes(world_size=world_size, fsdp_size=world_size, dp_size=1)
35+
36+
37+
def create_dataloader(device_mesh):
38+
dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(TRAIN_SAMPLES)))
39+
dataset.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH)
40+
dataset.map(SelfCognitionProcessor('twinkle模型', 'twinkle团队'))
41+
dataset.encode(batched=True)
42+
return DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh)
43+
44+
45+
def setup_multi_lora_model(device_mesh, total_steps):
46+
model = MultiLoraTransformersModel(
47+
model_id=MODEL_ID,
48+
device_mesh=device_mesh,
49+
max_loras=MAX_LORAS,
50+
max_r=MAX_R,
51+
)
52+
53+
# Two tenants with independent optimizer/scheduler states.
54+
tenant_settings = {
55+
'tenant_a': {
56+
'config': LoraConfig(r=8, lora_alpha=32, target_modules='all-linear'),
57+
'lr': 1e-4,
58+
},
59+
'tenant_b': {
60+
'config': LoraConfig(r=16, lora_alpha=32, target_modules='all-linear'),
61+
'lr': 8e-5,
62+
},
63+
}
64+
65+
steps_per_adapter = max(1, (total_steps + len(tenant_settings) - 1) // len(tenant_settings))
66+
warmup_steps = max(1, steps_per_adapter // 10)
67+
68+
for adapter_name, settings in tenant_settings.items():
69+
model.add_adapter_to_model(
70+
adapter_name,
71+
settings['config'],
72+
gradient_accumulation_steps=GRAD_ACC_STEPS,
73+
)
74+
model.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH, adapter_name=adapter_name)
75+
model.set_processor('InputProcessor', padding_side='right', adapter_name=adapter_name)
76+
model.set_loss('CrossEntropyLoss', adapter_name=adapter_name)
77+
model.set_optimizer('AdamW', lr=settings['lr'], adapter_name=adapter_name)
78+
model.set_lr_scheduler(
79+
'CosineWarmupScheduler',
80+
num_warmup_steps=warmup_steps,
81+
num_training_steps=steps_per_adapter,
82+
adapter_name=adapter_name,
83+
)
84+
85+
return model, list(tenant_settings.keys())
86+
87+
88+
def train():
89+
device_mesh = build_device_mesh()
90+
twinkle.initialize(mode='local', global_device_mesh=device_mesh, lazy_collect=False)
91+
dataloader = create_dataloader(device_mesh)
92+
total_steps = len(dataloader) * EPOCHS
93+
model, adapters = setup_multi_lora_model(device_mesh, total_steps=total_steps)
94+
95+
logger.info(get_device_placement())
96+
for adapter_name in adapters:
97+
logger.info(model.get_train_configs(adapter_name=adapter_name))
98+
99+
global_step = 0
100+
for epoch in range(EPOCHS):
101+
for _, batch in enumerate(dataloader):
102+
adapter_name = adapters[global_step % len(adapters)]
103+
loss = model.forward_backward(inputs=batch, adapter_name=adapter_name)
104+
model.clip_grad_and_step(max_grad_norm=1.0, adapter_name=adapter_name)
105+
106+
if global_step % LOG_INTERVAL == 0:
107+
metric = model.calculate_metric(is_training=True, adapter_name=adapter_name)
108+
logger.info(
109+
f'epoch={epoch}, global_step={global_step}, adapter={adapter_name}, '
110+
f'loss={loss}, metric={metric}'
111+
)
112+
global_step += 1
113+
114+
if SAVE_EVERY_EPOCH:
115+
for adapter_name in adapters:
116+
model.save(
117+
name=f'{adapter_name}-epoch-{epoch}',
118+
output_dir=OUTPUT_DIR,
119+
save_optimizer=True,
120+
adapter_name=adapter_name,
121+
)
122+
123+
for adapter_name in adapters:
124+
model.save(
125+
name=f'{adapter_name}-final',
126+
output_dir=OUTPUT_DIR,
127+
save_optimizer=True,
128+
adapter_name=adapter_name,
129+
)
130+
131+
132+
if __name__ == '__main__':
133+
train()
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import os
2+
from peft import LoraConfig
3+
4+
import twinkle
5+
from twinkle import DeviceMesh, get_device_placement, get_logger
6+
from twinkle.dataloader import DataLoader
7+
from twinkle.dataset import Dataset, DatasetMeta
8+
from twinkle.model import MultiLoraMegatronModel
9+
from twinkle.preprocessor import SelfCognitionProcessor
10+
11+
logger = get_logger()
12+
13+
MODEL_ID = os.getenv('MODEL_ID', 'ms://Qwen/Qwen2.5-7B-Instruct')
14+
DATASET_ID = os.getenv('DATASET_ID', 'ms://swift/self-cognition')
15+
OUTPUT_DIR = os.getenv('OUTPUT_DIR', 'output/multi_lora_megatron')
16+
17+
TRAIN_SAMPLES = int(os.getenv('TRAIN_SAMPLES', '1000'))
18+
BATCH_SIZE = int(os.getenv('BATCH_SIZE', '16'))
19+
EPOCHS = int(os.getenv('EPOCHS', '1'))
20+
GRAD_ACC_STEPS = int(os.getenv('GRAD_ACC_STEPS', '1'))
21+
MAX_LENGTH = int(os.getenv('MAX_LENGTH', '1024'))
22+
MAX_LORAS = int(os.getenv('MAX_LORAS', '4'))
23+
MAX_R = int(os.getenv('MAX_R', '32'))
24+
LOG_INTERVAL = int(os.getenv('LOG_INTERVAL', '10'))
25+
SWITCH_EVERY = int(os.getenv('SWITCH_EVERY', '1'))
26+
SAVE_EVERY_EPOCH = os.getenv('SAVE_EVERY_EPOCH', '1') == '1'
27+
MIXED_PRECISION = os.getenv('MIXED_PRECISION', 'bf16')
28+
29+
DP_SIZE = int(os.getenv('DP_SIZE', '1'))
30+
TP_SIZE = int(os.getenv('TP_SIZE', '1'))
31+
PP_SIZE = int(os.getenv('PP_SIZE', '1'))
32+
CP_SIZE = int(os.getenv('CP_SIZE', '1'))
33+
EP_SIZE = int(os.getenv('EP_SIZE', '1'))
34+
SEQUENCE_PARALLEL = os.getenv('SEQUENCE_PARALLEL', '0') == '1'
35+
USE_DISTRIBUTED_OPTIMIZER = os.getenv('USE_DISTRIBUTED_OPTIMIZER', '1') == '1'
36+
37+
38+
def build_device_mesh() -> DeviceMesh:
39+
kwargs = dict(
40+
dp_size=DP_SIZE,
41+
tp_size=TP_SIZE,
42+
pp_size=PP_SIZE,
43+
cp_size=CP_SIZE,
44+
sequence_parallel=SEQUENCE_PARALLEL,
45+
)
46+
if EP_SIZE > 1:
47+
kwargs['ep_size'] = EP_SIZE
48+
return DeviceMesh.from_sizes(**kwargs)
49+
50+
51+
def create_dataloader(device_mesh: DeviceMesh):
52+
dataset = Dataset(dataset_meta=DatasetMeta(DATASET_ID, data_slice=range(TRAIN_SAMPLES)))
53+
dataset.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH)
54+
dataset.map(SelfCognitionProcessor('twinkle模型', 'twinkle团队'))
55+
dataset.encode(batched=True)
56+
return DataLoader(dataset=dataset, batch_size=BATCH_SIZE, device_mesh=device_mesh)
57+
58+
59+
def setup_model(device_mesh: DeviceMesh, total_steps: int):
60+
model = MultiLoraMegatronModel(
61+
model_id=MODEL_ID,
62+
device_mesh=device_mesh,
63+
mixed_precision=MIXED_PRECISION,
64+
max_loras=MAX_LORAS,
65+
max_r=MAX_R,
66+
use_distributed_optimizer=USE_DISTRIBUTED_OPTIMIZER,
67+
)
68+
69+
tenant_settings = {
70+
'tenant_a': {
71+
'config': LoraConfig(r=8, lora_alpha=32, target_modules='all-linear'),
72+
'lr': 1e-4,
73+
},
74+
'tenant_b': {
75+
'config': LoraConfig(r=16, lora_alpha=32, target_modules='all-linear'),
76+
'lr': 8e-5,
77+
},
78+
}
79+
steps_per_adapter = max(1, (total_steps + len(tenant_settings) - 1) // len(tenant_settings))
80+
warmup_steps = max(1, steps_per_adapter // 10)
81+
82+
for adapter_name, settings in tenant_settings.items():
83+
model.add_adapter_to_model(
84+
adapter_name,
85+
settings['config'],
86+
gradient_accumulation_steps=GRAD_ACC_STEPS,
87+
)
88+
model.set_template('Template', model_id=MODEL_ID, max_length=MAX_LENGTH, adapter_name=adapter_name)
89+
model.set_processor('InputProcessor', padding_side='right', adapter_name=adapter_name)
90+
model.set_loss('CrossEntropyLoss', adapter_name=adapter_name)
91+
model.set_optimizer('default', lr=settings['lr'], adapter_name=adapter_name)
92+
model.set_lr_scheduler(
93+
'default',
94+
lr_warmup_steps=warmup_steps,
95+
lr_decay_steps=steps_per_adapter,
96+
adapter_name=adapter_name,
97+
)
98+
99+
return model, list(tenant_settings.keys())
100+
101+
102+
def train():
103+
device_mesh = build_device_mesh()
104+
twinkle.initialize(mode='local', global_device_mesh=device_mesh, lazy_collect=False)
105+
106+
dataloader = create_dataloader(device_mesh)
107+
total_steps = len(dataloader) * EPOCHS
108+
model, adapters = setup_model(device_mesh, total_steps)
109+
110+
logger.info(get_device_placement())
111+
for adapter_name in adapters:
112+
logger.info(model.get_train_configs(adapter_name=adapter_name))
113+
114+
global_step = 0
115+
for epoch in range(EPOCHS):
116+
for _, batch in enumerate(dataloader):
117+
adapter_name = adapters[(global_step // SWITCH_EVERY) % len(adapters)]
118+
loss = model.forward_backward(inputs=batch, adapter_name=adapter_name)
119+
model.clip_grad_and_step(adapter_name=adapter_name)
120+
121+
if global_step % LOG_INTERVAL == 0:
122+
metric = model.calculate_metric(is_training=True, adapter_name=adapter_name)
123+
logger.info(
124+
f'epoch={epoch}, global_step={global_step}, adapter={adapter_name}, '
125+
f'loss={loss}, metric={metric}'
126+
)
127+
global_step += 1
128+
129+
if SAVE_EVERY_EPOCH:
130+
for adapter_name in adapters:
131+
checkpoint_dir = model.save(
132+
name=f'{adapter_name}-epoch-{epoch}',
133+
output_dir=OUTPUT_DIR,
134+
adapter_name=adapter_name,
135+
)
136+
logger.info(f'saved checkpoint: {checkpoint_dir}')
137+
138+
for adapter_name in adapters:
139+
checkpoint_dir = model.save(
140+
name=f'{adapter_name}-final',
141+
output_dir=OUTPUT_DIR,
142+
adapter_name=adapter_name,
143+
)
144+
logger.info(f'saved checkpoint: {checkpoint_dir}')
145+
146+
147+
if __name__ == '__main__':
148+
train()

0 commit comments

Comments
 (0)