-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrecognize.py
More file actions
87 lines (70 loc) · 3 KB
/
recognize.py
File metadata and controls
87 lines (70 loc) · 3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import glob
import argparse
import numpy as np
import PIL.Image as Image
from sklearn import preprocessing
import torch
import torchvision.transforms as transforms
from torchmetrics import CharErrorRate
from nn import CaptchaModel
from util import decode_predictions
def get_image(data_path, image_path, width=200, height=50):
image = Image.open(os.path.join(data_path, image_path)).convert("RGB")
image = image.resize(
(width, height), resample=Image.BILINEAR
)
image = np.array(image)
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean, std, inplace=True)
])
image = transform(image)
image = image.unsqueeze(0)
return image
def calc_cer(images_path, weights_path):
images_paths = glob.glob(os.path.join(images_path, "*.png"))
targets = [x.split("/")[-1][:-4] for x in images_paths]
targets_chars = [[c for c in x] for x in targets]
targets_flat = [c for clist in targets_chars for c in clist]
lbl_enc = preprocessing.LabelEncoder()
lbl_enc.fit(targets_flat)
model = CaptchaModel(num_chars=len(lbl_enc.classes_))
model.load_state_dict(torch.load(weights_path))
preds = []
with torch.no_grad():
model.eval()
for image_path in images_paths:
image = get_image(image_path.split("/")[-2], image_path.split("/")[-1])
pred, _ = model(image)
current_preds = decode_predictions(pred, lbl_enc)
preds += current_preds
cer = CharErrorRate()
return cer(preds, targets).item()
def recognize(opt):
image_files = glob.glob(os.path.join(opt.data_path, "*.png"))
targets_orig = [x.split("/")[-1][:-4] for x in image_files]
targets = [[c for c in x] for x in targets_orig]
targets_flat = [c for clist in targets for c in clist]
lbl_enc = preprocessing.LabelEncoder()
lbl_enc.fit(targets_flat)
image = get_image(opt.data_path, opt.image_path, opt.width, opt.height)
model = CaptchaModel(num_chars=len(lbl_enc.classes_))
model.load_state_dict(torch.load(opt.saved_model))
with torch.no_grad():
model.eval()
pred, _ = model(image)
current_preds = decode_predictions(pred, lbl_enc)
print(f'Ground truth: {opt.image_path[:-4]}\n'
f'Prediction: {current_preds[0]}')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Captcha recognition using RNN')
parser.add_argument('--data_path', type=str, default='samples', help='path to data folder')
parser.add_argument('--saved_model', type=str, default='weights/weights.pth', help='path to saved model')
parser.add_argument('--height', type=int, default=50, help='height of the input image')
parser.add_argument('--width', type=int, default=200, help='width of the input image')
parser.add_argument('--image_path', type=str, default='n8pfe.png', help='path to test image')
opt = parser.parse_args()
recognize(opt)