Skip to content

Commit ed00c1b

Browse files
committed
wip
1 parent f4fe545 commit ed00c1b

File tree

6 files changed

+44
-64
lines changed

6 files changed

+44
-64
lines changed

cookbook/rl/dpo_lora.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
from twinkle.dataset import Dataset, DatasetMeta
5959
from twinkle.loss import DPOLoss
6060
from twinkle.metric import DPOMetric
61-
from twinkle.model import TransformersModel
61+
from twinkle.model import MegatronModel
6262
from twinkle.preprocessor import EmojiDPOProcessor
6363
from twinkle.processor import InputProcessor
6464

@@ -157,15 +157,15 @@ def main():
157157
lora_dropout=0.05,
158158
)
159159

160-
policy_model = TransformersModel(
160+
policy_model = MegatronModel(
161161
model_id=MODEL_ID,
162162
device_mesh=policy_mesh,
163163
remote_group='policy',
164164
)
165165
MAX_STEPS = len(dataloader)
166166
policy_model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS)
167-
policy_model.set_optimizer('AdamW', lr=LEARNING_RATE, weight_decay=0.01)
168-
policy_model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=LEARNING_RATE * 0.1)
167+
policy_model.set_optimizer('default', lr=LEARNING_RATE, weight_decay=0.01)
168+
policy_model.set_lr_scheduler('default', lr_decay_steps=MAX_STEPS)
169169

