-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathcaptcha_model.py
More file actions
78 lines (62 loc) · 2.01 KB
/
captcha_model.py
File metadata and controls
78 lines (62 loc) · 2.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import string
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
CLASSES = list(string.digits + string.ascii_uppercase)
LETTER_COUNT = 5
class Net(nn.Module):
def __init__(self):
super().__init__()
width = 6
self.conv1 = nn.Conv2d(1, width, 3, 1)
self.conv2 = nn.Conv2d(width, 64, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(33856, 128)
self.fc2 = nn.Linear(128, len(CLASSES))
def forward(self, x):
x = x
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
def split_letters(image, letter_count: int = 5):
w, h = image.size
part_width = w / letter_count
for i in range(letter_count):
yield image.crop((i * part_width, 0, i * part_width + part_width, h))
def load_net(path: str) -> Net:
net = Net()
net.load_state_dict(torch.load(path, map_location=get_device()))
return net
def solve_image(net: Net, image: Image) -> str:
image = image.convert("L")
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5), (0.5))]
)
with torch.no_grad():
net.eval()
images = [transform(x) for x in split_letters(image, letter_count=LETTER_COUNT)]
outputs = net(torch.stack(images))
predictions = outputs.argmax(dim=1, keepdim=True)
return "".join(CLASSES[pred] for pred in predictions)
def get_device():
if torch.cuda.is_available():
dev_name = "cuda:0"
elif torch.backends.mps.is_available():
dev_name = "mps"
else:
dev_name = "cpu"
device = torch.device(dev_name)
return device