@@ -80,13 +80,13 @@ def __post_init__(self):
8080
8181 def _build_metrics (self ):
8282 self .train_metrics = [
83- LossMetric (self ._device_mesh , self ._dp_group , loss_reduction = 'mean ' ),
83+ LossMetric (self ._device_mesh , self ._dp_group , loss_reduction = 'sum ' ),
8484 Accuracy (self ._device_mesh , self ._dp_group ),
8585 TrainMetric (self ._device_mesh , self ._dp_group ),
8686 ]
8787
8888 self .eval_metrics = [
89- LossMetric (self ._device_mesh , self ._dp_group , loss_reduction = 'mean ' ),
89+ LossMetric (self ._device_mesh , self ._dp_group , loss_reduction = 'sum ' ),
9090 Accuracy (self ._device_mesh , self ._dp_group ),
9191 TrainMetric (self ._device_mesh , self ._dp_group ),
9292 ]
@@ -317,7 +317,7 @@ def _ensure_optimizer_dp_groups(self):
317317
318318 def _construct_default_optimizer_group (self ):
319319 return OptimizerGroup (
320- loss_instance = CrossEntropyLoss (reduction = 'mean ' ),
320+ loss_instance = CrossEntropyLoss (reduction = 'sum ' ),
321321 template = Template (self .tokenizer_id ),
322322 processor = InputProcessor (self .device_mesh ),
323323 _device_mesh = self .device_mesh ,
@@ -431,10 +431,9 @@ def calculate_loss(self, **kwargs):
431431 optimizer_config = self .optimizer_group [adapter_name ]
432432 optimizer_config .num_tokens += counts .item ()
433433 if self .sp_strategy is not None and 'labels' in inputs :
434- if "loss_reduction" not in self .sp_strategy .sp_config :
435- reduction = getattr (loss_instance , "reduction" , None )
436- if reduction is not None :
437- self .sp_strategy .sp_config ["loss_reduction" ] = str (reduction )
434+ reduction = getattr (loss_instance , "reduction" , None )
435+ if reduction is not None :
436+ self .sp_strategy .sp_config ["loss_reduction" ] = str (reduction )
438437 loss_value = self .sp_strategy .reduce_loss (loss_value , inputs ['labels' ])
439438 optimizer_config .loss_value += loss_value
440439 outputs ['loss' ] = optimizer_config .loss_value
0 commit comments