-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcount_a_mlp.py
More file actions
86 lines (72 loc) · 2.85 KB
/
count_a_mlp.py
File metadata and controls
86 lines (72 loc) · 2.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as np
from tqdm import tqdm
import wandb
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.dense = nn.Linear(4800, 2400)
self.dense2 = nn.Linear(2400, 1200)
self.dense3 = nn.Linear(1200,600)
self.dense4 = nn.Linear(600, 300)
self.dense5 = nn.Linear(300,1)
def forward(self, input):
d1 = F.sigmoid(self.dense(input.flatten(start_dim=1,end_dim=2)))
d2 = F.sigmoid(self.dense2(d1))
d3 = F.sigmoid(self.dense3(d2))
d4 = F.sigmoid(self.dense4(d3))
d5 = self.dense5(d4)
return d5
def seq_to_one_hot(seq):
indexes = {"A": 0, "C": 1, "G": 2, "T": 3, "U": 3}
encoded_seq = torch.zeros([4, len(seq)])
for i, nuc in enumerate(seq):
encoded_seq[indexes[nuc], i] = 1
return encoded_seq.unsqueeze(0)
def count_a(seq):
return seq.count("A")
if __name__=="__main__":
nucs = ["A","C","G","T"]
sequences_train = ["".join(random.choices(nucs,k=1200,weights=[random.random() for i in range(4)])) for _ in range(20000)]
sequences_test = ["".join(random.choices(nucs,k=1200,weights=[random.random() for i in range(4)])) for _ in range(1000)]
encoded_sequences_train = torch.tensor(np.array([seq_to_one_hot(seq) for seq in sequences_train])).squeeze().float()
encoded_sequences_test = torch.tensor(np.array([seq_to_one_hot(seq) for seq in sequences_test])).squeeze().float()
net = Net()
epochs = 20
optim = torch.optim.AdamW(params = net.parameters())
loss_fn = nn.MSELoss(reduction='mean')
Xtrain = encoded_sequences_train
Ytrain = torch.tensor([count_a(seq)/1200 for seq in sequences_train]).float().unsqueeze(dim=1)
Xtest = encoded_sequences_test
Ytest = torch.tensor([count_a(seq)/1200 for seq in sequences_test]).float().unsqueeze(dim=1)
dataset = torch.utils.data.TensorDataset(Xtrain,Ytrain)
dataloader = torch.utils.data.DataLoader(dataset,batch_size=20,shuffle=True)
wandb.init(
project="count_a",
config={
"sequence_lenth": 1200,
"batch_size": 20,
"architecture": "MLP",
"epochs": epochs,
}
)
for e in tqdm(range(epochs)):
net.eval()
predTest = net(Xtest)
predTrain = net(Xtrain)
lossTest = loss_fn(predTest,Ytest).item()
loss_fn.zero_grad()
lossTrain = loss_fn(predTrain,Ytrain).item()
loss_fn.zero_grad()
wandb.log({"train_loss": lossTrain, "test_loss": lossTest })
net.train()
for i, (batch_X, batch_Y) in enumerate(dataloader):
pred = net(batch_X)
loss = loss_fn(pred,batch_Y)
loss.backward()
optim.step()
optim.zero_grad()
torch.save(net.state_dict(), 'models/first_model.pth')