-
Notifications
You must be signed in to change notification settings - Fork 9
Description
Hi,
I was interested in trying to train a FIF on MNIST down to a latent dimensionality of 128, or even 12 to compare the compression.
I am using an encoder/decoder, similar to Table 11: https://arxiv.org/pdf/2306.01843. However, I am using a slightly different network, where the hidden dimensionality of the decoder and encoder are the same. If interested, I have the gist.
During training of the model with beta=100, I am observing that the NLL consistently decreases, while the reconstruction loss stagnates. For example, here is a snippet of my training log:
Epoch 446: 0%| | 0/54 [00:00<?, ?it/s]
train_loss: -311.244 | recon_loss: 0.920 | nll_loss: -403.208 | surrogate_loss: 0.677
Epoch 447: 0%| | 0/54 [00:00<?, ?it/s]
train_loss: -312.497 | recon_loss: 0.922 | nll_loss: -404.648 | surrogate_loss: 1.268
Epoch 448: 0%| | 0/54 [00:00<?, ?it/s]
train_loss: -312.804 | recon_loss: 0.920 | nll_loss: -404.820 | surrogate_loss: 0.932
Epoch 449: 0%| | 0/54 [00:00<?, ?it/s]
train_loss: -313.825 | recon_loss: 0.921 | nll_loss: -405.916 | surrogate_loss: 1.183
Epoch 449: 100%|██████████| 54/54 [00:04<00:00, 11.23it/s] val_loss: -313.076
Epoch 450: 0%| | 0/54 [00:00<?, ?it/s] t/s]
train_loss: -313.134 | recon_loss: 0.921 | nll_loss: -405.244 | surrogate_loss: 0.111
Epoch 451: 0%| | 0/54 [00:00<?, ?it/s]
train_loss: -313.543 | recon_loss: 0.925 | nll_loss: -406.025 | surrogate_loss: 0.120
Epoch 452: 0%| | 0/54 [00:00<?, ?it/s]
train_loss: -314.921 | recon_loss: 0.921 | nll_loss: -407.062 | surrogate_loss: 0.177
Epoch 453: 0%| | 0/54 [00:00<?, ?it/s]
train_loss: -316.010 | recon_loss: 0.923 | nll_loss: -408.296 | surrogate_loss: 0.582
Epoch 454: 0%| | 0/54 [00:00<?, ?it/s]
train_loss: -317.485 | recon_loss: 0.921 | nll_loss: -409.606 | surrogate_loss: 1.434
Epoch 454: 100%|██████████| 54/54 [00:04<00:00, 11.77it/s] val_loss: -317.109
Epoch 455: 0%| | 0/54 [00:00<?, ?it/s] t/s]
train_loss: -317.156 | recon_loss: 0.920 | nll_loss: -409.162 | surrogate_loss: 0.394
As a result, my overall training loss decreases, but my reconstruction loss looks like it plateaus around 0.920. It seems this problem is alluded to in the paper, but it suggests that higher betas will work. I am just wondering if there is any intuition on how to fix this problem?
The training seems to just focus on minimizing negative log likelihood, and so the sampled images look fine, but not great. The output of a sampled image:
@torch.no_grad()
def sample(self, num_samples=16, **params):
"""
Sample a batch of images from the flow.
"""
# sample latent space and reshape to (batches, 1, embed_dim)
v = self.latent.sample(num_samples, **params)
v = v.reshape(num_samples, 1, -1)
return self.decoder(v)
