@@ -447,7 +447,13 @@ def post_loss_function(output_tensor, inputs, logps):
447447 def forward_step_func (data_iterator , model ):
448448 batch = next (data_iterator )
449449 labels = batch .pop ('labels' , None )
450- output_tensor = model (** batch )
450+ # Handle disable_lora for base model inference (e.g., reference in DPO)
451+ unwrapped_model = self .strategy .unwrap_model ([model ])[0 ]
452+ if disable_lora and isinstance (unwrapped_model , PeftModel ):
453+ with unwrapped_model .disable_adapter ():
454+ output_tensor = model (** batch )
455+ else :
456+ output_tensor = model (** batch )
451457 batch ['labels' ] = labels
452458 logps = None
453459 if labels is not None and mpu .is_pipeline_last_stage ():
@@ -475,34 +481,17 @@ def forward_step_func(data_iterator, model):
475481
476482 self ._accumulate_metric (optimizer_config , is_training = not forward_only )
477483
478- # Handle disable_lora for base model inference (e.g., reference in DPO)
479- def _set_disable_adapters (model , value : bool ):
480- model = self .strategy .unwrap_model (model )
481- if isinstance (model , list ):
482- for m in model :
483- if isinstance (m , PeftModel ):
484- m .disable_adapters = value
485- elif isinstance (model , PeftModel ):
486- model .disable_adapters = value
487-
488- if disable_lora :
489- _set_disable_adapters (self .model , True )
490-
491- try :
492- # Run forward-backward with Megatron's scheduler
493- # Megatron handles all communication internally using proper process groups
494- losses = forward_backward_func (
495- forward_step_func = forward_step_func ,
496- data_iterator = data_iter ,
497- model = self .model ,
498- num_microbatches = len (inputs ),
499- seq_length = seq_length ,
500- micro_batch_size = micro_batch_size ,
501- forward_only = forward_only ,
502- )
503- finally :
504- if disable_lora :
505- _set_disable_adapters (self .model , False )
484+ # Run forward-backward with Megatron's scheduler
485+ # Megatron handles all communication internally using proper process groups
486+ losses = forward_backward_func (
487+ forward_step_func = forward_step_func ,
488+ data_iterator = data_iter ,
489+ model = self .model ,
490+ num_microbatches = len (inputs ),
491+ seq_length = seq_length ,
492+ micro_batch_size = micro_batch_size ,
493+ forward_only = forward_only ,
494+ )
506495
507496 # Extract loss from results (only last PP stage returns non-empty)
508497 loss = torch .tensor (0.0 ).to (Platform .get_local_device ())
@@ -559,9 +548,11 @@ def _set_disable_adapters(model, value: bool):
559548 if forward_only :
560549 optimizer_config .eval_status .inputs = inputs
561550 optimizer_config .eval_status .outputs = ModelOutput (logits = logits , loss = loss , logps = logps )
551+ optimizer_config .eval_status .forward_kwargs = kwargs
562552 else :
563553 optimizer_config .train_status .inputs = inputs
564554 optimizer_config .train_status .outputs = ModelOutput (logits = logits , loss = loss , logps = logps )
555+ optimizer_config .train_status .forward_kwargs = kwargs
565556 return ModelOutput (logits = logits , loss = loss , logps = logps )
566557
567558 @remote_function (dispatch = 'all' )
@@ -692,6 +683,7 @@ def set_loss(self, loss_cls: Union[Loss, Type[Loss], str, Callable[[InputFeature
692683 optimizer_config = self .optimizer_group [adapter_name ]
693684 optimizer_config .loss_instance = construct_class (loss_cls , Loss , twinkle .loss , ** kwargs )
694685
686+ @remote_function ()
695687 def add_metric (self , metric_cls : Union [Metric , str ], is_training : Optional [bool ] = None , ** kwargs ):
696688 """Add an eval metric
697689
@@ -773,16 +765,16 @@ def _create_megatron_optimizer(self, **kwargs):
773765 opt_config = OptimizerConfig (
774766 optimizer = 'adam' ,
775767 lr = lr ,
776- min_lr = kwargs .get ('min_lr' , 0.0 ),
777- weight_decay = kwargs .get ('weight_decay' , 0.01 ),
778- adam_beta1 = kwargs .get ('adam_beta1' , 0.9 ),
779- adam_beta2 = kwargs .get ('adam_beta2' , 0.999 ),
780- adam_eps = kwargs .get ('adam_eps' , 1e-8 ),
781- clip_grad = kwargs .get ('clip_grad' , 1.0 ),
782- bf16 = kwargs .get ('bf16' , True ),
768+ min_lr = kwargs .pop ('min_lr' , 0.0 ),
769+ weight_decay = kwargs .pop ('weight_decay' , 0.01 ),
770+ adam_beta1 = kwargs .pop ('adam_beta1' , 0.9 ),
771+ adam_beta2 = kwargs .pop ('adam_beta2' , 0.999 ),
772+ adam_eps = kwargs .pop ('adam_eps' , 1e-8 ),
773+ clip_grad = kwargs .pop ('clip_grad' , 1.0 ),
774+ bf16 = kwargs .pop ('bf16' , True ),
783775 use_distributed_optimizer = use_distributed_optimizer ,
784- overlap_param_gather = kwargs .get ('overlap_param_gather' , False ),
785- log_num_zeros_in_grad = kwargs .get ('log_num_zeros_in_grad' , False ),
776+ overlap_param_gather = kwargs .pop ('overlap_param_gather' , False ),
777+ log_num_zeros_in_grad = kwargs .pop ('log_num_zeros_in_grad' , False ),
786778 ** kwargs ,
787779 )
788780
0 commit comments