diff --git a/pytlib/configuration/multobjectdet_config.py b/pytlib/configuration/multobjectdet_config.py index a040d36..444aa9c 100644 --- a/pytlib/configuration/multobjectdet_config.py +++ b/pytlib/configuration/multobjectdet_config.py @@ -11,7 +11,7 @@ classes = ['Car'] def get_loader(): - source = KITTISource('/home/ray/Data/KITTI/training',max_frames=5000) + source = KITTISource('/home/ray/Data/KITTI/training',max_frames=10000) return MultiObjectDetectionLoader(source,crop_size=[1024,320],obj_types=classes,max_objects=100) # loader = (get_loader,dict()) diff --git a/pytlib/data_loading/loaders/multiobject_detection_loader.py b/pytlib/data_loading/loaders/multiobject_detection_loader.py index 5e0f2e9..f2d12d5 100644 --- a/pytlib/data_loading/loaders/multiobject_detection_loader.py +++ b/pytlib/data_loading/loaders/multiobject_detection_loader.py @@ -10,9 +10,11 @@ from image.random_perturber import RandomPerturber from image.image_utils import draw_objects_on_np_image from image.affine_transforms import resize_image_center_crop,apply_affine_to_frame +import torch.nn.functional as F import numpy as np import random import torch +from networks.mask_block import mask_function from interface import implements from visualization.image_visualizer import ImageVisualizer from networks.multi_object_detector import MultiObjectDetector @@ -35,8 +37,14 @@ def __convert_to_objects(self,boxes,classes): def visualize(self,parameters={}): image_original = PTImage.from_cwh_torch(self.data[0]) drawing_image = image_original.to_order_and_class(Ordering.HWC,ValueClass.BYTE0255).get_data().copy() + boxes,classes = self.output[1],self.output[2] + mask0 = self.output[3][0] + mask0 = mask_function(mask0) + # import ipdb;ipdb.set_trace() + mask_image = PTImage.from_cwh_torch(mask0).to_order_and_class(Ordering.CHW,ValueClass.BYTE0255) + mask_image.get_data()[mask_image.get_data()>0]=255 + ImageVisualizer().set_image(mask_image,parameters.get('title','') + ' : Mask') - boxes,classes = self.output[1:] # Nx4 boxes and N class tensor valid_boxes, valid_classes = MultiObjectDetector.post_process_boxes(self.data[0],boxes,classes,len(self.class_lookup)) # convert targets @@ -89,6 +97,8 @@ def next(self): # 2) generate a random perturbation and perturb the frame perturb_params = {'translation_range':[-0.1,0.1], 'scaling_range':[0.9,1.1]} + # perturb_params = {'translation_range':[0.0,0.0], + # 'scaling_range':[1.0,1.0]} perturbed_frame = RandomPerturber.perturb_frame(frame,perturb_params) crop_affine = resize_image_center_crop(perturbed_frame.image,self.crop_size) output_size = [self.crop_size[1],self.crop_size[0]] diff --git a/pytlib/loss_functions/multi_object_detector_loss.py b/pytlib/loss_functions/multi_object_detector_loss.py index 6dec031..e39eb5d 100644 --- a/pytlib/loss_functions/multi_object_detector_loss.py +++ b/pytlib/loss_functions/multi_object_detector_loss.py @@ -4,6 +4,7 @@ from loss_functions.box_loss import box_loss from utils.logger import Logger from utils.batch_box_utils import rescale_boxes, euc_distance_cost, generate_region_meshgrid +from networks.mask_block import mask_function import numpy as np def preprocess_targets_and_preds(targets, box_preds, class_preds, original_image): @@ -65,7 +66,8 @@ def assign_targets(box_preds, box_targets, dummy_target_masks=None): def multi_object_detector_loss(original_image, box_preds, - class_preds, + class_preds, + masks, targets, pos_to_neg_class_weight_ratio=0.25, class_loss_weight=2.0, @@ -114,4 +116,17 @@ def multi_object_detector_loss(original_image, total_negative_targets = torch.sum(F.softmax(class_preds.flatten(start_dim=2).transpose(1,2),dim=2)[:,:,1]>0.5) Logger().set('loss_component.total_negative_targets',total_negative_targets.item()) Logger().set('loss_component.total_positive_targets',total_positive_targets.item()) - return total_loss \ No newline at end of file + + total_mask_loss = 0 + count = 0 + mask_loss_factor = 0.1 + for i,mask in enumerate(masks): + applied_mask = mask_function(mask) + total_mask_loss+=torch.sum(torch.abs(applied_mask)) + count+=applied_mask.numel() + non_zeros = torch.nonzero(applied_mask).size(0) + Logger().set('loss_component.count_nonzero_fraction_{}'.format(i),(non_zeros/float(mask.numel()))) + + Logger().set('loss_component.mask_loss',mask_loss_factor*(total_mask_loss/count).item()) + + return total_loss + mask_loss_factor*total_mask_loss/count \ No newline at end of file diff --git a/pytlib/networks/mask_block.py b/pytlib/networks/mask_block.py new file mode 100644 index 0000000..e9f1e38 --- /dev/null +++ b/pytlib/networks/mask_block.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.models.resnet import conv3x3 + +def relu_E(x, epislon=1e-2): + return F.relu(x - epislon) + +def sigmoid_T(x, temp=100): + return torch.sigmoid(x*temp) + +def mask_function(x): + return relu_E(sigmoid_T(x)) + +# extends a basic resnet block with an extra layer mask +class MaskConvBlock(nn.Module): + def __init__(self, inchans, outchans, stride=1, downsample=None): + super(MaskConvBlock, self).__init__() + self.conv1 = conv3x3(inchans, outchans+1, stride) + self.bn1 = nn.BatchNorm2d(outchans) + self.relu = nn.ReLU(inplace=True) + # output mask layer here + self.conv2 = conv3x3(outchans, outchans) + self.bn2 = nn.BatchNorm2d(outchans) + self.downsample = downsample + self.stride = stride + + def forward(self, x, mask): + # sigmoid the mask + mask = mask_function(mask) + + identity = x + out = self.conv1(mask*x) + # assume channel dim is 1 + mask_channel = out.shape[1]-1 + new_mask = out[:,mask_channel,:,:].unsqueeze(1) + new_out = out[:,0:mask_channel,:,:] + + new_out = self.bn1(new_out) + new_out = self.relu(new_out) + new_out = self.conv2(new_out) + + new_out = self.bn2(new_out) + + if self.downsample is not None: + identity = self.downsample(x) + + new_out += identity + new_out = self.relu(new_out) + return new_out, new_mask diff --git a/pytlib/networks/maskresnet.py b/pytlib/networks/maskresnet.py new file mode 100644 index 0000000..d5a6a9d --- /dev/null +++ b/pytlib/networks/maskresnet.py @@ -0,0 +1,59 @@ +import torch +from networks.mask_block import MaskConvBlock +from torchvision.models.resnet import BasicBlock,Bottleneck +from torch.autograd import Variable +import torch.nn as nn +import torch.nn.functional as F +import math +from torch.nn import ModuleList + +class MaskResnetCNN(nn.Module): + def __init__(self, layers=[3, 4, 23, 3], initchans=3): + self.initchans = initchans + self.inplanes = 64 + super(MaskResnetCNN, self).__init__() + self.all_layers = ModuleList() + self.all_layers.append(self._make_layer(MaskConvBlock, 64, 2, stride=2, inplanes=initchans)) + self.all_layers.append(self._make_layer(MaskConvBlock, 64, layers[0], stride=2)) + self.all_layers.append(self._make_layer(MaskConvBlock, 128, layers[1], stride=2)) + self.all_layers.append(self._make_layer(MaskConvBlock, 256, layers[2], stride=2)) + self.all_layers.append(self._make_layer(MaskConvBlock, 512, layers[3], stride=2)) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1, inplanes=None): + downsample = None + if stride != 1: + downsample = nn.Sequential( + nn.Conv2d(inplanes or self.inplanes, planes, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes), + ) + + layers = ModuleList() + layers.append(block(inplanes or self.inplanes, planes, stride, downsample)) + self.inplanes = planes + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return layers + + def forward(self, x): + mask = torch.ones_like(x) + all_masks = [] + for blocks in self.all_layers: + for block in blocks: + if block.__class__.__name__=='MaskConvBlock': + x,mask = block(x, mask) + all_masks.append(mask) + else: + x = block(x) + return x, all_masks[:-1] # dont return the last mask since its not actually used + + diff --git a/pytlib/networks/multi_object_detector.py b/pytlib/networks/multi_object_detector.py index bf4b275..96813f0 100644 --- a/pytlib/networks/multi_object_detector.py +++ b/pytlib/networks/multi_object_detector.py @@ -5,15 +5,15 @@ import torch.nn.functional as F from torch.nn import ModuleList from torch.autograd import Variable -from networks.resnetcnn import ResNetCNN -from utils.batch_box_utils import rescale_boxes, generate_region_meshgrid +from networks.maskresnet import MaskResnetCNN +from utils.batch_box_utils import rescale_boxes, generate_region_meshgrid, batch_nms import numpy as np class MultiObjectDetector(nn.Module): def __init__(self, nboxes_per_pixel=5, num_classes=2): # num_classes to predict, includes background super(MultiObjectDetector, self).__init__() - self.feature_map_generator = ResNetCNN() + self.feature_map_generator = MaskResnetCNN() self.register_parameter('box_predictor_weights', None) self.register_parameter('class_predictor_weights', None) self.nboxes_per_pixel = nboxes_per_pixel @@ -70,15 +70,18 @@ def post_process_boxes(cls, original_image, boxes, classes, num_classes): # only select those that are non-background valid_boxes = flatten_boxes[:,mask].transpose(0,1) valid_classes = argmax_classes[mask] - return valid_boxes, valid_classes + + nms_boxes, mask = batch_nms(valid_boxes) + nms_classes = valid_classes[mask] + return nms_boxes, nms_classes def forward(self, x): # CNN Compute, outputs BCHW order on cudnn - feature_maps = self.feature_map_generator.forward(x) + feature_maps,masks = self.feature_map_generator.forward(x) if self.class_predictor_weights is None: self.__init_weights(feature_maps) boxes = self.__box_predictor(feature_maps) classes = self.__class_predictor(feature_maps) - return x, boxes, classes + return x, boxes, classes, masks