Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 65 additions & 60 deletions eagle/traineagle3/cnets.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,6 @@ def dataprepare(self, input_ids, attention_mask, loss_mask):
hidden_states1 = outs.hidden_states[1]
hidden_states2 = outs.hidden_states[2]
hidden_states=torch.cat((hidden_states0,hidden_states1,hidden_states2),dim=-1)
# hidden_states=torch.cat((hidden_states0,hidden_states1),dim=-1)
target = outs.logits
target = padding(target, left=False)
input_ids = padding(input_ids, left=False)
Expand All @@ -732,7 +731,6 @@ def dataprepare(self, input_ids, attention_mask, loss_mask):

def forward(
self,
# hidden_states,
input_ids,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
Expand All @@ -744,7 +742,53 @@ def forward(

):
hidden_states, target, loss_mask, input_ids = self.dataprepare(input_ids, attention_mask, loss_mask)
if self.training and self.gradient_checkpointing:
plosses,acces = torch.utils.checkpoint.checkpoint(
self._run_midlayer_loop,
input_ids,
target,
hidden_states,
attention_mask,
position_ids,
past_key_values,
use_cache,
output_attentions,
output_hidden_states,
loss_mask,
use_reentrant=False # 保持 False
)


else:
plosses,acces = self._run_midlayer_loop(
input_ids,
target,
hidden_states,
attention_mask,
position_ids,
past_key_values,
use_cache,
output_attentions,
output_hidden_states,
loss_mask)

return plosses,acces

def _run_midlayer_loop(self,
input_ids,
target,
hidden_states,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
loss_mask: Optional[torch.Tensor] = None):

cache_hidden = [[], []]
plosses = []
acces = []
batch_size, seq_length, _ = hidden_states.shape
seq_length_with_past = seq_length
past_key_values_length = 0
Expand Down Expand Up @@ -781,55 +825,29 @@ def forward(
if self.gradient_checkpointing and self.training:
if use_cache:
use_cache = False

plosses = []
vlosses = []
acces = []
cache_hidden = [[], []]

for idx in range(self.length):
last = idx == self.length - 1
inputs_embeds = self.embed_tokens(input_ids)
if self.training and self.gradient_checkpointing and not inputs_embeds.requires_grad:
inputs_embeds.requires_grad = True
inputs_embeds = inputs_embeds.to(hidden_states.dtype)

if self.gradient_checkpointing and self.training:

def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, None, output_attentions)

return custom_forward

layer_outputs, cache_hidden = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.midlayer),
inputs_embeds,
hidden_states,
cache_hidden,
attention_mask,
position_ids,
)
else:

layer_outputs, cache_hidden = self.midlayer(
input_emb=inputs_embeds,
hidden_states=hidden_states,
cache_hidden=cache_hidden,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=None,
output_attentions=output_attentions,
use_cache=True,
)
layer_outputs, cache_hidden = self.midlayer(
input_emb=inputs_embeds,
hidden_states=hidden_states,
cache_hidden=cache_hidden,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=None,
output_attentions=output_attentions,
use_cache=True,
)

hidden_states_out = layer_outputs[0]
# cache_hidden.append(layer_outputs[1])
# kv_cahce = layer_outputs[-1]
hidden_states = hidden_states_out
hidden_states_out = self.norm(hidden_states_out)
logits = self.lm_head(hidden_states_out)
logits = logits.float()

with torch.no_grad():
# hidden_states_target = padding(hidden_states, left=False)
target_head = target
target_max_token = target_head.argmax(-1)
# Move d2t to the same device as target_max_token
Expand All @@ -840,35 +858,22 @@ def custom_forward(*inputs):
target_head = target_head[..., self.t2d]
target_head = target_head.float()
target_p = nn.Softmax(dim=2)(target_head)
target_p = target_p.detach()



hidden_states = hidden_states_out

hidden_states_out = self.norm(hidden_states_out)

logits = self.lm_head(hidden_states_out)
logits = logits.float()
out_logp = nn.LogSoftmax(dim=2)(logits)
plogp = target_p * out_logp
loss = -torch.sum(position_mask * plogp, 2).mean()
sum_logit = torch.sum(position_mask * plogp, 2)
loss = -sum_logit.mean()
plosses.append(loss)
with torch.no_grad():
acces.append(((logits.argmax(-1) == target_p.argmax(-1)) * position_mask.squeeze(-1)).sum().item() / (
loss_mask.sum().item() + 1e-6))

acces.append(((logits.argmax(-1) == target_p.argmax(-1)) * position_mask.squeeze(-1)).sum().item() / (loss_mask.sum().item() + 1e-6))
if not last:
input_ids = padding(input_ids, left=False)
target = padding(target, left=False)
loss_mask = padding(loss_mask, left=False)
seq_length = attention_mask.shape[-1]
ind=torch.arange(seq_length,device=attention_mask.device)
ind0=ind[idx:]
ind1=ind[:seq_length-idx]
attention_mask[:,:,ind0,ind1]=torch.finfo(attention_mask.dtype).min



return plosses, vlosses, acces


return plosses,acces
6 changes: 3 additions & 3 deletions eagle/traineagle3/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def find_max_state_with_file(directory, filename="zero_to_fp32.py"):

model.zero_grad()

plosses, vlosses, acces = model_engine(input_ids=data["input_ids"].to(rank),
plosses, acces = model_engine(input_ids=data["input_ids"].to(rank),
attention_mask=data["attention_mask"].to(rank),
loss_mask=data["loss_mask"],
)
Expand Down Expand Up @@ -321,7 +321,7 @@ def find_max_state_with_file(directory, filename="zero_to_fp32.py"):

for batch_idx, data in enumerate(tqdm(test_loader)):
with torch.no_grad():
plosses, vlosses, acces = model_engine(input_ids=data["input_ids"].to(rank),
plosses, acces = model_engine(input_ids=data["input_ids"].to(rank),
attention_mask=data["attention_mask"].to(rank),
loss_mask=data["loss_mask"],
)
Expand All @@ -348,4 +348,4 @@ def find_max_state_with_file(directory, filename="zero_to_fp32.py"):

model_engine.save_16bit_model(f"{args.savedir}/state_{epoch}", exclude_frozen_parameters=True)
if epoch % 10 == 0:
deepspeed.DeepSpeedEngine.save_checkpoint(model_engine, save_dir=f"{args.savedir}/state_{epoch}")
deepspeed.DeepSpeedEngine.save_checkpoint(model_engine, save_dir=f"{args.savedir}/state_{epoch}")