Skip to content

how to decode the embeds to original sequences #12

@samw0806

Description

@samw0806

`import model_utils # found under the "tcr" folder

tcrbert_trb_cls = model_utils.load_classification_pipeline("wukevin/tcr-bert", device=0)

df = model_utils.reformat_classification_pipeline_preds(tcrbert_trb_cls([
"C A S S P V T G G I Y G Y T F", # Binds to NLVPMVATV CMV antigen
"C A T S G R A G V E Q F F", # Binds to GILGFVFTL flu antigen
]))
Is there any decoder that can decode the dataframe's output? Maybe like
'model = BertModel.from_pretrained("wukevin/tcr-bert").to(device)
tokenizer = BertTokenizer.from_pretrained("wukevin/tcr-bert")
outputs = model.decoder(embedding_tensor)
logits = outputs.logits
predicted_ids = torch.argmax(logits, dim=-1)

decoded_seq = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
decoded_sequences.append(decoded_seq)`

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions