Skip to content

Commit 22a1119

Browse files
authored
[bugfix] fix 4d attention mask device (#85)
1 parent e668f39 commit 22a1119

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/twinkle/processor/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,9 @@ def _create_4d_attention_mask(attention_mask):
246246
import torch
247247
seq_lens = [s.shape[0] for s in attention_mask]
248248
max_len = max(seq_lens)
249-
attention_mask = torch.tril(torch.ones((len(seq_lens), max_len, max_len),
250-
dtype=torch.bool)).view(len(seq_lens), 1, max_len, max_len)
249+
device = attention_mask[0].device
250+
attention_mask = torch.tril(torch.ones((len(seq_lens), max_len, max_len), dtype=torch.bool,
251+
device=device)).view(len(seq_lens), 1, max_len, max_len)
251252
assert attention_mask.dtype is torch.bool, f'attention_mask.dtype: {attention_mask.dtype}'
252253
for i, seq_len in enumerate(seq_lens):
253254
attention_mask[i, :, :, seq_len:] = 0

0 commit comments

Comments
 (0)