Skip to content

Conversation

@Ximingwang-09
Copy link
Contributor

Motivation

This PR fixes an issue in the DFlash flex_attention implementation where BlockMask was incorrectly created with H=num_head.
Since the mask function dflash_mask_fn(b, h, q_idx, kv_idx) ignores the h parameter, creating separate masks per head is redundant. Using H=1 allows PyTorch to automatically broadcast the mask across all heads

Modifications

Changes

  1. Simplified BlockMask creation: Changed H=num_heads to H=1 in
    create_block_mask() call
  2. Removed unnecessary cache key: Removed _cached_num_heads from cache invalidation logic
    Cleaned up function signature: Removed num_heads parameter from _get_or_create_block_mask()

Related Issues

#452

Accuracy Test

Benchmark & Profiling

Checklist

@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@xiaomin-D
Copy link
Contributor

cc @FrankLeeeee @sleepcoo

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants