-
Notifications
You must be signed in to change notification settings - Fork 147
[Draft] Long Context Training VRAM Optimization #446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6c778df
2560f94
5501b3a
61cbcc1
9c94079
83b83d8
79324f8
d00ee39
e4e5d02
5412491
79975fc
5adf840
dd792b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,7 @@ | |
| import torch | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
| from torch.utils.checkpoint import checkpoint | ||
| from transformers.cache_utils import DynamicCache | ||
| from yunchang import EXTRACT_FUNC_DICT | ||
|
|
||
|
|
@@ -122,6 +123,7 @@ def forward( | |
| length=self.length, | ||
| ) | ||
| del target | ||
| torch.cuda.empty_cache() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Explicitly calling |
||
|
|
||
| # basic info | ||
| batch_size, seq_length, _ = hidden_states.shape | ||
|
|
@@ -166,7 +168,7 @@ def forward( | |
| dtype=torch.bool, | ||
| device=hidden_states.device, | ||
| ) | ||
| if self.attention_backend in ("sdpa", "usp"): | ||
| if self.attention_backend == "sdpa": | ||
| attention_mask = self.draft_model.prepare_decoder_attention_mask( | ||
| attention_mask=attention_mask, | ||
| hidden_states=hidden_states, | ||
|
|
@@ -175,6 +177,24 @@ def forward( | |
| past_key_values_length=past_key_values_length, | ||
| ) | ||
|
|
||
| def compute_loss_and_acc_checkpointed(hs, tgt_p, pos_mask, l_mask): | ||
| # 1. Compute Logits(The part that consumes the most VRAM.) | ||
| logits_ = self.draft_model.compute_logits(hs) | ||
| logits = gather_outputs_and_unpad(logits_, gather_dim=1) | ||
|
|
||
| # 2. Compute Loss | ||
| loss_val = LogSoftmaxLoss.apply(logits, tgt_p, pos_mask) | ||
|
|
||
| # 3. Compute Accuracy | ||
| with torch.no_grad(): | ||
| acc_val = _compute_metric_acc( | ||
| logits=logits, | ||
| target_p=tgt_p, | ||
| position_mask=pos_mask, | ||
| loss_mask=l_mask, | ||
| ) | ||
| return loss_val, acc_val | ||
|
|
||
| # Step 5: run TTT | ||
| plosses = [] | ||
| vlosses = [] | ||
|
|
@@ -217,24 +237,22 @@ def forward( | |
| # update hidden states for next step | ||
| hidden_states = hidden_states_out | ||
|
|
||
| # Step 5.4: get logits | ||
| logits = self.draft_model.compute_logits(hidden_states) | ||
| logits = gather_outputs_and_unpad(logits, gather_dim=1) | ||
| # Step 5.5: record metrics first as we in-place modify logits | ||
| with torch.no_grad(): | ||
| acces.append( | ||
| _compute_metric_acc( | ||
| logits=logits, | ||
| target_p=target_p, | ||
| position_mask=position_mask, | ||
| loss_mask=loss_mask, | ||
| ) | ||
| if hidden_states.requires_grad: | ||
| loss, acc = checkpoint( | ||
| compute_loss_and_acc_checkpointed, | ||
| hidden_states, | ||
| target_p, | ||
| position_mask, | ||
| loss_mask, | ||
| use_reentrant=False, | ||
| ) | ||
| else: | ||
| loss, acc = compute_loss_and_acc_checkpointed( | ||
| hidden_states, target_p, position_mask, loss_mask | ||
| ) | ||
|
|
||
| # Step 5.6: calculate loss, in-place modifies logits! | ||
| loss = LogSoftmaxLoss.apply(logits, target_p, position_mask) | ||
| plosses.append(loss) | ||
|
|
||
| acces.append(acc) | ||
| if not is_last: | ||
| # Step 5.7: we need to update the loss mask | ||
| global_input_ids = padding(global_input_ids, left=False) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the impact of this on performance? if it is large, maybe we can set it as a flag to control whether do this on GPU or CPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can simply compute this: vocab size(150000) * seq_length(64k) will cost 10G more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason is target_head's
preprocessfunction will usepaddingwill generate an extra copy of thetargetmemory.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can split hidden state in dataset getitem for usp to reduce memory use.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think this is a better optimization method. Can you help add this optimization?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#454
Hi, I've finished the updates. Note that SP currently works with batch size 1. This seems reasonable for long-sequence scenarios to avoid OOM, but I'm open to feedback. Ready for review!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, I'll help review it.