-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathv2
More file actions
81 lines (63 loc) · 2.33 KB
/
v2
File metadata and controls
81 lines (63 loc) · 2.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import torch.nn as nn
from torch.nn import functional as F
import math
import numpy as np
class Model:
def __init__(self, inptxt, temp):
self.words = open(inptxt, 'r').read().splitlines()
self.stoi = {}
self.itos = {}
self.vocabsz = 0
self.xs = []
self.ys = []
self.wei = None
self.probs = torch.zeros(())
# map characters to ints
chars = sorted(list(set(''.join(self.words))))
self.stoi = {ch: i + 1 for i, ch in enumerate(chars)}
self.stoi['.'] = 0
self.itos = {i: ch for ch, i in self.stoi.items()}
self.vocabsz = len(self.stoi)
# Prepare training data and targets
xs, ys = [], []
for w in self.words:
chs = ['.'] + list(w) + ['.']
for ch1, ch2 in zip(chs, chs[1:]):
self.xs.append(self.stoi[ch1])
self.ys.append(self.stoi[ch2])
self.xs = torch.tensor(self.xs)
self.ys = torch.tensor(self.ys)
self.wei = torch.randn((27, 27), requires_grad=True)
def forward(self):
xencoded = F.one_hot(self.xs, num_classes=27).float()
logits = xencoded @ self.wei
self.probs = logits.exp()
self.probs = self.probs / self.probs.sum(dim=1, keepdim=True)
return self.probs
def trainloop(self, iters):
for i in range(iters):
self.forward()
loss = -self.probs[torch.arange(len(self.xs)), self.ys].log().mean() + 0.01*(self.wei**2).mean() # calculates negative log likelihood loss
if i % 100 == 0:
print(f'Iteration {i}, Loss: {loss.item()}')
self.wei.grad = None
loss.backward() # backpropagate the loss
self.wei.data -= 50 * self.wei.grad # update weights with a learning rate of 50
def sample(self, numnames, T=1.0):
out = []
while len(out) < numnames:
word = ''
smp = 0
while True:
smp = torch.multinomial(self.probs[smp], 1, replacement=True).item()
ch = self.itos[smp]
if ch == '.':
break
word += ch
if 2 < len(word) < 9:
out.append(word)
return out
m = Model('names.txt', 1)
print(m.trainloop(1000))
print(m.sample(9))