diff --git a/denoiser.py b/denoiser.py index e6c3ff2..2d81fb8 100644 --- a/denoiser.py +++ b/denoiser.py @@ -89,13 +89,19 @@ def generate(self, labels): @torch.no_grad() def _forward_sample(self, z, t, labels): + # conditional x_cond = self.net(z, t.flatten(), labels) - v_cond = (x_cond - z) / (1.0 - t).clamp_min(self.t_eps) # unconditional x_uncond = self.net(z, t.flatten(), torch.full_like(labels, self.num_classes)) - v_uncond = (x_uncond - z) / (1.0 - t).clamp_min(self.t_eps) + + # use a tiny eps at sampling, not the training t_eps. + # this prevents evident leftover noise in some generations. + den = (1.0 - t).clamp_min(getattr(self, "sample_eps", 1e-5)) + + v_cond = (x_cond - z) / den + v_uncond = (x_uncond - z) / den # cfg interval low, high = self.cfg_interval