Inside diff_forward_x0_constraint three losses are computed: xt_noise, loss, x0_loss.
The loss corresponds to L3 in Eq.21 in the paper. However, the way this loss is implemented in diff_forward_x0_constraint seems to prevent gradient backpropagation..
Printing loss.grad_fn and loss.requires_grad gives None and False, respectively, which means this term never contributes to network optimization.
mask = spatial_A_trans[id_sim[:, :-1], id_sim[:, 1:]] == 1e-10
loss = torch.sum(mask).float()
print('loss.grad_fn:', loss.grad_fn) # None
print('can backward:', loss.requires_grad) # False
Could you kindly let me know if this implementation is indeed problematic, or if I’ve overlooked anything?