Skip to content

请问模型中的编码他车信息的NAT模型,为什么会引入随机性呢? #30

@Hailan-9

Description

@Hailan-9

麻烦作者大佬,想请问模型中的编码他车信息的NAT模型,为什么会引入随机性呢?我固定了各种随机种子:
random.seed(CUR_SEED)
np.random.seed(CUR_SEED)
torch.manual_seed(CUR_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
但是因为编码他车的NAT模型的原因(将其输出全部置为全零张量,即: x_agent_tmp = torch.zeros_like(x_agent)
x_agent[:, 1:] = x_agent_tmp[:, 1:]每次训练结果就完全一致了),同样的训练条件下,训练结果不一致,但是我看了NAT的代码(即layers/embedding.py),没发现哪个地方有问题,不太确定from natten import NeighborhoodAttention1D引入的这个模型是不是有一些不受上面固定随机性代码的影响。
非常期待大佬的回复!
@jchengai @Rex-sys-hk

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions