Skip to content

Commit 04f9f71

Browse files
committed
wip
1 parent 3fa6f17 commit 04f9f71

File tree

3 files changed

+124
-19
lines changed

3 files changed

+124
-19
lines changed

cookbook/megatron/ddp.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import os
2+
3+
from peft import LoraConfig
4+
from tqdm import tqdm
5+
6+
import twinkle
7+
from twinkle import DeviceMesh, Platform
8+
from twinkle import get_device_placement, get_logger
9+
from twinkle.dataloader import DataLoader
10+
from twinkle.dataset import Dataset, DatasetMeta
11+
from twinkle.model import MegatronModel
12+
from twinkle.preprocessor import SelfCognitionProcessor
13+
14+
if Platform.get_rank() == 0:
15+
# rank0 recording
16+
import swanlab
17+
swanlab.login(api_key=os.environ['SWANLAB_API_KEY'], save=True)
18+
19+
run = swanlab.init(
20+
project="twinkle",
21+
)
22+
23+
24+
# Construct a device_mesh, tp=pp=cp=2, dp=1
25+
device_mesh = DeviceMesh.from_sizes(dp_size=1, tp_size=2, pp_size=2, cp_size=2)
26+
# use torchrun mode
27+
twinkle.initialize(mode='local', global_device_mesh=device_mesh)
28+
29+
logger = get_logger()
30+
31+
32+
def eval(model):
33+
# 100 Samples
34+
dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(100)))
35+
dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-7B-Instruct')
36+
dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
37+
dataset.encode()
38+
dataloader = DataLoader(dataset=dataset, batch_size=1)
39+
for step, batch in tqdm(enumerate(dataloader)):
40+
model.forward_only(inputs=batch)
41+
model.calculate_loss()
42+
metrics = model.calculate_metric(is_training=False)
43+
return metrics
44+
45+
46+
def train():
47+
# 1000 samples
48+
dataset = Dataset(dataset_meta=DatasetMeta('ms://swift/self-cognition', data_slice=range(1000)))
49+
# Set template to prepare encoding
50+
dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-7B-Instruct')
51+
# Preprocess the dataset to standard format
52+
dataset.map(SelfCognitionProcessor('twinkle大模型', 'ModelScope社区'))
53+
# Encode dataset
54+
dataset.encode()
55+
# Global batch size = 1, dp_size = 1
56+
dataloader = DataLoader(dataset=dataset, batch_size=1)
57+
# Use a MegatronModel
58+
model = MegatronModel(model_id='ms://Qwen/Qwen2.5-7B-Instruct')
59+
60+
lora_config = LoraConfig(
61+
r=8,
62+
lora_alpha=32,
63+
target_modules='all-linear'
64+
)
65+
66+
# Add a lora to model, with name `default`
67+
model.add_adapter_to_model('default', lora_config, gradient_accumulation_steps=16)
68+
# Add Optimizer for lora `default`
69+
model.set_optimizer(optimizer_cls='default', lr=1e-4)
70+
# Add LRScheduler for lora `default`
71+
model.set_lr_scheduler(scheduler_cls='default', num_warmup_steps=5, num_training_steps=len(dataloader))
72+
logger.info(get_device_placement())
73+
# Print the training config
74+
logger.info(model.get_train_configs())
75+
logger.info(f'Total steps: {len(dataloader)}')
76+
loss_metric = 99.0
77+
for step, batch in enumerate(dataloader):
78+
# Do forward and backward
79+
model.forward_backward(inputs=batch)
80+
# Step
81+
model.clip_grad_and_step()
82+
if step % 20 == 0:
83+
# Print metric
84+
metric = model.calculate_metric(is_training=True)
85+
if Platform.get_rank() == 0:
86+
swanlab.log(metric)
87+
logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
88+
if step > 0 and step % 40 == 0:
89+
metrics = eval(model)
90+
logger.info(f'Eval metric: {metrics}')
91+
metrics['step'] = step
92+
if loss_metric > float(metrics['loss']):
93+
model.save(f'checkpoint-{step}')
94+
loss_metric = float(metrics['loss'])
95+
model.save(f'last-checkpoint')
96+
97+
98+
if __name__ == '__main__':
99+
train()

cookbook/transformers/fsdp2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
swanlab.login(api_key=os.environ['SWANLAB_API_KEY'], save=True)
1818

1919
run = swanlab.init(
20-
project="megatron-swift",
20+
project="twinkle",
2121
)
2222

2323

@@ -92,7 +92,7 @@ def train():
9292
if loss_metric > float(metrics['loss']):
9393
model.save(f'checkpoint-{step}')
9494
loss_metric = float(metrics['loss'])
95-
model.save(f'last-checkpoint', adapter_name='default')
95+
model.save(f'last-checkpoint')
9696

