For the line:
latents[0] = reference_samples
The error reported in title is being thrown.
I added a bunch of print statements around it, which make the issue clear. What am i missing?
Shape of reference_samples: torch.Size([3, 4, 128, 128])
Shape of latents after zeros_like: torch.Size([3, 4, 128, 128])
Shape of latents after cat: torch.Size([6, 4, 128, 128])
Shape of latents[0]: torch.Size([4, 128, 128])
Shape of reference_samples: torch.Size([3, 4, 128, 128])