From 1acf8bdd9adc00ce8a853cba533fd00dac61e55b Mon Sep 17 00:00:00 2001 From: COLAZERO2 Date: Wed, 6 Aug 2025 19:53:19 +0800 Subject: [PATCH] fix: torch.checkpoint() incorrectly wraps single forward step This caused the loss to remain high due to unstable gradients when training with gradient checkpointing enabled. After fixing, accuracy increases as intended when using the gradient checkpoint memory optimization trick. --- eagle/traineagle3/cnets.py | 125 +++++++++++++++++++------------------ eagle/traineagle3/main.py | 6 +- 2 files changed, 68 insertions(+), 63 deletions(-) diff --git a/eagle/traineagle3/cnets.py b/eagle/traineagle3/cnets.py index 0cd63314..791d331a 100644 --- a/eagle/traineagle3/cnets.py +++ b/eagle/traineagle3/cnets.py @@ -718,7 +718,6 @@ def dataprepare(self, input_ids, attention_mask, loss_mask): hidden_states1 = outs.hidden_states[1] hidden_states2 = outs.hidden_states[2] hidden_states=torch.cat((hidden_states0,hidden_states1,hidden_states2),dim=-1) - # hidden_states=torch.cat((hidden_states0,hidden_states1),dim=-1) target = outs.logits target = padding(target, left=False) input_ids = padding(input_ids, left=False) @@ -732,7 +731,6 @@ def dataprepare(self, input_ids, attention_mask, loss_mask): def forward( self, - # hidden_states, input_ids, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -744,7 +742,53 @@ def forward( ): hidden_states, target, loss_mask, input_ids = self.dataprepare(input_ids, attention_mask, loss_mask) + if self.training and self.gradient_checkpointing: + plosses,acces = torch.utils.checkpoint.checkpoint( + self._run_midlayer_loop, + input_ids, + target, + hidden_states, + attention_mask, + position_ids, + past_key_values, + use_cache, + output_attentions, + output_hidden_states, + loss_mask, + use_reentrant=False # 保持 False + ) + + + else: + plosses,acces = self._run_midlayer_loop( + input_ids, + target, + hidden_states, + attention_mask, + position_ids, + past_key_values, + use_cache, + output_attentions, + output_hidden_states, + loss_mask) + + return plosses,acces + + def _run_midlayer_loop(self, + input_ids, + target, + hidden_states, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + loss_mask: Optional[torch.Tensor] = None): + cache_hidden = [[], []] + plosses = [] + acces = [] batch_size, seq_length, _ = hidden_states.shape seq_length_with_past = seq_length past_key_values_length = 0 @@ -781,55 +825,29 @@ def forward( if self.gradient_checkpointing and self.training: if use_cache: use_cache = False - - plosses = [] - vlosses = [] - acces = [] - cache_hidden = [[], []] - for idx in range(self.length): last = idx == self.length - 1 inputs_embeds = self.embed_tokens(input_ids) - if self.training and self.gradient_checkpointing and not inputs_embeds.requires_grad: - inputs_embeds.requires_grad = True inputs_embeds = inputs_embeds.to(hidden_states.dtype) - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, None, output_attentions) - - return custom_forward - - layer_outputs, cache_hidden = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.midlayer), - inputs_embeds, - hidden_states, - cache_hidden, - attention_mask, - position_ids, - ) - else: - - layer_outputs, cache_hidden = self.midlayer( - input_emb=inputs_embeds, - hidden_states=hidden_states, - cache_hidden=cache_hidden, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=None, - output_attentions=output_attentions, - use_cache=True, - ) + layer_outputs, cache_hidden = self.midlayer( + input_emb=inputs_embeds, + hidden_states=hidden_states, + cache_hidden=cache_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=None, + output_attentions=output_attentions, + use_cache=True, + ) hidden_states_out = layer_outputs[0] - # cache_hidden.append(layer_outputs[1]) - # kv_cahce = layer_outputs[-1] + hidden_states = hidden_states_out + hidden_states_out = self.norm(hidden_states_out) + logits = self.lm_head(hidden_states_out) + logits = logits.float() with torch.no_grad(): - # hidden_states_target = padding(hidden_states, left=False) target_head = target target_max_token = target_head.argmax(-1) # Move d2t to the same device as target_max_token @@ -840,28 +858,17 @@ def custom_forward(*inputs): target_head = target_head[..., self.t2d] target_head = target_head.float() target_p = nn.Softmax(dim=2)(target_head) - target_p = target_p.detach() - - - - hidden_states = hidden_states_out - - hidden_states_out = self.norm(hidden_states_out) - - logits = self.lm_head(hidden_states_out) - logits = logits.float() out_logp = nn.LogSoftmax(dim=2)(logits) plogp = target_p * out_logp - loss = -torch.sum(position_mask * plogp, 2).mean() + sum_logit = torch.sum(position_mask * plogp, 2) + loss = -sum_logit.mean() plosses.append(loss) - with torch.no_grad(): - acces.append(((logits.argmax(-1) == target_p.argmax(-1)) * position_mask.squeeze(-1)).sum().item() / ( - loss_mask.sum().item() + 1e-6)) - + acces.append(((logits.argmax(-1) == target_p.argmax(-1)) * position_mask.squeeze(-1)).sum().item() / (loss_mask.sum().item() + 1e-6)) if not last: input_ids = padding(input_ids, left=False) target = padding(target, left=False) loss_mask = padding(loss_mask, left=False) + seq_length = attention_mask.shape[-1] ind=torch.arange(seq_length,device=attention_mask.device) ind0=ind[idx:] ind1=ind[:seq_length-idx] @@ -869,6 +876,4 @@ def custom_forward(*inputs): - return plosses, vlosses, acces - - + return plosses,acces diff --git a/eagle/traineagle3/main.py b/eagle/traineagle3/main.py index 37012a16..aa2405f8 100644 --- a/eagle/traineagle3/main.py +++ b/eagle/traineagle3/main.py @@ -276,7 +276,7 @@ def find_max_state_with_file(directory, filename="zero_to_fp32.py"): model.zero_grad() - plosses, vlosses, acces = model_engine(input_ids=data["input_ids"].to(rank), + plosses, acces = model_engine(input_ids=data["input_ids"].to(rank), attention_mask=data["attention_mask"].to(rank), loss_mask=data["loss_mask"], ) @@ -321,7 +321,7 @@ def find_max_state_with_file(directory, filename="zero_to_fp32.py"): for batch_idx, data in enumerate(tqdm(test_loader)): with torch.no_grad(): - plosses, vlosses, acces = model_engine(input_ids=data["input_ids"].to(rank), + plosses, acces = model_engine(input_ids=data["input_ids"].to(rank), attention_mask=data["attention_mask"].to(rank), loss_mask=data["loss_mask"], ) @@ -348,4 +348,4 @@ def find_max_state_with_file(directory, filename="zero_to_fp32.py"): model_engine.save_16bit_model(f"{args.savedir}/state_{epoch}", exclude_frozen_parameters=True) if epoch % 10 == 0: - deepspeed.DeepSpeedEngine.save_checkpoint(model_engine, save_dir=f"{args.savedir}/state_{epoch}") + deepspeed.DeepSpeedEngine.save_checkpoint(model_engine, save_dir=f"{args.savedir}/state_{epoch}") \ No newline at end of file