thanks for sharing the triton fast kernel. I have a question regarding existing implementation which seems subject to the head_dim to be multiple of 16/32.
could you please share how to extend to head_dim to non multiple of 32, like 30. There are quite limitation on triton for block pointer and arange api blocking to do so. And zero padding to Q K V seems needed to bypass. however this will lead to slowdown.
would appreciate if you can help share efficient bypass solution or provide support.
thanks for sharing the triton fast kernel. I have a question regarding existing implementation which seems subject to the head_dim to be multiple of 16/32.
could you please share how to extend to head_dim to non multiple of 32, like 30. There are quite limitation on triton for block pointer and arange api blocking to do so. And zero padding to Q K V seems needed to bypass. however this will lead to slowdown.
would appreciate if you can help share efficient bypass solution or provide support.