-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathpretrain_joint.py
More file actions
157 lines (125 loc) · 6.68 KB
/
pretrain_joint.py
File metadata and controls
157 lines (125 loc) · 6.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import argparse
import os
import numpy as np
import datetime
import pandas as pd
import torch
from torch import distributions
from lib.dataloader import dataloader
from src.icnn import FICNN, PICNN
from src.mapficnn import MapFICNN
from src.pcpmap import PCPMap
from lib.utils import makedirs, get_logger, AverageMeter
"""
argument parser for hyper parameters and model handling
"""
parser = argparse.ArgumentParser('PCP-Map')
parser.add_argument(
'--data', choices=['wt_wine', 'rd_wine', 'parkinson'], type=str, default='rd_wine'
)
parser.add_argument('--input_x_dim', type=int, default=6, help="input data convex dimension")
parser.add_argument('--input_y_dim', type=int, default=5, help="input data non-convex dimension")
parser.add_argument('--out_dim', type=int, default=1, help="output dimension")
parser.add_argument('--clip', type=bool, default=True, help="whether clipping the weights or not")
parser.add_argument('--tol', type=float, default=1e-12, help="LBFGS tolerance")
parser.add_argument('--test_ratio', type=float, default=0.10, help="test set ratio")
parser.add_argument('--valid_ratio', type=float, default=0.10, help="validation set ratio")
parser.add_argument('--random_state', type=int, default=42, help="random state for splitting dataset")
parser.add_argument('--num_epochs', type=int, default=3, help="pilot run number of epochs")
parser.add_argument('--save', type=str, default='experiments/tabjoint', help="define the save directory")
args = parser.parse_args()
sStartTime = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
# logger
makedirs(args.save)
logger = get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__), saving=True)
logger.info("start time: " + sStartTime)
logger.info(args)
# GPU Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if __name__ == '__main__':
columns_params = ["batchsz", "lr", "width", "width_y", "depth"]
columns_valid = ["ficnn_nll", "picnn_nll", "tot_nll"]
params_hist = pd.DataFrame(columns=columns_params)
valid_hist = pd.DataFrame(columns=columns_valid)
log_msg = ('{:5s} {:9s} {:9s} {:9s}'.format('trial', 'val_ficnn', 'val_picnn', 'val_loss'))
logger.info(log_msg)
# sample space for hyperparameters
width_list = np.array([32, 64, 128, 256, 512])
depth_list = np.array([2, 3, 4, 5, 6])
batch_size_list = np.array([32, 64])
lr_list = np.array([0.01, 0.005, 0.001])
for trial in range(100):
batch_size = int(np.random.choice(batch_size_list))
train_loader, valid_loader, _, train_size = dataloader(args.data, batch_size, args.test_ratio,
args.valid_ratio, args.random_state)
if args.clip is True:
reparam = False
else:
reparam = True
width = np.random.choice(width_list)
width_y_list = [width, args.input_y_dim]
feat_dim = width
while feat_dim // 2 > args.input_y_dim:
feat_dim = feat_dim // 2
width_y_list.append(feat_dim)
width_y = np.random.choice(width_y_list)
num_layers = np.random.choice(depth_list)
lr = np.random.choice(lr_list)
# Multivariate Gaussian as Reference
prior_ficnn = distributions.MultivariateNormal(torch.zeros(args.input_y_dim).to(device),
torch.eye(args.input_y_dim).to(device))
prior_picnn = distributions.MultivariateNormal(torch.zeros(args.input_x_dim).to(device),
torch.eye(args.input_x_dim).to(device))
# build FICNN map and PCP-Map
ficnn = FICNN(args.input_y_dim, width, args.out_dim, num_layers, reparam=reparam)
picnn = PICNN(args.input_x_dim, args.input_y_dim, width, width_y, args.out_dim, num_layers, reparam=reparam)
map_ficnn = MapFICNN(prior_ficnn, ficnn).to(device)
map_picnn = PCPMap(prior_picnn, picnn).to(device)
optimizer1 = torch.optim.Adam(map_ficnn.parameters(), lr=lr)
optimizer2 = torch.optim.Adam(map_picnn.parameters(), lr=lr)
params_hist.loc[len(params_hist.index)] = [batch_size, lr, width, width_y, num_layers]
if args.data == 'parkinson' or args.data == 'wt_wine':
num_epochs = args.num_epochs
else:
num_epochs = 4
for epoch in range(num_epochs):
for sample in train_loader:
x = sample[:, args.input_y_dim:].requires_grad_(True).to(device)
y = sample[:, :args.input_y_dim].requires_grad_(True).to(device)
# optimizer step for flow1
optimizer1.zero_grad()
loss1 = -map_ficnn.loglik_ficnn(y).mean()
loss1.backward()
optimizer1.step()
# non-negative constraint
if args.clip is True:
for lz in map_ficnn.ficnn.Lz:
with torch.no_grad():
lz.weight.data = map_ficnn.ficnn.nonneg(lz.weight)
# optimizer step for flow2
optimizer2.zero_grad()
loss2 = -map_picnn.loglik_picnn(x, y).mean()
loss2.backward()
optimizer2.step()
# non-negative constraint
if args.clip is True:
for lw in map_picnn.picnn.Lw:
with torch.no_grad():
lw.weight.data = map_picnn.picnn.nonneg(lw.weight)
valLossMeterFICNN = AverageMeter()
valLossMeterPICNN = AverageMeter()
for valid_sample in valid_loader:
x_valid = valid_sample[:, args.input_y_dim:].requires_grad_(True).to(device)
y_valid = valid_sample[:, :args.input_y_dim].requires_grad_(True).to(device)
mean_valid_loss_ficnn = -map_ficnn.loglik_ficnn(y_valid).mean()
mean_valid_loss_picnn = -map_picnn.loglik_picnn(x_valid, y_valid).mean()
valLossMeterFICNN.update(mean_valid_loss_ficnn.item(), valid_sample.shape[0])
valLossMeterPICNN.update(mean_valid_loss_picnn.item(), valid_sample.shape[0])
val_loss_ficnn = valLossMeterFICNN.avg
val_loss_picnn = valLossMeterPICNN.avg
val_loss = val_loss_ficnn + val_loss_picnn
log_message = '{:05d} {:9.3e} {:9.3e} {:9.3e} '.format(trial+1, val_loss_ficnn, val_loss_picnn, val_loss)
logger.info(log_message)
valid_hist.loc[len(valid_hist.index)] = [val_loss_ficnn, val_loss_picnn, val_loss]
params_hist.to_csv(os.path.join(args.save, '%s_params_hist.csv' % args.data))
valid_hist.to_csv(os.path.join(args.save, '%s_valid_hist.csv' % args.data))