-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
78 lines (63 loc) · 2.84 KB
/
utils.py
File metadata and controls
78 lines (63 loc) · 2.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import sys
import torch
import datetime
import numpy as np
class EarlyStopping(object):
def __init__(self, save_path, patience=10):
dt = datetime.datetime.now()
self.filename = save_path
self.patience = patience
self.counter = 0
self.best_loss = None
self.early_stop = False
def step(self, loss, model, mode, tol, *args):
if all(vio <= tol for vio in args):
if self.best_loss is None:
self.best_loss = loss
self.save_checkpoint(model)
self.counter = 0
elif mode == 'min':
if (loss <= self.best_loss):
self.save_checkpoint(model)
self.best_loss = np.min((loss, self.best_loss))
self.counter = 0
else:
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
elif mode == 'max':
if (loss >= self.best_loss):
self.save_checkpoint(model)
self.best_loss = np.max((loss, self.best_loss))
self.counter = 0
else:
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
else:
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
return self.early_stop
def save_checkpoint(self, model):
torch.save(model.state_dict(), self.filename)
def load_checkpoint(self, model):
model.load_state_dict(torch.load(self.filename))
def obj_fn(x, Q, p):
return 0.5 * torch.bmm(x.permute(0,2,1), torch.bmm(Q, x)) + torch.bmm(p.permute(0, 2, 1), x)
def ineq_dist(x, G, c):
return torch.clamp(torch.bmm(G, x) - c, 0)
def eq_dist(x, A, b):
return torch.abs(b - torch.bmm(A, x))
def lb_dist(x, lb):
return torch.clamp(lb - x, 0)
def ub_dist(x, ub):
return torch.clamp(x - ub, 0)
def primal_dual_loss(x, y, z, Q, p, A0):
primal_residual = torch.linalg.vector_norm((torch.bmm(A0, x)-z), dim=(1,2), keepdim=True)
dual_residual = torch.linalg.vector_norm(torch.bmm(Q, x)+p+torch.bmm(A0.permute(0,2,1), y), dim=(1,2), keepdim=True)
return primal_residual, dual_residual, primal_residual+dual_residual
def aug_lagr(x, z, y, Q, p, A0, rho_vec):
fx = 0.5*torch.bmm(x.permute(0,2,1), torch.bmm(Q, p))+torch.bmm(p.permute(0,2,1), x)
dual_item = torch.bmm(y.permute(0,2,1), torch.bmm(A0, x)-z)
aug_item = 0.5*(torch.bmm((torch.bmm(A0, x)-z).permute(0,2,1), torch.bmm(torch.diag_embed(rho_vec.squeeze(-1)), torch.bmm(A0, x)-z)))
return fx+dual_item+aug_item