-
Notifications
You must be signed in to change notification settings - Fork 7
Open
Labels
questionFurther information is requestedFurther information is requested
Description
In INN.BatchNorm1d, the forward function is:
def forward(self, x, log_p=0, log_det_J=0):
if self.compute_p:
if not self.training:
# if in self.eval()
var = self.running_var # [dim]
else:
# if in training
# TODO: Do we need to add .detach() after var?
var = torch.var(x, dim=0, unbiased=False) # [dim]
x = super(BatchNorm1d, self).forward(x)
log_det = -0.5 * torch.log(var + self.eps)
log_det = torch.sum(log_det, dim=-1)
return x, log_p, log_det_J + log_det
else:
return super(BatchNorm1d, self).forward(x)Do we need to requires var has gradient information? It seems not training BatchNorm1d, but training modules before it. Is there any references on this?
Metadata
Metadata
Assignees
Labels
questionFurther information is requestedFurther information is requested