Skip to content

Commit bebe60e

Browse files
committed
wip
1 parent ed00c1b commit bebe60e

File tree

6 files changed

+29
-23
lines changed

6 files changed

+29
-23
lines changed

cookbook/rl/dpo_lora.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
from twinkle.dataset import Dataset, DatasetMeta
5959
from twinkle.loss import DPOLoss
6060
from twinkle.metric import DPOMetric
61-
from twinkle.model import MegatronModel
61+
from twinkle.model import MultiLoraMegatronModel
6262
from twinkle.preprocessor import EmojiDPOProcessor
6363
from twinkle.processor import InputProcessor
6464

@@ -68,7 +68,7 @@
6868
MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen2.5-7B-Instruct')
6969
DATASET_ID = os.environ.get('DATASET_ID', 'ms://hjh0119/shareAI-Llama3-DPO-zh-en-emoji')
7070

71-
MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 2))
71+
MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 8))
7272

7373
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 2)) # Number of preference pairs
7474
MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 2))
@@ -137,7 +137,7 @@ def main():
137137
DeviceGroup(name='policy', ranks=list(range(MODEL_GPUS)), device_type='GPU'),
138138
]
139139

140-
policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS)
140+
policy_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=1, pp_size=2, cp_size=2, tp_size=2)
141141
twinkle.initialize(mode='ray', nproc_per_node=8, groups=device_groups)
142142

143143
# ── DataLoader Setup ──────────────────────────────────────────────────────
@@ -157,15 +157,17 @@ def main():
157157
lora_dropout=0.05,
158158
)
159159

160-
policy_model = MegatronModel(
160+
policy_model = MultiLoraMegatronModel(
161161
model_id=MODEL_ID,
162162
device_mesh=policy_mesh,
163163
remote_group='policy',
164164
)
165165
MAX_STEPS = len(dataloader)
166166
policy_model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS)
167-
policy_model.set_optimizer('default', lr=LEARNING_RATE, weight_decay=0.01)
168-
policy_model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS)
167+
# policy_model.set_optimizer('AdamW', lr=LEARNING_RATE, weight_decay=0.01, adapter_name=ADAPTER_NAME)
168+
# policy_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, adapter_name=ADAPTER_NAME)
169+
policy_model.set_optimizer('default', lr=LEARNING_RATE, weight_decay=0.01, adapter_name=ADAPTER_NAME)
170+
policy_model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS, adapter_name=ADAPTER_NAME)
169171

170172
# Set up loss function and metrics
171173
loss_fn = DPOLoss(
@@ -174,10 +176,10 @@ def main():
174176
reference_free=False, # We use base model as reference via disable_lora=True
175177
sft_weight=SFT_WEIGHT,
176178
)
177-
policy_model.set_loss(loss_fn)
178-
policy_model.add_metric(DPOMetric, beta=DPO_BETA)
179-
policy_model.set_processor(InputProcessor)
180-
policy_model.set_template('Template', model_id=MODEL_ID)
179+
policy_model.set_loss(loss_fn, adapter_name=ADAPTER_NAME)
180+
policy_model.add_metric(DPOMetric, beta=DPO_BETA, adapter_name=ADAPTER_NAME)
181+
policy_model.set_processor(InputProcessor, adapter_name=ADAPTER_NAME)
182+
policy_model.set_template('Template', model_id=MODEL_ID, adapter_name=ADAPTER_NAME)
181183

182184
optim_step = 0
183185
logger.info(get_device_placement())
@@ -191,32 +193,32 @@ def main():
191193

192194
# Get reference outputs using base model (without LoRA adapter)
193195
# disable_lora=True tells the model to skip LoRA and use base weights
194-
ref_outputs = policy_model.forward_only(inputs=dpo_batch, micro_batch_size=2, disable_lora=True)
195-
196+
ref_outputs = policy_model.forward_only(inputs=dpo_batch, micro_batch_size=2, disable_lora=True, adapter_name=ADAPTER_NAME)
196197
# Forward-backward pass with DPO loss (using LoRA adapter)
197198
# ref_outputs is passed to loss which extracts logps internally
198199
policy_model.forward_backward(
199200
inputs=dpo_batch,
200201
ref_outputs=ref_outputs,
201202
micro_batch_size=2,
203+
adapter_name=ADAPTER_NAME
202204
)
203205

204206
# Gradient clipping and optimizer step
205-
policy_model.clip_grad_and_step()
207+
policy_model.clip_grad_and_step(adapter_name=ADAPTER_NAME)
206208
optim_step += 1
207209

208210
# Logging
209211
if optim_step % 1 == 0:
210-
metrics = policy_model.calculate_metric(is_training=True)
212+
metrics = policy_model.calculate_metric(is_training=True, adapter_name=ADAPTER_NAME)
211213
logger.info(f'[Step {optim_step}/{MAX_STEPS}] {metrics}')
212214

