From 6eb62d9c984ac3ae765308ac7915f18e2a0077c7 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Tue, 5 Sep 2023 20:44:15 +0000 Subject: [PATCH 1/9] add Alibi --- .../modules/layers/position_embedding.py | 98 +++++++++++++++++++ 1 file changed, 98 insertions(+) diff --git a/torchmultimodal/modules/layers/position_embedding.py b/torchmultimodal/modules/layers/position_embedding.py index 7920ce1cc..c703c81dc 100644 --- a/torchmultimodal/modules/layers/position_embedding.py +++ b/torchmultimodal/modules/layers/position_embedding.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import itertools +import math from typing import Tuple import torch @@ -169,3 +170,100 @@ def forward(self, t: Tensor) -> Tensor: if self.embed_dim % 2 == 1: embeddings = nn.functional.pad(embeddings, (0, 1)) return embeddings + + +class AlibiPositionEmbeddings(nn.Module): + """Attention with Linear Biases (ALiBi) + + # Softmax(qiKT + m ยท [-(i - 1), ..., -2, -1, 0]), + where m = fixed specific slope per head + + as proposed in: + https://arxiv.org/abs/2108.12409 + Train Short, Test Long: Attention with Linear Biases + Enables Input Length Extrapolation + + derived from Ofir Press (alibi author) codebase: + https://github.com/ofirpress/attention_with_linear_biases + + """ + + def __init__( + self, + max_seq_len: int, + num_heads: int, + ) -> None: + """recommended usage: create alibi mask before transformer block loop and integrate + Alibi should be applied before the sqrt scaling of the attention values + + Example: + before Transformer block loop: + from alibi_embeddings import AlibiPE + self.alibi = AlibiPE(config.max_seq_len, config.num_heads) + pass a reference to the alibi class to each transformer layer + then in forward of transformer layer: + alibi_mask = self.alibi.get_attention_mask(N) # N = seq length of this batch + ... + attn = q @ k.transpose( -2, -1) + attn += alibi_mask + attn *= 1.0 / math.sqrt(k.size(-1)) + + """ + super().__init__() + + self.num_heads = num_heads + self.max_seq_len = max_seq_len + + self.causal_mask = self.build_causal_attention_mask( + self.max_seq_len, self.num_heads + ) + self.alibi_mask_base = self.build_alibi_mask(self.max_seq_len, self.num_heads) + self.decoder_mask = self.causal_mask + self.alibi_mask_base + self.register_buffer("alibi_mask", self.decoder_mask, persistent=False) + + def get_attention_mask(self, curr_seq_len: int) -> torch.tensor: + """returns the alibi mask, clipped to the current batch seq len""" + return self.alibi_mask[:, :curr_seq_len, :curr_seq_len] + + def build_causal_attention_mask(self, seq_len: int, num_heads: int) -> torch.Tensor: + """builds a generic causal attention mask""" + causal_mask = torch.ones(seq_len, seq_len).tril() + causal_mask = causal_mask.masked_fill(causal_mask == 0, -float("inf")) + attn_mask = causal_mask.repeat(num_heads, 1, 1) + return attn_mask + + def build_alibi_mask(self, seq_len: int, num_heads: int) -> torch.Tensor: + """generate the alibi mask by computing a distance matrix multiplied by each head's m (slope)""" + distance_matrix = torch.arange(seq_len) - torch.arange(seq_len).view(-1, 1) + slope_per_head = Tensor(self.get_slopes(num_heads)).view(-1, 1, 1) + alibi_mask = distance_matrix * slope_per_head + return alibi_mask + + def get_slopes(self, num_heads: int) -> torch.Tensor: + """for n heads, a range from (0,1) and is the geometric sequence + that starts at 2^(-8/n) and uses this same value as its ratio + + example: num_heads =4 + result: [0.25, 0.0625, 0.015625, 0.00390625] + + """ + + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(num_heads).is_integer(): + return get_slopes_power_of_2(num_heads) + + # paper authors note they only trained models that have 2^a heads for some a. + # This has beneficial properties related to input being power of 2. + # Closest power of 2 below is workaround for when num of heads is not power of 2 + + closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][ + : num_heads - closest_power_of_2 + ] + ) From 0c52d1134eaa2f330945c0c9862117d2a51cd0a8 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Wed, 6 Sep 2023 01:09:40 +0000 Subject: [PATCH 2/9] add unit tests for alibi --- .../modules/layers/test_position_embedding.py | 81 ++++++++++++++++++- 1 file changed, 80 insertions(+), 1 deletion(-) diff --git a/tests/modules/layers/test_position_embedding.py b/tests/modules/layers/test_position_embedding.py index 583394601..e580cee63 100644 --- a/tests/modules/layers/test_position_embedding.py +++ b/tests/modules/layers/test_position_embedding.py @@ -7,14 +7,20 @@ import pytest import torch -from tests.test_utils import assert_expected +from tests.test_utils import assert_expected, set_rng_seed from torch import nn from torchmultimodal.modules.layers.position_embedding import ( + AlibiPositionEmbeddings, BroadcastedPositionEmbedding, SinusoidalPositionEmbeddings, ) +@pytest.fixture(autouse=True) +def random(): + set_rng_seed(2023) + + class TestBroadcastedPositionEmbedding: @pytest.fixture(scope="class") def pos_emb(self): @@ -112,3 +118,76 @@ def test_forward(self, data, emb): actual = emb(data) expected = torch.Size([3, 5]) assert_expected(actual.shape, expected) + + +class TestAlibiPositionEmbedding: + @pytest.fixture + def max_seq_len(self): + return 16 + + @pytest.fixture + def embedding_dim(self): + return 32 + + @pytest.fixture + def num_heads(self): + return 8 + + @pytest.fixture + def data(self, max_seq_len, embedding_dim): + return torch.randn(1, max_seq_len, embedding_dim) # bs, seq_len, emb_dim + + def test_alibi_mask( + self, + data, + max_seq_len, + num_heads, + ): + alibi_class = AlibiPositionEmbeddings( + max_seq_len=max_seq_len, num_heads=num_heads + ) + base_mask = alibi_class.get_attention_mask(max_seq_len) + + # verify mask shape + expected_shape = torch.Size((num_heads, max_seq_len, max_seq_len)) + assert_expected(base_mask.shape, expected_shape) + + # verify alibi mask components + expected_last_head_row = torch.tensor( + [ + 0.9414, + 0.9453, + 0.9492, + 0.9531, + 0.9570, + 0.9609, + 0.9648, + 0.9688, + 0.9727, + 0.9766, + 0.9805, + 0.9844, + 0.9883, + 0.9922, + 0.9961, + 1.0000, + ] + ) + + expected_first_head_first_row_first_entry = torch.tensor( + 1.0000, + ) + + assert_expected( + base_mask[0][0][0], + expected_first_head_first_row_first_entry, + rtol=0, + atol=1e-4, + ) + + assert_expected( + base_mask[num_heads - 1][max_seq_len - 1], + expected_last_head_row, + rtol=0, + atol=1e-4, + ) From f263b23579548c2554a15ff54568a8b52c207fe4 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Wed, 6 Sep 2023 01:10:30 +0000 Subject: [PATCH 3/9] remove unused fixture --- tests/modules/layers/test_position_embedding.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/modules/layers/test_position_embedding.py b/tests/modules/layers/test_position_embedding.py index e580cee63..1a4d868b1 100644 --- a/tests/modules/layers/test_position_embedding.py +++ b/tests/modules/layers/test_position_embedding.py @@ -133,10 +133,6 @@ def embedding_dim(self): def num_heads(self): return 8 - @pytest.fixture - def data(self, max_seq_len, embedding_dim): - return torch.randn(1, max_seq_len, embedding_dim) # bs, seq_len, emb_dim - def test_alibi_mask( self, data, From 926b3fb519ff69012757c7487c15a1ba0a9a017a Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Wed, 6 Sep 2023 03:46:36 +0000 Subject: [PATCH 4/9] update typedef for slopes inline function --- torchmultimodal/modules/layers/position_embedding.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchmultimodal/modules/layers/position_embedding.py b/torchmultimodal/modules/layers/position_embedding.py index c703c81dc..151c4b17a 100644 --- a/torchmultimodal/modules/layers/position_embedding.py +++ b/torchmultimodal/modules/layers/position_embedding.py @@ -6,7 +6,7 @@ import itertools import math -from typing import Tuple +from typing import List, Tuple import torch from torch import nn, Tensor @@ -221,7 +221,7 @@ def __init__( self.decoder_mask = self.causal_mask + self.alibi_mask_base self.register_buffer("alibi_mask", self.decoder_mask, persistent=False) - def get_attention_mask(self, curr_seq_len: int) -> torch.tensor: + def get_attention_mask(self, curr_seq_len: int) -> torch.Tensor: """returns the alibi mask, clipped to the current batch seq len""" return self.alibi_mask[:, :curr_seq_len, :curr_seq_len] @@ -239,7 +239,7 @@ def build_alibi_mask(self, seq_len: int, num_heads: int) -> torch.Tensor: alibi_mask = distance_matrix * slope_per_head return alibi_mask - def get_slopes(self, num_heads: int) -> torch.Tensor: + def get_slopes(self, num_heads: int) -> List[float]: """for n heads, a range from (0,1) and is the geometric sequence that starts at 2^(-8/n) and uses this same value as its ratio From 8aadfb4b6af30dbc2dae72240e2a6eb8fdcad2e3 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Wed, 6 Sep 2023 04:03:10 +0000 Subject: [PATCH 5/9] all typedefs for function get_slopes_power_of_2 --- torchmultimodal/modules/layers/position_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmultimodal/modules/layers/position_embedding.py b/torchmultimodal/modules/layers/position_embedding.py index 151c4b17a..e637e009e 100644 --- a/torchmultimodal/modules/layers/position_embedding.py +++ b/torchmultimodal/modules/layers/position_embedding.py @@ -248,7 +248,7 @@ def get_slopes(self, num_heads: int) -> List[float]: """ - def get_slopes_power_of_2(n): + def get_slopes_power_of_2(n: int) -> List[float]: start = 2 ** (-(2 ** -(math.log2(n) - 3))) ratio = start return [start * ratio**i for i in range(n)] From d409abc1a3600763dbd52eb97bb3f9fb14be57e4 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Wed, 6 Sep 2023 04:06:25 +0000 Subject: [PATCH 6/9] remove unused data fixture, param from alibi test --- tests/modules/layers/test_position_embedding.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/modules/layers/test_position_embedding.py b/tests/modules/layers/test_position_embedding.py index 1a4d868b1..df51b13c0 100644 --- a/tests/modules/layers/test_position_embedding.py +++ b/tests/modules/layers/test_position_embedding.py @@ -135,7 +135,6 @@ def num_heads(self): def test_alibi_mask( self, - data, max_seq_len, num_heads, ): From 82876cfa4920e35198575bbb8e2ffc37b99a7491 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Thu, 7 Sep 2023 14:20:51 +0000 Subject: [PATCH 7/9] update alibi mask --- .../modules/layers/test_position_embedding.py | 34 +++++++++---------- .../modules/layers/position_embedding.py | 28 +++++++++------ 2 files changed, 34 insertions(+), 28 deletions(-) diff --git a/tests/modules/layers/test_position_embedding.py b/tests/modules/layers/test_position_embedding.py index df51b13c0..2af1c52c5 100644 --- a/tests/modules/layers/test_position_embedding.py +++ b/tests/modules/layers/test_position_embedding.py @@ -150,27 +150,27 @@ def test_alibi_mask( # verify alibi mask components expected_last_head_row = torch.tensor( [ - 0.9414, - 0.9453, - 0.9492, - 0.9531, - 0.9570, - 0.9609, - 0.9648, - 0.9688, - 0.9727, - 0.9766, - 0.9805, - 0.9844, - 0.9883, - 0.9922, - 0.9961, - 1.0000, + -0.0586, + -0.0547, + -0.0508, + -0.0469, + -0.0430, + -0.0391, + -0.0352, + -0.0312, + -0.0273, + -0.0234, + -0.0195, + -0.0156, + -0.0117, + -0.0078, + -0.0039, + 0.0000, ] ) expected_first_head_first_row_first_entry = torch.tensor( - 1.0000, + 0.0000, ) assert_expected( diff --git a/torchmultimodal/modules/layers/position_embedding.py b/torchmultimodal/modules/layers/position_embedding.py index e637e009e..5fda2b6a8 100644 --- a/torchmultimodal/modules/layers/position_embedding.py +++ b/torchmultimodal/modules/layers/position_embedding.py @@ -223,23 +223,29 @@ def __init__( def get_attention_mask(self, curr_seq_len: int) -> torch.Tensor: """returns the alibi mask, clipped to the current batch seq len""" - return self.alibi_mask[:, :curr_seq_len, :curr_seq_len] + return self.alibi_mask[..., :curr_seq_len, :curr_seq_len] - def build_causal_attention_mask(self, seq_len: int, num_heads: int) -> torch.Tensor: + @classmethod + def build_causal_attention_mask(cls, seq_len: int, num_heads: int) -> torch.Tensor: """builds a generic causal attention mask""" - causal_mask = torch.ones(seq_len, seq_len).tril() - causal_mask = causal_mask.masked_fill(causal_mask == 0, -float("inf")) + causal_mask = torch.triu( + torch.ones(seq_len, seq_len) * float("-inf"), diagonal=1 + ) attn_mask = causal_mask.repeat(num_heads, 1, 1) return attn_mask - def build_alibi_mask(self, seq_len: int, num_heads: int) -> torch.Tensor: - """generate the alibi mask by computing a distance matrix multiplied by each head's m (slope)""" - distance_matrix = torch.arange(seq_len) - torch.arange(seq_len).view(-1, 1) - slope_per_head = Tensor(self.get_slopes(num_heads)).view(-1, 1, 1) - alibi_mask = distance_matrix * slope_per_head + @classmethod + def build_alibi_mask(cls, seq_len: int, num_heads: int) -> torch.Tensor: + """generate the alibi mask by computing a distance bias matrix multiplied by each head's m (slope)""" + distance_bias_matrix = -torch.abs( + torch.arange(seq_len) - torch.arange(seq_len).view(-1, 1) + ) + slope_per_head = Tensor(cls.get_slopes(num_heads)).view(-1, 1, 1) + alibi_mask = distance_bias_matrix * slope_per_head return alibi_mask - def get_slopes(self, num_heads: int) -> List[float]: + @staticmethod + def get_slopes(num_heads: int) -> List[float]: """for n heads, a range from (0,1) and is the geometric sequence that starts at 2^(-8/n) and uses this same value as its ratio @@ -256,7 +262,7 @@ def get_slopes_power_of_2(n: int) -> List[float]: if math.log2(num_heads).is_integer(): return get_slopes_power_of_2(num_heads) - # paper authors note they only trained models that have 2^a heads for some a. + # paper authors note that they only trained models that have 2^a heads for some a. # This has beneficial properties related to input being power of 2. # Closest power of 2 below is workaround for when num of heads is not power of 2 From 2289962107f424bbc33bcfd3368505f345424d5e Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Thu, 7 Sep 2023 14:33:43 +0000 Subject: [PATCH 8/9] update usage - alibi is applied after sqrt scaling --- torchmultimodal/modules/layers/position_embedding.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchmultimodal/modules/layers/position_embedding.py b/torchmultimodal/modules/layers/position_embedding.py index 5fda2b6a8..3f8f808ad 100644 --- a/torchmultimodal/modules/layers/position_embedding.py +++ b/torchmultimodal/modules/layers/position_embedding.py @@ -194,7 +194,7 @@ def __init__( num_heads: int, ) -> None: """recommended usage: create alibi mask before transformer block loop and integrate - Alibi should be applied before the sqrt scaling of the attention values + Alibi should be applied after the sqrt scaling of the attention values Example: before Transformer block loop: @@ -205,8 +205,8 @@ def __init__( alibi_mask = self.alibi.get_attention_mask(N) # N = seq length of this batch ... attn = q @ k.transpose( -2, -1) - attn += alibi_mask - attn *= 1.0 / math.sqrt(k.size(-1)) + att *= 1.0 / math.sqrt(k.size(-1)) + att += alibi_mask """ super().__init__() From ece9a0e9d6a946351e3884b11a3ad2fc5540e9a0 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Thu, 7 Sep 2023 23:44:46 +0000 Subject: [PATCH 9/9] return ordered slopes for non power 2, expand unit tests to cover non power 2 case --- .../modules/layers/test_position_embedding.py | 89 ++++++++++++++++++- .../modules/layers/position_embedding.py | 14 +-- 2 files changed, 96 insertions(+), 7 deletions(-) diff --git a/tests/modules/layers/test_position_embedding.py b/tests/modules/layers/test_position_embedding.py index 2af1c52c5..2f82aa91f 100644 --- a/tests/modules/layers/test_position_embedding.py +++ b/tests/modules/layers/test_position_embedding.py @@ -133,7 +133,11 @@ def embedding_dim(self): def num_heads(self): return 8 - def test_alibi_mask( + @pytest.fixture + def num_heads_non_power_2(self): + return 12 + + def test_alibi_mask_power_of_2( self, max_seq_len, num_heads, @@ -186,3 +190,86 @@ def test_alibi_mask( rtol=0, atol=1e-4, ) + + def test_alibi_mask_non_power_of_2( + self, + max_seq_len, + num_heads_non_power_2, + ): + alibi_class = AlibiPositionEmbeddings( + max_seq_len=max_seq_len, num_heads=num_heads_non_power_2 + ) + base_mask = alibi_class.get_attention_mask(max_seq_len) + + # verify mask shape + expected_shape = torch.Size((num_heads_non_power_2, max_seq_len, max_seq_len)) + assert_expected(base_mask.shape, expected_shape) + + # verify alibi mask components + expected_second_head_last_row = torch.tensor( + [ + -7.5000, + -7.0000, + -6.5000, + -6.0000, + -5.5000, + -5.0000, + -4.5000, + -4.0000, + -3.5000, + -3.0000, + -2.5000, + -2.0000, + -1.5000, + -1.0000, + -0.5000, + 0.0000, + ] + ) + + expected_third_head_last_row = torch.tensor( + [ + -5.3033, + -4.9497, + -4.5962, + -4.2426, + -3.8891, + -3.5355, + -3.1820, + -2.8284, + -2.4749, + -2.1213, + -1.7678, + -1.4142, + -1.0607, + -0.7071, + -0.3536, + 0.0000, + ] + ) + + expected_first_head_first_row_first_entry = torch.tensor( + 0.0000, + ) + + assert_expected( + base_mask[0][0][0], + expected_first_head_first_row_first_entry, + rtol=0, + atol=1e-4, + ) + + # verify 2nd and 3rd head to confirm non power 2 symmetry of slopes + assert_expected( + base_mask[1][max_seq_len - 1], + expected_second_head_last_row, + rtol=0, + atol=1e-4, + ) + + assert_expected( + base_mask[2][max_seq_len - 1], + expected_third_head_last_row, + rtol=0, + atol=1e-4, + ) diff --git a/torchmultimodal/modules/layers/position_embedding.py b/torchmultimodal/modules/layers/position_embedding.py index 3f8f808ad..32596a377 100644 --- a/torchmultimodal/modules/layers/position_embedding.py +++ b/torchmultimodal/modules/layers/position_embedding.py @@ -264,12 +264,14 @@ def get_slopes_power_of_2(n: int) -> List[float]: # paper authors note that they only trained models that have 2^a heads for some a. # This has beneficial properties related to input being power of 2. + # Closest power of 2 below is workaround for when num of heads is not power of 2 + # Slopes are returned in ordered sequence to keep symmetry. closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) - return ( - get_slopes_power_of_2(closest_power_of_2) - + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][ - : num_heads - closest_power_of_2 - ] - ) + + a = get_slopes_power_of_2(closest_power_of_2) + b = get_slopes_power_of_2(2 * closest_power_of_2)[0::2][ + : num_heads - closest_power_of_2 + ] + return [x for pair in zip(b, a) for x in pair] + a[len(b) :]