diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 289c3e82955b..ee43b687ec6a 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -126,7 +126,7 @@ if _CAN_USE_NPU_ATTN: - from torch_npu import npu_fusion_attention + from torch_npu import npu_fusion_attention, _npu_flash_attention_unpad else: npu_fusion_attention = None @@ -1582,7 +1582,7 @@ def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): enable_gqa=enable_gqa, return_lse=return_lse, ) - out = out.permute(0, 2, 1, 3) + out = out.permute(0, 2, 1, 3).contiguous() return out @@ -1602,23 +1602,17 @@ def _native_attention( return_lse: bool = False, _parallel_config: Optional["ParallelConfig"] = None, ) -> torch.Tensor: - if return_lse: - raise ValueError("Native attention backend does not support setting `return_lse=True`.") - query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) - out = torch.nn.functional.scaled_dot_product_attention( - query=query, - key=key, - value=value, - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=is_causal, - scale=scale, - enable_gqa=enable_gqa, - ) - out = out.permute(0, 2, 1, 3) + B, S, N, D = query.shape + query = query.view(B * S, N * D) + key = key.view(B * S, N * D) + value = value.view(B * S, N * D) + seq_len = torch.full((B,), S, dtype=torch.int32, device='cpu') + out = query + _npu_flash_attention_unpad(query, key, value, seq_len, 1/math.sqrt(D), N, N, out) + + out = out.view(B, S, N, D).contiguous() return out - @_AttentionBackendRegistry.register( AttentionBackendName._NATIVE_CUDNN, constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 16662e8d8fe8..aa9a36409171 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -18,6 +18,7 @@ import numpy as np import torch +import torch_npu import torch.nn as nn import torch.nn.functional as F @@ -130,7 +131,6 @@ def __call__( ) -> torch.Tensor: if hasattr(self._parallel_config, "context_parallel_config") and \ self._parallel_config.context_parallel_config is not None: - return self._context_parallel_forward( attn, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, pre_query, pre_key, cal_q ) @@ -242,32 +242,25 @@ def _context_parallel_forward( if image_rotary_emb is not None: key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) key_all = ulysses_preforward(key, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL) - + value_all = _wait_tensor(value_all) - value_all = value_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() + value_all = value_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous() + B, S, N, D = value_all.shape + value_all = value_all.view(B * S, N * D) query_all = _wait_tensor(query_all) - query_all = query_all.reshape(world_size, S_Q_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() - + query_all = query_all.reshape(world_size, S_Q_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous() + query_all = query_all.view(B * S, N * D) key_all = _wait_tensor(key_all) - key_all = key_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).transpose(2,1).contiguous() - - out = npu_fusion_attention( - query_all, - key_all, - value_all, - H_LOCAL, # num_heads - input_layout="BNSD", - pse=None, - scale=1.0 / math.sqrt(D), - pre_tockens=65536, - next_tockens=65536, - keep_prob=1.0, - sync=False, - inner_precise=0, - )[0] - out = out.transpose(1, 2).contiguous() + key_all = key_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous() + key_all = key_all.view(B * S, N * D) + + seq_len = torch.full((B,), S, dtype=torch.int32, device='cpu') + out = query_all + torch_npu._npu_flash_attention_unpad(query_all, key_all, value_all, seq_len, 1/math.sqrt(D), N, N, out) + + out = out.view(B, S, N, D).contiguous() out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous() out = _all_to_all_single(out, group) hidden_states = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous()