Skip to content

IndexError in generate protein notebook #152

@ekiefl

Description

@ekiefl

Hello,

Thanks for the great tool. I'm excited to use ProtTrans for generating protein sequences, but I'm getting an index error in the example notebook (https://github.com/agemagician/ProtTrans/blob/master/Generate/ProtXLNet.ipynb).

The error occurs when running cell 12:

output_ids = model.generate(
        input_ids=input_ids,
        max_length=max_length,
        temperature=temperature,
        top_k=k,
        top_p=p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        num_return_sequences=num_return_sequences,
    )

Here's the full traceback:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[24], line 1
----> 1 output_ids = model.generate(
      2         input_ids=input_ids,
      3         max_length=max_length,
      4         temperature=temperature,
      5         top_k=k,
      6         top_p=p,
      7         repetition_penalty=repetition_penalty,
      8         do_sample=True,
      9         num_return_sequences=num_return_sequences,
     10     )

File [~/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/utils/_contextlib.py:115](http://localhost:8888/lab/workspaces/auto-D/tree/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/utils/_contextlib.py#line=114), in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File [~/miniconda3/envs/genseq/lib/python3.10/site-packages/transformers/generation/utils.py:1758](http://localhost:8888/lab/workspaces/auto-D/tree/miniconda3/envs/genseq/lib/python3.10/site-packages/transformers/generation/utils.py#line=1757), in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   1750     input_ids, model_kwargs = self._expand_inputs_for_generation(
   1751         input_ids=input_ids,
   1752         expand_size=generation_config.num_return_sequences,
   1753         is_encoder_decoder=self.config.is_encoder_decoder,
   1754         **model_kwargs,
   1755     )
   1757     # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 1758     result = self._sample(
   1759         input_ids,
   1760         logits_processor=prepared_logits_processor,
   1761         logits_warper=prepared_logits_warper,
   1762         stopping_criteria=prepared_stopping_criteria,
   1763         generation_config=generation_config,
   1764         synced_gpus=synced_gpus,
   1765         streamer=streamer,
   1766         **model_kwargs,
   1767     )
   1769 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
   1770     # 11. prepare logits warper
   1771     prepared_logits_warper = (
   1772         self._get_logits_warper(generation_config) if generation_config.do_sample else None
   1773     )

File [~/miniconda3/envs/genseq/lib/python3.10/site-packages/transformers/generation/utils.py:2397](http://localhost:8888/lab/workspaces/auto-D/tree/miniconda3/envs/genseq/lib/python3.10/site-packages/transformers/generation/utils.py#line=2396), in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)
   2394 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
   2396 # forward pass to get next token
-> 2397 outputs = self(
   2398     **model_inputs,
   2399     return_dict=True,
   2400     output_attentions=output_attentions,
   2401     output_hidden_states=output_hidden_states,
   2402 )
   2404 if synced_gpus and this_peer_finished:
   2405     continue  # don't waste resources running the code we don't need

File [~/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1532](http://localhost:8888/lab/workspaces/auto-D/tree/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1531), in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File [~/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1541](http://localhost:8888/lab/workspaces/auto-D/tree/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1540), in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File [~/miniconda3/envs/genseq/lib/python3.10/site-packages/transformers/models/xlnet/modeling_xlnet.py:1440](http://localhost:8888/lab/workspaces/auto-D/tree/miniconda3/envs/genseq/lib/python3.10/site-packages/transformers/models/xlnet/modeling_xlnet.py#line=1439), in XLNetLMHeadModel.forward(self, input_ids, attention_mask, mems, perm_mask, target_mapping, token_type_ids, input_mask, head_mask, inputs_embeds, labels, use_mems, output_attentions, output_hidden_states, return_dict, **kwargs)
   1370 r"""
   1371 labels (`torch.LongTensor` of shape `(batch_size, num_predict)`, *optional*):
   1372     Labels for masked language modeling. `num_predict` corresponds to `target_mapping.shape[1]`. If
   (...)
   1436 ... )  # Logits have shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]
   1437 ```"""
   1438 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-> 1440 transformer_outputs = self.transformer(
   1441     input_ids,
   1442     attention_mask=attention_mask,
   1443     mems=mems,
   1444     perm_mask=perm_mask,
   1445     target_mapping=target_mapping,
   1446     token_type_ids=token_type_ids,
   1447     input_mask=input_mask,
   1448     head_mask=head_mask,
   1449     inputs_embeds=inputs_embeds,
   1450     use_mems=use_mems,
   1451     output_attentions=output_attentions,
   1452     output_hidden_states=output_hidden_states,
   1453     return_dict=return_dict,
   1454     **kwargs,
   1455 )
   1457 logits = self.lm_loss(transformer_outputs[0])
   1459 loss = None

File [~/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1532](http://localhost:8888/lab/workspaces/auto-D/tree/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1531), in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File [~/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1541](http://localhost:8888/lab/workspaces/auto-D/tree/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1540), in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File [~/miniconda3/envs/genseq/lib/python3.10/site-packages/transformers/models/xlnet/modeling_xlnet.py:1170](http://localhost:8888/lab/workspaces/auto-D/tree/miniconda3/envs/genseq/lib/python3.10/site-packages/transformers/models/xlnet/modeling_xlnet.py#line=1169), in XLNetModel.forward(self, input_ids, attention_mask, mems, perm_mask, target_mapping, token_type_ids, input_mask, head_mask, inputs_embeds, use_mems, output_attentions, output_hidden_states, return_dict, **kwargs)
   1168     word_emb_k = inputs_embeds
   1169 else:
-> 1170     word_emb_k = self.word_embedding(input_ids)
   1171 output_h = self.dropout(word_emb_k)
   1172 if target_mapping is not None:

File [~/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1532](http://localhost:8888/lab/workspaces/auto-D/tree/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1531), in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File [~/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/modules/module.py:1541](http://localhost:8888/lab/workspaces/auto-D/tree/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/modules/module.py#line=1540), in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File [~/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/modules/sparse.py:163](http://localhost:8888/lab/workspaces/auto-D/tree/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/modules/sparse.py#line=162), in Embedding.forward(self, input)
    162 def forward(self, input: Tensor) -> Tensor:
--> 163     return F.embedding(
    164         input, self.weight, self.padding_idx, self.max_norm,
    165         self.norm_type, self.scale_grad_by_freq, self.sparse)

File [~/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/functional.py:2264](http://localhost:8888/lab/workspaces/auto-D/tree/miniconda3/envs/genseq/lib/python3.10/site-packages/torch/nn/functional.py#line=2263), in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   2258     # Note [embedding_renorm set_grad_enabled]
   2259     # XXX: equivalent to
   2260     # with torch.no_grad():
   2261     #   torch.embedding_renorm_
   2262     # remove once script supports set_grad_enabled
   2263     _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2264 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)

IndexError: index out of range in self

This has occurred on two different sets of hardware.

Thanks for taking a look.

Evan

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions