forked from Mikoto10032/AutomaticWeightedLoss
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathAutomaticWeightedLoss.py
More file actions
31 lines (26 loc) · 842 Bytes
/
AutomaticWeightedLoss.py
File metadata and controls
31 lines (26 loc) · 842 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
class AutomaticWeightedLoss(nn.Module):
"""automatically weighted multi-task loss
Params:
num: int,the number of loss
x: multi-task loss
Examples:
loss1=1
loss2=2
awl = AutomaticWeightedLoss(2)
loss_sum = awl(loss1, loss2)
"""
def __init__(self, num=2):
super(AutomaticWeightedLoss, self).__init__()
params = torch.ones(num, requires_grad=True)
self.params = torch.nn.Parameter(params)
def forward(self, *x):
loss_sum = 0
for i, loss in enumerate(x):
loss_sum += 0.5 / (self.params[i] ** 2) * loss + torch.log(1 + self.params[i] ** 2)
return loss_sum
if __name__ == '__main__':
awl = AutomaticWeightedLoss(2)
print(awl.parameters())