-
Notifications
You must be signed in to change notification settings - Fork 18
Open
Description
在 main.py中数据准备时:
def collate_fn(examples, device):
token_ids = torch.tensor(
[example['token_ids'] for example in examples], device=device)
return **{'input_ids': token_ids[:, :-1], 'labels': token_ids[:, 1:]}**
def train_chunk(.......):
..........
batch = collate_fn(
examples=examples[i:i+per_device_batch_size], device=fabric.device)
input_ids, labels = batch['input_ids'], batch['labels']
在 modeling_llama.py 中loss计算时:
class LlamaForCausalLM(LlamaPreTrainedModel):
....................
if labels is not None:
# Shift so that tokens < n predict n
**shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()**
为什么在模型数据sample输入时进行了预测和真实值之间的位移对齐,在模型中loss计算时还进行了一次位移对齐?
Metadata
Metadata
Assignees
Labels
No labels