170170
# Set up loss function and metrics
171171
loss_fn = DPOLoss(
@@ -191,13 +191,14 @@ def main():
191191

192192
# Get reference outputs using base model (without LoRA adapter)
193193
# disable_lora=True tells the model to skip LoRA and use base weights
194-
ref_outputs = policy_model.forward_only(inputs=dpo_batch, disable_lora=True)
194+
ref_outputs = policy_model.forward_only(inputs=dpo_batch, micro_batch_size=2, disable_lora=True)
195195

196196
# Forward-backward pass with DPO loss (using LoRA adapter)
197197
# ref_outputs is passed to loss which extracts logps internally
198198
policy_model.forward_backward(
199199
inputs=dpo_batch,
200200
ref_outputs=ref_outputs,
201+
micro_batch_size=2,
201202
)
202203

203204
# Gradient clipping and optimizer step

src/twinkle/loss/dpo.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -326,10 +326,7 @@ def __call__(
326326
reference_chosen_logps = torch.zeros_like(policy_chosen_logps)
327327
reference_rejected_logps = torch.zeros_like(policy_rejected_logps)
328328
else:
329-
raise ValueError(
330-
"ref_logps or (ref_chosen_logps, ref_rejected_logps) must be provided "
331-
"unless reference_free=True"
332-
)
329+
return LossOutput(loss=torch.tensor(0.0, device=chosen_logps.device), num_tokens=0)
333330

334331
# Compute DPO loss
335332
dpo_loss = self._compute_dpo_loss(

src/twinkle/metric/dpo.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -80,27 +80,18 @@ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: M
8080
- kwargs['ref_outputs']: Optional reference model outputs with 'logps'
8181
"""
8282
import torch
83-
8483
logps = outputs.get('logps')
8584
if logps is None:
8685
return
8786

8887
# Get labels from inputs
8988
if isinstance(inputs, list):
90-
# Stack labels from list of inputs
91-
labels_list = [torch.as_tensor(inp['labels']) for inp in inputs]
92-
max_len = max(l.shape[0] for l in labels_list)
93-
padded = []
94-
for l in labels_list:
95-
if l.shape[0] < max_len:
96-
pad = torch.full((max_len - l.shape[0],), self.ignore_index, dtype=l.dtype)
97-
l = torch.cat([pad, l])
98-
padded.append(l)
99-
labels = torch.stack(padded)
100-
else:
101-
labels = torch.as_tensor(inputs['labels'])
102-
if labels.dim() == 1:
103-
labels = labels.unsqueeze(0)
89+
assert len(inputs) == 1
90+
inputs = inputs[0]
91+
92+
labels = torch.as_tensor(inputs['labels'])
93+
if labels.dim() == 1:
94+
labels = labels.unsqueeze(0)
10495

10596
# Ensure logps and labels have same device
10697
if logps.device != labels.device:
@@ -129,7 +120,6 @@ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: M
129120
ref_logps = ref_outputs.get('logps')
130121
if ref_logps is not None:
131122
# Align ref_logps to match labels shape (handles different seq lengths)
132-
# breakpoint()
133123
ref_logps = self._align_logps(
134124
ref_logps, labels.shape, labels.device, logps.dtype
135125
)

src/twinkle/model/megatron/megatron.py

Lines changed: 30 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,13 @@ def post_loss_function(output_tensor, inputs, logps):
447447
def forward_step_func(data_iterator, model):
448448
batch = next(data_iterator)
449449
labels = batch.pop('labels', None)
450-
output_tensor = model(**batch)
450+
# Handle disable_lora for base model inference (e.g., reference in DPO)
451+
unwrapped_model = self.strategy.unwrap_model([model])[0]
452+
if disable_lora and isinstance(unwrapped_model, PeftModel):
453+
with unwrapped_model.disable_adapter():
454+
output_tensor = model(**batch)
455+
else:
456+
output_tensor = model(**batch)
451457
batch['labels'] = labels
452458
logps = None
453459
if labels is not None and mpu.is_pipeline_last_stage():
@@ -475,34 +481,17 @@ def forward_step_func(data_iterator, model):
475481

476482
self._accumulate_metric(optimizer_config, is_training=not forward_only)
477483

478-
# Handle disable_lora for base model inference (e.g., reference in DPO)
479-
def _set_disable_adapters(model, value: bool):
480-
model = self.strategy.unwrap_model(model)
481-
if isinstance(model, list):
482-
for m in model:
483-
if isinstance(m, PeftModel):
484-
m.disable_adapters = value
485-
elif isinstance(model, PeftModel):
486-
model.disable_adapters = value
487-
488-
if disable_lora:
489-
_set_disable_adapters(self.model, True)
490-
491-
try:
492-
# Run forward-backward with Megatron's scheduler
493-
# Megatron handles all communication internally using proper process groups
494-
losses = forward_backward_func(
495-
forward_step_func=forward_step_func,
496-
data_iterator=data_iter,
497-
model=self.model,
498-
num_microbatches=len(inputs),
499-
seq_length=seq_length,
500-
micro_batch_size=micro_batch_size,
501-
forward_only=forward_only,
502-
)
503-
finally:
504-
if disable_lora:
505-
_set_disable_adapters(self.model, False)
484+
# Run forward-backward with Megatron's scheduler
485+
# Megatron handles all communication internally using proper process groups
486+
losses = forward_backward_func(
487+
forward_step_func=forward_step_func,
488+
data_iterator=data_iter,
489+
model=self.model,
490+
num_microbatches=len(inputs),
491+
seq_length=seq_length,
492+
micro_batch_size=micro_batch_size,
493+
forward_only=forward_only,
494+
)
506495

507496
# Extract loss from results (only last PP stage returns non-empty)
508497
loss = torch.tensor(0.0).to(Platform.get_local_device())
@@ -559,9 +548,11 @@ def _set_disable_adapters(model, value: bool):
559548
if forward_only:
560549
optimizer_config.eval_status.inputs = inputs
561550
optimizer_config.eval_status.outputs = ModelOutput(logits=logits, loss=loss, logps=logps)
551+
optimizer_config.eval_status.forward_kwargs = kwargs
562552
else:
563553
optimizer_config.train_status.inputs = inputs
564554
optimizer_config.train_status.outputs = ModelOutput(logits=logits, loss=loss, logps=logps)
555+
optimizer_config.train_status.forward_kwargs = kwargs
565556
return ModelOutput(logits=logits, loss=loss, logps=logps)
566557

567558
@remote_function(dispatch='all')
@@ -692,6 +683,7 @@ def set_loss(self, loss_cls: Union[Loss, Type[Loss], str, Callable[[InputFeature
692683
optimizer_config = self.optimizer_group[adapter_name]
693684
optimizer_config.loss_instance = construct_class(loss_cls, Loss, twinkle.loss, **kwargs)
694685

686+
@remote_function()
695687
def add_metric(self, metric_cls: Union[Metric, str], is_training: Optional[bool] = None, **kwargs):
696688
"""Add an eval metric
697689
@@ -773,16 +765,16 @@ def _create_megatron_optimizer(self, **kwargs):
773765
opt_config = OptimizerConfig(
774766
optimizer='adam',
775767
lr=lr,
776-
min_lr=kwargs.get('min_lr', 0.0),
777-
weight_decay=kwargs.get('weight_decay', 0.01),
778-
adam_beta1=kwargs.get('adam_beta1', 0.9),
779-
adam_beta2=kwargs.get('adam_beta2', 0.999),
780-
adam_eps=kwargs.get('adam_eps', 1e-8),
781-
clip_grad=kwargs.get('clip_grad', 1.0),
782-
bf16=kwargs.get('bf16', True),
768+
min_lr=kwargs.pop('min_lr', 0.0),
769+
weight_decay=kwargs.pop('weight_decay', 0.01),
770+
adam_beta1=kwargs.pop('adam_beta1', 0.9),
771+
adam_beta2=kwargs.pop('adam_beta2', 0.999),
772+
adam_eps=kwargs.pop('adam_eps', 1e-8),
773+
clip_grad=kwargs.pop('clip_grad', 1.0),
774+
bf16=kwargs.pop('bf16', True),
783775
use_distributed_optimizer=use_distributed_optimizer,
784-
overlap_param_gather=kwargs.get('overlap_param_gather', False),
785-
log_num_zeros_in_grad=kwargs.get('log_num_zeros_in_grad', False),
776+
overlap_param_gather=kwargs.pop('overlap_param_gather', False),
777+
log_num_zeros_in_grad=kwargs.pop('log_num_zeros_in_grad', False),
786778
**kwargs,
787779
)
788780

src/twinkle/model/megatron/multi_lora_megatron.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ def set_processor(self, processor_cls: Union[Type[InputProcessor], str, Callable
264264
self._check_adapter_valid(kwargs.get('adapter_name'))
265265
super().set_processor(processor_cls, **kwargs)
266266

267+
@remote_function()
267268
def add_metric(self, metric_cls: Union[Metric, str], is_training: Optional[bool] = None, **kwargs):
268269
self._check_adapter_valid(kwargs.get('adapter_name'))
269270
super().add_metric(metric_cls, is_training, **kwargs)

src/twinkle/model/transformers/transformers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,6 @@ def __init__(
188188
}
189189
self.optimizer_group[_default_adapter_name].adapter_name = _default_adapter_name
190190
self.active_group = _default_adapter_name
191-
# breakpoint()
192191

193192
def _decide_strategy(self, strategy: Literal['accelerate', 'native_fsdp']):
194193
self._expert_parallel_config = self._fsdp_config.pop('expert_parallel', None)

0 commit comments

Comments
 (0)