diff --git a/eagle/traineagle3/cnets.py b/eagle/traineagle3/cnets.py index 0cd63314..791d331a 100644 --- a/eagle/traineagle3/cnets.py +++ b/eagle/traineagle3/cnets.py @@ -718,7 +718,6 @@ def dataprepare(self, input_ids, attention_mask, loss_mask): hidden_states1 = outs.hidden_states[1] hidden_states2 = outs.hidden_states[2] hidden_states=torch.cat((hidden_states0,hidden_states1,hidden_states2),dim=-1) - # hidden_states=torch.cat((hidden_states0,hidden_states1),dim=-1) target = outs.logits target = padding(target, left=False) input_ids = padding(input_ids, left=False) @@ -732,7 +731,6 @@ def dataprepare(self, input_ids, attention_mask, loss_mask): def forward( self, - # hidden_states, input_ids, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -744,7 +742,53 @@ def forward( ): hidden_states, target, loss_mask, input_ids = self.dataprepare(input_ids, attention_mask, loss_mask) + if self.training and self.gradient_checkpointing: + plosses,acces = torch.utils.checkpoint.checkpoint( + self._run_midlayer_loop, + input_ids, + target, + hidden_states, + attention_mask, + position_ids, + past_key_values, + use_cache, + output_attentions, + output_hidden_states, + loss_mask, + use_reentrant=False # 保持 False + ) + + + else: + plosses,acces = self._run_midlayer_loop( + input_ids, + target, + hidden_states, + attention_mask, + position_ids, + past_key_values, + use_cache, + output_attentions, + output_hidden_states, + loss_mask) + + return plosses,acces + + def _run_midlayer_loop(self, + input_ids, + target, + hidden_states, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + loss_mask: Optional[torch.Tensor] = None): + cache_hidden = [[], []] + plosses = [] + acces = [] batch_size, seq_length, _ = hidden_states.shape seq_length_with_past = seq_length past_key_values_length = 0 @@ -781,55 +825,29 @@ def forward( if self.gradient_checkpointing and self.training: if use_cache: use_cache = False - - plosses = [] - vlosses = [] - acces = [] - cache_hidden = [[], []] - for idx in range(self.length): last = idx == self.length - 1 inputs_embeds = self.embed_tokens(input_ids) - if self.training and self.gradient_checkpointing and not inputs_embeds.requires_grad: - inputs_embeds.requires_grad = True inputs_embeds = inputs_embeds.to(hidden_states.dtype) - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, None, output_attentions) - - return custom_forward - - layer_outputs, cache_hidden = torch.utils.checkpoint.checkpoint( - create_custom_forward(self.midlayer), - inputs_embeds, - hidden_states, - cache_hidden, - attention_mask, - position_ids, - ) - else: - - layer_outputs, cache_hidden = self.midlayer( - input_emb=inputs_embeds, - hidden_states=hidden_states, - cache_hidden=cache_hidden, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=None, - output_attentions=output_attentions, - use_cache=True, - ) + layer_outputs, cache_hidden = self.midlayer( + input_emb=inputs_embeds, + hidden_states=hidden_states, + cache_hidden=cache_hidden, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=None, + output_attentions=output_attentions, + use_cache=True, + ) hidden_states_out = layer_outputs[0] - # cache_hidden.append(layer_outputs[1]) - # kv_cahce = layer_outputs[-1] + hidden_states = hidden_states_out + hidden_states_out = self.norm(hidden_states_out) + logits = self.lm_head(hidden_states_out) + logits = logits.float() with torch.no_grad(): - # hidden_states_target = padding(hidden_states, left=False) target_head = target target_max_token = target_head.argmax(-1) # Move d2t to the same device as target_max_token @@ -840,28 +858,17 @@ def custom_forward(*inputs): target_head = target_head[..., self.t2d] target_head = target_head.float() target_p = nn.Softmax(dim=2)(target_head) - target_p = target_p.detach() - - - - hidden_states = hidden_states_out - - hidden_states_out = self.norm(hidden_states_out) - - logits = self.lm_head(hidden_states_out) - logits = logits.float() out_logp = nn.LogSoftmax(dim=2)(logits) plogp = target_p * out_logp - loss = -torch.sum(position_mask * plogp, 2).mean() + sum_logit = torch.sum(position_mask * plogp, 2) + loss = -sum_logit.mean() plosses.append(loss) - with torch.no_grad(): - acces.append(((logits.argmax(-1) == target_p.argmax(-1)) * position_mask.squeeze(-1)).sum().item() / ( - loss_mask.sum().item() + 1e-6)) - + acces.append(((logits.argmax(-1) == target_p.argmax(-1)) * position_mask.squeeze(-1)).sum().item() / (loss_mask.sum().item() + 1e-6)) if not last: input_ids = padding(input_ids, left=False) target = padding(target, left=False) loss_mask = padding(loss_mask, left=False) + seq_length = attention_mask.shape[-1] ind=torch.arange(seq_length,device=attention_mask.device) ind0=ind[idx:] ind1=ind[:seq_length-idx] @@ -869,6 +876,4 @@ def custom_forward(*inputs): - return plosses, vlosses, acces - - + return plosses,acces diff --git a/eagle/traineagle3/main.py b/eagle/traineagle3/main.py index 37012a16..aa2405f8 100644 --- a/eagle/traineagle3/main.py +++ b/eagle/traineagle3/main.py @@ -276,7 +276,7 @@ def find_max_state_with_file(directory, filename="zero_to_fp32.py"): model.zero_grad() - plosses, vlosses, acces = model_engine(input_ids=data["input_ids"].to(rank), + plosses, acces = model_engine(input_ids=data["input_ids"].to(rank), attention_mask=data["attention_mask"].to(rank), loss_mask=data["loss_mask"], ) @@ -321,7 +321,7 @@ def find_max_state_with_file(directory, filename="zero_to_fp32.py"): for batch_idx, data in enumerate(tqdm(test_loader)): with torch.no_grad(): - plosses, vlosses, acces = model_engine(input_ids=data["input_ids"].to(rank), + plosses, acces = model_engine(input_ids=data["input_ids"].to(rank), attention_mask=data["attention_mask"].to(rank), loss_mask=data["loss_mask"], ) @@ -348,4 +348,4 @@ def find_max_state_with_file(directory, filename="zero_to_fp32.py"): model_engine.save_16bit_model(f"{args.savedir}/state_{epoch}", exclude_frozen_parameters=True) if epoch % 10 == 0: - deepspeed.DeepSpeedEngine.save_checkpoint(model_engine, save_dir=f"{args.savedir}/state_{epoch}") + deepspeed.DeepSpeedEngine.save_checkpoint(model_engine, save_dir=f"{args.savedir}/state_{epoch}") \ No newline at end of file