Skip to content

Resnet-20 doesn't achieve the expected effect on the cifar10 #5

@xiexiaona

Description

@xiexiaona

Thank for your code sharing. I have some problems when running the experiment.
I use the resnet-20 with the expansion=6, which is same as paper SSL. However, I can only achieve a 89% accuracy with SSL shift. The accuracy of original resnet-20 with the expansion=6 is 91%. I don't know what go wrong. Can anyone help me?

class SSLBlock(nn.Module):
    def __init__(self, inplanes, planes, stride=1, expansion=1, downsample=None, norm_layer=None):
        super(SSLBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        
        self.conv1 = conv1x1(inplanes, int(planes * expansion), bias=False) 
        self.bn1 = nn.BatchNorm2d(int(planes * expansion))
        
        self.conv2 = conv1x1(int(planes * expansion), planes, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

        if inplanes == planes and self.stride == 1:
            self.identity = True
        else:
            self.identity = False

        self.shift = Shift2d(in_channels=int(planes * expansion), init_shift=3, sparsity_term=0., active_flag=False)
        if self.stride > 1:
            # self.avg = nn.AvgPool2d(2) 
            self.avg = nn.AvgPool2d(kernel_size=2,stride=stride,padding=1)


    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out, _ = self.shift(out) #the loss should be added to total loss!!! it influences the sparsity but not gradient update!
        if self.stride > 1:
            out = self.avg(out) #returns output and loss            
        
        out = self.conv2(out)
        out = self.bn2(out)

        if self.identity:
            out += residual

        # out = self.relu(out)
        return out

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions