Skip to content

Generating image during training #89

@arghavan-kpm

Description

@arghavan-kpm

Hi, thanks for your great work. For my project, I want to define a loss function on the model output in pixel space. For that I need to convert the the vector of generated tokens to an image. I changed the code such that def training_losses() in gaussian_diffusion.py returns model_output (shape: bs x self.diffusion_batch_mul x seq_len, 16) in addition to loss in dict terms. This return value will be propagated all the way to def train_one_epoch() where I change the code like the following:

  ...
  with torch.no_grad():
        if args.use_cached:
            moments = samples
            posterior = DiagonalGaussianDistribution(moments)
        else:
            posterior = vae.encode(samples)

        # normalize the std of latent to be 1. Change it if you use a different tokenizer
        x = posterior.sample().mul_(0.2325)

  ##### MY CHANGE STARTS #####
  # forward
  with torch.cuda.amp.autocast():
      loss, model_output = model(x, labels)

  ### generate image
  if data_iter_step % 100 == 0:
      model_output = model_output.reshape(samples.shape[0], 4, 16, 16, -1)
      model_output = torch.einsum('nmhwc->nmchw', model_output).float()

      for i in range(4):
          with torch.no_grad():
              sampled_images = vae.decode(model_output[:, i, ...] / 0.2325)
  
          torch.distributed.barrier()
          sampled_images = sampled_images.detach().cpu()
          sampled_images = (sampled_images + 1) / 2
          
          gen_img = np.round(np.clip(sampled_images[0].numpy().transpose([1, 2, 0]) * 255, 0, 255))
          gen_img = gen_img.astype(np.uint8)[:, :, ::-1]
          save_path = '...'
          cv2.imwrite(os.path.join(save_path, f'{data_iter_step}_{i}.png'), gen_img)

  ##### MY CHANGE ENDS #####    

  loss_value = loss.item()
  ... 

However, even when I load the pretrained checkpoint, sampled_images is not a meaningful image. I'd really appreciate it if you could help me with where I'm making a mistake.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions