def compute(self, token_prediction_prob, tokens):
hits = torch.sum(torch.argmax(token_prediction_prob, dim=-1) == tokens)
NDCG10 = recalls_and_ndcgs_for_ks(token_prediction_prob.view(-1, token_prediction_prob.shape[-1]),
tokens.reshape(-1, 1), 10)
reconstruct_loss = self.ce(token_prediction_prob.view(-1, token_prediction_prob.shape[-1]), tokens.view(-1))
return reconstruct_loss, hits, NDCG10