From 000ded0596c81649bd9af079837b920c84854b9e Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Mon, 8 Nov 2021 17:09:49 +0800 Subject: [PATCH 1/2] Improve pos_embed calculation Make `rpe_bucket` `nn.Parameter` to avoid copying in forward Use a single linear to calculate `pos_q` and `pos_k` to increase forward speed --- fairseq/modules/transformer_sentence_encoder.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/fairseq/modules/transformer_sentence_encoder.py b/fairseq/modules/transformer_sentence_encoder.py index 524a510..2926f3f 100644 --- a/fairseq/modules/transformer_sentence_encoder.py +++ b/fairseq/modules/transformer_sentence_encoder.py @@ -63,7 +63,7 @@ def relative_position_bucket(relative_position, bidirectional=True, num_buckets= val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) ret += torch.where(is_small, n, val_if_large) - return ret + return nn.Parameter(ret, requires_grad=False) class TransformerSentenceEncoder(nn.Module): """ @@ -127,8 +127,7 @@ def __init__( self.attn_scale_factor = 2 self.num_attention_heads = num_attention_heads self.pos = nn.Embedding(self.max_seq_len + 1, self.embedding_dim) - self.pos_q_linear = nn.Linear(self.embedding_dim, self.embedding_dim) - self.pos_k_linear = nn.Linear(self.embedding_dim, self.embedding_dim) + self.pos_proj = nn.Linear(self.embedding_dim, self.embedding_dim * 2) self.pos_scaling = float(self.embedding_dim / num_attention_heads * self.attn_scale_factor) ** -0.5 self.pos_ln = LayerNorm(self.embedding_dim, export=export) self.layers = nn.ModuleList( @@ -185,11 +184,9 @@ def __init__( def get_rel_pos_bias(self, x): # Assume the input is ordered. If your input token is permuted, you may need to update this accordingly - if self.rp_bucket.device != x.device: - self.rp_bucket = self.rp_bucket.to(x.device) seq_len = x.size(1) rp_bucket = self.rp_bucket[:seq_len, :seq_len] - values = F.embedding(rp_bucket, self.relative_attention_bias.weight) + values = self.relative_attention_bias(rp_bucket) values = values.permute([2, 0, 1]) return values.contiguous() @@ -227,8 +224,8 @@ def forward( # 0 is for other-to-cls 1 is for cls-to-other # Assume the input is ordered. If your input token is permuted, you may need to update this accordingly weight = self.pos_ln(self.pos.weight[:seq_len + 1, :]) - pos_q = self.pos_q_linear(weight).view(seq_len + 1, self.num_attention_heads, -1).transpose(0, 1) * self.pos_scaling - pos_k = self.pos_k_linear(weight).view(seq_len + 1, self.num_attention_heads, -1).transpose(0, 1) + pos_q, pos_k = self.pos_proj(weight).reshape(seq_len + 1, 2, self.num_attention_heads, -1).permute(1, 2, 0, 3) + pos_q *= self.pos_scaling abs_pos_bias = torch.bmm(pos_q, pos_k.transpose(1, 2)) # p_0 \dot p_0 is cls to others cls_2_other = abs_pos_bias[:, 0, 0] From b5b1e0bc2ec8f12798c39ed6a5765d21acf24556 Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Mon, 8 Nov 2021 18:19:05 +0800 Subject: [PATCH 2/2] fix bug of using *= in tensor operation --- fairseq/modules/transformer_sentence_encoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fairseq/modules/transformer_sentence_encoder.py b/fairseq/modules/transformer_sentence_encoder.py index 2926f3f..bb5f6ab 100644 --- a/fairseq/modules/transformer_sentence_encoder.py +++ b/fairseq/modules/transformer_sentence_encoder.py @@ -225,7 +225,7 @@ def forward( # Assume the input is ordered. If your input token is permuted, you may need to update this accordingly weight = self.pos_ln(self.pos.weight[:seq_len + 1, :]) pos_q, pos_k = self.pos_proj(weight).reshape(seq_len + 1, 2, self.num_attention_heads, -1).permute(1, 2, 0, 3) - pos_q *= self.pos_scaling + pos_q = pos_q * self.pos_scaling abs_pos_bias = torch.bmm(pos_q, pos_k.transpose(1, 2)) # p_0 \dot p_0 is cls to others cls_2_other = abs_pos_bias[:, 0, 0]