diff --git a/model.py b/model.py index c34dc35..7cb84c3 100644 --- a/model.py +++ b/model.py @@ -86,12 +86,12 @@ def forward(self, x): #encoder enc_t = self.enc(torch.cat([phi_x_t, h[-1]], 1)) enc_mean_t = self.enc_mean(enc_t) - enc_std_t = self.enc_std(enc_t) + enc_std_t = self.enc_std(enc_t) + 1e-5 #prior prior_t = self.prior(h[-1]) prior_mean_t = self.prior_mean(prior_t) - prior_std_t = self.prior_std(prior_t) + prior_std_t = self.prior_std(prior_t) + 1e-5 #sampling and reparameterization z_t = self._reparameterized_sample(enc_mean_t, enc_std_t) @@ -100,7 +100,7 @@ def forward(self, x): #decoder dec_t = self.dec(torch.cat([phi_z_t, h[-1]], 1)) dec_mean_t = self.dec_mean(dec_t) - dec_std_t = self.dec_std(dec_t) + dec_std_t = self.dec_std(dec_t) + 1e-5 #recurrence _, h = self.rnn(torch.cat([phi_x_t, phi_z_t], 1).unsqueeze(0), h) @@ -125,6 +125,7 @@ def sample(self, seq_len): sample = torch.zeros(seq_len, self.x_dim) h = Variable(torch.zeros(self.n_layers, 1, self.h_dim)) + for t in range(seq_len): #prior @@ -172,7 +173,7 @@ def _kld_gauss(self, mean_1, std_1, mean_2, std_2): kld_element = (2 * torch.log(std_2) - 2 * torch.log(std_1) + (std_1.pow(2) + (mean_1 - mean_2).pow(2)) / - std_2.pow(2) - 1) + (std_2).pow(2) - 1) return 0.5 * torch.sum(kld_element) diff --git a/train.py b/train.py index 6ddfe9c..e784767 100644 --- a/train.py +++ b/train.py @@ -22,31 +22,33 @@ def train(epoch): #data = Variable(data) #to remove eventually data = Variable(data.squeeze().transpose(0, 1)) - data = (data - data.min().data[0]) / (data.max().data[0] - data.min().data[0]) - + #data = (data - data.min().item()) / (data.max().item() - data.min().item()) + #forward + backward + optimize optimizer.zero_grad() kld_loss, nll_loss, _, _ = model(data) loss = kld_loss + nll_loss loss.backward() - optimizer.step() #grad norm clipping, only in pytorch version >= 1.10 - nn.utils.clip_grad_norm(model.parameters(), clip) + nn.utils.clip_grad_norm_(model.parameters(), clip) + + optimizer.step() + #printing if batch_idx % print_every == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\t KLD Loss: {:.6f} \t NLL Loss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), - kld_loss.data[0] / batch_size, - nll_loss.data[0] / batch_size)) + kld_loss.item() / batch_size, + nll_loss.item() / batch_size)) - sample = model.sample(28) - plt.imshow(sample.numpy()) - plt.pause(1e-6) + #sample = model.sample(28) + #plt.imshow(sample.numpy()) + #plt.pause(1e-6) - train_loss += loss.data[0] + train_loss += loss.item() print('====> Epoch: {} Average loss: {:.4f}'.format( @@ -62,11 +64,11 @@ def test(epoch): #data = Variable(data) data = Variable(data.squeeze().transpose(0, 1)) - data = (data - data.min().data[0]) / (data.max().data[0] - data.min().data[0]) + #data = (data - data.min().item()) / (data.max().item() - data.min().item()) kld_loss, nll_loss, _, _ = model(data) - mean_kld_loss += kld_loss.data[0] - mean_nll_loss += nll_loss.data[0] + mean_kld_loss += kld_loss.item() + mean_nll_loss += nll_loss.item() mean_kld_loss /= len(test_loader.dataset) mean_nll_loss /= len(test_loader.dataset)