diff --git a/tests/modules/layers/test_position_embedding.py b/tests/modules/layers/test_position_embedding.py index 583394601..2f82aa91f 100644 --- a/tests/modules/layers/test_position_embedding.py +++ b/tests/modules/layers/test_position_embedding.py @@ -7,14 +7,20 @@ import pytest import torch -from tests.test_utils import assert_expected +from tests.test_utils import assert_expected, set_rng_seed from torch import nn from torchmultimodal.modules.layers.position_embedding import ( + AlibiPositionEmbeddings, BroadcastedPositionEmbedding, SinusoidalPositionEmbeddings, ) +@pytest.fixture(autouse=True) +def random(): + set_rng_seed(2023) + + class TestBroadcastedPositionEmbedding: @pytest.fixture(scope="class") def pos_emb(self): @@ -112,3 +118,158 @@ def test_forward(self, data, emb): actual = emb(data) expected = torch.Size([3, 5]) assert_expected(actual.shape, expected) + + +class TestAlibiPositionEmbedding: + @pytest.fixture + def max_seq_len(self): + return 16 + + @pytest.fixture + def embedding_dim(self): + return 32 + + @pytest.fixture + def num_heads(self): + return 8 + + @pytest.fixture + def num_heads_non_power_2(self): + return 12 + + def test_alibi_mask_power_of_2( + self, + max_seq_len, + num_heads, + ): + alibi_class = AlibiPositionEmbeddings( + max_seq_len=max_seq_len, num_heads=num_heads + ) + base_mask = alibi_class.get_attention_mask(max_seq_len) + + # verify mask shape + expected_shape = torch.Size((num_heads, max_seq_len, max_seq_len)) + assert_expected(base_mask.shape, expected_shape) + + # verify alibi mask components + expected_last_head_row = torch.tensor( + [ + -0.0586, + -0.0547, + -0.0508, + -0.0469, + -0.0430, + -0.0391, + -0.0352, + -0.0312, + -0.0273, + -0.0234, + -0.0195, + -0.0156, + -0.0117, + -0.0078, + -0.0039, + 0.0000, + ] + ) + + expected_first_head_first_row_first_entry = torch.tensor( + 0.0000, + ) + + assert_expected( + base_mask[0][0][0], + expected_first_head_first_row_first_entry, + rtol=0, + atol=1e-4, + ) + + assert_expected( + base_mask[num_heads - 1][max_seq_len - 1], + expected_last_head_row, + rtol=0, + atol=1e-4, + ) + + def test_alibi_mask_non_power_of_2( + self, + max_seq_len, + num_heads_non_power_2, + ): + alibi_class = AlibiPositionEmbeddings( + max_seq_len=max_seq_len, num_heads=num_heads_non_power_2 + ) + base_mask = alibi_class.get_attention_mask(max_seq_len) + + # verify mask shape + expected_shape = torch.Size((num_heads_non_power_2, max_seq_len, max_seq_len)) + assert_expected(base_mask.shape, expected_shape) + + # verify alibi mask components + expected_second_head_last_row = torch.tensor( + [ + -7.5000, + -7.0000, + -6.5000, + -6.0000, + -5.5000, + -5.0000, + -4.5000, + -4.0000, + -3.5000, + -3.0000, + -2.5000, + -2.0000, + -1.5000, + -1.0000, + -0.5000, + 0.0000, + ] + ) + + expected_third_head_last_row = torch.tensor( + [ + -5.3033, + -4.9497, + -4.5962, + -4.2426, + -3.8891, + -3.5355, + -3.1820, + -2.8284, + -2.4749, + -2.1213, + -1.7678, + -1.4142, + -1.0607, + -0.7071, + -0.3536, + 0.0000, + ] + ) + + expected_first_head_first_row_first_entry = torch.tensor( + 0.0000, + ) + + assert_expected( + base_mask[0][0][0], + expected_first_head_first_row_first_entry, + rtol=0, + atol=1e-4, + ) + + # verify 2nd and 3rd head to confirm non power 2 symmetry of slopes + assert_expected( + base_mask[1][max_seq_len - 1], + expected_second_head_last_row, + rtol=0, + atol=1e-4, + ) + + assert_expected( + base_mask[2][max_seq_len - 1], + expected_third_head_last_row, + rtol=0, + atol=1e-4, + ) diff --git a/torchmultimodal/modules/layers/position_embedding.py b/torchmultimodal/modules/layers/position_embedding.py index 7920ce1cc..32596a377 100644 --- a/torchmultimodal/modules/layers/position_embedding.py +++ b/torchmultimodal/modules/layers/position_embedding.py @@ -5,7 +5,8 @@ # LICENSE file in the root directory of this source tree. import itertools -from typing import Tuple +import math +from typing import List, Tuple import torch from torch import nn, Tensor @@ -169,3 +170,108 @@ def forward(self, t: Tensor) -> Tensor: if self.embed_dim % 2 == 1: embeddings = nn.functional.pad(embeddings, (0, 1)) return embeddings + + +class AlibiPositionEmbeddings(nn.Module): + """Attention with Linear Biases (ALiBi) + + # Softmax(qiKT + m ยท [-(i - 1), ..., -2, -1, 0]), + where m = fixed specific slope per head + + as proposed in: + https://arxiv.org/abs/2108.12409 + Train Short, Test Long: Attention with Linear Biases + Enables Input Length Extrapolation + + derived from Ofir Press (alibi author) codebase: + https://github.com/ofirpress/attention_with_linear_biases + + """ + + def __init__( + self, + max_seq_len: int, + num_heads: int, + ) -> None: + """recommended usage: create alibi mask before transformer block loop and integrate + Alibi should be applied after the sqrt scaling of the attention values + + Example: + before Transformer block loop: + from alibi_embeddings import AlibiPE + self.alibi = AlibiPE(config.max_seq_len, config.num_heads) + pass a reference to the alibi class to each transformer layer + then in forward of transformer layer: + alibi_mask = self.alibi.get_attention_mask(N) # N = seq length of this batch + ... + attn = q @ k.transpose( -2, -1) + att *= 1.0 / math.sqrt(k.size(-1)) + att += alibi_mask + + """ + super().__init__() + + self.num_heads = num_heads + self.max_seq_len = max_seq_len + + self.causal_mask = self.build_causal_attention_mask( + self.max_seq_len, self.num_heads + ) + self.alibi_mask_base = self.build_alibi_mask(self.max_seq_len, self.num_heads) + self.decoder_mask = self.causal_mask + self.alibi_mask_base + self.register_buffer("alibi_mask", self.decoder_mask, persistent=False) + + def get_attention_mask(self, curr_seq_len: int) -> torch.Tensor: + """returns the alibi mask, clipped to the current batch seq len""" + return self.alibi_mask[..., :curr_seq_len, :curr_seq_len] + + @classmethod + def build_causal_attention_mask(cls, seq_len: int, num_heads: int) -> torch.Tensor: + """builds a generic causal attention mask""" + causal_mask = torch.triu( + torch.ones(seq_len, seq_len) * float("-inf"), diagonal=1 + ) + attn_mask = causal_mask.repeat(num_heads, 1, 1) + return attn_mask + + @classmethod + def build_alibi_mask(cls, seq_len: int, num_heads: int) -> torch.Tensor: + """generate the alibi mask by computing a distance bias matrix multiplied by each head's m (slope)""" + distance_bias_matrix = -torch.abs( + torch.arange(seq_len) - torch.arange(seq_len).view(-1, 1) + ) + slope_per_head = Tensor(cls.get_slopes(num_heads)).view(-1, 1, 1) + alibi_mask = distance_bias_matrix * slope_per_head + return alibi_mask + + @staticmethod + def get_slopes(num_heads: int) -> List[float]: + """for n heads, a range from (0,1) and is the geometric sequence + that starts at 2^(-8/n) and uses this same value as its ratio + + example: num_heads =4 + result: [0.25, 0.0625, 0.015625, 0.00390625] + + """ + + def get_slopes_power_of_2(n: int) -> List[float]: + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(num_heads).is_integer(): + return get_slopes_power_of_2(num_heads) + + # paper authors note that they only trained models that have 2^a heads for some a. + # This has beneficial properties related to input being power of 2. + + # Closest power of 2 below is workaround for when num of heads is not power of 2 + # Slopes are returned in ordered sequence to keep symmetry. + + closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) + + a = get_slopes_power_of_2(closest_power_of_2) + b = get_slopes_power_of_2(2 * closest_power_of_2)[0::2][ + : num_heads - closest_power_of_2 + ] + return [x for pair in zip(b, a) for x in pair] + a[len(b) :]