Skip to content

PlanningDecoder潜在可能Bug #42

@HaiYangLib

Description

@HaiYangLib

代码 pluto/src/models/pluto/modules/planning_decoder.py的59行附近,

    tgt = tgt.transpose(1, 2).reshape(bs * M, R, D)
    tgt2 = self.norm1(tgt)
    tgt2 = self.r2r_attn(
        tgt2, tgt2, tgt2, key_padding_mask=tgt_key_padding_mask.repeat(M, 1)
    )[0]
    tgt = tgt + self.dropout1(tgt2)

key_padding_mask=tgt_key_padding_mask.repeat(M, 1)的操作可能不符合物理意义,源代码中的操作实际上将Batch复制了M份。
作者实际上应该是想将参考线复制M份,可以写成 tgt_key_padding_mask.unsqueeze(2).repeat(1, 1, M).transpose(1, 2).reshape(B * M, R) ,这样才能和tgt2的物理意义对应。
可以通过下面代码验证:

`
import torch

B = 2
R = 3
M = 2
tgt_key_padding_mask = torch.tensor([
    [False, False,  True],  # sample 0
    [False, False, False]   # sample 1
])
assert tgt_key_padding_mask.shape[0] == B
assert tgt_key_padding_mask.shape[1] == R

mask1 = tgt_key_padding_mask.repeat(M, 1).reshape(B * M, R)
mask2 = tgt_key_padding_mask.unsqueeze(2).repeat(1, 1, M).transpose(1, 2).reshape(B * M, R)

ret = torch.logical_xor(mask1, mask2)
print(f"{ret}") 
期望输出全false, 实际输出不全为false, 所以 mask1和mask2不同
tensor([[False, False, False],
         [False, False,  True],
         [False, False,  True],
         [False, False, False]])

`

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions