-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy patheval.py
More file actions
126 lines (104 loc) · 4.29 KB
/
eval.py
File metadata and controls
126 lines (104 loc) · 4.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import argparse
from typing import List
from pathlib import Path
import torch
from torchvision import transforms
import numpy as np
from PIL import Image
from skimage.metrics import peak_signal_noise_ratio as psnr
import lpips
from pytorch_fid import fid_score
from pytorch_msssim import ssim
def tag(name:str):
def wrapper(func):
func.tag = name
return func
return wrapper
class Factory(object):
def __init__(self, name: List[str]):
self.name = name
methods = {func for func in dir(self) if callable(getattr(self, func)) and hasattr(getattr(self, func), 'tag')}
self.tagged_method = {getattr(self, func).tag : getattr(self, func) for func in methods}
self._call_func = self.get_method(name)
def retrieve(self, input_dir, pred_dir):
input_path = sorted(list(Path(input_dir).glob('*.png'))) + sorted(list(Path(input_dir).glob('*.jpg')))
pred_path = sorted(list(Path(pred_dir).glob('*.png'))) + sorted(list(Path(pred_dir).glob('*.jpg')))
return input_path, pred_path
def __call__(self, *args, **kwargs):
output = []
for _func in self._call_func:
output.append(_func(*args, **kwargs))
return output
def get_method(self, name: list[str]):
methods = []
for n in name:
if n not in self.tagged_method:
raise ValueError(f'Cannot find {self.__class__.__name__} ({n})')
else:
methods.append(self.tagged_method[n])
return methods
class Metric(Factory):
@tag('psnr')
def _psnr(self, input_path, pred_path, transform=None, data_range:int=255, **kwargs):
if transform is None:
transform = transforms.Compose([
transforms.ToTensor()
])
values = []
in_fs, pred_fs = self.retrieve(input_path, pred_path)
for in_f, pred_f in zip(in_fs, pred_fs):
try:
img1 = np.array(transform(Image.open(in_f).convert('RGB'))) * data_range
img2 = np.array(transform(Image.open(pred_f).convert('RGB'))) * data_range
values.append(psnr(img1, img2, data_range=data_range))
except:
continue
return np.mean(values)
@tag('ssim')
def _ssim(self, input_path, pred_path, transform=None, data_range:int=255, **kwargs):
if transform is None:
transform = transforms.Compose([
transforms.ToTensor()
])
values = []
in_fs, pred_fs = self.retrieve(input_path, pred_path)
for in_f, pred_f in zip(in_fs, pred_fs):
try:
img1 = transform(Image.open(in_f).convert('RGB')).unsqueeze(0) * data_range
img2 = transform(Image.open(pred_f).convert('RGB')).unsqueeze(0) * data_range
values.append(ssim(img1, img2).item())
except:
continue
return np.mean(values)
@tag('fid')
def _fid(self, pred_path, label_path, **kwargs):
return fid_score.calculate_fid_given_paths([str(pred_path), str(label_path)], 50, 'cuda', 2048).item()
@tag('lpips')
def _lpips(self, input_path, pred_path, transform=None, **kwargs):
lpips_fn = lpips.LPIPS(net='vgg').to('cuda').eval()
if transform is None:
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
values = []
in_fs, pred_fs = self.retrieve(input_path, pred_path)
for in_f, pred_f in zip(in_fs, pred_fs):
try:
img1 = transform(Image.open(in_f).convert('RGB')).to('cuda')
img2 = transform(Image.open(pred_f).convert('RGB')).to('cuda')
values.append(lpips_fn(img1, img2).item())
except:
continue
return np.mean(values)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--path1', type=Path)
parser.add_argument('--path2', type=Path)
parser.add_argument('--metric', type=str, nargs='+')
parser.add_argument('--prompt', type=str)
args = parser.parse_args()
metric = Metric(args.metric)
output = metric(args.path1, args.path2, prompt=args.prompt)
for m, o in zip(args.metric, output):
print(f'{m}: {o}')