From 49f9be6bad0a316e94378da1eb7e6354a3adfd3c Mon Sep 17 00:00:00 2001 From: krenerd <48239275+krenerd@users.noreply.github.com> Date: Tue, 8 Nov 2022 22:08:41 +0900 Subject: [PATCH] b: update deprecated code in datasets, metric. --- dataset/cifar100.py | 16 +--------------- dataset/imagenet.py | 8 +------- helper/util.py | 2 +- 3 files changed, 3 insertions(+), 23 deletions(-) diff --git a/dataset/cifar100.py b/dataset/cifar100.py index 7f59aafe..2f3c4a4b 100644 --- a/dataset/cifar100.py +++ b/dataset/cifar100.py @@ -40,21 +40,7 @@ class CIFAR100Instance(datasets.CIFAR100): """CIFAR100Instance Dataset. """ def __getitem__(self, index): - if self.train: - img, target = self.train_data[index], self.train_labels[index] - else: - img, target = self.test_data[index], self.test_labels[index] - - # doing this so that it is consistent with all other datasets - # to return a PIL Image - img = Image.fromarray(img) - - if self.transform is not None: - img = self.transform(img) - - if self.target_transform is not None: - target = self.target_transform(target) - + img, target = super().__getitem__(index) return img, target, index diff --git a/dataset/imagenet.py b/dataset/imagenet.py index 47d6de3b..4a4a6496 100644 --- a/dataset/imagenet.py +++ b/dataset/imagenet.py @@ -39,13 +39,7 @@ def __getitem__(self, index): Returns: tuple: (image, target) where target is class_index of the target class. """ - path, target = self.imgs[index] - img = self.loader(path) - if self.transform is not None: - img = self.transform(img) - if self.target_transform is not None: - target = self.target_transform(target) - + img, target = super().__getitem__(index) return img, target, index diff --git a/helper/util.py b/helper/util.py index 4d412209..3d01b345 100644 --- a/helper/util.py +++ b/helper/util.py @@ -52,7 +52,7 @@ def accuracy(output, target, topk=(1,)): res = [] for k in topk: - correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res