Skip to content

L3 loss in diff_forward_x0_constraint does not back-propagate gradients #9

@HQ-LV

Description

@HQ-LV

Inside diff_forward_x0_constraint three losses are computed: xt_noise, loss, x0_loss.
The loss corresponds to L3 in Eq.21 in the paper. However, the way this loss is implemented in diff_forward_x0_constraint seems to prevent gradient backpropagation..
Printing loss.grad_fn and loss.requires_grad gives None and False, respectively, which means this term never contributes to network optimization.

mask = spatial_A_trans[id_sim[:, :-1], id_sim[:, 1:]] == 1e-10
loss = torch.sum(mask).float()
print('loss.grad_fn:', loss.grad_fn)   # None
print('can backward:', loss.requires_grad)  # False

Could you kindly let me know if this implementation is indeed problematic, or if I’ve overlooked anything?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions