import torch
import torch.nn.functional as F
# a helper function for padding a batch of tokenized texts
from fairseq2.nn.padding import pad_seqs
from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline, EmbeddingToTextModelPipeline
# define the loss computation function
def get_decoder_loss(
decoder,
batch_tokens,
batch_embs,
):
"""
Compute the cross entropy loss for each sentence in the batch (non-normalized),
and return per-sentence losses alongside with sentence lengths (for optional normalization).
"""
assert int(batch_tokens[0][0]) == 3, "EOS TOKEN MUST BE PREPENDED WHEN TRAINING A DECODER"
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction="none")
# prepare the batch for the model
padded, mask = pad_seqs(batch_tokens)
device = next(decoder.parameters()).device
padded = padded.to(device)
if mask is not None:
mask = mask.to(device)
batch_embs = batch_embs.to(device)
# feed the batch to the model in three steps (embeddings + decoder body + output projection)
seqs, padding_mask = decoder.decoder_frontend(
padded, padding_mask=mask,
)
decoder_output, decoder_padding_mask = decoder.decoder(
seqs,
mask,
encoder_output=batch_embs.unsqueeze(1),
)
logits = decoder.final_proj(decoder_output)
# the "targets" are all tokens except the first one (beginning-of-sentence)
labels = padded[:, 1:].clone()
# make the loss ignore the padding tokens
labels[labels==0] = -100
# make the loss ignore the first label (which is always the language tag; it doesn't have to be predicted)
labels[:, 0] = -100
loss = loss_fn(logits[:, :-1].reshape(-1, logits.size(-1)), labels.view(-1))
per_token_loss = loss.view(logits[:, :-1].shape[:2])
# per-sentence loss is the sum of its per-token losses
per_sent_loss = per_token_loss.sum(-1)
# we also compute the number of tokens, so that we could normalize the total loss by the total number of tokens
per_sent_toks = (labels !=-100).sum(1)
return per_sent_loss, per_sent_toks
# try it with a sample batch
enc = TextToEmbeddingModelPipeline(
encoder="text_sonar_basic_encoder",
tokenizer="text_sonar_basic_encoder",
device=torch.device("cuda"),
)
dec = EmbeddingToTextModelPipeline(
decoder="text_sonar_basic_decoder",
tokenizer="text_sonar_basic_encoder",
device=torch.device("cuda"),
)
batch_text = [
"hello world",
"hello",
"hello world. my name is jeff",
]
target_lang = "eng_Latn"
# a list of integer tensors (token ids) of different lengths
batch_text_tokenized = [dec.tokenizer.create_encoder(mode='target', lang=target_lang)(text) for text in batch_text]
# 3*1024 matrix
batch_embs = enc.predict(batch_text, source_lang = "eng_Latn")
# compute the losses!
with torch.inference_mode():
losses, n_toks = get_decoder_loss(dec.model.decoder, batch_text_tokenized, batch_embs)
print(losses)
# tensor([0.2440, 2.6020, 3.5046], device='cuda:0')
print(n_toks)
# tensor([3, 2, 9], device='cuda:0')
# If you are interested in the average per-token loss (which is normally optimized during training and is directly related to text perplexity),
# you can compute it by adding up all the sentence losses and dividing them by the total number of tokens:
avg_loss = losses.sum() / n_toks.sum()
print(avg_loss)
# tensor(0.4536, device='cuda:0')
I hope this message finds you well. I would like to obtain a decoder that uses a tokenizer different from the original one. Could you please advise on how I should train or fine-tune the model in this case?
Would it be appropriate to directly replace the tokenizer in the provided demo with my own tokenizer, or would additional modifications be necessary?
I would greatly appreciate any guidance you could offer.
Thank you very much for your time and assistance.
Dear Authors,
I hope this message finds you well. I would like to obtain a decoder that uses a tokenizer different from the original one. Could you please advise on how I should train or fine-tune the model in this case?
Would it be appropriate to directly replace the tokenizer in the provided demo with my own tokenizer, or would additional modifications be necessary?
I would greatly appreciate any guidance you could offer.
Thank you very much for your time and assistance.