@@ -185,8 +185,8 @@ def __init__(
185185 self ._model_wrapped = False
186186 # This correctly handles vocab sharding in Tensor Parallelism
187187 self .optimizer_group : Dict [str , MegatronOptimizerGroup ] = {_default_adapter_name : self ._construct_default_optimizer_group ()}
188- MegatronPeft (). patch ()
189-
188+ self . active_group = _default_adapter_name
189+ MegatronPeft (). __call__ ()
190190
191191 def _construct_default_optimizer_group (self ):
192192 return MegatronOptimizerGroup (
@@ -230,6 +230,12 @@ def _lazy_wrap_model(self):
230230 self .model = self .strategy .wrap_model (self .model )
231231 self ._model_wrapped = True
232232
233+ def _get_default_group (self ):
234+ """Get the only group has optimizer, else return the default one"""
235+ if len (self .optimizer_group ) == 1 :
236+ return next (iter (self .optimizer_group ))
237+ return self .active_group
238+
233239 @staticmethod
234240 def _not_encoded (inputs ):
235241 assert isinstance (inputs , dict )
@@ -299,7 +305,7 @@ def forward_backward(self,
299305 from megatron .core .pipeline_parallel import get_forward_backward_func
300306 from megatron .core import parallel_state as mpu
301307
302- adapter_name = kwargs .pop ('adapter_name' , _default_adapter_name )
308+ adapter_name = kwargs .pop ('adapter_name' , self . _get_default_group () )
303309 forward_only = kwargs .pop ('forward_only' , False )
304310 optimizer_config = self .optimizer_group [adapter_name ]
305311 loss_instance = self .optimizer_group [adapter_name ].loss_instance
@@ -465,7 +471,7 @@ def step(self, **kwargs):
465471 Args:
466472 **kwargs: Additional arguments.
467473 """
468- adapter_name = kwargs .pop ('adapter_name' , _default_adapter_name )
474+ adapter_name = kwargs .pop ('adapter_name' , self . _get_default_group () )
469475 optimizer_config = self .optimizer_group [adapter_name ]
470476
471477 if not optimizer_config .do_grad_sync (
@@ -503,7 +509,7 @@ def zero_grad(self, **kwargs):
503509 Args:
504510 **kwargs: Additional arguments.
505511 """
506- adapter_name = kwargs .pop ('adapter_name' , _default_adapter_name )
512+ adapter_name = kwargs .pop ('adapter_name' , self . _get_default_group () )
507513 optimizer_config = self .optimizer_group [adapter_name ]
508514
509515 # For DDP-wrapped models, ALWAYS zero the gradient buffer
@@ -528,7 +534,7 @@ def lr_step(self, **kwargs):
528534 Args:
529535 **kwargs: Additional arguments.
530536 """
531- adapter_name = kwargs .pop ('adapter_name' , _default_adapter_name )
537+ adapter_name = kwargs .pop ('adapter_name' , self . _get_default_group () )
532538 optimizer_config = self .optimizer_group [adapter_name ]
533539
534540 if not optimizer_config .do_grad_sync (
@@ -557,7 +563,7 @@ def set_loss(self, loss_cls: Union[Loss, Type[Loss], str, Callable[[InputFeature
557563 loss_cls: Loss class or string name (not used for Megatron).
558564 **kwargs: Additional arguments.
559565 """
560- adapter_name = kwargs .pop ('adapter_name' , _default_adapter_name )
566+ adapter_name = kwargs .pop ('adapter_name' , self . _get_default_group () )
561567 optimizer_config = self .optimizer_group [adapter_name ]
562568 optimizer_config .loss_instance = construct_class (loss_cls , Loss , twinkle .loss , ** kwargs )
563569
@@ -571,7 +577,7 @@ def add_metric(self, metric_cls: Union[Metric, str], is_training: Optional[bool]
571577 adapter_name: Lora adapter name.
572578 Any parameters needed to construct the metric_cls instance.
573579 """
574- adapter_name = kwargs .pop ('adapter_name' , _default_adapter_name )
580+ adapter_name = kwargs .pop ('adapter_name' , self . _get_default_group () )
575581 optimizer_config = self .optimizer_group [adapter_name ]
576582 kwargs ['device_mesh' ] = self .device_mesh
577583 kwargs ['process_group' ] = optimizer_config ._dp_group
@@ -593,7 +599,7 @@ def set_optimizer(self, optimizer_cls: Union[Optimizer, Type[Optimizer], str],
593599 - For standard optimizers: lr, weight_decay, etc.
594600 - For MegatronDistributed: use_distributed_optimizer, clip_grad, etc.
595601 """
596- adapter_name = kwargs .pop ('adapter_name' , _default_adapter_name )
602+ adapter_name = kwargs .pop ('adapter_name' , self . _get_default_group () )
597603 optimizer_config = self .optimizer_group [adapter_name ]
598604 if not self ._model_wrapped :
599605 self .model = self .strategy .wrap_model (self .model )
@@ -611,7 +617,7 @@ def _accumulate_metric(optimizer_config: MegatronOptimizerGroup, is_training):
611617
612618 @remote_function (collect = 'first' , lazy_collect = False )
613619 def calculate_metric (self , is_training , ** kwargs ):
614- adapter_name = kwargs .pop ('adapter_name' , _default_adapter_name )
620+ adapter_name = kwargs .pop ('adapter_name' , self . _get_default_group () )
615621 optimizer_config = self .optimizer_group [adapter_name ]
616622 return optimizer_config .calculate_metrics (is_training )
617623
@@ -715,7 +721,7 @@ def set_lr_scheduler(self, scheduler_cls: Union[LRScheduler, Type[LRScheduler],
715721 scheduler_cls: Scheduler class or string name.
716722 **kwargs: Additional arguments.
717723 """
718- adapter_name = kwargs .pop ('adapter_name' , _default_adapter_name )
724+ adapter_name = kwargs .pop ('adapter_name' , self . _get_default_group () )
719725 optimizer_config = self .optimizer_group [adapter_name ]
720726 optimizer = optimizer_config .optimizer
721727 if not scheduler_cls or scheduler_cls in ('OptimizerParamScheduler' , 'default' ):
@@ -738,7 +744,7 @@ def save(self, name: Optional[str] = None, output_dir: Optional[str] = None, int
738744 interval: Save each interval steps.
739745 **kwargs: Additional arguments.
740746 """
741- adapter_name = kwargs .pop ('adapter_name' , _default_adapter_name )
747+ adapter_name = kwargs .pop ('adapter_name' , self . _get_default_group () )
742748 optimizer_config = self .optimizer_group [adapter_name ]
743749 if optimizer_config .cur_step % interval != 0 :
744750 return
@@ -772,7 +778,7 @@ def load(self, name: str, output_dir: Optional[str] = None, **kwargs):
772778 checkpoint_dir = HubOperation .download_model (name , token = token )
773779 else :
774780 checkpoint_dir = os .path .join (output_dir , name )
775- adapter_name = kwargs .get ('adapter_name' )
781+ adapter_name = kwargs .get ('adapter_name' , self . _get_default_group () )
776782 bridge = self ._bridge
777783 for _model in self .strategy .unwrap_model (self .model ):
778784 bridge .load_weights (_model , checkpoint_dir , is_peft_format = (adapter_name != _default_adapter_name ))
@@ -860,7 +866,7 @@ def get_state_dict(self, **kwargs):
860866 Returns:
861867 State dict of trainable parameters.
862868 """
863- adapter_name = kwargs .pop ('adapter_name' , _default_adapter_name )
869+ adapter_name = kwargs .pop ('adapter_name' , self . _get_default_group () )
864870 return self ._get_trainable_parameters (adapter_name )
865871
866872 def get_hf_state_dict (self , adapter_name : str = '' ) -> Generator [Tuple [str , torch .Tensor ], None , None ]:
@@ -988,7 +994,7 @@ def set_template(self, template_cls: Union[Template, Type[Template], str], **kwa
988994 template_cls: Template class or string name.
989995 **kwargs: Additional arguments.
990996 """
991- adapter_name = kwargs .pop ('adapter_name' , _default_adapter_name )
997+ adapter_name = kwargs .pop ('adapter_name' , self . _get_default_group () )
992998 optimizer_config = self .optimizer_group [adapter_name ]
993999 optimizer_config .template = construct_class (template_cls , Template , twinkle .template , ** kwargs )
9941000
@@ -1000,7 +1006,7 @@ def set_processor(self, processor_cls: Union[InputProcessor, Type[InputProcessor
10001006 processor_cls: Processor class or string name.
10011007 **kwargs: Additional arguments.
10021008 """
1003- adapter_name = kwargs .pop ('adapter_name' , _default_adapter_name )
1009+ adapter_name = kwargs .pop ('adapter_name' , self . _get_default_group () )
10041010 optimizer_config = self .optimizer_group [adapter_name ]
10051011 kwargs ['framework' ] = 'megatron'
10061012 optimizer_config .processor = construct_class (processor_cls , InputProcessor , twinkle .processor , ** kwargs )
@@ -1015,7 +1021,7 @@ def get_train_configs(self, **kwargs):
10151021 Returns:
10161022 Configuration summary string.
10171023 """
1018- adapter_name = kwargs .pop ('adapter_name' , _default_adapter_name )
1024+ adapter_name = kwargs .pop ('adapter_name' , self . _get_default_group () )
10191025 optimizer_config = self .optimizer_group [adapter_name ]
10201026
10211027 expr = f'Backend: Megatron-Core\n '
0 commit comments