-
Notifications
You must be signed in to change notification settings - Fork 10
Description
`import model_utils # found under the "tcr" folder
tcrbert_trb_cls = model_utils.load_classification_pipeline("wukevin/tcr-bert", device=0)
df = model_utils.reformat_classification_pipeline_preds(tcrbert_trb_cls([
"C A S S P V T G G I Y G Y T F", # Binds to NLVPMVATV CMV antigen
"C A T S G R A G V E Q F F", # Binds to GILGFVFTL flu antigen
]))
Is there any decoder that can decode the dataframe's output? Maybe like
'model = BertModel.from_pretrained("wukevin/tcr-bert").to(device)
tokenizer = BertTokenizer.from_pretrained("wukevin/tcr-bert")
outputs = model.decoder(embedding_tensor)
logits = outputs.logits
predicted_ids = torch.argmax(logits, dim=-1)
decoded_seq = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
decoded_sequences.append(decoded_seq)`