-
Notifications
You must be signed in to change notification settings - Fork 4
Open
Description
Thank for your code sharing. I have some problems when running the experiment.
I use the resnet-20 with the expansion=6, which is same as paper SSL. However, I can only achieve a 89% accuracy with SSL shift. The accuracy of original resnet-20 with the expansion=6 is 91%. I don't know what go wrong. Can anyone help me?
class SSLBlock(nn.Module):
def __init__(self, inplanes, planes, stride=1, expansion=1, downsample=None, norm_layer=None):
super(SSLBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self.conv1 = conv1x1(inplanes, int(planes * expansion), bias=False)
self.bn1 = nn.BatchNorm2d(int(planes * expansion))
self.conv2 = conv1x1(int(planes * expansion), planes, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
if inplanes == planes and self.stride == 1:
self.identity = True
else:
self.identity = False
self.shift = Shift2d(in_channels=int(planes * expansion), init_shift=3, sparsity_term=0., active_flag=False)
if self.stride > 1:
# self.avg = nn.AvgPool2d(2)
self.avg = nn.AvgPool2d(kernel_size=2,stride=stride,padding=1)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out, _ = self.shift(out) #the loss should be added to total loss!!! it influences the sparsity but not gradient update!
if self.stride > 1:
out = self.avg(out) #returns output and loss
out = self.conv2(out)
out = self.bn2(out)
if self.identity:
out += residual
# out = self.relu(out)
return out
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels