Skip to content

The input tensor for F.scaled_dot_product_attention has the wrong dimension ordering. It should be (B, H, L, E), but the code is currently passing (B, L, H, E).#217

Open
JIANG-CX wants to merge 1 commit intoNVlabs:masterfrom
JIANG-CX:jcx
Open

The input tensor for F.scaled_dot_product_attention has the wrong dimension ordering. It should be (B, H, L, E), but the code is currently passing (B, L, H, E).#217
JIANG-CX wants to merge 1 commit intoNVlabs:masterfrom
JIANG-CX:jcx

Conversation

@JIANG-CX
Copy link

@JIANG-CX JIANG-CX commented Dec 3, 2025

According to https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html. The input tensor for F.scaled_dot_product_attention has the wrong dimension ordering. It should be (B, H, L, E), but the code is currently passing (B, L, H, E).

Copilot AI review requested due to automatic review settings December 3, 2025 11:23
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes a critical bug in the FlashMultiheadAttention module where tensors were being passed to F.scaled_dot_product_attention in the wrong dimension order. The PyTorch API expects (B, H, L, E) format, but the original code was passing (B, L, H, E).

Key Changes:

  • Added .transpose(1, 2) operations to convert Q, K, V tensors from (B, L, H, Dh) to (B, H, L, Dh) before attention
  • Added .permute(0, 2, 1, 3) to convert attention output back from (B, H, L, Dh) to (B, L, H, Dh)
  • Updated tensor reshaping to use explicit B, L variables instead of dynamic .size() calls

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

K = K.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)

attn_output = F.scaled_dot_product_attention(Q, K, V).permute(0, 2, 1, 3)
Copy link

Copilot AI Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After permute(), the tensor may not be contiguous in memory, which can cause .reshape() to fail or produce incorrect results. Consider using .contiguous() before .reshape(), or use .view() with .contiguous().

Suggested fix:

attn_output = F.scaled_dot_product_attention(Q, K, V).permute(0, 2, 1, 3).contiguous()

attn_output = attn_output.reshape(B, L, -1)

This pattern is used elsewhere in the codebase (see dinov2/dinov2/layers/attention.py:77).

Suggested change
attn_output = F.scaled_dot_product_attention(Q, K, V).permute(0, 2, 1, 3)
attn_output = F.scaled_dot_product_attention(Q, K, V).permute(0, 2, 1, 3).contiguous()

Copilot uses AI. Check for mistakes.
Q = Q.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(B, L, self.num_heads, self.head_dim).transpose(1, 2)

Copy link

Copilot AI Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Remove trailing whitespace at the end of this line for consistency with code style best practices.

Suggested change

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants