diff --git a/metrics/perplexity/perplexity.py b/metrics/perplexity/perplexity.py index ad307e8ad..557172cdb 100644 --- a/metrics/perplexity/perplexity.py +++ b/metrics/perplexity/perplexity.py @@ -166,7 +166,7 @@ def _compute( encoded_batch = encoded_texts[start_index:end_index] attn_mask = attn_masks[start_index:end_index] - if add_start_token: + if add_start_token and tokenizer.bos_token_id is not None: bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * encoded_batch.size(dim=0)).to(device) encoded_batch = torch.cat([bos_tokens_tensor, encoded_batch], dim=1) attn_mask = torch.cat(