Hi, I am wondering why you concatenate the feature vectors of all the slices before passing them to gumbel_softmax(), i.e., in the code below
class Tokenizer(nn.Module):
def __init__(self, rep_dim, vocab_size):
super(Tokenizer, self).__init__()
self.center = nn.Linear(rep_dim, vocab_size)
def forward(self, x):
bs, length, dim = x.shape
probs = self.center(x.view(-1, dim))
ret = F.gumbel_softmax(probs)
indexes = ret.max(-1, keepdim=True)[1]
return indexes.view(bs, length)
why is view(-1, dim) called to convert shape from (bs, length, dim) to (bs * length, dim)?