-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference.py
More file actions
39 lines (34 loc) · 1.47 KB
/
inference.py
File metadata and controls
39 lines (34 loc) · 1.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import argparse, torch
from torchvision import transforms, models
from PIL import Image
CLASSES = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
def get_tf(img_size=224):
return transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914,0.4822,0.4465], std=[0.247,0.243,0.261])
])
def load_model(ckpt_path, device):
m = models.resnet18(weights=None)
m.fc = torch.nn.Linear(m.fc.in_features, 10)
sd = torch.load(ckpt_path, map_location=device)['model']
m.load_state_dict(sd); m.eval().to(device)
return m
def predict(img_path, ckpt='runs/cifar_resnet18/best.pt', device='cpu', img_size=224, topk=3):
device = torch.device(device)
tf = get_tf(img_size)
img = Image.open(img_path).convert('RGB')
x = tf(img).unsqueeze(0).to(device)
with torch.no_grad():
logits = load_model(ckpt, device)(x)
probs = torch.softmax(logits, dim=1).squeeze().cpu()
topv, topi = probs.topk(topk)
return [(CLASSES[i], float(v)) for v, i in zip(topv, topi)]
if __name__ == "__main__":
ap = argparse.ArgumentParser()
ap.add_argument("--image", required=True)
ap.add_argument("--ckpt", default="runs/cifar_resnet18/best.pt")
ap.add_argument("--device", default="cpu")
ap.add_argument("--img_size", type=int, default=224)
args = ap.parse_args()
print(predict(args.image, args.ckpt, args.device, args.img_size))