forked from abhijitmishra/Thought2Text
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathloss.py
More file actions
42 lines (31 loc) · 1.37 KB
/
loss.py
File metadata and controls
42 lines (31 loc) · 1.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torch.nn as nn
import torch.nn.functional as F
class ContrastiveLoss(nn.Module):
def __init__(self, margin=1.0):
super(ContrastiveLoss, self).__init__()
self.margin = margin
def forward(self, E1, E2, labels):
# Compute the Euclidean distance between the embeddings
euclidean_distance = F.pairwise_distance(E1, E2)
# Create a binary label matrix indicating if pairs are from the same class
# Since we only have one set of labels, compare each pair of embeddings with itself
label = (labels.unsqueeze(1) == labels.unsqueeze(0)).float()
# We want similar labels to be labeled as 0 and dissimilar labels labelled as 1
label = 1 - label
# Compute the contrastive loss
loss_contrastive = torch.mean(
(1 - label) * torch.pow(euclidean_distance, 2)
+ label
* torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)
)
return loss_contrastive
class MSELoss(nn.Module):
def __init__(self):
super(MSELoss, self).__init__()
def forward(self, E1, E2, labels=None):
# Ensure that both embeddings have the same shape
assert E1.shape == E2.shape, f"Embeddings must have the same shape "
# Compute the mean squared error loss
loss = nn.functional.mse_loss(E1, E2)
return loss