Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytlib/configuration/multobjectdet_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

classes = ['Car']
def get_loader():
source = KITTISource('/home/ray/Data/KITTI/training',max_frames=5000)
source = KITTISource('/home/ray/Data/KITTI/training',max_frames=10000)
return MultiObjectDetectionLoader(source,crop_size=[1024,320],obj_types=classes,max_objects=100)

# loader = (get_loader,dict())
Expand Down
12 changes: 11 additions & 1 deletion pytlib/data_loading/loaders/multiobject_detection_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from image.random_perturber import RandomPerturber
from image.image_utils import draw_objects_on_np_image
from image.affine_transforms import resize_image_center_crop,apply_affine_to_frame
import torch.nn.functional as F
import numpy as np
import random
import torch
from networks.mask_block import mask_function
from interface import implements
from visualization.image_visualizer import ImageVisualizer
from networks.multi_object_detector import MultiObjectDetector
Expand All @@ -35,8 +37,14 @@ def __convert_to_objects(self,boxes,classes):
def visualize(self,parameters={}):
image_original = PTImage.from_cwh_torch(self.data[0])
drawing_image = image_original.to_order_and_class(Ordering.HWC,ValueClass.BYTE0255).get_data().copy()
boxes,classes = self.output[1],self.output[2]
mask0 = self.output[3][0]
mask0 = mask_function(mask0)
# import ipdb;ipdb.set_trace()
mask_image = PTImage.from_cwh_torch(mask0).to_order_and_class(Ordering.CHW,ValueClass.BYTE0255)
mask_image.get_data()[mask_image.get_data()>0]=255
ImageVisualizer().set_image(mask_image,parameters.get('title','') + ' : Mask')

boxes,classes = self.output[1:]
# Nx4 boxes and N class tensor
valid_boxes, valid_classes = MultiObjectDetector.post_process_boxes(self.data[0],boxes,classes,len(self.class_lookup))
# convert targets
Expand Down Expand Up @@ -89,6 +97,8 @@ def next(self):
# 2) generate a random perturbation and perturb the frame
perturb_params = {'translation_range':[-0.1,0.1],
'scaling_range':[0.9,1.1]}
# perturb_params = {'translation_range':[0.0,0.0],
# 'scaling_range':[1.0,1.0]}
perturbed_frame = RandomPerturber.perturb_frame(frame,perturb_params)
crop_affine = resize_image_center_crop(perturbed_frame.image,self.crop_size)
output_size = [self.crop_size[1],self.crop_size[0]]
Expand Down
19 changes: 17 additions & 2 deletions pytlib/loss_functions/multi_object_detector_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from loss_functions.box_loss import box_loss
from utils.logger import Logger
from utils.batch_box_utils import rescale_boxes, euc_distance_cost, generate_region_meshgrid
from networks.mask_block import mask_function
import numpy as np

def preprocess_targets_and_preds(targets, box_preds, class_preds, original_image):
Expand Down Expand Up @@ -65,7 +66,8 @@ def assign_targets(box_preds, box_targets, dummy_target_masks=None):

