Skip to content

How to train Dynamic Grained Coding module? #8

@xuesongnie

Description

@xuesongnie
  1. How does the following argmax backpropagate without gradients?
  gate = self.router(h_fine=h_fine, h_median=h_median, h_coarse=h_coarse, entropy=x_entropy)  # b, h, w, 3 (l m s)
  if self.training:  # may exists gaps between training and sampling
      gate = F.gumbel_softmax(gate, tau=1, dim=-1, hard=True)
  gate = gate.permute(0,3,1,2)  # b, 3, h, w
  indices = gate.argmax(dim=1)  # b, h, w
  1. What does the following code mean?
if self.training:
    gate_grad = gate.max(dim=1, keepdim=True)[0]
    gate_grad = gate_grad.repeat_interleave(4, dim=-1).repeat_interleave(4, dim=-2)
    # h_coarse = h_coarse * gate_grad
    # h_median = h_median * gate_grad
    # h_fine = h_fine * gate_grad

    h_triple = h_triple * gate_grad

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions