diff --git a/examples/supervise/model.py b/examples/supervise/model.py index fe2f432..1b08ca5 100644 --- a/examples/supervise/model.py +++ b/examples/supervise/model.py @@ -87,11 +87,12 @@ def get_loss_dict(self, render_results, batch) -> dict: ssim_loss = 1.0 - ssim(render_results['rgb'], gt_images) loss += (1.0 - self.cfg.lambda_ssim) * L1_loss loss += self.cfg.lambda_ssim * ssim_loss - # normal_loss = 0.1 * l1_loss(render_results['normal'], normal_images) - # loss += normal_loss + normal_loss = 0.1 * l1_loss(render_results['normal'], normal_images) + loss += normal_loss loss_dict = {"loss": loss, "L1_loss": L1_loss, - "ssim_loss": ssim_loss} + "ssim_loss": ssim_loss, + "normal_loss": normal_loss} return loss_dict @torch.no_grad()