Skip to content

Add torch scaled dot product attention (FlashAttention)#1798

Draft
dorian-K wants to merge 38 commits intomasterfrom
doriank-sdpa
Draft

Add torch scaled dot product attention (FlashAttention)#1798
dorian-K wants to merge 38 commits intomasterfrom
doriank-sdpa

Conversation

@dorian-K
Copy link
Contributor

No description provided.

@dorian-K

This comment was marked as resolved.

This comment was marked as outdated.

@dorian-K
Copy link
Contributor Author

dorian-K commented Dec 19, 2025

Todo:

  • tracers that use energy and att_weights are broken now, and there isn't a straightforward way to fix because even if we automatically fall back to the returnn impl once we detect a tracer, that implementation is in the backend class and not in rf.dot_attention so current implementations don't find those variables. As far as I can tell this is only used in attention weight analyses and a test in test_rf_attention.py, so not huge impact but still annoying. internal flag to fall back to the vanilla implementation, and duplicate some of the tests in test_rf_attention to test both the new an the vanilla implementation, and att weights / energy will only be tested for vanilla impl
  • write some tests to verify that torch scaled_dot_product_attention produces the same result as the returnn fallback impl
  • convert existing building blocks in attention to use efficient is_causal=True parameter
  • maybe rename scaled_dot_product_attention to just dot_attention api incompatible
  • there have been significant changes to torch scaled_dot_product_attention since version 2.0.0 till now, verify that all are compatible to the current impl
  • profile whether it is really faster, and whether we should always use it, or when to use it

@albertz
Copy link
Member

albertz commented Dec 19, 2025

tracers that use energy and att_weights are broken now, and there isn't a straightforward way to fix because even if we automatically fall back to the returnn impl once we detect a tracer, that implementation is in the backend class and not in rf.dot_attention so current implementations don't find those variables. As far as I can tell this is only used in attention weight analyses and a test in test_rf_attention.py, so not huge impact but still annoying

I think it is to be expected that using such tracers can never be reliable and stable. We don't guarantee that, and we don't need to guarantee that. So this is not really an issue.

We should still keep the existing tests. For that, there should be a flag (maybe only internal flag) to disable this and fall back to the current vanilla implementation.

@dorian-K
Copy link
Contributor Author

I remember, earlier you told me, the Torch SDPA was actually slower than our implementation. Was that incorrect? Or did you do sth wrong? What did you wrong?

This was without is_causal=True, and with torch amp. So I have to re-test this now

Do you know which kernel was actually used by Torch SDPA? Does it use Flash-Attention?

It appears that it doesn't use FlashAttention/Cudnn here (that needs dtype to be 16-bit or lower), but rather the memory efficient kernel
https://pytorch.org/blog/accelerated-pytorch-2/
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp

@dorian-K
Copy link
Contributor Author

dorian-K commented Jan 23, 2026

CUDA, with torch amp bfloat16, using flash attention:

----------------------------------------------------------------------------------------------- benchmark 'flashatt_causal_self_attention(cuda, seq_len=1024)': 2 tests ------------------------------------------------------------------------------------------------
Name (time in us)                                                                              Min                   Max                  Mean             StdDev                Median                IQR            Outliers         OPS            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_benchmark_flashatt_causal_self_attention[seq_len=1024-sdpa='torch_sdpa']             289.2460 (1.0)      1,289.8261 (1.0)        304.7318 (1.0)      17.9428 (1.0)        303.8545 (1.0)       8.3747 (1.0)         14;31  3,281.5742 (1.0)        3418           1
test_benchmark_flashatt_causal_self_attention[seq_len=1024-sdpa='returnn_fallback']     1,612.8628 (5.58)     2,196.4619 (1.70)     1,633.6496 (5.36)     24.3056 (1.35)     1,632.0341 (5.37)     13.0077 (1.55)          2;2    612.1264 (0.19)        625           1
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

--------------------------------------------------------------------------------------------- benchmark 'flashatt_causal_self_attention(cuda, seq_len=32)': 2 tests ----------------------------------------------------------------------------------------------
Name (time in us)                                                                          Min                   Max                Mean             StdDev              Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_benchmark_flashatt_causal_self_attention[seq_len=32-sdpa='torch_sdpa']           293.3440 (1.0)        346.5691 (1.0)      306.5958 (1.0)       6.0141 (1.0)      306.0140 (1.0)       8.5917 (1.0)       1123;18        3.2616 (1.0)        3395           1
test_benchmark_flashatt_causal_self_attention[seq_len=32-sdpa='returnn_fallback']     727.2561 (2.48)     1,293.1759 (3.73)     767.7142 (2.50)     22.9645 (3.82)     766.9624 (2.51)     21.9559 (2.56)        172;5        1.3026 (0.40)       1388           1
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

