-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcoco_detection.py
More file actions
49 lines (42 loc) · 2 KB
/
coco_detection.py
File metadata and controls
49 lines (42 loc) · 2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import numpy as np
import torch
import torchvision
import os
from torchvision.tv_tensors import BoundingBoxes
from torchvision.transforms.v2 import functional as F
class CocoDetection(torchvision.datasets.CocoDetection):
def __init__(self, img_folder, coco_json_filename, processor, augmentations=None, train=True):
ann_file = os.path.join(img_folder, coco_json_filename)
super(CocoDetection, self).__init__(
img_folder,
ann_file,
transform=None,
target_transform=None)
self.processor = processor
self.augmentations = augmentations
def __getitem__(self, idx):
# read in PIL image and target in COCO format
img, anns = super(CocoDetection, self).__getitem__(idx)
image_id = self.ids[idx]
bboxes = [a["bbox"] for a in anns]
category_ids = [a["category_id"] for a in anns]
if self.augmentations:
# Convert to Tensor and keep metadata
img_tensor = F.to_image(img)
bbox_tensor = torch.tensor(bboxes, dtype=torch.float32)
boxes_tv = BoundingBoxes(bbox_tensor, format="xywh", canvas_size=img_tensor.shape[-2:])
labels_tensor = torch.tensor(category_ids, dtype=torch.int64)
targets = {'boxes': boxes_tv, 'labels': labels_tensor}
img_tensor, targets = self.augmentations(img_tensor, targets)
new_anns = []
for bbox, category_id, ann in zip(targets["boxes"].tolist(), targets["labels"].tolist(), anns):
ann["bbox"] = bbox
ann["category_id"] = category_id
new_anns.append(ann)
anns = new_anns
img = img_tensor
target = {'image_id': image_id, 'annotations': anns}
encoding = self.processor(images=img, annotations=target, return_tensors="pt")
pixel_values = encoding["pixel_values"].squeeze() # remove batch dimension
labels_enc = encoding["labels"][0] # remove batch dimension
return pixel_values, labels_enc