Skip to content

Commit 356612f

Browse files
committed
wip
1 parent b5c535b commit 356612f

File tree

8 files changed

+106
-102
lines changed

8 files changed

+106
-102
lines changed

cookbook/components/dataset.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from twinkle.dataset import Dataset
2+
3+
4+
dataset = Dataset('ms://swift/self-cognition')

cookbook/sft/single_program.py

Lines changed: 0 additions & 97 deletions
This file was deleted.

cookbook/transformers/fsdp2.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import os
2+
3+
from peft import LoraConfig
4+
from tqdm import tqdm
5+
6+
import twinkle
7+
from twinkle import DeviceMesh, Platform
8+
from twinkle import get_device_placement, get_logger
9+
from twinkle.dataloader import DataLoader
10+
from twinkle.dataset import Dataset, DatasetMeta
11+
from twinkle.model import TransformersModel
12+
from twinkle.preprocessor import SelfCognitionProcessor
13+
14+
if Platform.get_rank() == 0:
15+
# rank0 recording
16+
import swanlab
17+
swanlab.login(api_key=os.environ['SWANLAB_API_KEY'], save=True)
18+
19+
run = swanlab.init(
20+
project="megatron-swift",
21+
)
22+
23+
24+
# Construct a device_mesh, fsdp=2, dp=2
25+
device_mesh = DeviceMesh.from_sizes(dp_size=2, fsdp_size=2)
26+
# use torchrun mode
27+
twinkle.initialize(mode='local', global_device_mesh=device_mesh)
28+
29+
logger = get_logger()
30+
31+
32+
def eval(model):
33+
# 100 Samples
34+
dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100)))
35+
dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-7B-Instruct')
36+
dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
37+
dataset.encode()
38+
dataloader = DataLoader(dataset=dataset, batch_size=4)
39+
for step, batch in tqdm(enumerate(dataloader)):
40+
model.forward_only(inputs=batch)
41+
model.calculate_loss()
42+
metrics = model.calculate_metric(is_training=False)
43+
return metrics
44+
45+
46+
def train():
47+
# 1000 samples
48+
dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000)))
49+
# Set template to prepare encoding
50+
dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-7B-Instruct')
51+
# Preprocess the dataset to standard format
52+
dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
53+
# Encode dataset
54+
dataset.encode()
55+
# Global batch size = 4, for GPUs, so 1 sample per GPU
56+
dataloader = DataLoader(dataset=dataset, batch_size=4)
57+
# Use a TransformersModel
58+
model = TransformersModel(model_id='ms://Qwen/Qwen2.5-7B-Instruct')
59+
60+
lora_config = LoraConfig(
61+
r=8,
62+
lora_alpha=32,
63+
target_modules='all-linear'
64+
)
65+
66+
# Add a lora to model, with name `default`
67+
model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=4)
68+
# Add Optimizer for lora `default`
69+
model.set_optimizer(optimizer_cls='AdamW', lr=1e-4)
70+
# Add LRScheduler for lora `default`
71+
model.set_lr_scheduler(scheduler_cls='CosineWarmupScheduler', num_warmup_steps=5, num_training_steps=len(dataloader))
72+
logger.info(get_device_placement())
73+
# Print the training config
74+
logger.info(model.get_train_configs())
75+
logger.info(f'Total steps: {len(dataloader)}')
76+
loss_metric = 99.0
77+
for step, batch in enumerate(dataloader):
78+
# Do forward and backward
79+
model.forward_backward(inputs=batch)
80+
# Step
81+
model.clip_grad_and_step()
82+
if step % 20 == 0:
83+
# Print metric
84+
metric = model.calculate_metric(is_training=True)
85+
if Platform.get_rank() == 0:
86+
swanlab.log(metric)
87+
logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
88+
if step > 0 and step % 40 == 0:
89+
metrics = eval(model)
90+
logger.info(f'Eval metric: {metrics}')
91+
metrics['step'] = step
92+
if loss_metric > float(metrics['loss']):
93+
model.save(f'checkpoint-{step}')
94+
loss_metric = float(metrics['loss'])
95+
model.save(f'last-checkpoint', adapter_name='default')
96+
97+
98+
if __name__ == '__main__':
99+
train()
File renamed without changes.
File renamed without changes.

src/twinkle/model/transformers/transformers.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ def __init__(self, # noqa
200200
self.sp_strategy = None
201201
self._model_wrapped = False
202202
self.optimizer_group: Dict[str, OptimizerGroup] = {_default_adapter_name: self._construct_default_optimizer_group()}
203+
self.active_group = _default_adapter_name
203204

204205
def _decide_strategy(self, strategy: Literal['accelerate', 'native_fsdp']):
205206
self._expert_parallel_config = self._fsdp_config.pop("expert_parallel", None)
@@ -243,11 +244,7 @@ def _get_default_group(self):
243244
"""Get the only group has optimizer, else return the default one"""
244245
if len(self.optimizer_group) == 1:
245246
return next(iter(self.optimizer_group))
246-
names = [name for name, og in self.optimizer_group.items() if og.optimizer is not None]
247-
if names:
248-
assert len(names) == 1, 'Only one group is supported.'
249-
return names[0]
250-
return _default_adapter_name
247+
return self.active_group
251248

252249
@staticmethod
253250
def _not_encoded(inputs):
@@ -905,6 +902,7 @@ def _patch_adapter(self, adapter_name: str, config_or_dir: Union[PeftConfig, str
905902
_gas_default = kwargs.get('gradient_accumulation_steps', 1)
906903
self.optimizer_group[adapter_name].gradient_accumulation_steps = _gas_default
907904
self._default_tokenizer = self.optimizer_group[adapter_name].template.processor
905+
self.active_group = adapter_name
908906

909907
@remote_function()
910908
def add_adapter_to_model(self, adapter_name: str, config_or_dir: Union[PeftConfig, str], **kwargs):

0 commit comments

Comments
 (0)