The input tensor for F.scaled_dot_product_attention has the wrong dimension ordering. It should be (B, H, L, E), but the code is currently passing (B, L, H, E).#217
Conversation
There was a problem hiding this comment.
Pull request overview
This PR fixes a critical bug in the FlashMultiheadAttention module where tensors were being passed to F.scaled_dot_product_attention in the wrong dimension order. The PyTorch API expects (B, H, L, E) format, but the original code was passing (B, L, H, E).
Key Changes:
- Added
.transpose(1, 2)operations to convert Q, K, V tensors from (B, L, H, Dh) to (B, H, L, Dh) before attention - Added
.permute(0, 2, 1, 3)to convert attention output back from (B, H, L, Dh) to (B, L, H, Dh) - Updated tensor reshaping to use explicit B, L variables instead of dynamic
.size()calls
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| K = K.view(B, L, self.num_heads, self.head_dim).transpose(1, 2) | ||
| V = V.view(B, L, self.num_heads, self.head_dim).transpose(1, 2) | ||
|
|
||
| attn_output = F.scaled_dot_product_attention(Q, K, V).permute(0, 2, 1, 3) |
There was a problem hiding this comment.
After permute(), the tensor may not be contiguous in memory, which can cause .reshape() to fail or produce incorrect results. Consider using .contiguous() before .reshape(), or use .view() with .contiguous().
Suggested fix:
attn_output = F.scaled_dot_product_attention(Q, K, V).permute(0, 2, 1, 3).contiguous()
attn_output = attn_output.reshape(B, L, -1)This pattern is used elsewhere in the codebase (see dinov2/dinov2/layers/attention.py:77).
| attn_output = F.scaled_dot_product_attention(Q, K, V).permute(0, 2, 1, 3) | |
| attn_output = F.scaled_dot_product_attention(Q, K, V).permute(0, 2, 1, 3).contiguous() |
| Q = Q.view(B, L, self.num_heads, self.head_dim).transpose(1, 2) | ||
| K = K.view(B, L, self.num_heads, self.head_dim).transpose(1, 2) | ||
| V = V.view(B, L, self.num_heads, self.head_dim).transpose(1, 2) | ||
|
|
There was a problem hiding this comment.
[nitpick] Remove trailing whitespace at the end of this line for consistency with code style best practices.
According to https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html. The input tensor for F.scaled_dot_product_attention has the wrong dimension ordering. It should be (B, H, L, E), but the code is currently passing (B, L, H, E).