Skip to content

Loss Function

Alex Zhuang edited this page Jun 27, 2022 · 18 revisions

Exploring our loss function

The original focal loss function in losses.py was unchanged from ACAR's original training on the AVA dataset, which meant that it was putting the first 13 action class indices through softmax activation and using sigmoid activation for the rest before running it through BCE. The working theory as to why this still worked out ok for us was that the highest yield mAP classes like Amber, Green, Red, MovAway, MovTow, etc. have class indices within the first 13 (coincidentally), and were more or less mutually exclusive among themselves. The source of truth for our base validation set performance is evaluated on the best checkpoint from July 2021, and can be viewed here

The experiments I did can be categorized as follows:

  • Testing whether full softmax activation over all classes still would yield ok results

    • results from this pretty much confirmed part of the working theory above, that softmax would do decenlty on its own even if it was wrong. Overall does not really contribute to our final loss function design.
  • Testing whether full sigmoid activation over all classes would yield better results

    • sigmoid BCE gave superior results the original loss, with the best checkpoint adding about 1/2 a percentage point to val mAP @0.5IOU. This run can be viewed here
  • Weighted sigmoid BCE loss from the Argus++ paper

    • Using just the weight formulas from the Argus++ paper, the loss function was extremely unstable, since the weight factors had range of 10^4 or 10^5.
    • I tried a couple tricks to normalize these weights, to no avail. One thing I haven't yet tried is to transform the weights in a non-linear way perhaps by using a log function. Maybe this could stabilize the training using this loss.
    • These runs are named weighted_bce_argus on wandb.
    • I've sent an email to the argus++ authors for clarification about their implementation details
  • Various iterations of focal loss

    • At first I tried focal loss with many different implementations (i.e. from torchvision, kornia, etc.) but all of them yielded worse results. I tried fiddling with activation functions in those implementations, and got significantly worse results than our original loss.
    • Then I looked at the 3D-retinanet implementation of focal loss. One interesting difference in their implementation is that they divide their per annotation loss by the number of positive labels in the ground truth. I tried this and it yielded very good results, improving by 3 full mAP percentage points from the original loss.
    • However, when correcting my implementation to be in line with the one in 3D-Retinanet, I realized that I was calculating my (per batch) scalar loss value in a different way from previous attempts. In the past, I was averaging over every value in the (nbatch, nclasses) loss, but now I was summing along the 2nd dimension first and then later averaging those sums over the batch. I wonder if this by itself will yield similar results to the above without the division step. That run is here, in progress.

Summary of per-class AP scores between the best results we have so far and the original are below.

Focal Loss

{ 'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Amber': 0.9329179340645092,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Brake': 0.160397576509561,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Green': 0.9577322060554998,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/HazLit': 0.42975307114682015,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/IncatLft': 0.3110217925092924,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/IncatRht': 0.07301602875680031,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Mov': 0.5773879766908123,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/MovAway': 0.7297124687411563,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/MovLft': 0.004658469257813191,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/MovRht': 0.003383006781602239,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/MovTow': 0.7723651745411098,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Ovtak': 0.012287952821300688,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/PushObj': 0.0,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Red': 0.9780924818827051,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Rev': 0.010312371512586824,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Stop': 0.8843413662541961,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/TurLft': 0.17978844847026057,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/TurRht': 0.16348627821358852,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Wait2X': 0.3247158732241647,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Xing': 0.10339383048358453,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/XingFmLft': 0.3546334355739871,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/XingFmRht': 0.3328907435184761,
  'PascalBoxes_Precision/mAP@0.5IOU': 0.3771040221368103}
{ 'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Amber': 0.7942481416306967,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Brake': 0.24639955720400736,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Green': 0.9506546852918074,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/HazLit': 0.41751796475143477,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/IncatLft': 0.06207431062104334,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/IncatRht': 0.41807081258402506,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Mov': 0.5634711371959854,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/MovAway': 0.7826908455607144,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/MovLft': 0.004726756314133784,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/MovRht': 0.002712461208870063,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/MovTow': 0.8215284990833897,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Ovtak': 0.013405004494272876,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/PushObj': 0.0,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Red': 0.9382332059083229,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Rev': 0.007752995854135711,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Stop': 0.8864995725967224,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/TurLft': 0.16998841573419826,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/TurRht': 0.19452584516811547,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Wait2X': 0.5265274542624926,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Xing': 0.1875141620128278,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/XingFmLft': 0.2588773825321833,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/XingFmRht': 0.325866818485017,
  'PascalBoxes_Precision/mAP@0.5IOU': 0.389694819477018}

