Skip to content

Commit 5565fbc

Browse files
author
wangxing2
committed
optimize
1 parent ee0baeb commit 5565fbc

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

deeplink_ext/internevo_ops/_rotary_embedding_npu.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Copyright (c) 2024, DeepLink.
22

33
import torch
4-
import torch_npu
54
from einops import repeat
65
from mindspeed.ops.npu_rotary_position_embedding import npu_rotary_position_embedding
76

@@ -45,9 +44,11 @@ def forward(
4544
else:
4645
cos = repeat(cos[:seqlen], "... d -> 1 ... 1 (2 d)")
4746
sin = repeat(sin[:seqlen], "... d -> 1 ... 1 (2 d)")
47+
4848
ctx.save_for_backward(cos, sin)
4949
ctx.interleaved = interleaved
5050
ctx.in_place = in_place
51+
5152
if interleaved:
5253
x_ro = x[..., :rotary_dim]
5354
out_ro = npu_rotary_position_embedding(x_ro, cos, sin, 1)
@@ -62,7 +63,7 @@ def forward(
6263
return out_ro
6364
else:
6465
x_ro = x[..., :rotary_dim]
65-
out_ro = torch_npu.npu_rotary_mul(x_ro, cos, sin)
66+
out_ro = npu_rotary_position_embedding(x_ro, cos, sin, 0)
6667
if in_place:
6768
x[..., :rotary_dim].copy_(out_ro)
6869
return x
@@ -78,6 +79,7 @@ def backward(ctx, grad_out):
7879
cos, sin = ctx.saved_tensors
7980
rotary_dim = cos.shape[-1]
8081
head_dim = grad_out.shape[-1]
82+
8183
if ctx.interleaved:
8284
grad_out_ro = grad_out[..., :rotary_dim]
8385
grad_input_ro = npu_rotary_position_embedding(
@@ -94,7 +96,9 @@ def backward(ctx, grad_out):
9496
return grad_input_ro, None, None, None, None
9597
else:
9698
grad_out_ro = grad_out[..., :rotary_dim]
97-
grad_input_ro = torch_npu.npu_rotary_mul(grad_out_ro, cos, torch.neg(sin))
99+
grad_input_ro = npu_rotary_position_embedding(
100+
grad_out_ro, cos, torch.neg(sin), 0
101+
)
98102
if ctx.in_place:
99103
grad_out[..., :rotary_dim].copy_(grad_input_ro)
100104
return grad_out, None, None, None, None

0 commit comments

Comments
 (0)