-
Notifications
You must be signed in to change notification settings - Fork 7
Open
Description
- How does the following argmax backpropagate without gradients?
gate = self.router(h_fine=h_fine, h_median=h_median, h_coarse=h_coarse, entropy=x_entropy) # b, h, w, 3 (l m s)
if self.training: # may exists gaps between training and sampling
gate = F.gumbel_softmax(gate, tau=1, dim=-1, hard=True)
gate = gate.permute(0,3,1,2) # b, 3, h, w
indices = gate.argmax(dim=1) # b, h, w
- What does the following code mean?
if self.training:
gate_grad = gate.max(dim=1, keepdim=True)[0]
gate_grad = gate_grad.repeat_interleave(4, dim=-1).repeat_interleave(4, dim=-2)
# h_coarse = h_coarse * gate_grad
# h_median = h_median * gate_grad
# h_fine = h_fine * gate_grad
h_triple = h_triple * gate_grad
Metadata
Metadata
Assignees
Labels
No labels