diff --git a/Train.ipynb b/Train.ipynb
new file mode 100644
index 000000000000..78ca6a8a8306
--- /dev/null
+++ b/Train.ipynb
@@ -0,0 +1,35 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "4fbd6fa3-a31c-4ba1-a917-6669fafcd904",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!python -m torch.distributed.launch --nproc_per_node 2 train.py --img
--batch --epochs --data --hyp data/hyps/hyp.scratch-low.yaml --weights yolov5m --cache --device 0,1 --name "
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/data/hyps/hyp.Objects365.yaml b/data/hyps/hyp.Objects365.yaml
index 74971740f7c7..ce9fd9c5f1f8 100644
--- a/data/hyps/hyp.Objects365.yaml
+++ b/data/hyps/hyp.Objects365.yaml
@@ -32,3 +32,10 @@ fliplr: 0.5
mosaic: 1.0
mixup: 0.0
copy_paste: 0.0
+
+# Custom Params
+area_threshold: 0.3 # area threshold for random perspective
+maximum_mistakes_size: 7680 # image size for mistakes mosaic
+maximum_mistakes_subplots: 64 # subimages in mistakes mosaic
+minimum_mistakes_iou: 0.5 # mistakes minimum iou for match
+minimum_mistakes_confidence: 0.5 # mistakes minimum confidence for checking
\ No newline at end of file
diff --git a/data/hyps/hyp.scratch-high.yaml b/data/hyps/hyp.scratch-high.yaml
index 123cc8407413..7df7fc548586 100644
--- a/data/hyps/hyp.scratch-high.yaml
+++ b/data/hyps/hyp.scratch-high.yaml
@@ -32,3 +32,10 @@ fliplr: 0.5 # image flip left-right (probability)
mosaic: 1.0 # image mosaic (probability)
mixup: 0.1 # image mixup (probability)
copy_paste: 0.1 # segment copy-paste (probability)
+
+# Custom Params
+area_threshold: 0.3 # area threshold for random perspective
+maximum_mistakes_size: 7680 # image size for mistakes mosaic
+maximum_mistakes_subplots: 64 # subimages in mistakes mosaic
+minimum_mistakes_iou: 0.5 # mistakes minimum iou for match
+minimum_mistakes_confidence: 0.5 # mistakes minimum confidence for checking
diff --git a/data/hyps/hyp.scratch-low.yaml b/data/hyps/hyp.scratch-low.yaml
index b9ef1d55a3b6..c3f02ddd5ccd 100644
--- a/data/hyps/hyp.scratch-low.yaml
+++ b/data/hyps/hyp.scratch-low.yaml
@@ -32,3 +32,10 @@ fliplr: 0.5 # image flip left-right (probability)
mosaic: 1.0 # image mosaic (probability)
mixup: 0.0 # image mixup (probability)
copy_paste: 0.0 # segment copy-paste (probability)
+
+# Custom Params
+area_threshold: 0.3 # area threshold for random perspective
+maximum_mistakes_size: 7680 # image size for mistakes mosaic
+maximum_mistakes_subplots: 64 # subimages in mistakes mosaic
+minimum_mistakes_iou: 0.5 # mistakes minimum iou for match
+minimum_mistakes_confidence: 0.5 # mistakes minimum confidence for checking
diff --git a/data/hyps/hyp.scratch-med.yaml b/data/hyps/hyp.scratch-med.yaml
index d6867d7557ba..5b1fd348cb6c 100644
--- a/data/hyps/hyp.scratch-med.yaml
+++ b/data/hyps/hyp.scratch-med.yaml
@@ -32,3 +32,10 @@ fliplr: 0.5 # image flip left-right (probability)
mosaic: 1.0 # image mosaic (probability)
mixup: 0.1 # image mixup (probability)
copy_paste: 0.0 # segment copy-paste (probability)
+
+# Custom Params
+area_threshold: 0.3 # area threshold for random perspective
+maximum_mistakes_size: 7680 # image size for mistakes mosaic
+maximum_mistakes_subplots: 64 # subimages in mistakes mosaic
+minimum_mistakes_iou: 0.5 # mistakes minimum iou for match
+minimum_mistakes_confidence: 0.5 # mistakes minimum confidence for checking
\ No newline at end of file
diff --git a/models/yolo.py b/models/yolo.py
index ed21c067ee93..c1328a0fd0a8 100644
--- a/models/yolo.py
+++ b/models/yolo.py
@@ -164,7 +164,7 @@ def _apply(self, fn):
class DetectionModel(BaseModel):
# YOLOv5 detection model
- def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
+ def __init__(self, cfg='yolov5s.yaml', ch=1, nc=None, anchors=None): # model, input channels, number of classes
super().__init__()
if isinstance(cfg, dict):
self.yaml = cfg # model dict
@@ -188,6 +188,9 @@ def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, i
# Build strides, anchors
m = self.model[-1] # Detect()
+
+ ch = 1
+
if isinstance(m, (Detect, Segment)):
s = 256 # 2x min stride
m.inplace = self.inplace
@@ -298,11 +301,13 @@ def _from_yaml(self, cfg):
def parse_model(d, ch): # model_dict, input_channels(3)
# Parse a YOLOv5 model.yaml dictionary
+ ch = [1] #!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
LOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")
anchors, nc, gd, gw, act = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')
if act:
Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()
LOGGER.info(f"{colorstr('activation:')} {act}") # print
+
na = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchors
no = na * (nc + 5) # number of outputs = anchors * (classes + 5)
diff --git a/models/yolov5m.yaml b/models/yolov5m.yaml
index ad13ab370ff6..06c0e0fbf5ad 100644
--- a/models/yolov5m.yaml
+++ b/models/yolov5m.yaml
@@ -46,3 +46,4 @@ head:
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]
+ch: 1
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 85eb839df8a0..11a728b19a0a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -12,8 +12,8 @@ psutil # system resources
PyYAML>=5.3.1
requests>=2.23.0
scipy>=1.4.1
-thop>=0.1.1 # FLOPs computation
-torch>=1.7.0 # see https://pytorch.org/get-started/locally (recommended)
+thop>=0.1.1 # FLOPs computation
+torch>=1.7.0
torchvision>=0.8.1
tqdm>=4.64.0
# protobuf<=3.20.1 # https://github.com/ultralytics/yolov5/issues/8012
diff --git a/train.py b/train.py
index 8b5446e58f2d..da49c38d3a96 100644
--- a/train.py
+++ b/train.py
@@ -21,6 +21,7 @@
import random
import sys
import time
+import json
from copy import deepcopy
from datetime import datetime
from pathlib import Path
@@ -75,8 +76,7 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
# Directories
w = save_dir / 'weights' # weights dir
(w.parent if evolve else w).mkdir(parents=True, exist_ok=True) # make dir
- last, best = w / 'last.pt', w / 'best.pt'
-
+ last, temp_best = w / 'last.pt', w / 'best.pt'
# Hyperparameters
if isinstance(hyp, str):
with open(hyp, errors='ignore') as f:
@@ -109,7 +109,25 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
init_seeds(opt.seed + 1 + RANK, deterministic=True)
with torch_distributed_zero_first(LOCAL_RANK):
data_dict = data_dict or check_dataset(data) # check if None
- train_path, val_path = data_dict['train'], data_dict['val']
+ multi_val = False
+ best = []
+ best_fitnesses = []
+ validation_paths = []
+ train_path, val_paths = data_dict['train'], data_dict['val']
+ if type(val_paths) == list:
+ print('Detected {} validation sets'.format(len(val_paths)))
+ multi_val = True
+
+ for val_path in (val_paths if multi_val else [val_paths]):
+ temp_val_path = Path(val_path).stem.replace("Valid", "").replace("Yolo", "")
+
+ best.append(w.joinpath(f'best_{temp_val_path}.pt'))
+ temp_val_path = Path(save_dir.joinpath('results_' + temp_val_path))
+ temp_val_path.mkdir(exist_ok=True)
+ temp_val_path.joinpath('classes').mkdir(exist_ok=True)
+ validation_paths.append(temp_val_path)
+ best_fitnesses.append(0)
+
nc = 1 if single_cls else int(data_dict['nc']) # number of classes
names = {0: 'item'} if single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
is_coco = isinstance(val_path, str) and val_path.endswith('coco/val2017.txt') # COCO dataset
@@ -120,8 +138,10 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
if pretrained:
with torch_distributed_zero_first(LOCAL_RANK):
weights = attempt_download(weights) # download if not found locally
+
ckpt = torch.load(weights, map_location='cpu') # load checkpoint to CPU to avoid CUDA memory leak
- model = Model(cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
+ model = Model(cfg or ckpt['model'].yaml, ch=1, nc=nc, anchors=hyp.get('anchors')).to(device) # create
+
exclude = ['anchor'] if (cfg or hyp.get('anchors')) and not resume else [] # exclude keys
csd = ckpt['model'].float().state_dict() # checkpoint state_dict as FP32
csd = intersect_dicts(csd, model.state_dict(), exclude=exclude) # intersect
@@ -130,7 +150,8 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
else:
model = Model(cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
amp = check_amp(model) # check AMP
-
+ amp = True
+
# Freeze
freeze = [f'model.{x}.' for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
for k, v in model.named_parameters():
@@ -198,26 +219,43 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
image_weights=opt.image_weights,
quad=opt.quad,
prefix=colorstr('train: '),
- shuffle=True)
+ shuffle=True,
+ rgb_mode=opt.rgb_mode)
+
+ if dataset.albumentations.transform is not None:
+ augemntations = {str(number+1):str(transform) for number,transform in enumerate(dataset.albumentations.transform.transforms) if transform.p}
+ augemntations_to_copy = 'self.transform = A.Compose(['
+ for i, transform in enumerate(augemntations.values()):
+ augemntations_to_copy += 'A.' + transform + ','
+ augemntations_to_copy = augemntations_to_copy[:-1]
+ augemntations_to_copy += '], bbox_params=A.BboxParams(format=\'yolo\', label_fields=[\'class_labels\']))'
+ augemntations['Whole composition (to copy)'] = augemntations_to_copy
+
+ # Save augemntations
+ with open(save_dir / "augmentations.json", "w") as file:
+ json.dump(augemntations , file)
+
labels = np.concatenate(dataset.labels, 0)
mlc = int(labels[:, 0].max()) # max label class
assert mlc < nc, f'Label class {mlc} exceeds nc={nc} in {data}. Possible class labels are 0-{nc - 1}'
# Process 0
- if RANK in {-1, 0}:
- val_loader = create_dataloader(val_path,
- imgsz,
- batch_size // WORLD_SIZE * 2,
- gs,
- single_cls,
- hyp=hyp,
- cache=None if noval else opt.cache,
- rect=True,
- rank=-1,
- workers=workers * 2,
- pad=0.5,
- prefix=colorstr('val: '))[0]
-
+ if RANK in [-1, 0]:
+ val_loaders = []
+ for val_path in (val_paths if multi_val else [val_paths]):
+ val_loaders.append(create_dataloader(val_path,
+ imgsz,
+ batch_size // WORLD_SIZE * 2,
+ gs,
+ single_cls,
+ hyp=hyp,
+ cache=None if noval else opt.cache,
+ rect=True,
+ rank=-1,
+ workers=workers,
+ pad=0.5,
+ prefix=colorstr('val: '),
+ rgb_mode=opt.rgb_mode)[0])
if not resume:
if not opt.noautoanchor:
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz) # run AutoAnchor
@@ -346,48 +384,58 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
callbacks.run('on_train_epoch_end', epoch=epoch)
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
- if not noval or final_epoch: # Calculate mAP
- results, maps, _ = validate.run(data_dict,
- batch_size=batch_size // WORLD_SIZE * 2,
- imgsz=imgsz,
- half=amp,
- model=ema.ema,
- single_cls=single_cls,
- dataloader=val_loader,
- save_dir=save_dir,
- plots=False,
- callbacks=callbacks,
- compute_loss=compute_loss)
-
- # Update best mAP
- fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
- stop = stopper(epoch=epoch, fitness=fi) # early stop check
- if fi > best_fitness:
- best_fitness = fi
- log_vals = list(mloss) + list(results) + lr
- callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi)
-
- # Save model
- if (not nosave) or (final_epoch and not evolve): # if save
- ckpt = {
- 'epoch': epoch,
- 'best_fitness': best_fitness,
- 'model': deepcopy(de_parallel(model)).half(),
- 'ema': deepcopy(ema.ema).half(),
- 'updates': ema.updates,
- 'optimizer': optimizer.state_dict(),
- 'opt': vars(opt),
- 'git': GIT_INFO, # {remote, branch, commit} if a git repo
- 'date': datetime.now().isoformat()}
-
- # Save last, best and delete
- torch.save(ckpt, last)
- if best_fitness == fi:
- torch.save(ckpt, best)
- if opt.save_period > 0 and epoch % opt.save_period == 0:
- torch.save(ckpt, w / f'epoch{epoch}.pt')
- del ckpt
- callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
+
+ for val_index, val_loader in enumerate(val_loaders):
+ if not noval or final_epoch: # Calculate mAP
+ results, maps, _ = validate.run(data_dict,
+ batch_size=batch_size // WORLD_SIZE * 2,
+ imgsz=imgsz,
+ model=ema.ema,
+ single_cls=single_cls,
+ dataloader=val_loader,
+ save_dir=validation_paths[val_index],
+ plots=False,
+ callbacks=callbacks,
+ compute_loss=compute_loss,
+ rgb_mode=opt.rgb_mode,
+ maximum_mistakes_size=hyp['maximum_mistakes_size'],
+ maximum_mistakes_subplots=hyp['maximum_mistakes_subplots'],
+ minimum_mistakes_iou=hyp['minimum_mistakes_iou'],
+ minimum_mistakes_confidence=hyp['minimum_mistakes_confidence']
+ )
+
+ # Update best mAP
+ fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, mAP@.5, mAP@.5-.95]
+ stop = stopper(epoch=epoch, fitness=fi) # early stop check
+
+ if fi > best_fitnesses[val_index]:
+ best_fitnesses[val_index] = fi
+ if fi > best_fitness:
+ best_fitness = fi
+ log_vals = list(mloss) + list(results) + lr
+ callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitnesses[val_index], fi)
+
+ # Save model
+ if (not nosave) or (final_epoch and not evolve): # if save
+ ckpt = {
+ 'epoch': epoch,
+ 'best_fitness': best_fitness,
+ 'model': deepcopy(de_parallel(model)).half(),
+ 'ema': deepcopy(ema.ema).half(),
+ 'updates': ema.updates,
+ 'optimizer': optimizer.state_dict(),
+ 'opt': vars(opt),
+ 'git': GIT_INFO, # {remote, branch, commit} if a git repo
+ 'date': datetime.now().isoformat()}
+
+ # Save last, best and delete
+ torch.save(ckpt, last)
+ if best_fitnesses[val_index] == fi:
+ torch.save(ckpt, best[val_index])
+ if (epoch > 0) and (opt.save_period > 0) and (epoch % opt.save_period == 0):
+ torch.save(ckpt, w / f'epoch{epoch}.pt')
+ del ckpt
+ callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
# EarlyStopping
if RANK != -1: # if DDP training
@@ -402,29 +450,36 @@ def train(hyp, opt, device, callbacks): # hyp is path/to/hyp.yaml or hyp dictio
# end training -----------------------------------------------------------------------------------------------------
if RANK in {-1, 0}:
LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
- for f in last, best:
- if f.exists():
- strip_optimizer(f) # strip optimizers
- if f is best:
- LOGGER.info(f'\nValidating {f}...')
- results, _, _ = validate.run(
- data_dict,
- batch_size=batch_size // WORLD_SIZE * 2,
- imgsz=imgsz,
- model=attempt_load(f, device).half(),
- iou_thres=0.65 if is_coco else 0.60, # best pycocotools at iou 0.65
- single_cls=single_cls,
- dataloader=val_loader,
- save_dir=save_dir,
- save_json=is_coco,
- verbose=True,
- plots=plots,
- callbacks=callbacks,
- compute_loss=compute_loss) # val best model with plots
- if is_coco:
- callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi)
-
- callbacks.run('on_train_end', last, best, epoch, results)
+ for best_index, best_model in enumerate(best):
+ for f in last, best_model:
+ if f.exists():
+ strip_optimizer(f) # strip optimizers
+ if f is best_model:
+ LOGGER.info(f'\nValidating {f}...')
+ results, _, _ = validate.run(data_dict,
+ batch_size=batch_size // WORLD_SIZE * 2,
+ imgsz=imgsz,
+ model=attempt_load(f, device).half(),
+ iou_thres=0.65 if is_coco else 0.60, # best pycocotools results at 0.65
+ single_cls=single_cls,
+ dataloader=val_loaders[best_index],
+ save_dir=validation_paths[best_index],
+ save_json=is_coco,
+ verbose=True,
+ plots=plots,
+ callbacks=callbacks,
+ compute_loss=compute_loss,
+ rgb_mode=opt.rgb_mode,
+ maximum_mistakes_size=hyp['maximum_mistakes_size'],
+ maximum_mistakes_subplots=hyp['maximum_mistakes_subplots'],
+ minimum_mistakes_iou=hyp['minimum_mistakes_iou'],
+ minimum_mistakes_confidence=hyp['minimum_mistakes_confidence']
+ ) # val best model with plots
+ if is_coco:
+ callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi)
+
+ callbacks.run('on_train_end', last, best_model, epoch, results)
+ LOGGER.info(f"Results saved to {colorstr('bold', validation_paths[best_index])}")
torch.cuda.empty_cache()
return results
@@ -466,6 +521,7 @@ def parse_opt(known=False):
parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
parser.add_argument('--seed', type=int, default=0, help='Global training seed')
parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify')
+ parser.add_argument('--rgb-mode', action='store_true', help='train model in rgb mode, with image_channels=3.')
# Logger arguments
parser.add_argument('--entity', default=None, help='Entity')
@@ -631,3 +687,4 @@ def run(**kwargs):
if __name__ == "__main__":
opt = parse_opt()
main(opt)
+
diff --git a/utils/augmentations.py b/utils/augmentations.py
index 1eae5db8f816..2a50d43f48d9 100644
--- a/utils/augmentations.py
+++ b/utils/augmentations.py
@@ -26,18 +26,18 @@ def __init__(self, size=640):
prefix = colorstr('albumentations: ')
try:
import albumentations as A
- check_version(A.__version__, '1.0.3', hard=True) # version requirement
-
- T = [
- A.RandomResizedCrop(height=size, width=size, scale=(0.8, 1.0), ratio=(0.9, 1.11), p=0.0),
- A.Blur(p=0.01),
- A.MedianBlur(p=0.01),
- A.ToGray(p=0.01),
- A.CLAHE(p=0.01),
- A.RandomBrightnessContrast(p=0.0),
- A.RandomGamma(p=0.0),
- A.ImageCompression(quality_lower=75, p=0.0)] # transforms
- self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
+ check_version(A.__version__, '1.0.3') # version requirement
+# self.transform = A.Compose([
+# A.RandomBrightnessContrast(always_apply=False, p=0.7, brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), brightness_by_max=True),
+# A.JpegCompression(always_apply=False, p=0.7, quality_lower=40, quality_upper=100),
+# A.GaussNoise(always_apply=False, p=0.6, mean=-21.0, var_limit=(40.0, 150)),
+# A.OneOf([
+# A.Blur(always_apply=False, p=0.5, blur_limit=(3, 5)),
+# A.MotionBlur(always_apply=False, p=0.5, blur_limit=(3, 5)),
+# ], p=0.7),
+# A.Cutout(always_apply=False, p=0.7, num_holes=32, max_h_size=8, max_w_size=8)
+# ], bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
+ self.transform = A.Compose()
LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
except ImportError: # package not installed, skip
@@ -108,7 +108,7 @@ def replicate(im, labels):
return im, labels
-def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
+def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=False, scaleFill=False, scaleup=True, stride=32):
# Resize and pad image while meeting stride-multiple constraints
shape = im.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
@@ -149,7 +149,8 @@ def random_perspective(im,
scale=.1,
shear=10,
perspective=0.0,
- border=(0, 0)):
+ border=(0, 0),
+ area_threshold=0.01):
# torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(0.1, 0.1), scale=(0.9, 1.1), shear=(-10, 10))
# targets = [cls, xyxy]
@@ -230,7 +231,7 @@ def random_perspective(im,
new[:, [1, 3]] = new[:, [1, 3]].clip(0, height)
# filter candidates
- i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=0.01 if use_segments else 0.10)
+ i = box_candidates(box1=targets[:, 1:5].T * s, box2=new.T, area_thr=area_threshold if use_segments else area_threshold)
targets = targets[i]
targets[:, 1:5] = new[i]
diff --git a/utils/dataloaders.py b/utils/dataloaders.py
index 6d2b27ea5e60..38c93651de86 100644
--- a/utils/dataloaders.py
+++ b/utils/dataloaders.py
@@ -115,7 +115,8 @@ def create_dataloader(path,
image_weights=False,
quad=False,
prefix='',
- shuffle=False):
+ shuffle=False,
+ rgb_mode=False):
if rect and shuffle:
LOGGER.warning('WARNING ⚠️ --rect is incompatible with DataLoader shuffle, setting shuffle=False')
shuffle = False
@@ -132,7 +133,8 @@ def create_dataloader(path,
stride=int(stride),
pad=pad,
image_weights=image_weights,
- prefix=prefix)
+ prefix=prefix,
+ rgb_mode=rgb_mode)
batch_size = min(batch_size, len(dataset))
nd = torch.cuda.device_count() # number of CUDA devices
@@ -303,15 +305,18 @@ def __next__(self):
else:
# Read image
self.count += 1
- im0 = cv2.imread(path) # BGR
+ img0 = cv2.imread(path,cv2.IMREAD_GRAYSCALE) # GRAY
+ img0 = img0.reshape(img0.shape[0],img0.shape[1],1)
assert im0 is not None, f'Image Not Found {path}'
s = f'image {self.count}/{self.nf} {path}: '
if self.transforms:
im = self.transforms(im0) # transforms
+ im = im.reshape(im.shape[0], im.shape[1], 1)
else:
im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
- im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
+ # im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
+ im = im.reshape(im.shape[0], im.shape[1], 1) # GRAY
im = np.ascontiguousarray(im) # contiguous
return path, im, im0, self.cap, s
@@ -448,7 +453,8 @@ def __init__(self,
stride=32,
pad=0.0,
min_items=0,
- prefix=''):
+ prefix='',
+ rgb_mode=False):
self.img_size = img_size
self.augment = augment
self.hyp = hyp
@@ -459,6 +465,7 @@ def __init__(self,
self.stride = stride
self.path = path
self.albumentations = Albumentations(size=img_size) if augment else None
+ self.rgb_mode = rgb_mode
try:
f = [] # image files
@@ -590,7 +597,11 @@ def check_cache_ram(self, safety_margin=0.1, prefix=''):
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
n = min(self.n, 30) # extrapolate from 30 random images
for _ in range(n):
- im = cv2.imread(random.choice(self.im_files)) # sample image
+ if self.rgb_mode:
+ im = cv2.imread(random.choice(self.im_files))
+ else:
+ im = cv2.imread(random.choice(self.im_files), cv2.IMREAD_GRAYSCALE) # sample image
+ im = im.reshape(im.shape[0], im.shape[1], 1)
ratio = self.img_size / max(im.shape[0], im.shape[1]) # max(h, w) # ratio
b += im.nbytes * ratio ** 2
mem_required = b * self.n / n # GB required to cache dataset into RAM
@@ -683,7 +694,8 @@ def __getitem__(self, index):
translate=hyp['translate'],
scale=hyp['scale'],
shear=hyp['shear'],
- perspective=hyp['perspective'])
+ perspective=hyp['perspective'],
+ area_threshold=hyp['area_threshold'])
nl = len(labels) # number of labels
if nl:
@@ -695,7 +707,8 @@ def __getitem__(self, index):
nl = len(labels) # update after albumentations
# HSV color-space
- augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
+ if self.rgb_mode:
+ augment_hsv(img, hgain=hyp['hsv_h'], sgain=hyp['hsv_s'], vgain=hyp['hsv_v'])
# Flip up-down
if random.random() < hyp['flipud']:
@@ -718,7 +731,10 @@ def __getitem__(self, index):
labels_out[:, 1:] = torch.from_numpy(labels)
# Convert
- img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
+ if self.rgb_mode:
+ img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
+ else:
+ img = img.reshape(1, img.shape[0], img.shape[1])
img = np.ascontiguousarray(img)
return torch.from_numpy(img), labels_out, self.im_files[index], shapes
@@ -730,13 +746,20 @@ def load_image(self, i):
if fn.exists(): # load npy
im = np.load(fn)
else: # read image
- im = cv2.imread(f) # BGR
+ if self.rgb_mode:
+ im = cv2.imread(f)
+ else:
+ im = cv2.imread(f,cv2.IMREAD_GRAYSCALE) # BGR
+ im = im.reshape(im.shape[0],im.shape[1],1)
assert im is not None, f'Image Not Found {f}'
h0, w0 = im.shape[:2] # orig hw
r = self.img_size / max(h0, w0) # ratio
if r != 1: # if sizes are not equal
interp = cv2.INTER_LINEAR if (self.augment or r > 1) else cv2.INTER_AREA
im = cv2.resize(im, (int(w0 * r), int(h0 * r)), interpolation=interp)
+
+ if not self.rgb_mode:
+ im = im.reshape(im.shape[0],im.shape[1],1)
return im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
return self.ims[i], self.im_hw0[i], self.im_hw[i] # im, hw_original, hw_resized
@@ -744,7 +767,11 @@ def cache_images_to_disk(self, i):
# Saves an image as an *.npy file for faster loading
f = self.npy_files[i]
if not f.exists():
- np.save(f.as_posix(), cv2.imread(self.im_files[i]))
+ if self.rgb_mode:
+ np.save(f.as_posix(), cv2.imread(self.im_files[i]))
+ else:
+ image = cv2.imread(self.im_files[i],cv2.IMREAD_GRAYSCALE)
+ np.save(f.as_posix(), image.reshape(image.shape[0], image.shape[1], 1))
def load_mosaic(self, index):
# YOLOv5 4-mosaic loader. Loads 1 image + 3 random images into a 4-image mosaic
@@ -800,7 +827,8 @@ def load_mosaic(self, index):
scale=self.hyp['scale'],
shear=self.hyp['shear'],
perspective=self.hyp['perspective'],
- border=self.mosaic_border) # border to remove
+ border=self.mosaic_border,
+ area_threshold=self.hyp['area_threshold']) # border to remove
return img4, labels4
@@ -877,7 +905,8 @@ def load_mosaic9(self, index):
scale=self.hyp['scale'],
shear=self.hyp['shear'],
perspective=self.hyp['perspective'],
- border=self.mosaic_border) # border to remove
+ border=self.mosaic_border,
+ area_threshold=self.hyp['area_threshold']) # border to remove
return img9, labels9
@@ -1179,13 +1208,17 @@ def __init__(self, root, augment, imgsz, cache=False):
def __getitem__(self, i):
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
if self.cache_ram and im is None:
- im = self.samples[i][3] = cv2.imread(f)
+ img = cv2.imread(f,cv2.IMREAD_GRAYSCALE)
+ im = self.samples[i][3] = img.reshape(img.shape[0],img.shape[1],1)
elif self.cache_disk:
if not fn.exists(): # load npy
- np.save(fn.as_posix(), cv2.imread(f))
+ img = cv2.imread(f,cv2.IMREAD_GRAYSCALE)
+ np.save(fn.as_posix(), img.reshape(img.shape[0],img.shape[1],1))
im = np.load(fn)
+ im = im.reshape(im.shape[0], im.shape[1], 1)
else: # read image
- im = cv2.imread(f) # BGR
+ im = cv2.imread(f,cv2.IMREAD_GRAYSCALE)
+ im = im.reshape(im.shape[0],im.shape[1],1) # BGR
if self.album_transforms:
sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]
else:
diff --git a/utils/loggers/__init__.py b/utils/loggers/__init__.py
index 22da87034f24..57901365531f 100644
--- a/utils/loggers/__init__.py
+++ b/utils/loggers/__init__.py
@@ -179,11 +179,13 @@ def on_train_batch_end(self, model, ni, imgs, targets, paths, vals):
# Callback runs on train batch end
# ni: number integrated batches (since train start)
if self.plots:
- if ni < 3:
- f = self.save_dir / f'train_batch{ni}.jpg' # filename
+ if ni < 16:
+ new_save_dir = self.save_dir.joinpath('train_batches')
+ new_save_dir.mkdir(exist_ok = True)
+ f = new_save_dir / f'train_batch{ni}.jpg' # filename
plot_images(imgs, targets, paths, f)
if ni == 0 and self.tb and not self.opt.sync_bn:
- log_tensorboard_graph(self.tb, model, imgsz=(self.opt.imgsz, self.opt.imgsz))
+ log_tensorboard_graph(self.tb, model, imgsz=(self.opt.imgsz, self.opt.imgsz),rgb_mode=self.opt.rgb_mode)
if ni == 10 and (self.wandb or self.clearml):
files = sorted(self.save_dir.glob('train*.jpg'))
if self.wandb:
@@ -390,12 +392,15 @@ def update_params(self, params):
wandb.run.config.update(params, allow_val_change=True)
-def log_tensorboard_graph(tb, model, imgsz=(640, 640)):
+def log_tensorboard_graph(tb, model, imgsz=(640, 640), rgb_mode=False):
# Log model graph to TensorBoard
try:
p = next(model.parameters()) # for device, type
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz # expand
- im = torch.zeros((1, 3, *imgsz)).to(p.device).type_as(p) # input image (WARNING: must be zeros, not empty)
+ if rgb_mode:
+ im = torch.zeros((1, 3, *imgsz)).to(p.device).type_as(p)
+ else:
+ im = torch.zeros((1, 1, *imgsz)).to(p.device).type_as(p) # input image (WARNING: must be zeros, not empty)
with warnings.catch_warnings():
warnings.simplefilter('ignore') # suppress jit trace warning
tb.add_graph(torch.jit.trace(de_parallel(model), im, strict=False), [])
diff --git a/utils/plots.py b/utils/plots.py
index d2f232de0e97..edf5b3dd79e2 100644
--- a/utils/plots.py
+++ b/utils/plots.py
@@ -226,15 +226,13 @@ def output_to_target(output, max_det=300):
@threaded
-def plot_images(images, targets, paths=None, fname='images.jpg', names=None):
+def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max_size=1920, max_subplots=16):
# Plot image grid with labels
if isinstance(images, torch.Tensor):
images = images.cpu().float().numpy()
if isinstance(targets, torch.Tensor):
targets = targets.cpu().numpy()
- max_size = 1920 # max image size
- max_subplots = 16 # max image subplots, i.e. 4x4
bs, _, h, w = images.shape # batch size, _, height, width
bs = min(bs, max_subplots) # limit plot images
ns = np.ceil(bs ** 0.5) # number of subplots (square)
@@ -258,8 +256,9 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None):
mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
# Annotate
- fs = int((h + w) * ns * 0.01) # font size
+ fs = int((h + w) * ns * 0.005) # font size
annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=names)
+
for i in range(i + 1):
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
@@ -300,7 +299,7 @@ def plot_lr_scheduler(optimizer, scheduler, epochs=300, save_dir=''):
plt.plot(y, '.-', label='LR')
plt.xlabel('epoch')
plt.ylabel('LR')
- plt.grid()
+# plt.grid()
plt.xlim(0, epochs)
plt.ylim(0)
plt.savefig(Path(save_dir) / 'LR.png', dpi=200)
@@ -370,7 +369,7 @@ def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_
alpha=.25,
label='EfficientDet')
- ax2.grid(alpha=0.2)
+# ax2.grid(alpha=0.2)
ax2.set_yticks(np.arange(20, 60, 5))
ax2.set_xlim(0, 57)
ax2.set_ylim(25, 55)
@@ -426,6 +425,7 @@ def plot_labels(labels, names=(), save_dir=Path('')):
plt.savefig(save_dir / 'labels.jpg', dpi=200)
matplotlib.use('Agg')
plt.close()
+ plt.style.use('default')
def imshow_cls(im, labels=None, pred=None, names=None, nmax=25, verbose=False, f=Path('images.jpg')):
diff --git a/val.py b/val.py
index e84249ed383f..498df591e06b 100644
--- a/val.py
+++ b/val.py
@@ -25,9 +25,11 @@
import sys
from pathlib import Path
+import pandas as pd
import numpy as np
import torch
from tqdm import tqdm
+import matplotlib.pyplot as plt
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0] # YOLOv5 root directory
@@ -45,6 +47,8 @@
from utils.plots import output_to_target, plot_images, plot_val_study
from utils.torch_utils import select_device, smart_inference_mode
+from yolo_analyze_service import YoloAnalyzeService
+
def save_one_txt(predn, save_conf, shape, file):
# Save one txt result
@@ -124,7 +128,12 @@ def run(
plots=True,
callbacks=Callbacks(),
compute_loss=None,
-):
+ rgb_mode=False,
+ maximum_mistakes_size = 7680,
+ maximum_mistakes_subplots = 64,
+ minimum_mistakes_iou = 0.5,
+ minimum_mistakes_confidence = 0.5,
+ ):
# Initialize/load model and set device
training = model is not None
if training: # called by train.py
@@ -168,8 +177,12 @@ def run(
ncm = model.model.nc
assert ncm == nc, f'{weights} ({ncm} classes) trained on different --data than what you passed ({nc} ' \
f'classes). Pass correct combination of --weights and --data that are trained together.'
- model.warmup(imgsz=(1 if pt else batch_size, 3, imgsz, imgsz)) # warmup
+ if rgb_mode:
+ model.warmup(imgsz=(3 if pt else batch_size, 1, imgsz, imgsz)) # warmup
+ else:
+ model.warmup(imgsz=(1 if pt else batch_size, 1, imgsz, imgsz)) # warmup
pad, rect = (0.0, False) if task == 'speed' else (0.5, pt) # square inference for benchmarks
+
task = task if task in ('train', 'val', 'test') else 'val' # path to train/val/test images
dataloader = create_dataloader(data[task],
imgsz,
@@ -179,7 +192,8 @@ def run(
pad=pad,
rect=rect,
workers=workers,
- prefix=colorstr(f'{task}: '))[0]
+ prefix=colorstr(f'{task}: '),
+ rgb_mode=rgb_mode)[0]
seen = 0
confusion_matrix = ConfusionMatrix(nc=nc)
@@ -267,6 +281,28 @@ def run(
plot_images(im, targets, paths, save_dir / f'val_batch{batch_i}_labels.jpg', names) # labels
plot_images(im, output_to_target(preds), paths, save_dir / f'val_batch{batch_i}_pred.jpg', names) # pred
+
+ if plots and batch_i < 16:
+ max_size=1920
+ max_subplots=16
+
+ valid_labels_path = save_dir.joinpath('labels')
+ valid_predictions_path = save_dir.joinpath('pred')
+ valid_labels_path.mkdir(exist_ok=True)
+ valid_predictions_path.mkdir(exist_ok=True)
+ f = valid_labels_path / f'val_batch{batch_i}_labels.jpg'
+ plot_images(im, targets, paths, f, names, max_size, max_subplots)
+ f = valid_predictions_path / f'val_batch{batch_i}_pred.jpg' # predictions
+ plot_images(im, output_to_target(preds), paths, f, names, max_size, max_subplots)
+
+ yolo_analyzer = YoloAnalyzeService(minimum_mistakes_iou, minimum_mistakes_confidence)
+ mistakes = yolo_analyzer.analyze_batch(targets.cpu().numpy(), output_to_target(preds))
+ mistakes_path = save_dir.joinpath("mistakes")
+ mistakes_path.mkdir(exist_ok=True)
+ f = mistakes_path / f'val_batch{batch_i}_mistakes.jpg'
+
+ plot_images(im, mistakes, paths, f, names, maximum_mistakes_size, maximum_mistakes_subplots)
+
callbacks.run('on_val_batch_end', batch_i, im, targets, paths, shapes, preds)
# Compute metrics
@@ -284,10 +320,51 @@ def run(
LOGGER.warning(f'WARNING ⚠️ no labels found in {task} set, can not compute metrics without labels')
# Print results per class
- if (verbose or (nc < 50 and not training)) and nc > 1 and len(stats):
+ if training and not verbose and (nc < 50 and nc > 1 and len(stats)):
+ for i, c in enumerate(ap_class):
+ save_path = save_dir.joinpath('classes').joinpath(names[c]+'.csv')
+ df = pd.DataFrame({'mAP.5': [ap50[i]], 'mAP': [ap[i]]})
+ df.to_csv(str(save_path), mode='a', index=False, header=not save_path.exists())
+
+ if (verbose or (nc < 50 and not training)) and nc > 1 and len(stats):
+ plt.figure(figsize=(16, 12), dpi=150)
+ last_results = []
for i, c in enumerate(ap_class):
LOGGER.info(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))
+ load_path = save_dir.joinpath('classes').joinpath(names[c]+'.csv')
+ df = pd.read_csv(str(load_path))
+ x = np.arange(df.shape[0])
+ last_results.append((names[c], np.array(df.tail(1))))
+ plt.plot(x, df['mAP.5'], label='mAP.5{}'.format(names[c]))
+ plt.plot(x, df['mAP'], label='mAP{}'.format(names[c]))
+ plt.legend(loc="upper left")
+ plt.savefig(load_path.parent.joinpath('AllClasses.png'))
+ reset_plot()
+
+ class_names = [item[0] for item in last_results]
+ mAP5s = np.multiply([item[1][0][0] for item in last_results], 100)
+ mAPs = np.multiply([item[1][0][1] for item in last_results], 100)
+
+ x = np.arange(len(class_names)) # the label locations
+ width = 0.35 # the width of the bars
+
+ fig, ax = plt.subplots(figsize=(20, 10))
+ fig.set_dpi(150)
+ rects1 = ax.bar(x - width/2, mAP5s, width, label='mAP.5')
+ rects2 = ax.bar(x + width/2, mAPs, width, label='mAP')
+
+ # Add some text for labels, title and custom x-axis tick labels, etc.
+ ax.set_ylabel('Percentage')
+ ax.set_title('mAP by class')
+ plt.xticks([r + (width * 0.1) for r in range(len(class_names))], class_names, rotation=90)
+ ax.legend()
+
+ ax.bar_label(rects1, padding=3)
+ ax.bar_label(rects2, padding=3)
+ plt.savefig(save_dir.joinpath('mAP_summary.png'))
+ reset_plot()
+
# Print speeds
t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
if not training:
@@ -335,7 +412,13 @@ def run(
maps[c] = ap[i]
return (mp, mr, map50, map, *(loss.cpu() / len(dataloader)).tolist()), maps, t
-
+def reset_plot():
+ plt.figure().clear()
+ plt.close()
+ plt.cla()
+ plt.clf()
+ plt.figure(figsize=(6.4, 4.8), dpi=100)
+
def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, default=ROOT / 'data/coco128.yaml', help='dataset.yaml path')
@@ -360,6 +443,7 @@ def parse_opt():
parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
+ parser.add_argument('--rgb-mode', action='store_true', help='train model in rgb mode, with image_channels=3.')
opt = parser.parse_args()
opt.data = check_yaml(opt.data) # check YAML
opt.save_json |= opt.data.endswith('coco.yaml')
diff --git a/yolo_analyze_service.py b/yolo_analyze_service.py
new file mode 100644
index 000000000000..fa2705f75f38
--- /dev/null
+++ b/yolo_analyze_service.py
@@ -0,0 +1,87 @@
+from typing import List
+
+import numpy as np
+
+from yolo_detections_matching_service import YoloDetectionsMatchingService
+
+
+class YoloAnalyzeService:
+ def __init__(self, minimum_iou: float, minimum_confidence: float) -> None:
+ self.__minimum_iou = minimum_iou
+ self.__minimum_confidence = minimum_confidence
+ self.__detections_matching_service = YoloDetectionsMatchingService()
+
+ def analyze_batch(self, targets: np.ndarray, detections: np.ndarray):
+ wrong_detections = []
+ not_detected_targets = []
+ wrong_detections_labels = []
+
+ targets_batched = self.__parse_targets(targets)
+ detections_batched = self.__parse_targets(detections)
+
+ for targets_batch, detections_batch in zip(targets_batched, detections_batched):
+ rows_assignment, columns_assignments, cost_matrix = self.__detections_matching_service.find_detections_assignment(
+ targets_batch, detections_batch)
+
+ found_targets = set()
+ matched_detections = set()
+
+ for row, col in zip(rows_assignment, columns_assignments):
+ iou = cost_matrix[row][col]
+
+ if iou < self.__minimum_iou:
+ continue
+
+ found_targets.add(row)
+ matched_detections.add(col)
+
+ target_label = targets_batch[row][1]
+ detection_label = detections_batch[col][1]
+
+ if target_label != detection_label:
+ wrong_detections_labels.append(detections_batch[col])
+ continue
+
+ for target_index in range(targets_batch.shape[0]):
+ if not found_targets.__contains__(target_index):
+ not_detected_targets.append(targets_batch[target_index])
+
+ for detection_index in range(detections_batch.shape[0]):
+ if not matched_detections.__contains__(detection_index):
+ wrong_detections.append(detections_batch[detection_index])
+
+ all_mistakes = wrong_detections + wrong_detections_labels + not_detected_targets
+
+ if all_mistakes == []:
+ return np.array([])
+ return np.concatenate(all_mistakes).reshape(-1, 7)
+
+ def __parse_targets(self, targets: np.ndarray) -> List[np.ndarray]:
+ targets_batched = []
+ last_batch_index = -1
+
+ for target in targets:
+ if target.shape[0] == 6:
+ target = np.append(target, [1.0])
+
+ if target[6] < self.__minimum_confidence:
+ continue
+
+ batch_index = int(target[0])
+
+ while batch_index != last_batch_index:
+ targets_batched.append([])
+ last_batch_index += 1
+
+ targets_batched[batch_index].append(target)
+
+ targets_final = []
+
+ for targets_batch in targets_batched:
+ if targets_batch == []:
+ targets_final.append(np.array([]))
+ continue
+ target_final = np.concatenate(targets_batch)
+ target_final = target_final.reshape(-1, 7)
+ targets_final.append(target_final)
+ return targets_final
diff --git a/yolo_detections_matching_service.py b/yolo_detections_matching_service.py
new file mode 100644
index 000000000000..f6e7cf0a5f15
--- /dev/null
+++ b/yolo_detections_matching_service.py
@@ -0,0 +1,48 @@
+from collections import namedtuple
+from typing import Tuple
+
+import numpy as np
+from scipy.optimize import linear_sum_assignment
+
+
+class YoloDetectionsMatchingService:
+
+ def __init__(self) -> None:
+ pass
+
+ def find_detections_assignment(self, targets: np.ndarray, detections: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
+ cost_matrix = self.__calculate_cost_matrix(
+ targets, detections)
+
+ rows_assignment, cols_assignments = linear_sum_assignment(
+ -cost_matrix)
+
+ return rows_assignment, cols_assignments, cost_matrix
+
+ def __calculate_cost_matrix(self, targets: np.ndarray, detections: np.ndarray) -> np.ndarray:
+ cost_matrix = np.zeros((targets.shape[0], detections.shape[0]))
+ for targets_index in range(targets.shape[0]):
+ for detections_index in range(detections.shape[0]):
+ cost_matrix[targets_index][detections_index] = self.__calculate_iou_yolo(
+ targets[targets_index], detections[detections_index])
+
+ return cost_matrix
+
+ def __calculate_iou_yolo(self, target: np.ndarray, detection: np.ndarray) -> float:
+ Rect = namedtuple("Rect", "x y width height")
+ lhs = Rect(target[2], target[3], target[4], target[5])
+ rhs = Rect(detection[2], detection[3], detection[4], detection[5])
+
+ intersection_width = min(
+ lhs.x + lhs.width, rhs.x + rhs.width) - max(lhs.x, rhs.x)
+ intersection_height = min(
+ lhs.y + lhs.height, rhs.y + rhs.height) - max(lhs.y, rhs.y)
+
+ if intersection_width <= 0 or intersection_height <= 0:
+ return 0
+
+ intersection = intersection_width * intersection_height
+ union = ((lhs.width * lhs.height) +
+ (rhs.width * rhs.height)) - intersection
+
+ return intersection / union