I have replaced my attention with FlashWindowAttention, but now the training does not converge. I am curious if I did something wrong? My code looks like the following:
batch_size, num_heads, num_windows, seq_len, head_dim = query.shape
# Reshape to mee dimensions
query = query.permute(0, 2, 1, 3, 4).reshape(batch_size * num_windows,
num_heads, seq_len, head_dim)
key = key.permute(0, 2, 1, 3, 4).reshape(batch_size * num_windows,
num_heads, seq_len, head_dim)
value = value.permute(0, 2, 1, 3, 4).reshape(batch_size * num_windows,
num_heads, seq_len, head_dim)
# batch, head, window_size, head_dim
o = flash_swin_attn_func(query, key, value, bias, scale_qk).reshape(batch_size,
num_windows, num_heads, seq_len, head_dim).permute(0, 2, 1, 3, 4)
scores = torch.matmul(q, k.transpose(-2, -1)) / (d_k ** 0.5)
attn_weights = F.softmax(scores, dim=-1)
return torch.matmul(attn_weights, v)
While it is faster, the results are not matching the matmul results at all? I am working with BFloat16.
Hi!
I have replaced my attention with FlashWindowAttention, but now the training does not converge. I am curious if I did something wrong? My code looks like the following:
While previously I have just:
While it is faster, the results are not matching the matmul results at all? I am working with BFloat16.