Skip to content

Understanding how to define key, query and value for the cross attention calculation #119

@neuronphysics

Description

@neuronphysics

Hello,

I have problem understanding how I can use this library to implement cross attention

for instance if tensor x=torch.rand(100,14,64) is key, tensor y=torch.rand(100,11,64) is value and tensorz=torch.rand(100,14,1) is query, how can I use TransformerDecoderBuilder to compute the cross attention for this example?

Here is how I built encoder and decoder class:

import math
import fast_transformers
from fast_transformers.builders import TransformerEncoderBuilder, TransformerDecoderBuilder
from collections import OrderedDict


class PositionalEncoding(nn.Module):
    def __init__(self, max_len, d_model, dropout_prob=0.0, series_dimensions=1):
        global pe
        super().__init__()
        self.dropout = nn.Dropout(p=dropout_prob)
        self.d_model = d_model
        self.max_len = max_len
        self.series_dimensions = series_dimensions
        
        if self.series_dimensions == 1:
            if d_model % 2 != 0:
                raise ValueError("Cannot use sin/cos positional encoding with "
                                 "odd dim (got dim={:d})".format(d_model))
            pe = torch.zeros(self.max_len, d_model).float()
            pe.require_grad = False
            position = torch.arange(0, self.max_len, dtype=torch.float).unsqueeze(1)
            div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
            pe[:, 0::2] = torch.sin(position * div_term)
            pe[:, 1::2] = torch.cos(position * div_term)
        elif self.series_dimensions > 1:
            if d_model % 4 != 0:
                raise ValueError("Cannot use sin/cos positional encoding with "
                                 "odd dim (got dim={:d})".format(d_model))
            height = self.series_dimensions
            width = self.max_len
            pe = torch.zeros(d_model, height, width).float()
            pe.require_grad = False
            # Each dimension use half of d_model
            d_model = int(d_model / 2)
            div_term = torch.exp(torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model))
            pos_w = torch.arange(0., width).unsqueeze(1)
            pos_h = torch.arange(0., height).unsqueeze(1)
            pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
            pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
            pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
            pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
            pe = pe.view(2*d_model, height * width, -1).squeeze(-1) # Flattening it back to 1D series
            pe = pe.transpose(0, 1)
            
        pe = pe.unsqueeze(0) # Extending it by an extra leading dim for the batches
        self.register_buffer('pe', pe)

    # Expecting a flattened (1D) series
    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)


class LinearTransformerCausalEncoder(torch.nn.Module):
    def __init__(self, input_features, output_features, hidden_dim, sequence_length, 
                 attention_type='causal-linear', n_layers=2, n_heads=4,
                 dropout=0.1, softmax_temp=None, activation_fn="gelu",
                 attention_dropout=0.1,
                ):
        super(LinearTransformerCausalEncoder, self).__init__()
        #
        self.d_model=hidden_dim*n_heads
        #
        self.pos_embedding = PositionalEncoding(
                                               max_len=sequence_length,
                                               d_model=self.d_model, #hidden_dim*n_heads      
                                               )
        self.value_embedding = nn.Linear(
            input_features,
            self.d_model
        )
        self.builder_dict = OrderedDict({
            "attention_type": attention_type,
            "n_layers": n_layers,
            "n_heads": n_heads,
            "feed_forward_dimensions": self.d_model*2,
            "query_dimensions": hidden_dim,
            "value_dimensions": hidden_dim,
            "dropout": dropout,
            "softmax_temp": softmax_temp,
            "activation" : activation_fn,
            "attention_dropout": attention_dropout,
        })
        self.transformer = TransformerEncoderBuilder.from_dictionary(
            self.builder_dict,
            strict=True
        ).get()
        hidden_size = n_heads*hidden_dim
        ##
        self.predictor = torch.nn.Linear(
            hidden_size,
            output_features
        )
    def forward(self, x):
        # x: [batch_size, input_dim, sequence_length]
        x = x.permute(0,2,1)
        x = self.value_embedding(x) # x: [batch size, sequence_length, n_heads* hiden_size]
        x = self.pos_embedding(x) # x: [batch size, sequence_length, n_heads* hiden_size]
        triangular_mask = fast_transformers.masking.TriangularCausalMask(x.size(1), device=x.device) # triangular_mask: [ sequence_length,  sequence_length]       
        y_hat = self.transformer(x, attn_mask=triangular_mask) # y_hat: [batch size, sequence_length, n_heads* hiden_size]     
        y_hat = self.predictor(y_hat) # y_hat: [batch size, sequence_length, output_size]
        return y_hat.permute(0,2,1)   # y_hat: [batch size, output_size, sequence_length]

class LinearTransformerCausalDecoder(torch.nn.Module):
    def __init__(self, output_features, hidden_dim, sequence_length, 
                 attention_type='causal-linear', n_layers=2, n_heads=4,
                 d_query=32, dropout=0.1, softmax_temp=None,activation_fn="gelu",
                 attention_dropout=0.1,):
        super(LinearTransformerCausalDecoder, self).__init__()
        self.d_model=hidden_dim*n_heads
        self.pos_embedding = PositionalEncoding(
             max_len=sequence_length,
            d_model=self.d_model, #hidden_dim*n_heads
           
        )
    
        self.value_embedding = torch.nn.Linear(
            output_features,
            self.d_model
        )
        self.builder_dict = OrderedDict({
            "cross_attention_type":attention_type,
            "self_attention_type":attention_type,
            "n_layers": n_layers,
            "n_heads": n_heads,
            "feed_forward_dimensions": self.d_model*2,
            "query_dimensions": hidden_dim,
            "value_dimensions": hidden_dim,
            "dropout": dropout,
            "softmax_temp": softmax_temp,
            "activation" : activation_fn,
            "attention_dropout": attention_dropout,
        })
        self.transformer = TransformerDecoderBuilder.from_dictionary(
            self.builder_dict,
            strict=True
        ).get()
        hidden_size = n_heads*hidden_dim
        
        self.predictor = torch.nn.Linear(
            hidden_size,
            output_features
        )
    def forward(self, target, memory, len_mask=None):
        
        x = target.permute(0,2,1) # x: [batch_size, sequence_length, input_dim]
        x = self.value_embedding(x) # x: [batch size, sequence_length, n_heads* hiden_size]
        x = self.pos_embedding(x) # x: [batch size, sequence_length, n_heads* hiden_size]
        triangular_mask = fast_transformers.masking.TriangularCausalMask(x.size(1), device=x.device) # triangular_mask: [ sequence_length,  sequence_length]       
        y_hat = self.transformer(x, memory, triangular_mask, len_mask=None) # y_hat: [batch size, sequence_length, n_heads* hiden_size]   
        y_hat = self.predictor(y_hat) # y_hat: [batch size, sequence_length, output_size]
        return y_hat.permute(0,2,1)   # y_hat: [batch size, output_size, sequence_length]x=torch.rand([100,14,64])

I have difficulty to comprehend how I can use LinearTransformerCausalDecoder for computing cross attention. I will appreciate if anyone can clarify it for this example key, query and value ? Thanks.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions