Skip to content

How to generate one token after the other with Scibert? #128

@junoriosity

Description

@junoriosity

I would like to use Scibert for iterated token generation. Here is my code:

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM

device = "cuda"
tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')
model = AutoModelForCausalLM.from_pretrained('allenai/scibert_scivocab_uncased').to(device)

input_sequence = "Hello, I'm a language model,"

inputs = torch.as_tensor(tokenizer.encode(input_sequence)).unsqueeze(0).to(device)
attention_mask = torch.as_tensor(tokenizer(input_sequence).attention_mask).unsqueeze(0).to(device)
past_key_values = None

count = 0
complete_token = []
with torch.no_grad():
    while count < 10:
        count += 1
        print("Iteration no.: " + str(count))
        if count > 1:
            inputs = input_token

        print(inputs.to(device))
        print(attention_mask)
        print(past_key_values[0][0].shape if past_key_values else None)

        model_out = model(input_ids=inputs.to(device), attention_mask=attention_mask, past_key_values=past_key_values)
        logits = model_out.logits[:, -1, :]
        past_key_values = model_out.past_key_values

        topk_values, topk_indices = torch.topk(logits, 5)

        log_probs = F.softmax(topk_values, dim=-1)
        inputs_in_topk = torch.multinomial(log_probs, num_samples=1, replacement=True)
        input_token = torch.gather(topk_indices, 1, inputs_in_topk)
        attention_mask = torch.concat((attention_mask, torch.tensor([[1]]).to(attention_mask.device)), dim=1)
        complete_token.append(input_token)

However, we have past_key_values = Null all the time. I tried this approach with other models and past_key_values is not null there. How can I make the iteration work here, such that we have the knowledge of the previous iteration?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions