-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathModel.py
More file actions
122 lines (103 loc) · 3.74 KB
/
Model.py
File metadata and controls
122 lines (103 loc) · 3.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from eldr.models.autoencoder import *
from eldr.models.vae.train_torch import *
from eldr.train import train_ae, train_scvis
from eldr.data import *
from torch.utils.data import Dataset, DataLoader
from types import SimpleNamespace
import json
import os
import sys
from eldr.models.vae.utils import sample_reparameterize
class Model(object):
def __init__(self, model, model_type):
"""
Model class initialzes the model used for learning low-dimensional representations
model: trained model (vae or autoencoder)
model_type: vae or encoder
Methods of the class:
a) Encode(input_): maps the input_ to low-dimensional latent space. This is used for the whole dataset.
b) Encode_ones(input) : maps a batch (input) to low-dimensional latent space
"""
self.model = model
self.model_type = model_type
@classmethod
def Initialize(cls, model_type, input_, pretrained_path=None, config=None):
"""
Initialize the low-dimensional representation learning model
model_type: either autoencoder or vae (variational autoencoder)
input_ : data to train the model on (used in case of autoencoder)
pretrained_path : path to the pretrained model
config: Python Namespace object with setting information to train the models
"""
if model_type != 'autoencoder' and model_type != 'vae':
sys.exit("model_type wrong, provide right model type from: [autoencoder, vae]")
if model_type == 'autoencoder':
if pretrained_path == None:
"""
Train the model and load the best model
"""
print("Wait, the model is in training...")
#Load the config file for the autoencoder
path = os.path.join('./configs', str(model_type) + '.json')
config = json.load(open(path, 'r'))
config = SimpleNamespace(**config)
print(config)
"""
train the model and return the best model.
"""
model = train_ae(input_,\
encoder_shape=config.encoder_shape,\
output_dim=config.output_dim,\
decoder_shape=config.decoder_shape,\
learning_rate=config.learning_rate,\
batch_size=config.batch_size,\
min_epochs=config.min_epochs,\
stopping_epochs=config.stopping_epochs,\
tol=config.tol,\
eval_freq=config.eval_freq)
else:
#Use the pretrained model placed at the pretrained_path
#for now the whole model is saved after training, so
#this doesn't require any args while loading.
print("Loading the pretrained model...")
model = torch.load(pretrained_path)
if model_type == 'vae':
if pretrained_path == None:
print("Wait, the model is in training")
model = train_scvis(
dataset = config.dataset,
features_path=config.features_path,
labels_path=config.labels_path,
model_dir=config.model_dir,
batch_size=config.batch_size,
min_epochs=config.min_epochs,
stopping_epochs=config.stopping_epochs,
tol=config.tol,
eval_freq=config.eval_freq,
lr=config.lr,)
else:
print("Loading the pretrained model...")
model = torch.load(pretrained_path, map_location=torch.device('cpu'))
return cls(model, model_type)
def Encode(self, input_):
"""
Encode the input_ into low dimensional representation
"""
recons = torch.empty(input_.shape[0], 2)
self.model.eval()
dl = DataLoader(Data(input_), batch_size = 1, shuffle = False)
for i, batch in enumerate(dl,0):
input = batch.float()
if self.model_type != 'vae':
recon = self.model.encoder(input)
recons[i,:] = recon.data.view(-1,2)
else:
means, log_stds = self.model.encoder(input)
recons[i,:] = means.data.view(-1,2)
return recons
def Encode_ones(self, input):
if self.model_type != 'vae':
return self.model.encoder(input)
else:
means, log_stds = self.model.encoder(input)
return means