TransformerDecoder: optional positional encoding and final matmul#93
TransformerDecoder: optional positional encoding and final matmul#93Gerstenberger wants to merge 4 commits intomainfrom
Conversation
| num_output: int | ||
| logits_bias: bool | ||
| share_embedding: bool | ||
| use_positional_encoding: bool = True |
There was a problem hiding this comment.
I wonder if, instead of being a flag, this should be a configurable module instead, which you simply replace with a noop if you don't want any positional encoding. This would allow using other positional encoding schemes other than sinusoidal as well.
There was a problem hiding this comment.
Yes, agree, better would be to have this more dynamic.
ConformerMHSARelPosV1._sinusoidal_pe should maybe be moved to a separate function, and then you would have positional_encoding=absolute_sinusoidal_positional_encoding as default, and None is also allowed.
| logits_bias: bool | ||
| share_embedding: bool | ||
| use_positional_encoding: bool = True | ||
| do_output_embedding_matmul: bool = True |
There was a problem hiding this comment.
Perhaps
| do_output_embedding_matmul: bool = True | |
| embed_outputs_to_vocab_dim: bool = True |
is clearer naming-wise?
There was a problem hiding this comment.
I don't think it's cleaner. But I also don't like the original name. But I'm also not sure whether I like the logic at all (see my separate comment on this, why to have the out_logits at all if it is not used).
|
As a first comment (I will try to comment in more detail later): The same questions have been thought about in the RF implementation, for Transformer encoder, decoder, and very related also Conformer encoder (to make the frontend optional, etc). Current RF TransformerDecoder implementation. It already has the |
| @@ -190,13 +194,20 @@ def __init__(self, cfg: TransformerDecoderV1Config): | |||
| else: | |||
| self.out_logits = nn.Linear(self.model_dim, cfg.num_output, bias=cfg.logits_bias) | |||
There was a problem hiding this comment.
I just realize, this sharing is weird. I would always set self.out_logits. If sharing, you can just do self.out_logits.weights = self.input_embedding.weight. That would simplify the other code.
Also, self.out_logits should always be set (be None if not used). But with my suggestion, you don't need to care about this.
And then you would also allow to have logits_bias=True with share_embedding=True.
| logits_bias: bool | ||
| share_embedding: bool | ||
| use_positional_encoding: bool = True | ||
| do_output_embedding_matmul: bool = True |
There was a problem hiding this comment.
If this is False, and not cfg.share_embedding, the out_logits are not used at all. Does it make sense to even have them then?
|
I made a initial proposal but I am not really sure about the changes. Maybe the proposal is too complicated/not straightforward enough, which I tend towards. In general, would a Please let me know. |
| :param lengths: input lengths | ||
| :param state: current state of positional encoding. | ||
| """ | ||
| sinus_pe = ConformerMHSARelPosV1._sinusoidal_pe( |
There was a problem hiding this comment.
Maybe should be moved to primitives?
| :param state: current state of positional encoding. | ||
| """ | ||
| sinus_pe = ConformerMHSARelPosV1._sinusoidal_pe( | ||
| torch.arange(labels.shape[-1], device=labels.device) + state["pos"], self.embed_dim |
There was a problem hiding this comment.
Do we have the constraint labels.shape[-1] == lenghts.max()? If so labels input can be removed.
Or do we only return sinus_pre.unsqueeze(0) and apply the addition later?
|
|
||
| block_state: List[TransformerDecoderBlockV1State] | ||
| pos: Tensor | ||
| pos_state: NotRequired[PositionalEncodingV1State] |
There was a problem hiding this comment.
Okay, maybe should not change names and type as this breaks existing setups.
Changes for positional encoding and the final matrix multiplication of model output and output embedding matrix to be both optional.
This allows us to use the implementation for self-normalized LM Transformer training, where positional encoding is not required and the final matmul is replaced by another matmul in the sampling loss.
My only question is: should this be a
TransformerDecoderV2instead?