Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions src/pruna/algorithms/flash_attn3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import functools
from collections.abc import Iterable
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional

import torch
from aenum import extend_enum
Expand Down Expand Up @@ -127,7 +127,7 @@ def import_algorithm_packages(self) -> Dict[str, Any]:
Dict[str, Any]
The algorithm packages.
"""
flash_attention_3 = get_kernel("kernels-community/flash-attn3", version="<0.1.0")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know that this was an important fix at some point, so not sure about removing it. Please wait for @begumcig 's review on this, she tackled this back then

flash_attention_3 = get_kernel("kernels-community/flash-attn3")
packages = {"flash_attention_3": flash_attention_3}

if Version(diffusers_version) >= Version("0.35.0.dev0"):
Expand Down Expand Up @@ -221,7 +221,7 @@ def _flash_attention_3(
enable_gqa=enable_gqa,
)
else:
out, _, *_ = torch.ops.flash_attn_pruna._flash_attn_forward(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might need to keep this flexibile depending on the fa3 version we encounter - can we return and check whether the output is a tuple or a tensor and handle it accordingly?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nevermind i get what you did now, this is great!

out = torch.ops.flash_attn_pruna._flash_attn_forward(
q=query, # type: ignore
k=key, # type: ignore
v=value, # type: ignore
Expand Down Expand Up @@ -290,7 +290,7 @@ def __torch_function__(self, func, types, args=(), kwargs=None): # noqa: D105
def _flash_attention3(query, key, value, *, is_causal=False, softmax_scale=None, kernel=None):
# convert (B, H, S, D) → (B, S, H, D)
q, k, v = [x.transpose(1, 2).contiguous() for x in (query, key, value)]
out, _ = torch.ops.flash_attn_pruna._flash_attn_forward(q, k, v, causal=is_causal, softmax_scale=softmax_scale) # type: ignore
out = torch.ops.flash_attn_pruna._flash_attn_forward(q, k, v, causal=is_causal, softmax_scale=softmax_scale) # type: ignore
# back to (B, H, S, D) for the rest of the pipeline
return out.transpose(1, 2)

Expand Down Expand Up @@ -336,9 +336,12 @@ def _flash_attn_forward(
v: torch.Tensor,
softmax_scale: float | None = None,
causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
out, lse = flash_attn_cuda(q, k, v, softmax_scale=softmax_scale or None, causal=causal, deterministic=False)
return out, lse.permute(0, 2, 1) # (B,H,S) → (B,S,H)
) -> torch.Tensor:
result = flash_attn_cuda(q, k, v, softmax_scale=softmax_scale or None, causal=causal, deterministic=False)
# Some kernel builds return (out, lse), others return just out, depending on torch and cuda version
if isinstance(result, tuple):
return result[0]
return result

@torch.library.register_fake("flash_attn_pruna::_flash_attn_forward")
def _flash_attn_forward_fake(
Expand All @@ -347,6 +350,5 @@ def _flash_attn_forward_fake(
v: torch.Tensor,
softmax_scale: float | None = None,
causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
b, s, h, _ = q.shape
return torch.empty_like(q), q.new_empty((b, s, h))
) -> torch.Tensor:
return torch.empty_like(q)
5 changes: 5 additions & 0 deletions src/pruna/algorithms/sage_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import torch
from diffusers import DiffusionPipeline
from diffusers.models.attention_dispatch import AttentionBackendName, _maybe_download_kernel_for_backend
from typing_extensions import cast

from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase
Expand Down Expand Up @@ -95,6 +96,10 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
target_modules = self.get_model_dependent_hyperparameter_defaults(model, smash_config)["target_modules"]
target_modules = cast(TARGET_MODULES_TYPE, target_modules)

# Diffusers has two set_attention_backend methods, one for the whole model and one for the submodules
# The submodule-level method does not trigger the download therefore we need to pre-load the kernel once
_maybe_download_kernel_for_backend(AttentionBackendName.SAGE_HUB)

def apply_sage_attn(
root_name: str | None,
root_nn_module: torch.nn.Module,
Expand Down
Loading