213215
# Checkpointing
214216
if optim_step % SAVE_STEPS == 0:
215-
policy_model.save(f'dpo-lora-checkpoint-{optim_step}')
217+
policy_model.save(f'dpo-lora-checkpoint-{optim_step}', adapter_name=ADAPTER_NAME)
216218

217219
# ── Save Final Checkpoint ─────────────────────────────────────────────────
218220
logger.info(f'Training completed. Total steps: {optim_step}')
219-
policy_model.save('dpo-lora-final-checkpoint')
221+
policy_model.save('dpo-lora-final-checkpoint', adapter_name=ADAPTER_NAME)
220222

221223

222224
if __name__ == '__main__':

src/twinkle/infra/collectors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def _pad_and_stack_tensors(tensors: List['torch.Tensor'], pad_value: float = -20
6161
raise ValueError('Empty tensor list')
6262

6363
if len(tensors) == 1:
64-
return tensors[0].unsqueeze(0)
64+
return tensors[0]
6565

6666
max_ndim = max(t.ndim for t in tensors)
6767
expanded_tensors = []

src/twinkle/model/megatron/megatron.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,9 @@ def forward_backward(self,
400400
seq_length = original_seq_length + (divisor - original_seq_length % divisor)
401401
else:
402402
seq_length = original_seq_length
403-
403+
404+
if 'ref_outputs' in kwargs:
405+
breakpoint()
404406
num_microbatches = len(inputs)
405407
loss_extra_kwargs_per_mb = []
406408
if num_microbatches <= 1:

src/twinkle/model/megatron/multi_lora_megatron.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T
129129
with self.multi_adapter.adapter(adapter_name, disable_lora=disable_lora):
130130
return super().forward_only(inputs=inputs, **kwargs)
131131

132-
@remote_function(dispatch='slice_dp', collect='mean', sync=True)
132+
@remote_function(dispatch='slice_dp', collect=collect_tensor_dict, sync=True)
133133
def forward_backward(self,
134134
*,
135135
inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]],

src/twinkle/model/multi_lora.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def _linear_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
205205

206206
lora_A_keys = self.lora_A.keys()
207207
for active_adapter in self.active_adapters:
208-
if active_adapter not in lora_A_keys:
208+
if active_adapter not in lora_A_keys or self.disable_adapters:
209209
continue
210210
_lora = _self.find_lora(active_adapter)
211211
target_modules = _lora.tenant_config.target_modules
@@ -238,7 +238,7 @@ def _embedding_forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
238238

239239
lora_embedding_A_keys = self.lora_embedding_A.keys()
240240
for active_adapter in self.active_adapters:
241-
if active_adapter not in lora_embedding_A_keys:
241+
if active_adapter not in lora_embedding_A_keys or self.disable_adapters:
242242
continue
243243
_lora = self.find_lora(active_adapter)
244244
target_modules = _lora.tenant_config.target_modules

src/twinkle/model/transformers/multi_lora_transformers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union
99

1010
from twinkle import DeviceMesh, remote_class, remote_function, template
11+
from twinkle.infra import collect_tensor_dict
1112
from twinkle.data_format import InputFeature, Trajectory
1213
from twinkle.hub import HubOperation
1314
from twinkle.loss import Loss
@@ -88,7 +89,7 @@ def unregister_mm_forward_hook(self, optimizer_group: OptimizerGroup):
8889
def _lazy_wrap_model(self):
8990
pass
9091

91-
@remote_function(dispatch='slice_dp', collect='mean')
92+
@remote_function(dispatch='slice_dp', collect=collect_tensor_dict)
9293
def forward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], **kwargs):
9394
self._check_adapter_valid(kwargs.get('adapter_name'))
9495
optimizer_config = self.optimizer_group[kwargs.get('adapter_name')]
@@ -104,7 +105,7 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory,
104105
with self.multi_adapter.adapter(kwargs.get('adapter_name')):
105106
return super().forward(inputs=inputs, **kwargs)
106107

107-
@remote_function(dispatch='slice_dp', collect='flatten')
108+
@remote_function(dispatch='slice_dp', collect=collect_tensor_dict)
108109
def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajectory]], **kwargs):
109110
adapter_name = kwargs.get('adapter_name')
110111
disable_lora = kwargs.get('disable_lora', False)
@@ -246,6 +247,7 @@ def set_grad_scaler(self, **kwargs):
246247
self._check_adapter_valid(kwargs.get('adapter_name'))
247248
super().set_grad_scaler(**kwargs)
248249

250+
@remote_function()
249251
def add_metric(self, metric_cls: Union[Metric, str], is_training: Optional[bool] = None, **kwargs):
250252
self._check_adapter_valid(kwargs.get('adapter_name'))
251253
super().add_metric(metric_cls, is_training, **kwargs)

0 commit comments

Comments
 (0)