def multi_object_detector_loss(original_image,
box_preds,
class_preds,
class_preds,
masks,
targets,
pos_to_neg_class_weight_ratio=0.25,
class_loss_weight=2.0,
Expand Down Expand Up @@ -114,4 +116,17 @@ def multi_object_detector_loss(original_image,
total_negative_targets = torch.sum(F.softmax(class_preds.flatten(start_dim=2).transpose(1,2),dim=2)[:,:,1]>0.5)
Logger().set('loss_component.total_negative_targets',total_negative_targets.item())
Logger().set('loss_component.total_positive_targets',total_positive_targets.item())
return total_loss

total_mask_loss = 0
count = 0
mask_loss_factor = 0.1
for i,mask in enumerate(masks):
applied_mask = mask_function(mask)
total_mask_loss+=torch.sum(torch.abs(applied_mask))
count+=applied_mask.numel()
non_zeros = torch.nonzero(applied_mask).size(0)
Logger().set('loss_component.count_nonzero_fraction_{}'.format(i),(non_zeros/float(mask.numel())))

Logger().set('loss_component.mask_loss',mask_loss_factor*(total_mask_loss/count).item())

return total_loss + mask_loss_factor*total_mask_loss/count
50 changes: 50 additions & 0 deletions pytlib/networks/mask_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.resnet import conv3x3

def relu_E(x, epislon=1e-2):
return F.relu(x - epislon)

def sigmoid_T(x, temp=100):
return torch.sigmoid(x*temp)

def mask_function(x):
return relu_E(sigmoid_T(x))

# extends a basic resnet block with an extra layer mask
class MaskConvBlock(nn.Module):
def __init__(self, inchans, outchans, stride=1, downsample=None):
super(MaskConvBlock, self).__init__()
self.conv1 = conv3x3(inchans, outchans+1, stride)
self.bn1 = nn.BatchNorm2d(outchans)
self.relu = nn.ReLU(inplace=True)
# output mask layer here
self.conv2 = conv3x3(outchans, outchans)
self.bn2 = nn.BatchNorm2d(outchans)
self.downsample = downsample
self.stride = stride

def forward(self, x, mask):
# sigmoid the mask
mask = mask_function(mask)

identity = x
out = self.conv1(mask*x)
# assume channel dim is 1
mask_channel = out.shape[1]-1
new_mask = out[:,mask_channel,:,:].unsqueeze(1)
new_out = out[:,0:mask_channel,:,:]

new_out = self.bn1(new_out)
new_out = self.relu(new_out)
new_out = self.conv2(new_out)

new_out = self.bn2(new_out)

if self.downsample is not None:
identity = self.downsample(x)

new_out += identity
new_out = self.relu(new_out)
return new_out, new_mask
59 changes: 59 additions & 0 deletions pytlib/networks/maskresnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch
from networks.mask_block import MaskConvBlock
from torchvision.models.resnet import BasicBlock,Bottleneck
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.nn import ModuleList

class MaskResnetCNN(nn.Module):
def __init__(self, layers=[3, 4, 23, 3], initchans=3):
self.initchans = initchans
self.inplanes = 64
super(MaskResnetCNN, self).__init__()
self.all_layers = ModuleList()
self.all_layers.append(self._make_layer(MaskConvBlock, 64, 2, stride=2, inplanes=initchans))
self.all_layers.append(self._make_layer(MaskConvBlock, 64, layers[0], stride=2))
self.all_layers.append(self._make_layer(MaskConvBlock, 128, layers[1], stride=2))
self.all_layers.append(self._make_layer(MaskConvBlock, 256, layers[2], stride=2))
self.all_layers.append(self._make_layer(MaskConvBlock, 512, layers[3], stride=2))

for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()

def _make_layer(self, block, planes, blocks, stride=1, inplanes=None):
downsample = None
if stride != 1:
downsample = nn.Sequential(
nn.Conv2d(inplanes or self.inplanes, planes,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes),
)

layers = ModuleList()
layers.append(block(inplanes or self.inplanes, planes, stride, downsample))
self.inplanes = planes
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))

return layers

def forward(self, x):
mask = torch.ones_like(x)
all_masks = []
for blocks in self.all_layers:
for block in blocks:
if block.__class__.__name__=='MaskConvBlock':
x,mask = block(x, mask)
all_masks.append(mask)
else:
x = block(x)
return x, all_masks[:-1] # dont return the last mask since its not actually used


15 changes: 9 additions & 6 deletions pytlib/networks/multi_object_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
import torch.nn.functional as F
from torch.nn import ModuleList
from torch.autograd import Variable
from networks.resnetcnn import ResNetCNN
from utils.batch_box_utils import rescale_boxes, generate_region_meshgrid
from networks.maskresnet import MaskResnetCNN
from utils.batch_box_utils import rescale_boxes, generate_region_meshgrid, batch_nms
import numpy as np

class MultiObjectDetector(nn.Module):
def __init__(self, nboxes_per_pixel=5, num_classes=2):
# num_classes to predict, includes background
super(MultiObjectDetector, self).__init__()
self.feature_map_generator = ResNetCNN()
self.feature_map_generator = MaskResnetCNN()
self.register_parameter('box_predictor_weights', None)
self.register_parameter('class_predictor_weights', None)
self.nboxes_per_pixel = nboxes_per_pixel
Expand Down Expand Up @@ -70,15 +70,18 @@ def post_process_boxes(cls, original_image, boxes, classes, num_classes):
# only select those that are non-background
valid_boxes = flatten_boxes[:,mask].transpose(0,1)
valid_classes = argmax_classes[mask]
return valid_boxes, valid_classes

nms_boxes, mask = batch_nms(valid_boxes)
nms_classes = valid_classes[mask]
return nms_boxes, nms_classes


def forward(self, x):
# CNN Compute, outputs BCHW order on cudnn
feature_maps = self.feature_map_generator.forward(x)
feature_maps,masks = self.feature_map_generator.forward(x)
if self.class_predictor_weights is None:
self.__init_weights(feature_maps)
boxes = self.__box_predictor(feature_maps)
classes = self.__class_predictor(feature_maps)
return x, boxes, classes
return x, boxes, classes, masks