------------------------------------------------------------------------------------------------ benchmark 'flashatt_self_attention(cuda, seq_len=1024)': 2 tests -----------------------------------------------------------------------------------------------
Name (time in us)                                                                       Min                   Max                  Mean             StdDev                Median                IQR            Outliers         OPS            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_benchmark_flashatt_self_attention[seq_len=1024-sdpa='torch_sdpa']             313.6480 (1.0)        463.1060 (1.0)        329.7658 (1.0)       7.7020 (1.0)        328.9190 (1.0)      10.5654 (1.21)       991;18  3,032.4549 (1.0)        3204           1
test_benchmark_flashatt_self_attention[seq_len=1024-sdpa='returnn_fallback']     1,221.0950 (3.89)     1,458.7201 (3.15)     1,233.6838 (3.74)     10.2281 (1.33)     1,232.5840 (3.75)      8.7003 (1.0)         64;10    810.5804 (0.27)        819           1
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

---------------------------------------------------------------------------------------------- benchmark 'flashatt_self_attention(cuda, seq_len=32)': 2 tests ---------------------------------------------------------------------------------------------
Name (time in us)                                                                   Min                   Max                Mean             StdDev              Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_benchmark_flashatt_self_attention[seq_len=32-sdpa='torch_sdpa']           311.6240 (1.0)        448.6521 (1.0)      328.9464 (1.0)       7.3794 (1.0)      328.2240 (1.0)      10.0214 (1.0)       1007;12        3.0400 (1.0)        3196           1
test_benchmark_flashatt_self_attention[seq_len=32-sdpa='returnn_fallback']     455.9960 (1.46)     1,724.7652 (3.84)     480.2450 (1.46)     28.2483 (3.83)     479.0064 (1.46)     13.5989 (1.36)         6;10        2.0823 (0.68)       2218           1
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

So here we see a speedup of 5x over the naive returnn implementation in the most extreme case.
Also the numbers for torch_sdpa seem to be the same regardless of sequence length, so maybe the speedup is even bigger but hidden by some overhead.

@dorian-K
Copy link
Contributor Author

@dorian-K
Copy link
Contributor Author

The old attention code still allows att_dropout_broadcast=True, pytorch doesn't support this.
I would suggest removing this functionality from returnn altogether, instead of silently falling back to the naive returnn impl or failing whenever the pytorch backend is used.
What do you think? @albertz

@albertz
Copy link
Member

albertz commented Jan 23, 2026

The old attention code still allows att_dropout_broadcast=True, pytorch doesn't support this. I would suggest removing this functionality from returnn altogether, instead of silently falling back to the naive returnn impl or failing whenever the pytorch backend is used. What do you think? @albertz

But we anyway still keep the naive implementation around, right? We should. For reference. For testing. Also it should be possible to somehow enable it (maybe similar to how this is controlled in PyTorch itself).

And if we anyway have the naive implementation available, isn't it so complicated to have this extra check for att_dropout_broadcast?

Maybe there are anyway also other situations where it make sense to fallback to it?

I think I would defer the decision on whether to remove att_dropout_broadcast to another separate issue / PR. Yes, I also don't like it too much. But let's discuss the pro/contra arguments separately, and keep it here.

@dorian-K
Copy link
Contributor Author

I added back att_dropout_broadcast.
I am still waiting for the training jobs to be scheduled on the rwth cluster to get a better idea of the speedup, but I believe the PR itself is feature complete

@dorian-K dorian-K marked this pull request as ready for review January 23, 2026 14:19
@dorian-K dorian-K requested review from a team, NeoLegends and albertz as code owners January 23, 2026 14:19
@albertz
Copy link
Member

albertz commented Jan 23, 2026

CUDA, with torch amp bfloat16, using flash attention: ...
So here we see a speedup of 5x over the naive returnn implementation in the most extreme case.

I don't understand. I thought this was the case earlier where the naive RETURNN implementation was faster? Or not? Or what exactly was the case where the naive impl was faster?

And do we correctly understand now, in case of AMP with bfloat16, what parts are computed in bfloat16, and what parts stay in float32? E.g. I was not sure about the softmax, maybe also the matmul of attention weights with values.

You earlier told me there were some numerical differences, not so small. How big were they? Is this still the case?

@dorian-K
Copy link
Contributor Author

Or what exactly was the case where the naive impl was faster?

