diff --git a/deeplink_ext/internevo_ops/_flash_attention_npu.py b/deeplink_ext/internevo_ops/_flash_attention_npu.py index 37110b9..c68ffd6 100644 --- a/deeplink_ext/internevo_ops/_flash_attention_npu.py +++ b/deeplink_ext/internevo_ops/_flash_attention_npu.py @@ -12,6 +12,43 @@ "flash_attn_varlen_kvpacked_func", ] +# construct a global attention mask for npu +_GLOBAL_ATTN_MASK = None + + +def get_attention_mask(seqlen, causal, window_size): + global _GLOBAL_ATTN_MASK + + if _GLOBAL_ATTN_MASK is not None: + return _GLOBAL_ATTN_MASK + + # causal attention + if causal: + if seqlen > 2048: + _GLOBAL_ATTN_MASK = torch.triu( + torch.ones([2048, 2048], dtype=bool, device=torch.npu.current_device()), + diagonal=1, + ) + else: + _GLOBAL_ATTN_MASK = torch.triu( + torch.ones( + [seqlen, seqlen], dtype=bool, device=torch.npu.current_device() + ), + diagonal=1, + ) + + # sliding window attention + if window_size[0] >= 0 or window_size[1] >= 0: + _GLOBAL_ATTN_MASK = torch.tril( + torch.ones([seqlen, seqlen], dtype=bool, device=torch.npu.current_device()), + diagonal=-((seqlen - 1 if window_size[0] < 0 else window_size[0]) + 1), + ) + torch.triu( + torch.ones([seqlen, seqlen], dtype=bool, device=torch.npu.current_device()), + diagonal=(seqlen - 1 if window_size[1] < 0 else window_size[1]) + 1, + ) + + return _GLOBAL_ATTN_MASK + def flash_attn_func( q, @@ -32,22 +69,15 @@ def flash_attn_func( seqlen_k = k.shape[1] head_num = q.shape[-2] - if seqlen_q == seqlen_k and seqlen_q < 2048 and seqlen_k < 2048: - sparse_mode = 0 - else: - sparse_mode = 2 + assert seqlen_q == seqlen_k, "Npu currently only supports seqlen_q = seqlen_k." + attention_mask = get_attention_mask(seqlen_q, causal, window_size) + sparse_mode = 0 if attention_mask is None or seqlen_q <= 2048 else 4 - seqlen_q = min(seqlen_q, 2048) - seqlen_k = min(seqlen_k, 2048) - - attention_mask = ( - torch.triu( - torch.ones([seqlen_q, seqlen_k], dtype=torch.bool, device=q.device), - diagonal=1, - ) - if causal - else None - ) + pre_tokens = seqlen_q - 1 + next_tokens = 0 + if window_size[0] >= 0 or window_size[1] >= 0: + pre_tokens = seqlen_q - 1 if window_size[0] < 0 else window_size[0] + next_tokens = seqlen_q - 1 if window_size[1] < 0 else window_size[1] out = torch_npu.npu_fusion_attention( q, @@ -58,8 +88,8 @@ def flash_attn_func( atten_mask=attention_mask, scale=softmax_scale, keep_prob=1 - dropout_p, - pre_tockens=seqlen_q, - next_tockens=0, + pre_tockens=pre_tokens, + next_tockens=next_tokens, sparse_mode=sparse_mode, )[0] @@ -89,22 +119,18 @@ def flash_attn_varlen_func( cu_seqlens_q = cu_seqlens_q[1:].tolist() cu_seqlens_k = cu_seqlens_k[1:].tolist() - seqlen_q = min(max_seqlen_q, 2048) - seqlen_k = min(max_seqlen_k, 2048) - - if max_seqlen_q < 2048: - sparse_mode = 0 - else: - sparse_mode = 2 - - attention_mask = ( - torch.triu( - torch.ones([seqlen_q, seqlen_k], dtype=torch.bool, device=q.device), - diagonal=1, - ) - if causal - else None - ) + + assert ( + max_seqlen_q == max_seqlen_k + ), "Npu currently only supports max_seqlen_q = max_seqlen_k." + attention_mask = get_attention_mask(max_seqlen_q, causal, window_size) + sparse_mode = 0 if attention_mask is None or max_seqlen_q <= 2048 else 4 + + pre_tokens = max_seqlen_q - 1 + next_tokens = 0 + if window_size[0] >= 0 or window_size[1] >= 0: + pre_tokens = max_seqlen_q - 1 if window_size[0] < 0 else window_size[0] + next_tokens = max_seqlen_q - 1 if window_size[1] < 0 else window_size[1] out = torch_npu.npu_fusion_attention( q, @@ -114,8 +140,8 @@ def flash_attn_varlen_func( "TND", atten_mask=attention_mask, scale=softmax_scale, - pre_tockens=q.shape[0], # seq_len - next_tockens=0, # 0 + pre_tockens=pre_tokens, + next_tockens=next_tokens, keep_prob=1 - dropout_p, sparse_mode=sparse_mode, actual_seq_qlen=cu_seqlens_q, @@ -143,21 +169,14 @@ def flash_attn_qkvpacked_func( seqlen_qkv = qkv.shape[1] head_num = q.shape[-2] - if seqlen_qkv < 2048: - sparse_mode = 0 - else: - sparse_mode = 2 - - seqlen_qkv = min(qkv.shape[1], 2048) + attention_mask = get_attention_mask(seqlen_qkv, causal, window_size) + sparse_mode = 0 if attention_mask is None or seqlen_qkv <= 2048 else 4 - attention_mask = ( - torch.triu( - torch.ones([seqlen_qkv, seqlen_qkv], dtype=torch.bool, device=q.device), - diagonal=1, - ) - if causal - else None - ) + pre_tokens = seqlen_qkv - 1 + next_tokens = 0 + if window_size[0] >= 0 or window_size[1] >= 0: + pre_tokens = seqlen_qkv - 1 if window_size[0] < 0 else window_size[0] + next_tokens = seqlen_qkv - 1 if window_size[1] < 0 else window_size[1] out = torch_npu.npu_fusion_attention( q, @@ -168,8 +187,8 @@ def flash_attn_qkvpacked_func( atten_mask=attention_mask, scale=softmax_scale, keep_prob=1 - dropout_p, - pre_tockens=seqlen_qkv, - next_tockens=0, + pre_tockens=pre_tokens, + next_tockens=next_tokens, sparse_mode=sparse_mode, )[0] @@ -192,26 +211,19 @@ def flash_attn_kvpacked_func( k = kv[:, :, 0] v = kv[:, :, 1] - s0 = q.shape[1] - s1 = kv.shape[1] + seqlen_q = q.shape[1] + seqlen_kv = kv.shape[1] head_num = q.shape[-2] - if s0 == s1 and s0 < 2048 and s1 < 2048: - sparse_mode = 0 - else: - sparse_mode = 2 - - seqlen_q = min(s0, 2048) - seqlen_k = min(s1, 2048) + assert seqlen_q == seqlen_kv, "Npu currently only supports seqlen_q = seqlen_kv." + attention_mask = get_attention_mask(seqlen_q, causal, window_size) + sparse_mode = 0 if attention_mask is None or seqlen_q <= 2048 else 4 - attention_mask = ( - torch.triu( - torch.ones([seqlen_q, seqlen_k], dtype=torch.bool, device=q.device), - diagonal=1, - ) - if causal - else None - ) + pre_tokens = seqlen_q - 1 + next_tokens = 0 + if window_size[0] >= 0 or window_size[1] >= 0: + pre_tokens = seqlen_q - 1 if window_size[0] < 0 else window_size[0] + next_tokens = seqlen_q - 1 if window_size[1] < 0 else window_size[1] out = torch_npu.npu_fusion_attention( q, @@ -222,8 +234,8 @@ def flash_attn_kvpacked_func( atten_mask=attention_mask, scale=softmax_scale, keep_prob=1 - dropout_p, - pre_tockens=seqlen_k, - next_tockens=0, + pre_tockens=pre_tokens, + next_tockens=next_tokens, sparse_mode=sparse_mode, )[0] @@ -247,32 +259,30 @@ def flash_attn_varlen_qkvpacked_func( q = qkv[:, 0] k = qkv[:, 1] v = qkv[:, 2] - n = q.shape[1] - if max_seqlen > 2048: - sparse_mode = 2 - else: - sparse_mode = 0 + head_num = q.shape[1] + cu_seqlens_q = cu_seqlens[1:].tolist() cu_seqlens_k = cu_seqlens[1:].tolist() - seqlen = min(max_seqlen, 2048) - attention_mask = ( - torch.triu( - torch.ones([seqlen, seqlen], dtype=torch.bool, device=q.device), - diagonal=1, - ) - if causal - else None - ) + + attention_mask = get_attention_mask(max_seqlen, causal, window_size) + sparse_mode = 0 if attention_mask is None or max_seqlen <= 2048 else 4 + + pre_tokens = max_seqlen - 1 + next_tokens = 0 + if window_size[0] >= 0 or window_size[1] >= 0: + pre_tokens = max_seqlen - 1 if window_size[0] < 0 else window_size[0] + next_tokens = max_seqlen - 1 if window_size[1] < 0 else window_size[1] + out = torch_npu.npu_fusion_attention( q, k, v, - n, + head_num, "TND", atten_mask=attention_mask, scale=softmax_scale, - pre_tockens=q.shape[0], # seq_len - next_tockens=0, # 0 + pre_tockens=pre_tokens, + next_tockens=next_tokens, keep_prob=1 - dropout_p, sparse_mode=sparse_mode, actual_seq_qlen=cu_seqlens_q, @@ -300,35 +310,32 @@ def flash_attn_varlen_kvpacked_func( softmax_scale = q.shape[-1] ** (-0.5) k = kv[:, 0] v = kv[:, 1] - n = q.shape[1] + head_num = q.shape[1] cu_seqlens_q = cu_seqlens_q[1:].tolist() cu_seqlens_k = cu_seqlens_k[1:].tolist() - seqlen_q = min(max_seqlen_q, 2048) - seqlen_k = min(max_seqlen_k, 2048) - - if max_seqlen_q > 2048: - sparse_mode = 2 - else: - sparse_mode = 0 - - attention_mask = ( - torch.triu( - torch.ones([seqlen_q, seqlen_k], dtype=torch.bool, device=q.device), - diagonal=1, - ) - if causal - else None - ) + + assert ( + max_seqlen_q == max_seqlen_k + ), "Npu currently only supports max_seqlen_q = max_seqlen_k." + attention_mask = get_attention_mask(max_seqlen_q, causal, window_size) + sparse_mode = 0 if attention_mask is None or max_seqlen_q <= 2048 else 4 + + pre_tokens = max_seqlen_q - 1 + next_tokens = 0 + if window_size[0] >= 0 or window_size[1] >= 0: + pre_tokens = max_seqlen_q - 1 if window_size[0] < 0 else window_size[0] + next_tokens = max_seqlen_k - 1 if window_size[1] < 0 else window_size[1] + out = torch_npu.npu_fusion_attention( q, k, v, - n, + head_num, "TND", atten_mask=attention_mask, scale=softmax_scale, - pre_tockens=q.shape[0], # seq_len - next_tockens=0, # 0 + pre_tockens=pre_tokens, + next_tockens=next_tokens, keep_prob=1 - dropout_p, sparse_mode=sparse_mode, actual_seq_qlen=cu_seqlens_q, diff --git a/deeplink_ext/internevo_ops/_rotary_embedding_npu.py b/deeplink_ext/internevo_ops/_rotary_embedding_npu.py index 4e27d04..ba2237f 100644 --- a/deeplink_ext/internevo_ops/_rotary_embedding_npu.py +++ b/deeplink_ext/internevo_ops/_rotary_embedding_npu.py @@ -1,8 +1,8 @@ # Copyright (c) 2024, DeepLink. import torch -import torch_npu -from einops import rearrange +from einops import repeat +from mindspeed.ops.npu_rotary_position_embedding import npu_rotary_position_embedding __all__ = ["ApplyRotaryEmb"] @@ -38,38 +38,73 @@ def forward( assert seqlen <= rotary_seqlen assert sin.shape == (rotary_seqlen, rotary_dim // 2) - re_cos = rearrange(cos[:seqlen], "s d -> s 1 d") - re_sin = rearrange(sin[:seqlen], "s d -> s 1 d") - - cat_cos = torch.cat([re_cos, re_cos], -1) - cat_sin = torch.cat([re_sin, re_sin], -1) + if interleaved: + cos = repeat(cos[:seqlen], "... d -> 1 ... 1 (d 2)") + sin = repeat(sin[:seqlen], "... d -> 1 ... 1 (d 2)") + else: + cos = repeat(cos[:seqlen], "... d -> 1 ... 1 (2 d)") + sin = repeat(sin[:seqlen], "... d -> 1 ... 1 (2 d)") - rot = torch_npu.npu_rotary_mul(x[..., :rotary_dim], cat_cos, cat_sin) - ctx.save_for_backward(cat_cos, cat_sin) + ctx.save_for_backward(cos, sin) ctx.interleaved = interleaved ctx.in_place = in_place - if in_place: - x[..., :rotary_dim].copy_(rot) - return x + + if interleaved: + x_ro = x[..., :rotary_dim] + out_ro = npu_rotary_position_embedding(x_ro, cos, sin, 1) + if in_place: + x[..., :rotary_dim].copy_(out_ro) + return x + if rotary_dim < head_dim: + out = torch.empty_like(x) + out[..., :rotary_dim].copy_(out_ro) + out[..., rotary_dim:].copy_(x[..., rotary_dim:]) + return out + return out_ro else: - out = x.detach().clone() - if rotary_dim < head_dim and not in_place: + x_ro = x[..., :rotary_dim] + out_ro = npu_rotary_position_embedding(x_ro, cos, sin, 0) + if in_place: + x[..., :rotary_dim].copy_(out_ro) + return x + if rotary_dim < head_dim: + out = torch.empty_like(x) + out[..., :rotary_dim].copy_(out_ro) out[..., rotary_dim:].copy_(x[..., rotary_dim:]) - return out + return out + return out_ro @staticmethod - def backward(ctx, do): - cat_cos, cat_sin = ctx.saved_tensors - *_, seqlen, _, head_dim = do.shape - rotary_dim = cat_cos.shape[-1] + def backward(ctx, grad_out): + cos, sin = ctx.saved_tensors + rotary_dim = cos.shape[-1] + head_dim = grad_out.shape[-1] - dx_out = torch_npu.npu_rotary_mul( - do[..., :rotary_dim], cat_cos, torch.neg(cat_sin) - ) - if ctx.in_place: - do[..., :rotary_dim].copy_(dx_out) - return do, None, None, None, None + if ctx.interleaved: + grad_out_ro = grad_out[..., :rotary_dim] + grad_input_ro = npu_rotary_position_embedding( + grad_out_ro, cos, torch.neg(sin), 1 + ) + if ctx.in_place: + grad_out[..., :rotary_dim].copy_(grad_input_ro) + return grad_out, None, None, None, None + if rotary_dim < head_dim: + grad_input = torch.empty_like(grad_out) + grad_input[..., :rotary_dim].copy_(grad_input_ro) + grad_input[..., rotary_dim:].copy_(grad_out[..., rotary_dim:]) + return grad_input, None, None, None, None + return grad_input_ro, None, None, None, None else: - dx = do.detach().clone() - dx[..., :rotary_dim].copy_(dx_out) - return dx, None, None, None, None + grad_out_ro = grad_out[..., :rotary_dim] + grad_input_ro = npu_rotary_position_embedding( + grad_out_ro, cos, torch.neg(sin), 0 + ) + if ctx.in_place: + grad_out[..., :rotary_dim].copy_(grad_input_ro) + return grad_out, None, None, None, None + if rotary_dim < head_dim: + grad_input = torch.empty_like(grad_out) + grad_input[..., :rotary_dim].copy_(grad_input_ro) + grad_input[..., rotary_dim:].copy_(grad_out[..., rotary_dim:]) + return grad_input, None, None, None, None + return grad_input_ro, None, None, None, None diff --git a/deeplink_ext/internevo_ops/rotary_embedding.py b/deeplink_ext/internevo_ops/rotary_embedding.py index 1a2a36d..7764b9b 100644 --- a/deeplink_ext/internevo_ops/rotary_embedding.py +++ b/deeplink_ext/internevo_ops/rotary_embedding.py @@ -4,8 +4,7 @@ platform_type = deeplink_ext_get_platform_type() if platform_type == PlatformType.TORCH_NPU: - # from ._rotary_embedding_npu import ApplyRotaryEmb - from .rotary_embedding_fallback import ApplyRotaryEmbTorch as ApplyRotaryEmb + from ._rotary_embedding_npu import ApplyRotaryEmb elif platform_type == PlatformType.TORCH_DIPU: from ._rotary_embedding_dipu import ApplyRotaryEmb else: diff --git a/deeplink_ext/interntrain_ops/rms_norm.py b/deeplink_ext/interntrain_ops/rms_norm.py index 301ab9e..e6834cb 100644 --- a/deeplink_ext/interntrain_ops/rms_norm.py +++ b/deeplink_ext/interntrain_ops/rms_norm.py @@ -4,13 +4,9 @@ platform_type = deeplink_ext_get_platform_type() if platform_type == PlatformType.TORCH_NPU: - # from ._mixed_rms_norm_npu import MixedFusedRMSNorm - # Due to the accuracy problem of the npu fused operator, a torch combination is used as an alternative. - from .rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm + from ._mixed_rms_norm_npu import MixedFusedRMSNorm elif platform_type == PlatformType.TORCH_DIPU: - # from ._mixed_rms_norm_dipu import MixedFusedRMSNorm - # Due to the accuracy problem of the npu fused operator, a torch combination is used as an alternative. - from .rms_norm_fallback import MixedRMSNormTorch as MixedFusedRMSNorm + from ._mixed_rms_norm_dipu import MixedFusedRMSNorm else: raise ImportError diff --git a/tests/internevo/test_flash_attention.py b/tests/internevo/test_flash_attention.py index 5126551..b4c4771 100644 --- a/tests/internevo/test_flash_attention.py +++ b/tests/internevo/test_flash_attention.py @@ -14,6 +14,15 @@ flash_attn_func, ) +def clear_global_attn_mask_for_npu(): + # clear the global attention mask set by the latest test case + from deeplink_ext.utils import PlatformType, deeplink_ext_get_platform_type + platform_type = deeplink_ext_get_platform_type() + if platform_type == PlatformType.TORCH_NPU: + import deeplink_ext.internevo_ops._flash_attention_npu + deeplink_ext.internevo_ops._flash_attention_npu._GLOBAL_ATTN_MASK = None + else: + pass def test_flash_attn_qkvpacked_func_mha(): batch, seqlen, num_heads, headdim = [8, 32, 32, 64] @@ -46,6 +55,7 @@ def test_flash_attn_qkvpacked_func_mha(): assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-3, atol=1e-3) assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-3) + clear_global_attn_mask_for_npu() def test_flash_attn_kvpacked_func_gqa(): @@ -83,6 +93,7 @@ def test_flash_attn_kvpacked_func_gqa(): assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-3, atol=1e-3) assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-3) + clear_global_attn_mask_for_npu() def test_flash_attn_func_gqa(): @@ -128,3 +139,4 @@ def test_flash_attn_func_gqa(): assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-3, atol=1e-3) assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-3) + clear_global_attn_mask_for_npu() diff --git a/tests/internevo/test_rotary_embedding.py b/tests/internevo/test_rotary_embedding.py index 981c2f0..a03bc95 100644 --- a/tests/internevo/test_rotary_embedding.py +++ b/tests/internevo/test_rotary_embedding.py @@ -8,40 +8,41 @@ def test_ApplyRotaryEmb(): input_dtype_list = [torch.float16, torch.bfloat16] - interleaved = False in_place_options = [False, True] + interleaved_options = [False, True] for input_dtype in input_dtype_list: for in_place in in_place_options: - input_ref = torch.randn( - 1, 64, 32, 64, dtype=input_dtype, device="cuda", requires_grad=True - ) - input_ext = input_ref.clone().detach().requires_grad_() - cos = torch.randn(64, 32, dtype=input_dtype, device="cuda") - sin = torch.randn(64, 32, dtype=input_dtype, device="cuda") + for interleaved in interleaved_options: + input_ref = torch.randn( + 1, 64, 32, 64, dtype=input_dtype, device="cuda", requires_grad=True + ) + input_ext = input_ref.clone().detach().requires_grad_() + cos = torch.randn(64, 32, dtype=input_dtype, device="cuda") + sin = torch.randn(64, 32, dtype=input_dtype, device="cuda") - output_ref, grad_ref = call_autograd_func( - ApplyRotaryEmbTorch, - "cuda", - input_dtype, - input_ref, - cos, - sin, - interleaved, - in_place, - ) - output_ext, grad_ext = call_autograd_func( - ApplyRotaryEmb, - "cuda", - input_dtype, - input_ext, - cos, - sin, - interleaved, - in_place, - ) - assert allclose( - output_ref, output_ext, rtol=1e-2, atol=5e-2 - ), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the forward test!" - assert allclose( - grad_ref, grad_ext - ), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the backward test!" + output_ref, grad_ref = call_autograd_func( + ApplyRotaryEmbTorch, + "cuda", + input_dtype, + input_ref, + cos, + sin, + interleaved, + in_place, + ) + output_ext, grad_ext = call_autograd_func( + ApplyRotaryEmb, + "cuda", + input_dtype, + input_ext, + cos, + sin, + interleaved, + in_place, + ) + assert allclose( + output_ref, output_ext + ), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the forward test!" + assert allclose( + grad_ref, grad_ext + ), f"When input dtype is {input_dtype} and in_place is {in_place}, ApplyRotaryEmb fails to pass the backward test!" diff --git a/tests/internevo/test_varlen_flash_attention.py b/tests/internevo/test_varlen_flash_attention.py index 97b8d64..d127241 100644 --- a/tests/internevo/test_varlen_flash_attention.py +++ b/tests/internevo/test_varlen_flash_attention.py @@ -14,6 +14,15 @@ flash_attn_varlen_func, ) +def clear_global_attn_mask_for_npu(): + # clear the global attention mask set by the latest test case + from deeplink_ext.utils import PlatformType, deeplink_ext_get_platform_type + platform_type = deeplink_ext_get_platform_type() + if platform_type == PlatformType.TORCH_NPU: + import deeplink_ext.internevo_ops._flash_attention_npu + deeplink_ext.internevo_ops._flash_attention_npu._GLOBAL_ATTN_MASK = None + else: + pass # fmt: off # latest sequence length is 20206-16110=4096 @@ -65,6 +74,7 @@ def test_flash_attn_varlen_qkvpacked_func_mha(): assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-5, atol=1e-5) assert allclose(grads_cpu, grads_gpu, rtol=1e-5, atol=1e-2) + clear_global_attn_mask_for_npu() def test_flash_attn_varlen_qkvpacked_func_mha_long_max_seqlen(): @@ -109,6 +119,7 @@ def test_flash_attn_varlen_qkvpacked_func_mha_long_max_seqlen(): assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-5, atol=1e-5) assert allclose(grads_cpu, grads_gpu, rtol=1e-5, atol=1e-2) + clear_global_attn_mask_for_npu() def test_flash_attn_varlen_kvpacked_func_gqa(): @@ -165,6 +176,7 @@ def test_flash_attn_varlen_kvpacked_func_gqa(): assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-5, atol=1e-5) assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-2) + clear_global_attn_mask_for_npu() def test_flash_attn_varlen_kvpacked_func_gqa_long_max_seqlen(): @@ -223,6 +235,7 @@ def test_flash_attn_varlen_kvpacked_func_gqa_long_max_seqlen(): assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-5, atol=1e-5) assert allclose(grads_cpu, grads_gpu, rtol=1e-3, atol=1e-2) + clear_global_attn_mask_for_npu() def test_flash_attn_varlen_func_gqa(): @@ -287,6 +300,7 @@ def test_flash_attn_varlen_func_gqa(): assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-5, atol=1e-5) assert allclose(grads_cpu, grads_gpu, rtol=1e-5, atol=1e-2) + clear_global_attn_mask_for_npu() def test_flash_attn_varlen_func_gqa_long_max_seqlen(): @@ -353,3 +367,4 @@ def test_flash_attn_varlen_func_gqa_long_max_seqlen(): assert allclose(ouput_forward_cpu, ouput_forward_gpu, rtol=1e-5, atol=1e-5) assert allclose(grads_cpu, grads_gpu, rtol=1e-5, atol=1e-2) + clear_global_attn_mask_for_npu()