-
Notifications
You must be signed in to change notification settings - Fork 20
Expand file tree
/
Copy pathwsddn.py
More file actions
72 lines (55 loc) · 1.85 KB
/
wsddn.py
File metadata and controls
72 lines (55 loc) · 1.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import numpy as np
import torch
import torchvision
import torch.utils.data as data
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F
from torchvision.ops import roi_pool, roi_align
class WSDDN(nn.Module):
n_classes = 20
classes = np.asarray([
'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat',
'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person',
'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
])
def __init__(self, classes=None):
super(WSDDN, self).__init__()
if classes is not None:
self.classes = classes
self.n_classes = len(classes)
print(classes)
# TODO (Q2.1): Define the WSDDN model
self.features = None
self.roi_pool = None
self.classifier = None
self.score_fc = None
self.bbox_fc = None
# loss
self.cross_entropy = None
@property
def loss(self):
return self.cross_entropy
def forward(self,
image,
rois=None,
gt_vec=None,
):
# TODO (Q2.1): Use image and rois as input
# compute cls_prob which are N_roi X 20 scores
cls_prob = None
if self.training:
label_vec = gt_vec.view(self.n_classes, -1)
self.cross_entropy = self.build_loss(cls_prob, label_vec)
return cls_prob
def build_loss(self, cls_prob, label_vec):
"""Computes the loss
:cls_prob: N_roix20 output scores
:label_vec: 1x20 one hot label vector
:returns: loss
"""
# TODO (Q2.1): Compute the appropriate loss using the cls_prob
# that is the output of forward()
# Checkout forward() to see how it is called
loss = None
return loss