-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_mu.py
More file actions
130 lines (108 loc) · 4.22 KB
/
train_mu.py
File metadata and controls
130 lines (108 loc) · 4.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from model import *
import torch
import gzip
import pickle
import os
import random
from alive_progress import alive_bar
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
ident = "lp"
idf = f"data_{ident}"
flist_train = os.listdir(f'./{idf}/train')
flist_valid = os.listdir(f'./{idf}/valid')
best_loss = 1e+20
mdl = framework_fixed_mu(2,2,64,4)
last_epoch=0
if os.path.exists(f"./model/best_model_fixed_mu_{ident}.mdl"):
checkpoint = torch.load(f"./model/best_model_fixed_mu_{ident}.mdl")
mdl.load_state_dict(checkpoint['model'])
if 'nepoch' in checkpoint:
last_epoch=checkpoint['nepoch']
best_loss=checkpoint['best_loss']
print(f'Last best val loss gen: {best_loss}')
print('Model Loaded')
loss_func = torch.nn.MSELoss()
optimizer = torch.optim.Adam(mdl.parameters(), lr=1e-4)
max_epoch = 10000
flog = open(f'./logs/train_log_fixed_mu_{ident}.log','w')
eps=0.2
for epoch in range(last_epoch, max_epoch):
avg_loss=[0,0,0]
random.shuffle(flist_train)
with alive_bar(len(flist_train),title=f"Training Epoch:{epoch}") as bar:
for fnm in flist_train:
# train
# reading
f = gzip.open(f'./data_{ident}/train/{fnm}','rb')
A,v,c,sol,dual,obj = pickle.load(f)
A = torch.as_tensor(A,dtype=torch.float32)
amx = torch.max(A)
m = A.shape[0]
mu = 1/eps * torch.log(m*amx/eps)
x = torch.as_tensor(v,dtype=torch.float32)
y = torch.as_tensor(c,dtype=torch.float32)
x_gt = torch.as_tensor(sol,dtype=torch.float32)
y_gt = torch.as_tensor(dual,dtype=torch.float32)
f.close()
# apply gradient
optimizer.zero_grad()
x,y = mdl(A,x,y,mu)
x_gt = x_gt.unsqueeze(-1)
loss_x = loss_func(x, x_gt)
avg_loss[0] += loss_x.item()
print(loss_x.item())
# loss_y = loss_func(y, y_gt)
# avg_loss[1] += loss_y.item()
# loss = loss_x+loss_y
loss = loss_x
# avg_loss[2] += loss.item()
loss.backward()
optimizer.step()
bar()
avg_loss[0] /= round(len(flist_train),2)
avg_loss[1] /= round(len(flist_train),2)
avg_loss[2] /= round(len(flist_train),2)
# print(f'Epoch {epoch} Train:::: primal loss:{avg_loss[0]}, dual loss:{avg_loss[1]}, total loss:{avg_loss[2]}')
print(f'Epoch {epoch} Train:::: primal loss:{avg_loss[0]}')
st = f'{avg_loss[0]} '
flog.write(st)
avg_loss=[0,0,0]
with alive_bar(len(flist_valid),title=f"Valid Epoch:{epoch}") as bar:
for fnm in flist_valid:
# valid
# reading
f = gzip.open(f'./data_{ident}/valid/{fnm}','rb')
A,v,c,sol,dual,obj = pickle.load(f)
A = torch.as_tensor(A,dtype=torch.float32)
amx = torch.max(A)
m = A.shape[0]
mu = 1/eps * torch.log(m*amx/eps)
x = torch.as_tensor(v,dtype=torch.float32)
y = torch.as_tensor(c,dtype=torch.float32)
x_gt = torch.as_tensor(sol,dtype=torch.float32)
y_gt = torch.as_tensor(dual,dtype=torch.float32)
f.close()
# obtain loss
x,y = mdl(A,x,y,mu)
loss_x = loss_func(x, x_gt)
avg_loss[0] += loss_x.item()
# loss_y = loss_func(y, y_gt)
# avg_loss[1] += loss_y.item()
# loss = loss_x+loss_y
# avg_loss[2] += loss.item()
bar()
avg_loss[0] /= round(len(flist_valid),2)
avg_loss[1] /= round(len(flist_valid),2)
avg_loss[2] /= round(len(flist_valid),2)
# print(f'Epoch {epoch} Valid:::: primal loss:{avg_loss[0]}, dual loss:{avg_loss[1]}, total loss:{avg_loss[2]}')
print(f'Epoch {epoch} Valid:::: primal loss:{avg_loss[0]}')
st = f'{avg_loss[0]}\n'
flog.write(st)
if best_loss > avg_loss[0]:
best_loss = avg_loss[0]
state={'model':mdl.state_dict(),'optimizer':optimizer.state_dict(),'best_loss':best_loss,'nepoch':epoch}
torch.save(state,f'./model/best_model_fixed_mu_{ident}.mdl')
print(f'Saving new best model with valid loss: {best_loss}')
flog.flush()
flog.close()