Skip to content

Commit 266cf83

Browse files
authored
[feat] Add initial Megatron support #2
2 parents 3d72cfe + 2b7b4b8 commit 266cf83

File tree

22 files changed

+7314
-6
lines changed

22 files changed

+7314
-6
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ wheels/
3535
/package
3636
/temp
3737
MANIFEST
38+
.locks/
3839

3940
# PyInstaller
4041
# Usually these files are written by a python script from a template
@@ -93,7 +94,6 @@ celerybeat-schedule
9394
*.sage.py
9495

9596
# Environments
96-
.locks
9797
.env
9898
.venv
9999
env/

cookbook/megatron/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Copyright (c) twinkle authors. All rights reserved.

cookbook/megatron/lora.py

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
# Copyright (c) twinkle authors. All rights reserved.
2+
"""Megatron-Core LoRA training example.
3+
4+
Supports both local (torchrun) and Ray execution modes.
5+
6+
Usage (Local mode):
7+
torchrun --nproc_per_node=4 cookbook/megatron/lora.py --tp_size 2 --pp_size 2
8+
9+
Usage (Ray mode):
10+
TRUST_REMOTE_CODE=1 python cookbook/megatron/lora.py --mode ray --tp_size 2 --pp_size 2 --num_gpus 4
11+
"""
12+
import argparse
13+
import os
14+
15+
import numpy as np
16+
# CRITICAL: Set CUDA device before any CUDA imports (local mode only)
17+
import torch
18+
from peft import LoraConfig
19+
from torch.optim import AdamW
20+
from torch.optim.lr_scheduler import LinearLR
21+
22+
import twinkle
23+
from twinkle import (DeviceGroup, DeviceMesh, Platform, get_device_placement,
24+
get_logger)
25+
from twinkle.dataloader import DataLoader
26+
from twinkle.dataset import Dataset, DatasetMeta
27+
from twinkle.loss import MegatronCrossEntropyLoss
28+
from twinkle.model import MegatronModel
29+
from twinkle.processor import InputProcessor
30+
31+
# Parse arguments first to determine mode
32+
parser = argparse.ArgumentParser()
33+
parser.add_argument('--mode',
34+
type=str,
35+
default='local',
36+
choices=['local', 'ray'])
37+
parser.add_argument('--tp_size', type=int, default=1)
38+
parser.add_argument('--pp_size', type=int, default=1)
39+
parser.add_argument('--cp_size', type=int, default=1)
40+
parser.add_argument('--num_gpus',
41+
type=int,
42+
default=4,
43+
help='Number of GPUs (Ray mode only)')
44+
parser.add_argument('--max_steps', type=int, default=None)
45+
parser.add_argument('--model',
46+
type=str,
47+
default='ms://Qwen/Qwen2.5-7B-Instruct')
48+
GAS = 16 # gradient accumulation steps
49+
args = parser.parse_args()
50+
51+
# Set mode in environment before importing twinkle
52+
os.environ['TWINKLE_MODE'] = args.mode
53+
54+
if args.mode == 'local':
55+
LOCAL_RANK = int(os.environ.get('LOCAL_RANK', '0'))
56+
torch.cuda.set_device(LOCAL_RANK)
57+
58+
logger = get_logger()
59+
60+
61+
def create_dataset():
62+
dataset = Dataset(
63+
dataset_meta=DatasetMeta('ms://modelscope/competition_math'))
64+
dataset.set_template('Qwen3Template',
65+
model_id='ms://Qwen/Qwen2.5-7B-Instruct')
66+
dataset.map('CompetitionMathProcessor')
67+
dataset.encode(batched=True, load_from_cache_file=False)
68+
return dataset
69+
70+
71+
def train():
72+
# Get parallelism config
73+
TP_SIZE = args.tp_size
74+
PP_SIZE = args.pp_size
75+
CP_SIZE = args.cp_size
76+
77+
if args.mode == 'local':
78+
WORLD_SIZE = int(os.environ.get('WORLD_SIZE', '1'))
79+
else:
80+
WORLD_SIZE = args.num_gpus
81+
82+
DP_SIZE = WORLD_SIZE // (TP_SIZE * PP_SIZE * CP_SIZE)
83+
84+
# Device mesh: Match Megatron's order "tp-cp-ep-dp-pp" from innermost to outermost
85+
device_mesh = DeviceMesh(
86+
device_type='cuda',
87+
mesh=np.arange(WORLD_SIZE).reshape(PP_SIZE, DP_SIZE, CP_SIZE, TP_SIZE),
88+
mesh_dim_names=('pp', 'dp', 'cp', 'tp'),
89+
)
90+
91+
# Device group name - used as remote_group in Ray mode
92+
GROUP_NAME = 'model'
93+
94+
device_group = [
95+
DeviceGroup(
96+
name=GROUP_NAME,
97+
ranks=list(range(WORLD_SIZE)),
98+
device_type=Platform.get_platform().device_prefix(),
99+
)
100+
]
101+
102+
twinkle.initialize(
103+
mode=args.mode,
104+
nproc_per_node=WORLD_SIZE,
105+
groups=device_group,
106+
global_device_mesh=device_mesh,
107+
lazy_collect=False,
108+
)
109+
110+
# Use smaller batch size for single GPU to avoid OOM
111+
batch_size = 2 if WORLD_SIZE == 1 else 8
112+
113+
# In Ray mode, pass remote_group and device_mesh
114+
if args.mode == 'ray':
115+
dataloader = DataLoader(
116+
dataset=create_dataset,
117+
batch_size=batch_size,
118+
remote_group=GROUP_NAME,
119+
device_mesh=device_mesh,
120+
)
121+
model = MegatronModel(
122+
pretrained_model_name_or_path=args.model,
123+
tensor_model_parallel_size=TP_SIZE,
124+
pipeline_model_parallel_size=PP_SIZE,
125+
context_parallel_size=CP_SIZE,
126+
mixed_precision='bf16',
127+
recompute_granularity='full' if WORLD_SIZE <= 2 else 'selective',
128+
remote_group=GROUP_NAME,
129+
device_mesh=device_mesh,
130+
)
131+
else:
132+
dataloader = DataLoader(dataset=create_dataset, batch_size=batch_size)
133+
model = MegatronModel(
134+
pretrained_model_name_or_path=args.model,
135+
tensor_model_parallel_size=TP_SIZE,
136+
pipeline_model_parallel_size=PP_SIZE,
137+
context_parallel_size=CP_SIZE,
138+
mixed_precision='bf16',
139+
recompute_granularity='full' if WORLD_SIZE <= 2 else 'selective',
140+
)
141+
142+
lora_config = LoraConfig(target_modules='all-linear')
143+
adapter_name = 'lora'
144+
model.add_adapter_to_model(adapter_name,
145+
lora_config,
146+
gradient_accumulation_steps=GAS)
147+
model.set_template('Qwen3Template', adapter_name=adapter_name)
148+
model.set_processor(InputProcessor,
149+
padding_side='right',
150+
adapter_name=adapter_name)
151+
model.set_loss(MegatronCrossEntropyLoss, adapter_name=adapter_name)
152+
model.set_optimizer(AdamW, lr=1e-4, adapter_name=adapter_name)
153+
model.set_lr_scheduler(LinearLR, adapter_name=adapter_name)
154+
155+
logger.info(get_device_placement())
156+
logger.info(model.get_train_configs(adapter_name=adapter_name))
157+
158+
for step, batch in enumerate(dataloader):
159+
output = model.forward_backward(inputs=batch,
160+
adapter_name=adapter_name)
161+
if step % GAS == 0:
162+
logger.info(f'Step {step // 16}, loss: {output}')
163+
model.clip_grad_norm(1.0, adapter_name=adapter_name)
164+
model.step(adapter_name=adapter_name)
165+
model.zero_grad(adapter_name=adapter_name)
166+
model.lr_step(adapter_name=adapter_name)
167+
if step > 0 and step % (100 * GAS) == 0:
168+
model.save('./output/megatron_lora', adapter_name=adapter_name)
169+
# Early stop for testing
170+
if args.max_steps and step >= args.max_steps * GAS:
171+
logger.info(f'Reached max_steps ({args.max_steps}), stopping.')
172+
break
173+
model.save('./output/megatron_lora', adapter_name=adapter_name)
174+
logger.info('Training completed!')
175+
176+
def cleanup():
177+
"""Clean up distributed resources."""
178+
import torch.distributed as dist
179+
try:
180+
if dist.is_initialized():
181+
dist.barrier()
182+
from megatron.core import parallel_state as mpu
183+
if mpu.is_initialized():
184+
mpu.destroy_model_parallel()
185+
except Exception as e:
186+
logger.warning(f"Error during cleanup: {e}")
187+
if dist.is_initialized():
188+
dist.destroy_process_group()
189+
190+
191+
if __name__ == '__main__':
192+
try:
193+
train()
194+
finally:
195+
cleanup()

0 commit comments

Comments
 (0)