diff --git a/src/criterions/lovasz_losses.py b/src/criterions/lovasz_losses.py index d83f292..a090f11 100644 --- a/src/criterions/lovasz_losses.py +++ b/src/criterions/lovasz_losses.py @@ -23,7 +23,7 @@ def lovasz_grad(gt_sorted): p = len(gt_sorted) gts = gt_sorted.sum() intersection = gts.float() - gt_sorted.float().cumsum(0) - union = gts.float() + (1 - gt_sorted).float().cumsum(0) + union = gts.float() + (~(gt_sorted)).float().cumsum(0) jaccard = 1. - intersection / union if p > 1: # cover 1-pixel case jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]