Hi, thanks for releasing MagiAttention -- it’s a very nice!
I have a question about how to make sure the attention slices have a stable shape to avoid torch.compile recompilations when the attention pattern is dynamic per batch:
- Different examples have different numbers of "attention slices" (flex ranges).
- These are stored in
q_ranges, k_ranges, and attn_type_map tensors as documented.
- The total number of tokens (shapes of
qkv) are fixed, but the number of ranges may vary per batch.
WIth torch.compile, changing the shape of q_ranges etc. will generally cause a new graph to be compiled. To avoid excessive recompilations, I'd like to keep the shapes of these tensors fixed and only vary their values.
This naturally leads to the idea of "dummy" or "padded" ranges to signal to FFA that those ranges are unused.
Roughly I'd like to do something like
R_max = 16 # Max number of flex slices per batch.
def make_padded_ranges(raw_q_ranges, raw_k_ranges, raw_types, total_tokens):
num_real = raw_q_ranges.shape[0]
q_ranges = torch.zeros(
(R_max, 2),
dtype=raw_q_ranges.dtype,
device=raw_q_ranges.device,
)
k_ranges = torch.zeros(
(R_max, 2),
dtype=raw_k_ranges.dtype,
device=raw_k_ranges.device,
)
attn_type_map = torch.zeros(
(R_max,),
dtype=raw_types.dtype,
device=raw_types.device,
)
# Fill the first `num_real` rows with “real” slices.
q_ranges[:num_real] = raw_q_ranges
k_ranges[:num_real] = raw_k_ranges
attn_type_map[:num_real] = raw_types
# The remaining rows are “dummy” slices.
# Question: what is the safe way to define these?
return q_ranges, k_ranges, attn_type_map
then in a compiled function call
@torch.compile
def attn_step(q, k, v, q_ranges, k_ranges, attn_type_map):
return flex_flash_attn_func(
q,
k,
v,
q_ranges=q_ranges,
k_ranges=k_ranges,
attn_type_map=attn_type_map,
# other args kept static,
)
From torch.compile’s perspective, the argument shapes are now static, so it won’t recompile just because a batch has more or fewer ranges.
Question:
Is there any notion of “dummy” / no-op ranges for flex_flash_attn_func? For example zero-length ranges (start = end) or padded ranges that are guaranteed to be ignored.
Hi, thanks for releasing MagiAttention -- it’s a very nice!
I have a question about how to make sure the attention slices have a stable shape to avoid
torch.compilerecompilations when the attention pattern is dynamic per batch:q_ranges,k_ranges, andattn_type_maptensors as documented.qkv) are fixed, but the number of ranges may vary per batch.WIth
torch.compile, changing the shape ofq_rangesetc. will generally cause a new graph to be compiled. To avoid excessive recompilations, I'd like to keep the shapes of these tensors fixed and only vary their values.This naturally leads to the idea of "dummy" or "padded" ranges to signal to FFA that those ranges are unused.
Roughly I'd like to do something like
then in a compiled function call
From torch.compile’s perspective, the argument shapes are now static, so it won’t recompile just because a batch has more or fewer ranges.
Question:
Is there any notion of “dummy” / no-op ranges for
flex_flash_attn_func? For example zero-length ranges (start = end) or padded ranges that are guaranteed to be ignored.