diff --git a/src/pruna/algorithms/flash_attn3.py b/src/pruna/algorithms/flash_attn3.py index 6c8839d6..459fbc16 100644 --- a/src/pruna/algorithms/flash_attn3.py +++ b/src/pruna/algorithms/flash_attn3.py @@ -15,7 +15,7 @@ import functools from collections.abc import Iterable -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional import torch from aenum import extend_enum @@ -127,7 +127,7 @@ def import_algorithm_packages(self) -> Dict[str, Any]: Dict[str, Any] The algorithm packages. """ - flash_attention_3 = get_kernel("kernels-community/flash-attn3", version="<0.1.0") + flash_attention_3 = get_kernel("kernels-community/flash-attn3") packages = {"flash_attention_3": flash_attention_3} if Version(diffusers_version) >= Version("0.35.0.dev0"): @@ -221,7 +221,7 @@ def _flash_attention_3( enable_gqa=enable_gqa, ) else: - out, _, *_ = torch.ops.flash_attn_pruna._flash_attn_forward( + out = torch.ops.flash_attn_pruna._flash_attn_forward( q=query, # type: ignore k=key, # type: ignore v=value, # type: ignore @@ -290,7 +290,7 @@ def __torch_function__(self, func, types, args=(), kwargs=None): # noqa: D105 def _flash_attention3(query, key, value, *, is_causal=False, softmax_scale=None, kernel=None): # convert (B, H, S, D) → (B, S, H, D) q, k, v = [x.transpose(1, 2).contiguous() for x in (query, key, value)] - out, _ = torch.ops.flash_attn_pruna._flash_attn_forward(q, k, v, causal=is_causal, softmax_scale=softmax_scale) # type: ignore + out = torch.ops.flash_attn_pruna._flash_attn_forward(q, k, v, causal=is_causal, softmax_scale=softmax_scale) # type: ignore # back to (B, H, S, D) for the rest of the pipeline return out.transpose(1, 2) @@ -336,9 +336,12 @@ def _flash_attn_forward( v: torch.Tensor, softmax_scale: float | None = None, causal: bool = False, - ) -> Tuple[torch.Tensor, torch.Tensor]: - out, lse = flash_attn_cuda(q, k, v, softmax_scale=softmax_scale or None, causal=causal, deterministic=False) - return out, lse.permute(0, 2, 1) # (B,H,S) → (B,S,H) + ) -> torch.Tensor: + result = flash_attn_cuda(q, k, v, softmax_scale=softmax_scale or None, causal=causal, deterministic=False) + # Some kernel builds return (out, lse), others return just out, depending on torch and cuda version + if isinstance(result, tuple): + return result[0] + return result @torch.library.register_fake("flash_attn_pruna::_flash_attn_forward") def _flash_attn_forward_fake( @@ -347,6 +350,5 @@ def _flash_attn_forward_fake( v: torch.Tensor, softmax_scale: float | None = None, causal: bool = False, - ) -> Tuple[torch.Tensor, torch.Tensor]: - b, s, h, _ = q.shape - return torch.empty_like(q), q.new_empty((b, s, h)) + ) -> torch.Tensor: + return torch.empty_like(q) diff --git a/src/pruna/algorithms/sage_attn.py b/src/pruna/algorithms/sage_attn.py index 5fa68f83..05d94863 100644 --- a/src/pruna/algorithms/sage_attn.py +++ b/src/pruna/algorithms/sage_attn.py @@ -19,6 +19,7 @@ import torch from diffusers import DiffusionPipeline +from diffusers.models.attention_dispatch import AttentionBackendName, _maybe_download_kernel_for_backend from typing_extensions import cast from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase @@ -95,6 +96,10 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: target_modules = self.get_model_dependent_hyperparameter_defaults(model, smash_config)["target_modules"] target_modules = cast(TARGET_MODULES_TYPE, target_modules) + # Diffusers has two set_attention_backend methods, one for the whole model and one for the submodules + # The submodule-level method does not trigger the download therefore we need to pre-load the kernel once + _maybe_download_kernel_for_backend(AttentionBackendName.SAGE_HUB) + def apply_sage_attn( root_name: str | None, root_nn_module: torch.nn.Module,