{ 'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Amber': 0.8309782270702902,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Brake': 0.27649611843066846,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Green': 0.9693532638135449,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/HazLit': 0.4638499590312983,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/IncatLft': 0.09518445049536675,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/IncatRht': 0.08144916949245797,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Mov': 0.5394485341027369,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/MovAway': 0.8053834910728084,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/MovLft': 0.0053379821699517166,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/MovRht': 0.008163022094552693,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/MovTow': 0.8544985703661481,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Ovtak': 0.04135926309813064,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/PushObj': 0.0,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Red': 0.9790149529192005,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Rev': 0.006601620823659995,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Stop': 0.9063993433799805,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/TurLft': 0.1673016689757366,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/TurRht': 0.20484707920301856,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Wait2X': 0.6693743086526533,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Xing': 0.24070367000382953,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/XingFmLft': 0.42607584647759467,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/XingFmRht': 0.36939576431873233,
  'PascalBoxes_Precision/mAP@0.5IOU': 0.40641892299965277}

Original Loss

'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Amber': 0.8679948575380332,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Brake': 0.26143032368936314,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Green': 0.8865218399146948,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/HazLit': 0.292303491926961,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/IncatLft': 0.062484558500333215,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/IncatRht': 0.08768266104513127,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Mov': 0.4808283380743277,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/MovAway': 0.6447430827566635,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/MovLft': 0.00478112958388003,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/MovRht': 0.003082992373510508,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/MovTow': 0.6799842185893896,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Ovtak': 0.0054205047482294415,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/PushObj': 0.0,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Red': 0.9588671526563031,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Rev': 0.014933138413770771,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Stop': 0.7723645249181748,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/TurLft': 0.07890639981645235,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/TurRht': 0.15995179247571847,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Wait2X': 0.40660507579460375,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/Xing': 0.2859700155601819,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/XingFmLft': 0.3809936317540863,
  'PascalBoxes_PerformanceByCategory/AP@0.5IOU/XingFmRht': 0.29350611407319116,
  'PascalBoxes_Precision/mAP@0.5IOU': 0.3467889020092273

weighting frequencies calculated using formulas in Argus++ paper

these are in the order of the official road dataset class indices

freq_weight = torch.tensor([1.4746e-02, 1.0005e-01, 3.0093e-02, 4.5750e-03, 5.2246e-03, 7.4671e-02,
        1.4900e+01, 3.2358e-02, 7.1825e-03, 1.8372e-01, 1.0777e-01, 2.7154e-01,
        1.4700e-01, 9.9196e-02, 2.2968e+00, 2.5265e+00, 7.7790e-01, 5.9937e-02,
        3.7043e-02, 5.3740e-02, 1.2106e-01, 1.4915e-01])


pn_weight = torch.tensor([9.5713e+00, 7.0724e+01, 2.0573e+01, 2.2797e+00, 2.7454e+00, 5.2530e+01,
        1.0680e+04, 2.2197e+01, 4.1490e+00, 1.3070e+02, 7.6258e+01, 1.9366e+02,
        1.0438e+02, 7.0112e+01, 1.6455e+03, 1.8102e+03, 5.5666e+02, 4.1968e+01,
        2.5555e+01, 3.7525e+01, 8.5786e+01, 1.0592e+02])

Clone this wiki locally