11# Copyright (c) 2024, DeepLink.
22
33import torch
4- import torch_npu
54from einops import repeat
65from mindspeed .ops .npu_rotary_position_embedding import npu_rotary_position_embedding
76
@@ -45,9 +44,11 @@ def forward(
4544 else :
4645 cos = repeat (cos [:seqlen ], "... d -> 1 ... 1 (2 d)" )
4746 sin = repeat (sin [:seqlen ], "... d -> 1 ... 1 (2 d)" )
47+
4848 ctx .save_for_backward (cos , sin )
4949 ctx .interleaved = interleaved
5050 ctx .in_place = in_place
51+
5152 if interleaved :
5253 x_ro = x [..., :rotary_dim ]
5354 out_ro = npu_rotary_position_embedding (x_ro , cos , sin , 1 )
@@ -62,7 +63,7 @@ def forward(
6263 return out_ro
6364 else :
6465 x_ro = x [..., :rotary_dim ]
65- out_ro = torch_npu . npu_rotary_mul (x_ro , cos , sin )
66+ out_ro = npu_rotary_position_embedding (x_ro , cos , sin , 0 )
6667 if in_place :
6768 x [..., :rotary_dim ].copy_ (out_ro )
6869 return x
@@ -78,6 +79,7 @@ def backward(ctx, grad_out):
7879 cos , sin = ctx .saved_tensors
7980 rotary_dim = cos .shape [- 1 ]
8081 head_dim = grad_out .shape [- 1 ]
82+
8183 if ctx .interleaved :
8284 grad_out_ro = grad_out [..., :rotary_dim ]
8385 grad_input_ro = npu_rotary_position_embedding (
@@ -94,7 +96,9 @@ def backward(ctx, grad_out):
9496 return grad_input_ro , None , None , None , None
9597 else :
9698 grad_out_ro = grad_out [..., :rotary_dim ]
97- grad_input_ro = torch_npu .npu_rotary_mul (grad_out_ro , cos , torch .neg (sin ))
99+ grad_input_ro = npu_rotary_position_embedding (
100+ grad_out_ro , cos , torch .neg (sin ), 0
101+ )
98102 if ctx .in_place :
99103 grad_out [..., :rotary_dim ].copy_ (grad_input_ro )
100104 return grad_out , None , None , None , None
0 commit comments