diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index ab0d7102ee83..9e750425af0c 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -41,6 +41,7 @@ is_torch_xla_version, is_xformers_available, is_xformers_version, + is_mindie_sd_available, ) from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS @@ -63,6 +64,7 @@ _CAN_USE_NPU_ATTN = is_torch_npu_available() _CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION) _CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION) +_CAN_USE_MINDIESD_ATTN = is_mindie_sd_available() if _CAN_USE_FLASH_ATTN: @@ -142,6 +144,13 @@ else: xops = None + +if _CAN_USE_MINDIESD_ATTN: + from mindiesd import attention_forward as mindie_sd_attn_forward +else: + mindie_sd_attn_forward = None + + # Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4 if torch.__version__ >= "2.4.0": _custom_op = torch.library.custom_op @@ -215,6 +224,9 @@ class AttentionBackendName(str, Enum): # `xformers` XFORMERS = "xformers" + # mindie_sd + _MINDIE_SD_LASER = "_mindie_sd_la" + class _AttentionBackendRegistry: _backends = {} @@ -470,6 +482,12 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None f"Xformers Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `xformers>={_REQUIRED_XFORMERS_VERSION}`." ) + elif backend == AttentionBackendName._MINDIE_SD_LASER: + if not _CAN_USE_MINDIESD_ATTN: + raise RuntimeError( + f"NPU Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `torch_npu`." + ) + @functools.lru_cache(maxsize=128) def _prepare_for_flash_attn_or_sage_varlen_without_mask( @@ -893,6 +911,47 @@ def _sage_attention_backward_op( raise NotImplementedError("Backward pass is not implemented for Sage attention.") +def _mindie_sd_laser_attn_forward_op( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + _save_ctx: bool = True, + _parallel_config: Optional["ParallelConfig"] = None, +): + if enable_gqa: + raise ValueError("`enable_gqa` is not yet supported for MindIE SD Laser Attention.") + if return_lse: + raise ValueError("MindIE SD attention backend does not support setting `return_lse=True`.") + + out = mindie_sd_attn_forward( + query, + key, + value, + opt_mode="manual", + op_type="ascend_laser_attention", + layout="BNSD" + ) + + # out = out.transpose(1, 2).contiguous() + + return out + +def _mindie_sd_laser_attn_backward_op( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, + **kwargs, +): + raise NotImplementedError("Backward pass is not implemented for MindIE SD Laser Attention.") + + # ===== Context parallel ===== @@ -2012,3 +2071,47 @@ def _xformers_attention( out = out.flatten(2, 3) return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._MINDIE_SD_LASER, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _mindie_sd_laser_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + scale: Optional[float] = None, + return_lse: bool = False, + _parallel_config: Optional["ParallelConfig"] = None, +) -> torch.Tensor: + if return_lse: + raise ValueError("MINDIE SD attention backend does not support setting `return_lse=True`.") + if _parallel_config is None: + # query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value)) + out = mindie_sd_attn_forward( + query, + key, + value, + opt_mode="manual", + op_type="ascend_laser_attention", + layout="BNSD" + ) + # out = out.transpose(1, 2).contiguous() + else: + out = _templated_context_parallel_attention( + query, + key, + value, + None, + dropout_p, + None, + scale, + None, + return_lse, + forward_op=_mindie_sd_laser_attn_forward_op, + backward_op=_mindie_sd_laser_attn_backward_op, + _parallel_config=_parallel_config, + ) + return out \ No newline at end of file diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index cf77aaee8205..49758d1b2454 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -122,6 +122,7 @@ is_wandb_available, is_xformers_available, is_xformers_version, + is_mindie_sd_available, requires_backends, ) from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index adf8ed8b0694..985bb896fa20 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -229,6 +229,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b _aiter_available, _aiter_version = _is_package_available("aiter") _kornia_available, _kornia_version = _is_package_available("kornia") _nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True) +_mindie_sd_available, _mindie_sd_version = _is_package_available("mindiesd") def is_torch_available(): @@ -414,6 +415,9 @@ def is_aiter_available(): def is_kornia_available(): return _kornia_available +def is_mindie_sd_available(): + return _mindie_sd_available + # docstyle-ignore FLAX_IMPORT_ERROR = """