9797

9898
if __name__ == '__main__':

src/twinkle/model/megatron/megatron.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,8 @@ def __init__(
185185
self._model_wrapped = False
186186
# This correctly handles vocab sharding in Tensor Parallelism
187187
self.optimizer_group: Dict[str, MegatronOptimizerGroup] = {_default_adapter_name: self._construct_default_optimizer_group()}
188-
MegatronPeft().patch()
189-
188+
self.active_group = _default_adapter_name
189+
MegatronPeft().__call__()
190190

191191
def _construct_default_optimizer_group(self):
192192
return MegatronOptimizerGroup(
@@ -230,6 +230,12 @@ def _lazy_wrap_model(self):
230230
self.model = self.strategy.wrap_model(self.model)
231231
self._model_wrapped = True
232232

233+
def _get_default_group(self):
234+
"""Get the only group has optimizer, else return the default one"""
235+
if len(self.optimizer_group) == 1:
236+
return next(iter(self.optimizer_group))
237+
return self.active_group
238+
233239
@staticmethod
234240
def _not_encoded(inputs):
235241
assert isinstance(inputs, dict)
@@ -299,7 +305,7 @@ def forward_backward(self,
299305
from megatron.core.pipeline_parallel import get_forward_backward_func
300306
from megatron.core import parallel_state as mpu
301307

302-
adapter_name = kwargs.pop('adapter_name', _default_adapter_name)
308+
adapter_name = kwargs.pop('adapter_name', self._get_default_group())
303309
forward_only = kwargs.pop('forward_only', False)
304310
optimizer_config = self.optimizer_group[adapter_name]
305311
loss_instance = self.optimizer_group[adapter_name].loss_instance
@@ -465,7 +471,7 @@ def step(self, **kwargs):
465471
Args:
466472
**kwargs: Additional arguments.
467473
"""
468-
adapter_name = kwargs.pop('adapter_name', _default_adapter_name)
474+
adapter_name = kwargs.pop('adapter_name', self._get_default_group())
469475
optimizer_config = self.optimizer_group[adapter_name]
470476

471477
if not optimizer_config.do_grad_sync(
@@ -503,7 +509,7 @@ def zero_grad(self, **kwargs):
503509
Args:
504510
**kwargs: Additional arguments.
505511
"""
506-
adapter_name = kwargs.pop('adapter_name', _default_adapter_name)
512+
adapter_name = kwargs.pop('adapter_name', self._get_default_group())
507513
optimizer_config = self.optimizer_group[adapter_name]
508514

509515
# For DDP-wrapped models, ALWAYS zero the gradient buffer
@@ -528,7 +534,7 @@ def lr_step(self, **kwargs):
528534
Args:
529535
**kwargs: Additional arguments.
530536
"""
531-
adapter_name = kwargs.pop('adapter_name', _default_adapter_name)
537+
adapter_name = kwargs.pop('adapter_name', self._get_default_group())
532538
optimizer_config = self.optimizer_group[adapter_name]
533539

534540
if not optimizer_config.do_grad_sync(
@@ -557,7 +563,7 @@ def set_loss(self, loss_cls: Union[Loss, Type[Loss], str, Callable[[InputFeature
557563
loss_cls: Loss class or string name (not used for Megatron).
558564
**kwargs: Additional arguments.
559565
"""
560-
adapter_name = kwargs.pop('adapter_name', _default_adapter_name)
566+
adapter_name = kwargs.pop('adapter_name', self._get_default_group())
561567
optimizer_config = self.optimizer_group[adapter_name]
562568
optimizer_config.loss_instance = construct_class(loss_cls, Loss, twinkle.loss, **kwargs)
563569

@@ -571,7 +577,7 @@ def add_metric(self, metric_cls: Union[Metric, str], is_training: Optional[bool]
571577
adapter_name: Lora adapter name.
572578
Any parameters needed to construct the metric_cls instance.
573579
"""
574-
adapter_name = kwargs.pop('adapter_name', _default_adapter_name)
580+
adapter_name = kwargs.pop('adapter_name', self._get_default_group())
575581
optimizer_config = self.optimizer_group[adapter_name]
576582
kwargs['device_mesh'] = self.device_mesh
577583
kwargs['process_group'] = optimizer_config._dp_group
@@ -593,7 +599,7 @@ def set_optimizer(self, optimizer_cls: Union[Optimizer, Type[Optimizer], str],
593599
- For standard optimizers: lr, weight_decay, etc.
594600
- For MegatronDistributed: use_distributed_optimizer, clip_grad, etc.
595601
"""
596-
adapter_name = kwargs.pop('adapter_name', _default_adapter_name)
602+
adapter_name = kwargs.pop('adapter_name', self._get_default_group())
597603
optimizer_config = self.optimizer_group[adapter_name]
598604
if not self._model_wrapped:
599605
self.model = self.strategy.wrap_model(self.model)
@@ -611,7 +617,7 @@ def _accumulate_metric(optimizer_config: MegatronOptimizerGroup, is_training):
611617

612618
@remote_function(collect='first', lazy_collect=False)
613619
def calculate_metric(self, is_training, **kwargs):
614-
adapter_name = kwargs.pop('adapter_name', _default_adapter_name)
620+
adapter_name = kwargs.pop('adapter_name', self._get_default_group())
615621
optimizer_config = self.optimizer_group[adapter_name]
616622
return optimizer_config.calculate_metrics(is_training)
617623

@@ -715,7 +721,7 @@ def set_lr_scheduler(self, scheduler_cls: Union[LRScheduler, Type[LRScheduler],
715721
scheduler_cls: Scheduler class or string name.
716722
**kwargs: Additional arguments.
717723
"""
718-
adapter_name = kwargs.pop('adapter_name', _default_adapter_name)
724+
adapter_name = kwargs.pop('adapter_name', self._get_default_group())
719725
optimizer_config = self.optimizer_group[adapter_name]
720726
optimizer = optimizer_config.optimizer
721727
if not scheduler_cls or scheduler_cls in ('OptimizerParamScheduler', 'default'):
@@ -738,7 +744,7 @@ def save(self, name: Optional[str] = None, output_dir: Optional[str] = None, int
738744
interval: Save each interval steps.
739745
**kwargs: Additional arguments.
740746
"""
741-
adapter_name = kwargs.pop('adapter_name', _default_adapter_name)
747+
adapter_name = kwargs.pop('adapter_name', self._get_default_group())
742748
optimizer_config = self.optimizer_group[adapter_name]
743749
if optimizer_config.cur_step % interval != 0:
744750
return
@@ -772,7 +778,7 @@ def load(self, name: str, output_dir: Optional[str] = None, **kwargs):
772778
checkpoint_dir = HubOperation.download_model(name, token=token)
773779
else:
774780
checkpoint_dir = os.path.join(output_dir, name)
775-
adapter_name = kwargs.get('adapter_name')
781+
adapter_name = kwargs.get('adapter_name', self._get_default_group())
776782
bridge = self._bridge
777783
for _model in self.strategy.unwrap_model(self.model):
778784
bridge.load_weights(_model, checkpoint_dir, is_peft_format = (adapter_name != _default_adapter_name))
@@ -860,7 +866,7 @@ def get_state_dict(self, **kwargs):
860866
Returns:
861867
State dict of trainable parameters.
862868
"""
863-
adapter_name = kwargs.pop('adapter_name', _default_adapter_name)
869+
adapter_name = kwargs.pop('adapter_name', self._get_default_group())
864870
return self._get_trainable_parameters(adapter_name)
865871

866872
def get_hf_state_dict(self, adapter_name: str = '') -> Generator[Tuple[str, torch.Tensor], None, None]:
@@ -988,7 +994,7 @@ def set_template(self, template_cls: Union[Template, Type[Template], str], **kwa
988994
template_cls: Template class or string name.
989995
**kwargs: Additional arguments.
990996
"""
991-
adapter_name = kwargs.pop('adapter_name', _default_adapter_name)
997+
adapter_name = kwargs.pop('adapter_name', self._get_default_group())
992998
optimizer_config = self.optimizer_group[adapter_name]
993999
optimizer_config.template = construct_class(template_cls, Template, twinkle.template, **kwargs)
9941000

@@ -1000,7 +1006,7 @@ def set_processor(self, processor_cls: Union[InputProcessor, Type[InputProcessor
10001006
processor_cls: Processor class or string name.
10011007
**kwargs: Additional arguments.
10021008
"""
1003-
adapter_name = kwargs.pop('adapter_name', _default_adapter_name)
1009+
adapter_name = kwargs.pop('adapter_name', self._get_default_group())
10041010
optimizer_config = self.optimizer_group[adapter_name]
10051011
kwargs['framework'] = 'megatron'
10061012
optimizer_config.processor = construct_class(processor_cls, InputProcessor, twinkle.processor, **kwargs)
@@ -1015,7 +1021,7 @@ def get_train_configs(self, **kwargs):
10151021
Returns:
10161022
Configuration summary string.
10171023
"""
1018-
adapter_name = kwargs.pop('adapter_name', _default_adapter_name)
1024+
adapter_name = kwargs.pop('adapter_name', self._get_default_group())
10191025
optimizer_config = self.optimizer_group[adapter_name]
10201026

10211027
expr = f'Backend: Megatron-Core\n'

0 commit comments

Comments
 (0)