Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 11 additions & 17 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@


if _CAN_USE_NPU_ATTN:
from torch_npu import npu_fusion_attention
from torch_npu import npu_fusion_attention, _npu_flash_attention_unpad
else:
npu_fusion_attention = None

Expand Down Expand Up @@ -1582,7 +1582,7 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
enable_gqa=enable_gqa,
return_lse=return_lse,
)
out = out.permute(0, 2, 1, 3)
out = out.permute(0, 2, 1, 3).contiguous()
return out


Expand All @@ -1602,23 +1602,17 @@ def _native_attention(
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if return_lse:
raise ValueError("Native attention backend does not support setting `return_lse=True`.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
dropout_p=dropout_p,
is_causal=is_causal,
scale=scale,
enable_gqa=enable_gqa,
)
out = out.permute(0, 2, 1, 3)
B, S, N, D = query.shape
query = query.view(B * S, N * D)
key = key.view(B * S, N * D)
value = value.view(B * S, N * D)
seq_len = torch.full((B,), S, dtype=torch.int32, device='cpu')
out = query
_npu_flash_attention_unpad(query, key, value, seq_len, 1/math.sqrt(D), N, N, out)

out = out.view(B, S, N, D).contiguous()
return out


@_AttentionBackendRegistry.register(
AttentionBackendName._NATIVE_CUDNN,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
Expand Down
37 changes: 15 additions & 22 deletions src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import numpy as np
import torch
import torch_npu
import torch.nn as nn
import torch.nn.functional as F

Expand Down Expand Up @@ -130,7 +131,6 @@ def __call__(
) -> torch.Tensor:
if hasattr(self._parallel_config, "context_parallel_config") and \
self._parallel_config.context_parallel_config is not None:

return self._context_parallel_forward(
attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, pre_query, pre_key, cal_q
)
Expand Down Expand Up @@ -242,32 +242,25 @@ def _context_parallel_forward(
if image_rotary_emb is not None:
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
key_all = ulysses_preforward(key, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL)

value_all = _wait_tensor(value_all)
value_all = value_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous()
value_all = value_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous()
B, S, N, D = value_all.shape
value_all = value_all.view(B * S, N * D)

query_all = _wait_tensor(query_all)
query_all = query_all.reshape(world_size, S_Q_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous()

query_all = query_all.reshape(world_size, S_Q_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous()
query_all = query_all.view(B * S, N * D)

key_all = _wait_tensor(key_all)
key_all = key_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous()

out = npu_fusion_attention(
query_all,
key_all,
value_all,
H_LOCAL, # num_heads
input_layout="BNSD",
pse=None,
scale=1.0 / math.sqrt(D),
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.0,
sync=False,
inner_precise=0,
)[0]
out = out.transpose(1, 2).contiguous()
key_all = key_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous()
key_all = key_all.view(B * S, N * D)

seq_len = torch.full((B,), S, dtype=torch.int32, device='cpu')
out = query_all
torch_npu._npu_flash_attention_unpad(query_all, key_all, value_all, seq_len, 1/math.sqrt(D), N, N, out)

out = out.view(B, S, N, D).contiguous()
out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
out = _all_to_all_single(out, group)
hidden_states = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous()
Expand Down