diff --git a/projects/YOLOX_opt_elan/configs/cityscapes/YOLOX_opt-S-DynamicRecognition/yolox-s-opt-elan-semseg_960x960_300e_cityscapes.py b/projects/YOLOX_opt_elan/configs/cityscapes/YOLOX_opt-S-DynamicRecognition/yolox-s-opt-elan-semseg_960x960_300e_cityscapes.py new file mode 100644 index 000000000..05eac75af --- /dev/null +++ b/projects/YOLOX_opt_elan/configs/cityscapes/YOLOX_opt-S-DynamicRecognition/yolox-s-opt-elan-semseg_960x960_300e_cityscapes.py @@ -0,0 +1,288 @@ +_base_ = [ + "../../../../../autoware_ml/configs/detection2d/default_runtime.py", + "../../../../../autoware_ml/configs/detection2d/schedules/schedule_1x.py", +] + +custom_imports = dict( + imports=[ + "projects.YOLOX_opt_elan.yolox", + "autoware_ml.detection2d.metrics", + "autoware_ml.detection2d.datasets", + "projects.YOLOX_opt_elan.yolox.models", + "projects.YOLOX_opt_elan.yolox.models.yolox_multitask", + "projects.YOLOX_opt_elan.yolox.transforms", + "mmseg.evaluation.metrics", # 引入分割评估指标 + ], + allow_failed_imports=False, +) + +# parameter settings +# IMG_SCALE = (960, 960) +IMG_SCALE = (1024, 512) +max_epochs = 300 +num_last_epochs = 15 +resume_from = None +interval = 1 +batch_size = 16 +activation = "ReLU6" +num_workers = 4 +base_lr = 0.001 + + +classes = ("person", "rider", "car", "truck", "bus", "train", "motorcycle", "bicycle") +palette = [(220, 20, 60), (255, 0, 0), (0, 0, 142), (0, 0, 70), (0, 60, 100), (0, 80, 100), (0, 0, 230), (119, 11, 32)] + +seg_classes = ( + "road", + "sidewalk", + "building", + "wall", + "fence", + "pole", + "traffic light", + "traffic sign", + "vegetation", + "terrain", + "sky", + "person", + "rider", + "car", + "truck", + "bus", + "train", + "motorcycle", + "bicycle", +) +seg_palette = [ + (128, 64, 128), + (244, 35, 232), + (70, 70, 70), + (102, 102, 156), + (190, 153, 153), + (153, 153, 153), + (250, 170, 30), + (220, 220, 0), + (107, 142, 35), + (152, 251, 152), + (70, 130, 180), + (220, 20, 60), + (255, 0, 0), + (0, 0, 142), + (0, 0, 70), + (0, 60, 100), + (0, 80, 100), + (0, 0, 230), + (119, 11, 32), +] + +# metainfo = dict(classes=classes, palette=palette) +metainfo = dict(classes=seg_classes, palette=seg_palette) + +model = dict( + type="YOLOXMultiTask", + data_preprocessor=dict( + type="DetDataPreprocessor", + pad_size_divisor=32, + # batch_augments=[ + # dict( + # type="BatchSyncRandomResize", + # random_size_range=(480, 800), + # size_divisor=32, + # interval=10, + # ) + # ], + ), + backbone=dict( + type="ELANDarknet", + deepen_factor=2, + widen_factor=1, + out_indices=(2, 3, 4), + act_cfg=dict(type=activation), + ), + neck=dict( + type="YOLOXPAFPN_ELAN", + in_channels=[128, 256, 512], + out_channels=128, + num_elan_blocks=2, + act_cfg=dict(type=activation), + ), + bbox_head=dict( + type="YOLOXHead", + num_classes=8, + in_channels=128, + feat_channels=128, + act_cfg=dict(type=activation), + loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=True, loss_weight=0.0), + loss_bbox=dict(type="IoULoss", loss_weight=0.0), + loss_obj=dict(type="CrossEntropyLoss", use_sigmoid=True, loss_weight=0.0), + loss_l1=dict(type="L1Loss", loss_weight=0.0), + ), + mask_head=dict( + type="YOLOXSegHead", + in_channels=[128, 128, 128], + feat_channels=128, + num_classes=19, + act_cfg=dict(type=activation), + loss=dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0), + ), + train_cfg=dict(assigner=dict(type="SimOTAAssigner", center_radius=2.5)), + test_cfg=dict(score_thr=0.01, nms=dict(type="nms", iou_threshold=0.65)), +) + +dataset_type = "CocoDataset" +data_root = "data/cityscapes/" +backend_args = None + +train_pipeline = [ + # dict(type="Mosaic", img_scale=IMG_SCALE, pad_val=114.0), + # dict(type="MixUp", img_scale=IMG_SCALE, ratio_range=(0.8, 1.6), pad_val=114.0), + dict(type="YOLOXHSVRandomAug"), + dict(type="RandomFlip", prob=0.5), + dict(type="Resize", scale=IMG_SCALE, keep_ratio=False), + dict( + type="Pad", + pad_to_square=False, + size_divisor=32, + pad_val=dict(img=(114.0, 114.0, 114.0), seg=255), + ), + dict(type="FilterAnnotations", min_gt_bbox_wh=(1, 1), keep_empty=False), + dict(type="PackDetInputs"), +] + +test_pipeline = [ + dict(type="LoadImageFromFile", backend_args=backend_args), + dict(type="FixCityscapesPath", data_root=data_root, split="val"), + dict(type="LoadAnnotations", with_bbox=True, with_seg=True), + dict(type="Resize", scale=IMG_SCALE, keep_ratio=False), + dict(type="Pad", pad_to_square=False, size_divisor=32, pad_val=dict(img=(114.0, 114.0, 114.0), seg=255)), + dict( + type="PackDetInputs", + meta_keys=("img_id", "img_path", "ori_shape", "img_shape", "scale_factor"), + ), +] + +train_dataset = dict( + type="MultiImageMixDataset", + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict(img="leftImg8bit/train", seg_map_path="gtFine/train"), + ann_file="annotations/instancesonly_filtered_gtFine_train.json", + pipeline=[ + dict(type="LoadImageFromFile", backend_args=backend_args), + dict(type="FixCityscapesPath", data_root=data_root, split="train"), + dict(type="LoadAnnotations", with_bbox=True, with_seg=True), + ], + filter_cfg=dict(filter_empty_gt=False, min_size=8), + backend_args=backend_args, + metainfo=metainfo, + ), + pipeline=train_pipeline, +) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=num_workers, + persistent_workers=True, + sampler=dict(type="DefaultSampler", shuffle=True), + dataset=train_dataset, +) + +val_dataloader = dict( + batch_size=batch_size, + num_workers=num_workers, + persistent_workers=True, + drop_last=False, + sampler=dict(type="DefaultSampler", shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict(img="leftImg8bit/val", seg_map_path="gtFine/val"), + ann_file="annotations/instancesonly_filtered_gtFine_val.json", + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args, + metainfo=metainfo, + ), +) +test_dataloader = val_dataloader + +val_evaluator = [ + dict(type="mmseg.IoUMetric", ignore_index=255, iou_metrics=["mIoU"], prefix="seg", classes=seg_classes) +] +test_evaluator = val_evaluator + +train_cfg = dict(max_epochs=max_epochs, val_interval=interval) + +optimizer = dict( + type="OptimWrapper", + optimizer=dict(type="SGD", lr=base_lr, momentum=0.9, weight_decay=5e-4, nesterov=True), + paramwise_cfg=dict(norm_decay_mult=0.0, bias_decay_mult=0.0), +) + +if max_epochs > 5: + param_scheduler = [ + dict( + type="mmdet.QuadraticWarmupLR", + by_epoch=True, + begin=0, + end=5, + convert_to_iter_based=True, + ), + dict( + type="CosineAnnealingLR", + eta_min=base_lr * 0.05, + begin=5, + T_max=max_epochs - num_last_epochs, + end=max_epochs - num_last_epochs, + by_epoch=True, + convert_to_iter_based=True, + ), + dict( + type="ConstantLR", + by_epoch=True, + factor=1, + begin=max_epochs - num_last_epochs, + end=max_epochs, + ), + ] +else: + param_scheduler = [] + +log_config = dict( + interval=1, + hooks=[dict(type="TextLoggerHook"), dict(type="TensorboardLoggerHook")], +) + +default_hooks = dict( + checkpoint=dict(interval=interval, max_keep_ckpts=3, save_best="seg/mIoU", rule="greater"), + visualization=dict( + type="DetVisualizationHook", draw=False, interval=50, show=False, wait_time=2, test_out_dir="vis_data" + ), +) + +custom_hooks = [ + dict(type="YOLOXModeSwitchHook", num_last_epochs=num_last_epochs, priority=48), + dict(type="SyncNormHook", priority=48), + dict( + type="EMAHook", + ema_type="ExpMomentumEMA", + momentum=0.0001, + update_buffers=True, + priority=4, + ), +] + +auto_scale_lr = dict(base_batch_size=batch_size) + +vis_backends = [ + dict(type="LocalVisBackend"), + dict(type="TensorboardVisBackend"), +] + +visualizer = dict( + type="DetLocalVisualizer", + vis_backends=[dict(type="LocalVisBackend"), dict(type="TensorboardVisBackend")], + name="visualizer", + alpha=0.5, +) diff --git a/projects/YOLOX_opt_elan/configs/t4dataset/YOLOX_opt-S-DynamicRecognition/yolox-s-opt-elan-semseg_960x960_300e_t4dataset.py b/projects/YOLOX_opt_elan/configs/t4dataset/YOLOX_opt-S-DynamicRecognition/yolox-s-opt-elan-semseg_960x960_300e_t4dataset.py new file mode 100644 index 000000000..dc78b0080 --- /dev/null +++ b/projects/YOLOX_opt_elan/configs/t4dataset/YOLOX_opt-S-DynamicRecognition/yolox-s-opt-elan-semseg_960x960_300e_t4dataset.py @@ -0,0 +1,354 @@ +_base_ = [ + "../../../../../autoware_ml/configs/detection2d/default_runtime.py", + "../../../../../autoware_ml/configs/detection2d/schedules/schedule_1x.py", + "../../../../../autoware_ml/configs/detection2d/dataset/t4dataset/comlops.py", +] + +custom_imports = dict( + imports=[ + "projects.YOLOX_opt_elan.yolox", + "autoware_ml.detection2d.metrics", + "autoware_ml.detection2d.datasets", + "projects.YOLOX_opt_elan.yolox.models", + "projects.YOLOX_opt_elan.yolox.models.yolox_multitask", + "projects.YOLOX_opt_elan.yolox.transforms", + ], + allow_failed_imports=False, +) + +IMG_SCALE = (960, 960) + +# parameter settings +img_scale = (960, 960) +max_epochs = 300 +num_last_epochs = 15 +resume_from = None +interval = 1 +batch_size = 12 +activation = "ReLU6" +num_workers = 4 + +base_lr = 0.001 + +# model settings +model = dict( + type="YOLOXMultiTask", + data_preprocessor=dict( + type="DetDataPreprocessor", + pad_size_divisor=32, + batch_augments=[ + dict( + type="BatchSyncRandomResize", + random_size_range=(480, 800), + size_divisor=32, + interval=10, + ) + ], + ), + backbone=dict( + type="ELANDarknet", + deepen_factor=2, + widen_factor=1, + out_indices=(2, 3, 4), + act_cfg=dict(type=activation), + ), + neck=dict( + type="YOLOXPAFPN_ELAN", + in_channels=[128, 256, 512], + out_channels=128, + num_elan_blocks=2, + act_cfg=dict(type=activation), + ), + bbox_head=dict( + type="YOLOXHead", + num_classes=40, + in_channels=128, + feat_channels=128, + act_cfg=dict(type=activation), + loss_cls=dict(type="CrossEntropyLoss", use_sigmoid=True, loss_weight=0.0), + loss_bbox=dict(type="IoULoss", loss_weight=0.0), + loss_obj=dict(type="CrossEntropyLoss", use_sigmoid=True, loss_weight=0.0), + loss_l1=dict(type="L1Loss", loss_weight=0.0), + ), + mask_head=dict( + type="YOLOXSegHead", + in_channels=[128, 128, 128], + feat_channels=128, + num_classes=40, + act_cfg=dict(type=activation), + loss=dict(type="CrossEntropyLoss", use_sigmoid=False, loss_weight=1.0), + ), + train_cfg=dict(assigner=dict(type="SimOTAAssigner", center_radius=2.5)), + test_cfg=dict(score_thr=0.01, nms=dict(type="nms", iou_threshold=0.65)), +) + +data_root = "" +anno_file_root = "./data/comlops/semseg/" +dataset_type = "T4Dataset" + +backend_args = None + +# pipeline +train_pipeline = [ + # dict(type="Mosaic", img_scale=IMG_SCALE, pad_val=114.0), + # dict( + # type="RandomAffine", + # scaling_ratio_range=(0.1, 2), + # border=(-IMG_SCALE[0] // 2, -IMG_SCALE[1] // 2), + # ), + # dict(type="MixUp", img_scale=IMG_SCALE, ratio_range=(0.8, 1.6), pad_val=114.0), + dict(type="LoadAnnotations", with_bbox=True, with_seg=True), + dict(type="YOLOXHSVRandomAug"), + dict(type="RandomFlip", prob=0.5), + dict(type="Resize", scale=IMG_SCALE, keep_ratio=False), + dict( + type="Pad", + pad_to_square=True, + pad_val=dict(img=(114.0, 114.0, 114.0), seg=255), + ), + dict(type="FilterAnnotations", min_gt_bbox_wh=(1, 1), keep_empty=False), + dict(type="PackDetInputs"), +] + +classes = ( + "animal", + "bicycle", + "building", + "bus", + "car", + "cone", + "construction", + "crosswalk", + "dashed_lane_marking", + "deceleration_line", + "gate", + "guide_post", + "laneline_dash_white", + "laneline_dash_yellow", + "laneline_solid_green", + "laneline_solid_red", + "laneline_solid_white", + "laneline_solid_yellow", + "marking_arrow", + "marking_character", + "marking_other", + "motorcycle", + "other_obstacle", + "other_pedestrian", + "other_vehicle", + "parking_lot", + "pedestrian", + "pole", + "road", + "road_debris", + "sidewalk", + "sky", + "stopline", + "striped_road_marking", + "traffic_light", + "traffic_sign", + "train", + "truck", + # "unknown", + "vegetation/terrain", + "wall/fence", +) + +palette = [ + (150, 120, 90), # 0: animal + (119, 11, 32), # 1: bicycle + (70, 70, 70), # 2: building + (0, 60, 100), # 3: bus + (0, 0, 142), # 4: car + (250, 170, 30), # 5: cone + (230, 150, 140), # 6: construction + (140, 140, 200), # 7: crosswalk + (255, 255, 255), # 8: dashed_lane_marking + (200, 200, 200), # 9: deceleration_line + (190, 153, 153), # 10: gate + (250, 170, 30), # 11: guide_post + (255, 255, 255), # 12: laneline_dash_white + (255, 255, 0), # 13: laneline_dash_yellow + (0, 255, 0), # 14: laneline_solid_green + (255, 0, 0), # 15: laneline_solid_red + (255, 255, 255), # 16: laneline_solid_white + (255, 215, 0), # 17: laneline_solid_yellow + (0, 255, 255), # 18: marking_arrow + (200, 0, 200), # 19: marking_character + (150, 0, 150), # 20: marking_other + (0, 0, 230), # 21: motorcycle + (80, 80, 80), # 22: other_obstacle + (250, 170, 160), # 23: other_pedestrian + (100, 80, 200), # 24: other_vehicle + (180, 165, 180), # 25: parking_lot + (220, 20, 60), # 26: pedestrian + (153, 153, 153), # 27: pole + (128, 64, 128), # 28: road + (110, 110, 110), # 29: road_debris + (244, 35, 232), # 30: sidewalk + (70, 130, 180), # 31: sky + (220, 220, 220), # 32: stopline + (160, 150, 180), # 33: striped_road_marking + (250, 170, 30), # 34: traffic_light + (220, 220, 0), # 35: traffic_sign + (0, 80, 100), # 36: train + (0, 0, 70), # 37: truck + (107, 142, 35), # 38: vegetation/terrain + (102, 102, 156), # 39: wall/fence +] +metainfo = dict(classes=classes, palette=palette) + +train_dataset = dict( + type="MultiImageMixDataset", + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file=anno_file_root + "comlops_infos_train_cleaned.json", + pipeline=[ + dict(type="LoadImageFromFile", backend_args=backend_args), + dict(type="LoadAnnotations", with_bbox=True, with_seg=True), + ], + filter_cfg=dict(filter_empty_gt=False, min_size=8), + backend_args=backend_args, + metainfo=metainfo, + ), + pipeline=train_pipeline, +) + +test_pipeline = [ + dict(type="LoadImageFromFile", backend_args=backend_args), + dict(type="LoadAnnotations", with_bbox=True, with_seg=True), + dict(type="Resize", scale=img_scale, keep_ratio=False), + dict(type="Pad", pad_to_square=True, pad_val=dict(img=(114.0, 114.0, 114.0), seg=255)), + dict( + type="PackDetInputs", + meta_keys=( + "img_id", + "img_path", + "ori_shape", + "img_shape", + "scale_factor", + ), + ), +] + +train_dataloader = dict( + batch_size=batch_size, + num_workers=num_workers, + persistent_workers=True, + sampler=dict(type="DefaultSampler", shuffle=True), + dataset=train_dataset, +) + +val_dataloader = dict( + batch_size=batch_size, + num_workers=16, + persistent_workers=True, + drop_last=False, + sampler=dict(type="DefaultSampler", shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file=anno_file_root + "comlops_infos_val_cleaned.json", + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args, + metainfo=metainfo, + indices=2000, + ), +) + +test_dataloader = val_dataloader + +val_evaluator = [ + dict(type="VOCMetric", metric="mAP", prefix="det"), + dict(type="mmseg.IoUMetric", ignore_index=255, iou_metrics=["mIoU"], prefix="seg"), +] + +test_evaluator = val_evaluator + +# train_cfg = dict(max_epochs=max_epochs, val_interval=interval) +train_cfg = dict(_delete_=True, type="IterBasedTrainLoop", max_iters=200000, val_interval=1000) + +# optimizer +optimizer = dict( + type="OptimWrapper", + optimizer=dict(type="SGD", lr=base_lr, momentum=0.9, weight_decay=5e-4, nesterov=True), + paramwise_cfg=dict(norm_decay_mult=0.0, bias_decay_mult=0.0), +) + +# learning rate scheduler +if max_epochs > 5: + param_scheduler = [ + dict( + type="mmdet.QuadraticWarmupLR", + by_epoch=True, + begin=0, + end=5, + convert_to_iter_based=True, + ), + dict( + type="CosineAnnealingLR", + eta_min=base_lr * 0.05, + begin=5, + T_max=max_epochs - num_last_epochs, + end=max_epochs - num_last_epochs, + by_epoch=True, + convert_to_iter_based=True, + ), + dict( + type="ConstantLR", + by_epoch=True, + factor=1, + begin=max_epochs - num_last_epochs, + end=max_epochs, + ), + ] +else: + param_scheduler = [] + +# logging +log_config = dict( + interval=1, + hooks=[dict(type="TextLoggerHook"), dict(type="TensorboardLoggerHook")], +) + +# default_hooks = dict( +# checkpoint=dict(interval=interval, max_keep_ckpts=3), +# ) +default_hooks = dict( + checkpoint=dict( + type="CheckpointHook", interval=1000, by_epoch=False, max_keep_ckpts=5, save_best="seg/mIoU", rule="greater" + ), + logger=dict(type="LoggerHook", interval=50), + visualization=dict( + type="DetVisualizationHook", draw=False, interval=100, show=False, wait_time=2, test_out_dir="vis_data" + ), +) + +custom_hooks = [ + dict(type="YOLOXModeSwitchHook", num_last_epochs=num_last_epochs, priority=48), + dict(type="SyncNormHook", priority=48), + dict( + type="EMAHook", + ema_type="ExpMomentumEMA", + momentum=0.0001, + update_buffers=True, + priority=4, + ), +] + +auto_scale_lr = dict(base_batch_size=batch_size) + + +vis_backends = [ + dict(type="LocalVisBackend"), + dict(type="TensorboardVisBackend"), +] + +visualizer = dict( + type="DetLocalVisualizer", + vis_backends=[dict(type="LocalVisBackend"), dict(type="TensorboardVisBackend")], + name="visualizer", + alpha=0.3, +) diff --git a/projects/YOLOX_opt_elan/yolox/models/heads/__init__.py b/projects/YOLOX_opt_elan/yolox/models/heads/__init__.py new file mode 100644 index 000000000..598422de0 --- /dev/null +++ b/projects/YOLOX_opt_elan/yolox/models/heads/__init__.py @@ -0,0 +1,3 @@ +from .seg_head import YOLOXSegHead + +__all__ = ("YOLOXSegHead",) diff --git a/projects/YOLOX_opt_elan/yolox/models/heads/seg_head.py b/projects/YOLOX_opt_elan/yolox/models/heads/seg_head.py new file mode 100644 index 000000000..866fc727d --- /dev/null +++ b/projects/YOLOX_opt_elan/yolox/models/heads/seg_head.py @@ -0,0 +1,122 @@ +from typing import Dict, List, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmdet.models.seg_heads.base_semantic_head import BaseSemanticHead +from mmengine.registry import MODELS +from torch import Tensor + +from ..layers.network_blocks import BaseConv, CSPLayer, DWConv + + +def get_activation(name="ReLU6"): + if name.lower() == "relu6": + return nn.ReLU6(inplace=True) + elif name.lower() == "relu": + return nn.ReLU(inplace=True) + elif name.lower() == "silu": + return nn.SiLU(inplace=True) + elif name.lower() == "lrelu": + return nn.LeakyReLU(0.1, inplace=True) + else: + raise AttributeError(f"Unsupported act type {name}") + + +@MODELS.register_module() +class YOLOXSegHead(nn.Module): + def __init__( + self, + in_channels, + num_classes, + feat_channels=None, + act_cfg=dict(type="ReLU6"), + width=1.0, + depthwise=False, + train_cfg=None, + test_cfg=None, + **kwargs, + ): + super().__init__() + self.num_classes = num_classes + self.width = width + # self.stem_channels = feat_channels if feat_channels is not None else int(64 * width) + self.stem_channels = sum(in_channels) + + act_type = act_cfg.get("type", "ReLU6") + self.act_fn = get_activation(act_type) + + self.train_cfg = train_cfg + + Conv = DWConv if depthwise else BaseConv + + # mask head layers + self.conv1 = Conv(self.stem_channels, self.stem_channels, 3, 1, act=act_type) + self.conv2 = Conv(self.stem_channels, self.stem_channels, 3, 1, act=act_type) + self.up1 = nn.Upsample(scale_factor=2, mode="nearest") + self.conv3 = Conv(self.stem_channels, self.stem_channels // 2, 3, 1, act=act_type) + self.up2 = nn.Upsample(scale_factor=2, mode="nearest") + self.conv4 = Conv(self.stem_channels // 2, self.stem_channels // 2, 3, 1, act=act_type) + self.up3 = nn.Upsample(scale_factor=2, mode="nearest") + self.out_conv = nn.Conv2d(self.stem_channels // 2, num_classes, kernel_size=1, stride=1, padding=0) + + def forward(self, feats): + """ + Args: + feats (list[Tensor] or Tensor): features from backbone+neck + Returns: + seg_pred (Tensor): [B, num_classes, H, W] + """ + if isinstance(feats, (list, tuple)): + target_size = feats[0].shape[2:] + up_feats = [F.interpolate(f, size=target_size, mode="bilinear", align_corners=False) for f in feats] + x = torch.cat(up_feats, dim=1) # [B, sum(C_i), H, W] + else: + x = feats + + x = self.conv1(x) + x = self.conv2(x) + x = self.up1(x) + x = self.conv3(x) + x = self.up2(x) + x = self.conv4(x) + x = self.up3(x) + seg_pred = self.out_conv(x) + return seg_pred + + def loss(self, seg_pred, gt_masks): + """ + Args: + seg_pred: [B, C, H, W] + gt_masks: [B, H, W] long + Returns: + dict: {'loss_mask': ...} + """ + return dict(loss_mask=F.cross_entropy(seg_pred, gt_masks.long(), ignore_index=255)) + + def predict(self, x: Union[Tensor, Tuple[Tensor]], batch_data_samples, rescale: bool = False) -> List[Tensor]: + + batch_img_metas = [data_sample.metainfo for data_sample in batch_data_samples] + seg_preds = self.forward(x) + + input_shape = batch_img_metas[0]["batch_input_shape"] + seg_preds = F.interpolate(seg_preds, size=input_shape, mode="bilinear", align_corners=False) + + result_list = [] + for i in range(len(batch_img_metas)): + img_meta = batch_img_metas[i] + h, w = img_meta["img_shape"] + + seg_pred = seg_preds[i][:, :h, :w] + + if rescale: + ori_h, ori_w = img_meta["ori_shape"] + seg_pred = F.interpolate( + seg_pred.unsqueeze(0), size=(ori_h, ori_w), mode="bilinear", align_corners=False + ).squeeze(0) + + seg_pred = seg_pred.argmax(dim=0).to(torch.int64) + + result_list.append(seg_pred) + + return result_list diff --git a/projects/YOLOX_opt_elan/yolox/models/layers/network_blocks.py b/projects/YOLOX_opt_elan/yolox/models/layers/network_blocks.py new file mode 100644 index 000000000..6edebdaae --- /dev/null +++ b/projects/YOLOX_opt_elan/yolox/models/layers/network_blocks.py @@ -0,0 +1,435 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +# Copyright (c) Megvii Inc. All rights reserved. + +import torch +import torch.nn as nn +from torch.autograd import Function + + +class SiLU(nn.Module): + """export-friendly version of nn.SiLU()""" + + @staticmethod + def forward(x): + return x * torch.sigmoid(x) + + +class PactFix(nn.Module): + """export-friendly version of nn.SiLU()""" + + @staticmethod + def forward(x, alpha=4.0): + y = torch.clamp(x, min=0, max=alpha) + return y + + +class Pact(Function): + @staticmethod + def forward(ctx, x, alpha, k): + ctx.save_for_backward(x, alpha) + # y_1 = 0.5 * ( torch.abs(x).detach() - torch.abs(x - alpha).detach() + alpha.item() ) + y = torch.clamp(x, min=0, max=alpha.item()) + scale = (2**k - 1) / alpha + y_q = torch.round(y * scale) / scale + return y_q + + @staticmethod + def backward(ctx, dLdy_q): + # Backward function, I borrowed code from + # https://github.com/obilaniu/GradOverride/blob/master/functional.py + # We get dL / dy_q as a gradient + ( + x, + alpha, + ) = ctx.saved_tensors + # Weight gradient is only valid when [0, alpha] + # Actual gradient for alpha, + # By applying Chain Rule, we get dL / dy_q * dy_q / dy * dy / dalpha + # dL / dy_q = argument, dy_q / dy * dy / dalpha = 0, 1 with x value range + lower_bound = x < 0 + upper_bound = x > alpha + # x_range = 1.0-lower_bound-upper_bound + x_range = ~(lower_bound | upper_bound) + grad_alpha = torch.sum(dLdy_q * torch.ge(x, alpha).float()).view(-1) + return dLdy_q * x_range.float(), grad_alpha, None + + +def get_activation(name="silu", inplace=True): + # name = 'relu' + name = name.lower() + if name == "silu": + module = nn.SiLU(inplace=inplace) + elif name == "relu": + module = nn.ReLU(inplace=inplace) + elif name == "lrelu": + module = nn.LeakyReLU(0.1, inplace=inplace) + elif name == "pact": + module = Pact.apply + elif name == "pactfix": + module = PactFix() + elif name == "relu6": + module = nn.ReLU6() + else: + raise AttributeError("Unsupported act type: {}".format(name)) + return module + + +K = 2 + + +# fisrt = True +class BaseConv(nn.Module): + """A Conv2d -> Batchnorm -> silu/leaky relu block""" + + def __init__(self, in_channels, out_channels, ksize, stride, groups=1, bias=False, act="silu"): + super().__init__() + # same padding + pad = (ksize - 1) // 2 + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=ksize, + stride=stride, + padding=pad, + groups=groups, + bias=bias, + ) + self.bn = nn.BatchNorm2d(out_channels) + self.act = get_activation(act, inplace=True) + self.act_name = act + if self.act_name == "pact": + self.alpha = nn.Parameter(torch.tensor(20.0)) + + def forward(self, x): + if self.act_name == "pact": + return self.act(self.bn(self.conv(x)), self.alpha, 2) + else: + # print(self.conv) + return self.act(self.bn(self.conv(x))) + + def fuseforward(self, x): + return self.act(self.conv(x)) + + +class DWConv(nn.Module): + """Depthwise Conv + Conv""" + + def __init__(self, in_channels, out_channels, ksize, stride=1, act="silu"): + super().__init__() + self.dconv = BaseConv( + in_channels, + in_channels, + ksize=ksize, + stride=stride, + groups=in_channels, + act=act, + ) + self.pconv = BaseConv(in_channels, out_channels, ksize=1, stride=1, groups=1, act=act) + + def forward(self, x): + x = self.dconv(x) + return self.pconv(x) + + +class Bottleneck(nn.Module): + # Standard bottleneck + def __init__(self, in_channels, out_channels, shortcut=True, expansion=0.5, depthwise=False, act="silu", kernel=3): + super().__init__() + hidden_channels = int(out_channels * expansion) + Conv = DWConv if depthwise else BaseConv + self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act) + self.conv2 = Conv(hidden_channels, out_channels, kernel, stride=1, act=act) + self.use_add = shortcut and in_channels == out_channels + + def forward(self, x): + y = self.conv2(self.conv1(x)) + if self.use_add: + y = y + x + return y + + +class BottleneckV8(nn.Module): + # Standard bottleneck + def __init__( + self, + in_channels, + out_channels, + shortcut=True, + expansion=0.5, + depthwise=False, + act="silu", + kernel=3, + ): + super().__init__() + hidden_channels = int(out_channels * expansion) + Conv = DWConv if depthwise else BaseConv + self.conv1 = BaseConv(in_channels, hidden_channels, kernel, stride=1, act=act) + self.conv2 = Conv(hidden_channels, out_channels, kernel, stride=1, act=act) + self.use_add = shortcut and in_channels == out_channels + + def forward(self, x): + y = self.conv2(self.conv1(x)) + if self.use_add: + y = y + x + return y + + +class Bottleneck_EFF(nn.Module): + # Standard bottleneck + def __init__( + self, + in_channels, + out_channels, + shortcut=True, + expansion=0.5, + depthwise=False, + act="silu", + kernel=3, + ): + super().__init__() + hidden_channels = int(out_channels * expansion) + Conv = DWConv if depthwise else BaseConv + self.conv1 = BaseConv(in_channels, hidden_channels, kernel, stride=1, act=act) + self.conv2 = Conv(hidden_channels, out_channels, 5, stride=1, act=act) + self.use_add = shortcut and in_channels == out_channels + + def forward(self, x): + y = self.conv2(self.conv1(x)) + if self.use_add: + y = y + x + return y + + +class ResLayer(nn.Module): + "Residual layer with `in_channels` inputs." + + def __init__(self, in_channels: int): + super().__init__() + mid_channels = in_channels // 2 + self.layer1 = BaseConv(in_channels, mid_channels, ksize=1, stride=1, act="lrelu") + self.layer2 = BaseConv(mid_channels, in_channels, ksize=3, stride=1, act="lrelu") + + def forward(self, x): + out = self.layer2(self.layer1(x)) + return x + out + + +class SPPBottleneck(nn.Module): + """Spatial pyramid pooling layer used in YOLOv3-SPP""" + + def __init__(self, in_channels, out_channels, kernel_sizes=(5, 9, 13), activation="silu"): + super().__init__() + hidden_channels = in_channels // 2 + self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=activation) + self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) for ks in kernel_sizes]) + conv2_channels = hidden_channels * (len(kernel_sizes) + 1) + self.conv2 = BaseConv(conv2_channels, out_channels, 1, stride=1, act=activation) + + def forward(self, x): + x = self.conv1(x) + x = torch.cat([x] + [m(x) for m in self.m], dim=1) + x = self.conv2(x) + return x + + +class CSPLayer(nn.Module): + """C3 in yolov5, CSP Bottleneck with 3 convolutions""" + + def __init__( + self, + in_channels, + out_channels, + n=1, + shortcut=True, + expansion=0.5, + depthwise=False, + act="silu", + elan=False, + kernel=3, + ): + """ + Args: + in_channels (int): input channels. + out_channels (int): output channels. + n (int): number of Bottlenecks. Default value: 1. + """ + # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__() + hidden_channels = int(out_channels * expansion) # hidden channels + self.conv1 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act) + self.conv2 = BaseConv(in_channels, hidden_channels, 1, stride=1, act=act) + self.elan = elan + + if self.elan == True: + self.conv3 = BaseConv((n + 1) * hidden_channels, out_channels, 1, stride=1, act=act) + module_list = [ + BottleneckV8(hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act, kernel=kernel) + for _ in range(n) + ] + else: + self.conv3 = BaseConv(2 * hidden_channels, out_channels, 1, stride=1, act=act) + module_list = [ + Bottleneck(hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act, kernel=kernel) + for _ in range(n) + ] + self.m = nn.Sequential(*module_list) + + def forward(self, x): + x_1 = self.conv1(x) + x_2 = self.conv2(x) + el = [] + if self.elan == True: + x = x_1 + for m in self.m: + x = m(x) + el.append(x) + x = torch.cat([x_2] + [m for m in el], dim=1) + else: + x_1 = self.m(x_1) + x = torch.cat((x_1, x_2), dim=1) + return self.conv3(x) + + +class CSPLayer_EFF(nn.Module): + """C3 in yolov5, CSP Bottleneck with 3 convolutions""" + + def __init__( + self, + in_channels, + out_channels, + n=1, + shortcut=True, + # expansion=0.5, + expansion=1.0, + depthwise=False, + act="silu", + kernel=3, + elan=False, + ): + """ + Args: + in_channels (int): input channels. + out_channels (int): output channels. + n (int): number of Bottlenecks. Default value: 1. + """ + # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__() + hidden_channels = int(out_channels * expansion) # hidden channels + self.conv1 = BaseConv(in_channels, hidden_channels, 3, stride=1, act=act) + self.conv2 = BaseConv(in_channels, hidden_channels, 3, stride=1, act=act) + self.elan = elan + + if self.elan == True: + self.conv3 = BaseConv((n + 1) * hidden_channels, out_channels, 1, stride=1, act=act) + module_list = [ + BottleneckV8(hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act) for _ in range(n) + ] + else: + self.conv3 = BaseConv(2 * hidden_channels, out_channels, 3, stride=1, act=act) + + module_list = [ + Bottleneck_EFF(hidden_channels, hidden_channels, shortcut, 1.0, depthwise, act=act, kernel=kernel) + for _ in range(n) + ] + self.m = nn.Sequential(*module_list) + + def forward(self, x): + x_1 = self.conv1(x) + x_2 = self.conv2(x) + el = [] + if self.elan == True: + x = x_1 + for m in self.m: + x = m(x) + el.append(x) + x = torch.cat([x_2] + [m for m in el], dim=1) + else: + x_1 = self.m(x_1) + x = torch.cat((x_1, x_2), dim=1) + return self.conv3(x) + + +class ELAN(nn.Module): + """C3 in yolov5, CSP Bottleneck with 3 convolutions""" + + def __init__( + self, + in_channels, + out_channels, + n=3, + shortcut=True, + # expansion=0.5, + expansion=1.0, + depthwise=False, + act="silu", + kernel=3, + ): + """ + Args: + in_channels (int): input channels. + out_channels (int): output channels. + n (int): number of Bottlenecks. Default value: 1. + """ + # ch_in, ch_out, number, shortcut, groups, expansion + super().__init__() + hidden_channels = int(out_channels * expansion) # hidden channels + self.conv1 = BaseConv(hidden_channels, out_channels, 3, stride=1, act=act) + self.conv_c = BaseConv(in_channels, hidden_channels, kernel, stride=1, act=act) + self.conv2 = BaseConv((n + 1) * hidden_channels, out_channels, 1, stride=1, act=act) + + module_list = [self.conv_c for _ in range(n)] + self.m = nn.Sequential(*module_list) + + def forward(self, x): + x_1 = self.conv1(x) + el = [] + for m in self.m: + x = m(x) + el.append(x) + x = torch.cat([x_1] + [m for m in el], dim=1) + return self.conv2(x) + + +class Focus(nn.Module): + """Focus width and height information into channel space.""" + + def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="silu"): + super().__init__() + self.conv = BaseConv(in_channels * 4, out_channels, ksize, stride, act=act) + + def forward(self, x): + # shape of x (b,c,w,h) -> y(b,4c,w/2,h/2) + patch_top_left = x[..., ::2, ::2] + patch_top_right = x[..., ::2, 1::2] + patch_bot_left = x[..., 1::2, ::2] + patch_bot_right = x[..., 1::2, 1::2] + x = torch.cat( + ( + patch_top_left, + patch_bot_left, + patch_top_right, + patch_bot_right, + ), + dim=1, + ) + return self.conv(x) + + +class SimpleStem(nn.Module): + """Simple Stem for Acceleration on Embedded Devices""" + + def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="silu"): + super().__init__() + # self.conv1 = BaseConv(in_channels, out_channels, ksize, stride, act=act) + # self.down1 = BaseConv(out_channels, out_channels, ksize, 2, act=act) + self.down1 = BaseConv(in_channels, out_channels, ksize, 2, act=act) + self.conv2 = BaseConv(out_channels, out_channels, ksize, stride, act=act) + # self.down2 = BaseConv(out_channels, out_channels, ksize, 2, act=act) + + def forward(self, x): + # x = self.conv1(x) + x = self.down1(x) + x = self.conv2(x) + # x = self.down2(x) + return x diff --git a/projects/YOLOX_opt_elan/yolox/models/yolox_multitask.py b/projects/YOLOX_opt_elan/yolox/models/yolox_multitask.py new file mode 100644 index 000000000..ac52e9cf1 --- /dev/null +++ b/projects/YOLOX_opt_elan/yolox/models/yolox_multitask.py @@ -0,0 +1,186 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmdet.models import BaseDetector +from mmdet.registry import MODELS +from mmdet.structures import DetDataSample, SampleList +from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig +from mmengine.logging import print_log +from mmengine.model import BaseModule +from mmengine.structures import InstanceData, PixelData +from torch import Tensor + +from .heads import YOLOXSegHead + + +@MODELS.register_module() +class YOLOXMultiTask(BaseDetector): + """ + YOLOX MultiTask detector + Supports bbox + mask heads. + """ + + def __init__( + self, + backbone, + neck, + bbox_head, + mask_head=None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor=None, + init_cfg=None, + **kwargs, + ): + super().__init__(init_cfg=init_cfg) + self.backbone = MODELS.build(backbone) + self.neck = MODELS.build(neck) if neck is not None else None + if bbox_head is not None: + bbox_head.update(train_cfg=train_cfg) + bbox_head.update(test_cfg=test_cfg) + self.bbox_head = MODELS.build(bbox_head) + if mask_head is not None: + mask_head.update(train_cfg=train_cfg) + mask_head.update(test_cfg=test_cfg) + self.mask_head = MODELS.build(mask_head) if mask_head is not None else None + self.data_preprocessor = MODELS.build(data_preprocessor) if data_preprocessor else None + + def extract_feat(self, inputs): + x = self.backbone(inputs) + if self.neck is not None: + x = self.neck(x) + return x + + def _forward(self, imgs, **kwargs): + return self.forward(imgs, **kwargs) + + def forward_train(self, imgs, gt_bboxes, gt_labels, gt_masks=None, **kwargs): + feats = self.extract_feat(imgs) + losses = dict() + losses.update(self.bbox_head.loss(feats[-1], gt_bboxes, gt_labels)) + if self.mask_head is not None and gt_masks is not None: + mask_pred = self.mask_head(feats) + losses.update(self.mask_head.loss(mask_pred, gt_masks)) + return losses + + def forward_test(self, imgs, **kwargs): + feats = self.extract_feat(imgs) + bbox_results = self.bbox_head(feats[-1]) + mask_results = None + if self.mask_head is not None: + mask_results = self.mask_head(feats) + return dict(bboxes=bbox_results, masks=mask_results) + + def forward(self, inputs, data_samples=None, mode="tensor"): + """Forward function with training and testing mode.""" + feats = self.extract_feat(inputs) + + if mode == "tensor": + return self.bbox_head(feats) + elif mode == "loss": + s = self.loss(feats, data_samples) + return s + elif mode == "predict": + pred_instances = self.predict(inputs, data_samples) + + for pred, data_sample in zip(pred_instances, data_samples): + pred.gt_instances = data_sample.gt_instances + if hasattr(data_sample, "gt_sem_seg"): + pred.gt_sem_seg = data_sample.gt_sem_seg + + return pred_instances + else: + raise ValueError(f"Invalid mode {mode}") + + def loss(self, feats, data_samples): + loss = dict() + # bbox head forward + cls_scores, bbox_preds, objectnesses = self.bbox_head(feats) + batch_gt_instances = [d.gt_instances for d in data_samples] + batch_img_metas = [d.metainfo for d in data_samples] + + loss.update( + self.bbox_head.loss_by_feat(cls_scores, bbox_preds, objectnesses, batch_gt_instances, batch_img_metas) + ) + + # mask head + if self.mask_head is not None: + seg_pred = self.mask_head(feats) + target_size = data_samples[0].gt_sem_seg.sem_seg.shape[-2:] + if seg_pred.shape[-2:] != target_size: + seg_pred = F.interpolate(seg_pred, size=target_size, mode="bilinear", align_corners=False) + + gt_masks_tensor = [] + gt_masks = torch.stack([d.gt_sem_seg.sem_seg.squeeze(0) for d in data_samples], dim=0) # (B, H, W) + gt_masks = gt_masks.to(seg_pred.device) + + mask_loss_dict = self.mask_head.loss(seg_pred, gt_masks) + for k, v in mask_loss_dict.items(): + if torch.is_tensor(v): + loss[k] = v + else: + raise TypeError(f"mask loss '{k}' is not a tensor") + + return loss + + def predict( + self, batch_inputs: Tensor, batch_data_samples: SampleList, rescale: bool = True, **kwargs + ) -> SampleList: + + x = self.extract_feat(batch_inputs) + + if self.with_bbox: + bbox_results_list = self.bbox_head.predict(x, batch_data_samples, rescale=True) + else: + bbox_results_list = [InstanceData() for _ in batch_data_samples] + + seg_results_list = None + if self.with_mask: + seg_results_list = self.mask_head.predict(x, batch_data_samples, rescale=True) + + results = [] + for i, data_sample in enumerate(batch_data_samples): + data_sample.pred_instances = bbox_results_list[i] + + if seg_results_list is not None: + pixel_data = PixelData() + pixel_data.data = seg_results_list[i] + pixel_data.sem_seg = seg_results_list[i] + data_sample.pred_sem_seg = pixel_data + + img_h, img_w = data_sample.metainfo["img_shape"] + ori_h, ori_w = data_sample.metainfo["ori_shape"] + + if hasattr(data_sample, "gt_instances"): + + scale_factor = data_sample.metainfo["scale_factor"] # (w_scale, h_scale) + + scale_factor_bbox = [scale_factor[0], scale_factor[1], scale_factor[0], scale_factor[1]] + scale_tensor = data_sample.gt_instances.bboxes.new_tensor(scale_factor_bbox) + + data_sample.gt_instances.bboxes = data_sample.gt_instances.bboxes / scale_tensor + + if hasattr(data_sample, "gt_sem_seg") and data_sample.gt_sem_seg is not None: + gt_sem_seg_data = data_sample.gt_sem_seg.sem_seg # [H_pad, W_pad] + + gt_valid = gt_sem_seg_data[..., :img_h, :img_w] + + if gt_valid.shape[-2:] != (ori_h, ori_w): + gt_resized = ( + F.interpolate( + gt_valid.unsqueeze(0).float(), size=(ori_h, ori_w), mode="nearest" # [1, 1, h, w] + ) + .squeeze(0) + .long() + ) + + new_gt_pixel_data = PixelData() + new_gt_pixel_data.sem_seg = gt_resized + new_gt_pixel_data.data = gt_resized + data_sample.gt_sem_seg = new_gt_pixel_data + elif "data" not in data_sample.gt_sem_seg: + data_sample.gt_sem_seg.data = data_sample.gt_sem_seg.sem_seg + + results.append(data_sample) + + return results diff --git a/projects/YOLOX_opt_elan/yolox/transforms.py b/projects/YOLOX_opt_elan/yolox/transforms.py new file mode 100644 index 000000000..3aad09cff --- /dev/null +++ b/projects/YOLOX_opt_elan/yolox/transforms.py @@ -0,0 +1,39 @@ +import os.path as osp + +import torch +import torch.nn.functional as F +from mmcv.transforms import BaseTransform +from mmdet.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class ResizeSegMask: + def __init__(self, size): + self.size = size # (H_out, W_out) + + def __call__(self, results): + if "gt_seg_map" in results and results["gt_seg_map"] is not None: + seg = results["gt_seg_map"] # numpy array (H, W) + seg = torch.from_numpy(seg).unsqueeze(0).unsqueeze(0).float() # (1,1,H,W) + seg = F.interpolate(seg, size=self.size, mode="nearest") + results["gt_seg_map"] = seg.squeeze(0).squeeze(0).long().numpy() # back to (H_out,W_out) + return results + + +@TRANSFORMS.register_module() +class FixCityscapesPath(BaseTransform): + def __init__(self, data_root, split="train"): + self.data_root = data_root + self.split = split + + def transform(self, results): + img_path = results["img_path"] + filename = osp.basename(img_path) + + seg_filename = filename.replace("_leftImg8bit.png", "_gtFine_labelTrainIds.png") + city = filename.split("_")[0] + seg_path = osp.join(self.data_root, "gtFine", self.split, city, seg_filename) + + results["seg_map_path"] = seg_path + + return results