Add torch scaled dot product attention (FlashAttention)#1798
Add torch scaled dot product attention (FlashAttention)#1798
Conversation
1893a2d to
9a2cae2
Compare
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as outdated.
This comment was marked as outdated.
|
Todo:
|
I think it is to be expected that using such tracers can never be reliable and stable. We don't guarantee that, and we don't need to guarantee that. So this is not really an issue. We should still keep the existing tests. For that, there should be a flag (maybe only internal flag) to disable this and fall back to the current vanilla implementation. |
This was without
It appears that it doesn't use FlashAttention/Cudnn here (that needs dtype to be 16-bit or lower), but rather the memory efficient kernel |
|
CUDA, with torch amp bfloat16, using flash attention: So here we see a speedup of 5x over the naive returnn implementation in the most extreme case. |
|
The old attention code still allows |
But we anyway still keep the naive implementation around, right? We should. For reference. For testing. Also it should be possible to somehow enable it (maybe similar to how this is controlled in PyTorch itself). And if we anyway have the naive implementation available, isn't it so complicated to have this extra check for Maybe there are anyway also other situations where it make sense to fallback to it? I think I would defer the decision on whether to remove |
|
I added back |
I don't understand. I thought this was the case earlier where the naive RETURNN implementation was faster? Or not? Or what exactly was the case where the naive impl was faster? And do we correctly understand now, in case of AMP with bfloat16, what parts are computed in bfloat16, and what parts stay in float32? E.g. I was not sure about the softmax, maybe also the matmul of attention weights with values. You earlier told me there were some numerical differences, not so small. How big were they? Is this still the case? |
When I started this PR, I started a training with this new torch sdpa impl by accident, and observed it was slower than on the master branch. But that was without
I do not know, I would need to do some additional research. I do know that the FlashAttention and Cudnn kernels are only used if the input dtype is bf16 or similar, and torch falls back to the slower "memory efficient" kernel if float32 is passed as input. The float32 results are seen in my first benchmark comment. But I don't know what happens in the kernels.
There was a test failing where single vs batched results were different, but this turned out to be a bug which I have since fixed. |
|
Hmm there seems to be another issue, in the code I assert that the query and key tensors share all dimensions (except their spatial dims), but in crossattention during beam search this isn't actually the case. |
Make sure there is also a test case which covers this. |
|
I ran DLM training with both the master branch and this pullrequest, and it looks very similar: So it appears that master branch is still faster. But for torch sdpa, computing time % is lower. Torch profile shows that forward+backward is ~50ms faster (of ~500ms total) than master branch, but in the torch sdpa profile a dataloader I don't know whats going on here, maybe there is an issue with my profiler (or my interpretation of it), maybe the forward/backward passes are faster than the dataloader and now there is some polling timeout going on, or something else entirely. Anyways even without this issue, it would only be 10% faster so I am not really sure if it is worth it for me to spend much more time on refining this pull request, and I would rather leave it stale for now. Maybe it will be more useful if we are training models with longer sequences in the future. The only issue left to fix is the one mentioned in the comments above this one, but I don't see the point of even merging this if there is barely any improvement in training time (or a worse one, like in my case) |
|
I think those numbers don't really tell much. The dataset processing time should be totally independent from the SDPA implementation. But the dataset processing time or computing time will very much influence the total time. If the dataset processing time in one case is larger than in the other, that could be e.g. because of dataset file caching or just totally random. Ignore the first epoch. Compare second, third or so epoch. And also only if their dataset processing time or computing time is similar. For Torch profile, also skip the first few steps, for similar reasons.
I don't really understand that. The dataloader |
No description provided.