-
Notifications
You must be signed in to change notification settings - Fork 189
Open
Description
Hello,
I have problem understanding how I can use this library to implement cross attention
for instance if tensor x=torch.rand(100,14,64) is key, tensor y=torch.rand(100,11,64) is value and tensorz=torch.rand(100,14,1) is query, how can I use TransformerDecoderBuilder to compute the cross attention for this example?
Here is how I built encoder and decoder class:
import math
import fast_transformers
from fast_transformers.builders import TransformerEncoderBuilder, TransformerDecoderBuilder
from collections import OrderedDict
class PositionalEncoding(nn.Module):
def __init__(self, max_len, d_model, dropout_prob=0.0, series_dimensions=1):
global pe
super().__init__()
self.dropout = nn.Dropout(p=dropout_prob)
self.d_model = d_model
self.max_len = max_len
self.series_dimensions = series_dimensions
if self.series_dimensions == 1:
if d_model % 2 != 0:
raise ValueError("Cannot use sin/cos positional encoding with "
"odd dim (got dim={:d})".format(d_model))
pe = torch.zeros(self.max_len, d_model).float()
pe.require_grad = False
position = torch.arange(0, self.max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
elif self.series_dimensions > 1:
if d_model % 4 != 0:
raise ValueError("Cannot use sin/cos positional encoding with "
"odd dim (got dim={:d})".format(d_model))
height = self.series_dimensions
width = self.max_len
pe = torch.zeros(d_model, height, width).float()
pe.require_grad = False
# Each dimension use half of d_model
d_model = int(d_model / 2)
div_term = torch.exp(torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model))
pos_w = torch.arange(0., width).unsqueeze(1)
pos_h = torch.arange(0., height).unsqueeze(1)
pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
pe = pe.view(2*d_model, height * width, -1).squeeze(-1) # Flattening it back to 1D series
pe = pe.transpose(0, 1)
pe = pe.unsqueeze(0) # Extending it by an extra leading dim for the batches
self.register_buffer('pe', pe)
# Expecting a flattened (1D) series
def forward(self, x):
x = x + self.pe[:, :x.size(1), :]
return self.dropout(x)
class LinearTransformerCausalEncoder(torch.nn.Module):
def __init__(self, input_features, output_features, hidden_dim, sequence_length,
attention_type='causal-linear', n_layers=2, n_heads=4,
dropout=0.1, softmax_temp=None, activation_fn="gelu",
attention_dropout=0.1,
):
super(LinearTransformerCausalEncoder, self).__init__()
#
self.d_model=hidden_dim*n_heads
#
self.pos_embedding = PositionalEncoding(
max_len=sequence_length,
d_model=self.d_model, #hidden_dim*n_heads
)
self.value_embedding = nn.Linear(
input_features,
self.d_model
)
self.builder_dict = OrderedDict({
"attention_type": attention_type,
"n_layers": n_layers,
"n_heads": n_heads,
"feed_forward_dimensions": self.d_model*2,
"query_dimensions": hidden_dim,
"value_dimensions": hidden_dim,
"dropout": dropout,
"softmax_temp": softmax_temp,
"activation" : activation_fn,
"attention_dropout": attention_dropout,
})
self.transformer = TransformerEncoderBuilder.from_dictionary(
self.builder_dict,
strict=True
).get()
hidden_size = n_heads*hidden_dim
##
self.predictor = torch.nn.Linear(
hidden_size,
output_features
)
def forward(self, x):
# x: [batch_size, input_dim, sequence_length]
x = x.permute(0,2,1)
x = self.value_embedding(x) # x: [batch size, sequence_length, n_heads* hiden_size]
x = self.pos_embedding(x) # x: [batch size, sequence_length, n_heads* hiden_size]
triangular_mask = fast_transformers.masking.TriangularCausalMask(x.size(1), device=x.device) # triangular_mask: [ sequence_length, sequence_length]
y_hat = self.transformer(x, attn_mask=triangular_mask) # y_hat: [batch size, sequence_length, n_heads* hiden_size]
y_hat = self.predictor(y_hat) # y_hat: [batch size, sequence_length, output_size]
return y_hat.permute(0,2,1) # y_hat: [batch size, output_size, sequence_length]
class LinearTransformerCausalDecoder(torch.nn.Module):
def __init__(self, output_features, hidden_dim, sequence_length,
attention_type='causal-linear', n_layers=2, n_heads=4,
d_query=32, dropout=0.1, softmax_temp=None,activation_fn="gelu",
attention_dropout=0.1,):
super(LinearTransformerCausalDecoder, self).__init__()
self.d_model=hidden_dim*n_heads
self.pos_embedding = PositionalEncoding(
max_len=sequence_length,
d_model=self.d_model, #hidden_dim*n_heads
)
self.value_embedding = torch.nn.Linear(
output_features,
self.d_model
)
self.builder_dict = OrderedDict({
"cross_attention_type":attention_type,
"self_attention_type":attention_type,
"n_layers": n_layers,
"n_heads": n_heads,
"feed_forward_dimensions": self.d_model*2,
"query_dimensions": hidden_dim,
"value_dimensions": hidden_dim,
"dropout": dropout,
"softmax_temp": softmax_temp,
"activation" : activation_fn,
"attention_dropout": attention_dropout,
})
self.transformer = TransformerDecoderBuilder.from_dictionary(
self.builder_dict,
strict=True
).get()
hidden_size = n_heads*hidden_dim
self.predictor = torch.nn.Linear(
hidden_size,
output_features
)
def forward(self, target, memory, len_mask=None):
x = target.permute(0,2,1) # x: [batch_size, sequence_length, input_dim]
x = self.value_embedding(x) # x: [batch size, sequence_length, n_heads* hiden_size]
x = self.pos_embedding(x) # x: [batch size, sequence_length, n_heads* hiden_size]
triangular_mask = fast_transformers.masking.TriangularCausalMask(x.size(1), device=x.device) # triangular_mask: [ sequence_length, sequence_length]
y_hat = self.transformer(x, memory, triangular_mask, len_mask=None) # y_hat: [batch size, sequence_length, n_heads* hiden_size]
y_hat = self.predictor(y_hat) # y_hat: [batch size, sequence_length, output_size]
return y_hat.permute(0,2,1) # y_hat: [batch size, output_size, sequence_length]x=torch.rand([100,14,64])
I have difficulty to comprehend how I can use LinearTransformerCausalDecoder for computing cross attention. I will appreciate if anyone can clarify it for this example key, query and value ? Thanks.
Metadata
Metadata
Assignees
Labels
No labels