c = p_a.shape[0] reshape_pa = p_a.reshape((c, -1)).permute(1, 0) # c h w -> c p reshape_pb = p_b.reshape((c, -1)).permute(1, 0) 在网络中p_a的维度不应该是(N,C,W,H)么,这里直接做c=p_a.shape感觉有些不太对呀