@@ -74,24 +74,22 @@ def forward_backward(self, *, inputs: List[types.Datum], adapter_name: str, loss
7474 loss_kwargs = kwargs .copy ()
7575 loss_kwargs .update (loss_values )
7676 # Megatron forward_backward returns loss directly
77- loss = super ().forward_backward (inputs = input_features , adapter_name = adapter_name , ** loss_kwargs )
78-
79- # Get logits from outputs
80- optimizer_config = self .optimizer_group .get (adapter_name )
81- outputs = optimizer_config .outputs if optimizer_config else {}
77+ outputs = super ().forward_backward (inputs = input_features , adapter_name = adapter_name , ** loss_kwargs )
78+ loss = outputs .get ('loss' , None )
8279 logits_list = outputs .get ('logits' , [])
83- logps = outputs .get ('logprobs' , [])
84-
80+ logps = outputs .get ('logps' , [])
8581 # When PP enabled, only logits from last stage are available
86- if not logits_list and not logps :
82+ if logits_list is None and logps is None :
8783 return [None , None ]
8884
89- # Process logits to match transformers output format
90- if isinstance (logits_list , torch .Tensor ):
91- logits = logits_list .detach ()
92- else :
93- # Concatenate logits from multiple microbatches
94- logits = torch .cat ([logit .detach () for logit in logits_list ], dim = 0 )
85+ logits = None
86+ if logits_list is not None :
87+ # Process logits to match transformers output format
88+ if isinstance (logits_list , torch .Tensor ):
89+ logits = logits_list .detach ()
90+ else :
91+ # Concatenate logits from multiple microbatches
92+ logits = torch .cat ([logit .detach () for logit in logits_list ], dim = 0 )
9593 logps = logps .detach ().cpu ()
9694 results = self ._get_forward_output (inputs , logits , logps )
9795
0 commit comments