-
Notifications
You must be signed in to change notification settings - Fork 13
Open
Description
def get_gt_loss(self, inputs, scale, outputs, disp, depth_gt, mask):
singel_scale_total_loss = 0
if self.opt.volume_depth:
if self.opt.l1_voxel != 'No':
density_center = outputs[('density_center', 0)]
label_true = torch.ones_like(density_center, requires_grad=False)
all_empty = outputs[('all_empty', 0)]
label_false = torch.zeros_like(all_empty, requires_grad=False)
if 'l1' in self.opt.l1_voxel:
surface_loss_true = F.l1_loss(density_center, label_true, size_average=True)
surface_loss_false = F.l1_loss(all_empty, label_false, size_average=True)
total_grid_loss = self.opt.empty_w * surface_loss_false + surface_loss_true
elif 'ce' in self.opt.l1_voxel:
label = torch.cat((label_true, label_false))
pred = torch.cat((density_center, all_empty))
total_grid_loss = self.criterion(pred, label)
if self.local_rank == 0 and scale == 0:
print('ce loss:', total_grid_loss)
Metadata
Metadata
Assignees
Labels
No labels