Skip to content

How to change number of classes for already trained retinanet model #28

@Monk5088

Description

@Monk5088

I have trained my retinanet mode from object detection library, now i want to change the number of classes for the next dataset.
I have found a way in pytorch to change the classification head of retinanet to do the same, can anyone help me on how can i perform the same for retinanet.py from object-detection-fastai.

from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights
        from torchvision.models.detection.retinanet import RetinaNetHead, RetinaNetClassificationHead
        weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
        model = retinanet_resnet50_fpn_v2(weights=weights, box_score_thresh=0.7)

        # replace classification layer
        out_channels = model.head.classification_head.conv[0].out_channels
        num_anchors = model.head.classification_head.num_anchors
        model.head.classification_head.num_classes = num_classes

        cls_logits = torch.nn.Conv2d(out_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
        torch.nn.init.normal_(cls_logits.weight, std=0.01)  # as per pytorch code
        torch.nn.init.constant_(cls_logits.bias, -math.log((1 - 0.01) / 0.01))  # as per pytorcch code
        # assign cls head to model
        model.head.classification_head.cls_logits = cls_logits

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions