-
Notifications
You must be signed in to change notification settings - Fork 88
fix: pre-download sage_attention kernel before applying backend, remove pinned fa3 kernel version #578
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
fix: pre-download sage_attention kernel before applying backend, remove pinned fa3 kernel version #578
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,7 +15,7 @@ | |
|
|
||
| import functools | ||
| from collections.abc import Iterable | ||
| from typing import Any, Dict, Optional, Tuple | ||
| from typing import Any, Dict, Optional | ||
|
|
||
| import torch | ||
| from aenum import extend_enum | ||
|
|
@@ -127,7 +127,7 @@ def import_algorithm_packages(self) -> Dict[str, Any]: | |
| Dict[str, Any] | ||
| The algorithm packages. | ||
| """ | ||
| flash_attention_3 = get_kernel("kernels-community/flash-attn3", version="<0.1.0") | ||
| flash_attention_3 = get_kernel("kernels-community/flash-attn3") | ||
| packages = {"flash_attention_3": flash_attention_3} | ||
|
|
||
| if Version(diffusers_version) >= Version("0.35.0.dev0"): | ||
|
|
@@ -221,7 +221,7 @@ def _flash_attention_3( | |
| enable_gqa=enable_gqa, | ||
| ) | ||
| else: | ||
| out, _, *_ = torch.ops.flash_attn_pruna._flash_attn_forward( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we might need to keep this flexibile depending on the fa3 version we encounter - can we return and check whether the output is a tuple or a tensor and handle it accordingly?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nevermind i get what you did now, this is great! |
||
| out = torch.ops.flash_attn_pruna._flash_attn_forward( | ||
| q=query, # type: ignore | ||
| k=key, # type: ignore | ||
| v=value, # type: ignore | ||
|
|
@@ -290,7 +290,7 @@ def __torch_function__(self, func, types, args=(), kwargs=None): # noqa: D105 | |
| def _flash_attention3(query, key, value, *, is_causal=False, softmax_scale=None, kernel=None): | ||
| # convert (B, H, S, D) → (B, S, H, D) | ||
| q, k, v = [x.transpose(1, 2).contiguous() for x in (query, key, value)] | ||
| out, _ = torch.ops.flash_attn_pruna._flash_attn_forward(q, k, v, causal=is_causal, softmax_scale=softmax_scale) # type: ignore | ||
| out = torch.ops.flash_attn_pruna._flash_attn_forward(q, k, v, causal=is_causal, softmax_scale=softmax_scale) # type: ignore | ||
| # back to (B, H, S, D) for the rest of the pipeline | ||
| return out.transpose(1, 2) | ||
|
|
||
|
|
@@ -336,9 +336,12 @@ def _flash_attn_forward( | |
| v: torch.Tensor, | ||
| softmax_scale: float | None = None, | ||
| causal: bool = False, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| out, lse = flash_attn_cuda(q, k, v, softmax_scale=softmax_scale or None, causal=causal, deterministic=False) | ||
| return out, lse.permute(0, 2, 1) # (B,H,S) → (B,S,H) | ||
| ) -> torch.Tensor: | ||
| result = flash_attn_cuda(q, k, v, softmax_scale=softmax_scale or None, causal=causal, deterministic=False) | ||
| # Some kernel builds return (out, lse), others return just out, depending on torch and cuda version | ||
| if isinstance(result, tuple): | ||
| return result[0] | ||
| return result | ||
|
|
||
| @torch.library.register_fake("flash_attn_pruna::_flash_attn_forward") | ||
| def _flash_attn_forward_fake( | ||
|
|
@@ -347,6 +350,5 @@ def _flash_attn_forward_fake( | |
| v: torch.Tensor, | ||
| softmax_scale: float | None = None, | ||
| causal: bool = False, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| b, s, h, _ = q.shape | ||
| return torch.empty_like(q), q.new_empty((b, s, h)) | ||
| ) -> torch.Tensor: | ||
| return torch.empty_like(q) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know that this was an important fix at some point, so not sure about removing it. Please wait for @begumcig 's review on this, she tackled this back then