When I started this PR, I started a training with this new torch sdpa impl by accident, and observed it was slower than on the master branch. But that was without is_causal=True, and there were some other bugs that I have since fixed.
I am still waiting on my rwth cluster jobs to be scheduled to see if this is now faster or still slower. But I assume it will be faster based on the results from my previous comments.
In the benchmarks from my previous comments the torch sdpa was always at least as fast as the returnn fallback.

And do we correctly understand now, in case of AMP with bfloat16, what parts are computed in bfloat16, and what parts stay in float32? E.g. I was not sure about the softmax, maybe also the matmul of attention weights with values.

I do not know, I would need to do some additional research. I do know that the FlashAttention and Cudnn kernels are only used if the input dtype is bf16 or similar, and torch falls back to the slower "memory efficient" kernel if float32 is passed as input. The float32 results are seen in my first benchmark comment. But I don't know what happens in the kernels.

And do we correctly understand now, in case of AMP with bfloat16, what parts are computed in bfloat16, and what parts stay in float32? E.g. I was not sure about the softmax, maybe also the matmul of attention weights with values.

There was a test failing where single vs batched results were different, but this turned out to be a bug which I have since fixed.
I assume there are still numerical differences, but they are smaller than 1e-5 (as verified by the tests).
I can look into this further, but I don't think it's an issue

@dorian-K
Copy link
Contributor Author

dorian-K commented Jan 24, 2026

Hmm there seems to be another issue, in the code I assert that the query and key tensors share all dimensions (except their spatial dims), but in crossattention during beam search this isn't actually the case.
I did this because the pytorch documentation specifies that the query spatial dim needs to be at a specific index so I search for it by assuming that the spatial dim is the only non-batch dim in the query tensor, but that doesnt work if there are multiple.
But I believe this is only important if an attention mask is used or is_causal=True, so I can relax this requirement.
So not ready for merge yet

@dorian-K dorian-K marked this pull request as draft January 24, 2026 08:52
@albertz
Copy link
Member

albertz commented Jan 25, 2026

Hmm there seems to be another issue, in the code I assert that the query and key tensors share all dimensions (except their spatial dims), but in crossattention during beam search this isn't actually the case.

Make sure there is also a test case which covers this.

@dorian-K
Copy link
Contributor Author

dorian-K commented Jan 30, 2026

I ran DLM training with both the master branch and this pullrequest, and it looks very similar:

torch sdpa (this pull request):
1 epoch: 44:05 (03:55 of which is dataset load at startup)
Epoch 1: Trained 4589 steps, 0:44:05 elapsed (82.9% computing time, hyps: 46.4%, real: 46.9% padding)
Memory usage (cuda): alloc cur 7.1GB alloc peak 45.6GB reserved cur 66.3GB reserved peak 66.3GB

master branch:
1 epoch: 41:25 (03:00 of which is dataset load at startup)
Epoch 1: Trained 4589 steps, 0:41:25 elapsed (91.6% computing time, hyps: 46.4%, real: 46.9% padding)
Memory usage (cuda): alloc cur 7.1GB alloc peak 47.4GB reserved cur 67.6GB reserved peak 67.6GB

So it appears that master branch is still faster. But for torch sdpa, computing time % is lower.

Torch profile shows that forward+backward is ~50ms faster (of ~500ms total) than master branch, but in the torch sdpa profile a dataloader next() call shows up that is between 150ms-250ms in every step, but for the master branch this call is so fast that it is barely visible in the profile.

I don't know whats going on here, maybe there is an issue with my profiler (or my interpretation of it), maybe the forward/backward passes are faster than the dataloader and now there is some polling timeout going on, or something else entirely.

Anyways even without this issue, it would only be 10% faster so I am not really sure if it is worth it for me to spend much more time on refining this pull request, and I would rather leave it stale for now. Maybe it will be more useful if we are training models with longer sequences in the future.

The only issue left to fix is the one mentioned in the comments above this one, but I don't see the point of even merging this if there is barely any improvement in training time (or a worse one, like in my case)

@albertz
Copy link
Member

albertz commented Jan 30, 2026

I think those numbers don't really tell much. The dataset processing time should be totally independent from the SDPA implementation. But the dataset processing time or computing time will very much influence the total time. If the dataset processing time in one case is larger than in the other, that could be e.g. because of dataset file caching or just totally random.

Ignore the first epoch. Compare second, third or so epoch. And also only if their dataset processing time or computing time is similar.

For Torch profile, also skip the first few steps, for similar reasons.

in the torch sdpa profile a dataloader next() call shows up that is between 150ms-250ms in every step, but for the master branch this call is so fast that it is barely visible in the profile.

I don't really understand that. The dataloader next() call is not at all influenced by any change here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants