-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathload_mvtec_loco.py
More file actions
executable file
·109 lines (71 loc) · 3.44 KB
/
load_mvtec_loco.py
File metadata and controls
executable file
·109 lines (71 loc) · 3.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torchvision
import torchvision.transforms as transforms
import torch
from torch.utils.data import Dataset,DataLoader
import numpy as np
import PIL.Image as Image
import os
def default_loader(path):
return Image.open(path).convert('RGB')
def find_classes(dir):
classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))]
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
class MyDataset(Dataset):
def __init__(self, parent_path, which_set, class_name, anom_type, img_size=1024, img_resize=256, is_loco = True):
transform = transforms.Compose([
transforms.Resize((img_resize,img_resize)),
transforms.CenterCrop(img_size),
transforms.ToTensor()
])
if is_loco:
data_path = '../dataset_loco/'
else:
data_path = '../dataset_mvtec/'
if which_set == 'train':
fold_path = os.path.join(parent_path, data_path, class_name, "train")
if which_set == 'validation':
fold_path = os.path.join(parent_path, data_path, class_name, "validation")
if which_set == 'test':
fold_path = os.path.join(parent_path, data_path, class_name, "test")
dataset = torchvision.datasets.ImageFolder(fold_path, transform)
trainloader = torch.utils.data.DataLoader(dataset, batch_size=1,
shuffle=False, num_workers=0,drop_last=False)
target_list = np.zeros(len(trainloader))
imgs = torch.zeros((len(trainloader),3, img_size, img_size))
label_list = torch.zeros((len(trainloader)))
is_relevant_list = np.zeros((len(trainloader)))
for batch_idx, (inputs, targets) in enumerate(trainloader):
if (which_set is 'test'):
if ((os.path.join('/test/',anom_type) in dataset.imgs[batch_idx][0]) or ('/test/good' in dataset.imgs[batch_idx][0])) or (anom_type == 'all'):
is_relevant_list[batch_idx] = 1
else:
is_relevant_list[batch_idx] = 0
else:
is_relevant_list[batch_idx] = 1
if (which_set is 'test') and (int(dataset.imgs[batch_idx][1]) > 0): #If there is gt mask: anomlous test data
label = 1
else:
label = 0
imgs[batch_idx] = inputs[0]
target_list[batch_idx] = (targets.item())
label_list[batch_idx] = label
self.targets = np.array(target_list)
self.imgs = imgs
relevant_inds = np.where(is_relevant_list==1)
self.targets = np.array(label_list)[relevant_inds]
self.imgs = imgs[relevant_inds]
def __getitem__(self, index):
img = self.imgs[index]
label = self.targets[index]
return img, label
def __len__(self):
return len(self.imgs)
def get_mvt_loader(parent_path, which_set = 'train', class_name = "breakfast_box", anom_type = "logical_anomalies", img_size=1024, img_resize=1024, is_loco = True):
mvt_data_in = MyDataset(parent_path, which_set, class_name, anom_type, img_size, img_resize, is_loco = is_loco)
mvt_loader = torch.utils.data.DataLoader(
mvt_data_in,
batch_size=1, shuffle=False,
num_workers=0)
return mvt_loader