-
Notifications
You must be signed in to change notification settings - Fork 92
Open
Description
代码 pluto/src/models/pluto/modules/planning_decoder.py的59行附近,
tgt = tgt.transpose(1, 2).reshape(bs * M, R, D)
tgt2 = self.norm1(tgt)
tgt2 = self.r2r_attn(
tgt2, tgt2, tgt2, key_padding_mask=tgt_key_padding_mask.repeat(M, 1)
)[0]
tgt = tgt + self.dropout1(tgt2)
key_padding_mask=tgt_key_padding_mask.repeat(M, 1)的操作可能不符合物理意义,源代码中的操作实际上将Batch复制了M份。
作者实际上应该是想将参考线复制M份,可以写成 tgt_key_padding_mask.unsqueeze(2).repeat(1, 1, M).transpose(1, 2).reshape(B * M, R) ,这样才能和tgt2的物理意义对应。
可以通过下面代码验证:
`
import torch
B = 2
R = 3
M = 2
tgt_key_padding_mask = torch.tensor([
[False, False, True], # sample 0
[False, False, False] # sample 1
])
assert tgt_key_padding_mask.shape[0] == B
assert tgt_key_padding_mask.shape[1] == R
mask1 = tgt_key_padding_mask.repeat(M, 1).reshape(B * M, R)
mask2 = tgt_key_padding_mask.unsqueeze(2).repeat(1, 1, M).transpose(1, 2).reshape(B * M, R)
ret = torch.logical_xor(mask1, mask2)
print(f"{ret}")
期望输出全false, 实际输出不全为false, 所以 mask1和mask2不同
tensor([[False, False, False],
[False, False, True],
[False, False, True],
[False, False, False]])
`
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels