From 460d1084389d5ec98d6c433ed21d6c8c15ef403d Mon Sep 17 00:00:00 2001 From: Ragzz258 Date: Wed, 28 Jun 2023 13:08:07 +0000 Subject: [PATCH] added intermediate latent during DDIM inversion for content querying --- masactrl_w_adapter/ddim.py | 8 +++++++- masactrl_w_adapter/masactrl_w_adapter.py | 1 + 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/masactrl_w_adapter/ddim.py b/masactrl_w_adapter/ddim.py index eddff2c..08d4bc2 100644 --- a/masactrl_w_adapter/ddim.py +++ b/masactrl_w_adapter/ddim.py @@ -71,6 +71,7 @@ def sample(self, corrector_kwargs=None, verbose=True, x_T=None, + ld = None, log_every_t=100, unconditional_guidance_scale=1., unconditional_conditioning=None, @@ -114,6 +115,7 @@ def sample(self, append_to_context=append_to_context, cond_tau=cond_tau, style_cond_tau=style_cond_tau, + ld = ld, ) return samples, intermediates @@ -124,7 +126,7 @@ def ddim_sampling(self, cond, shape, mask=None, x0=None, img_callback=None, log_every_t=100, temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, unconditional_guidance_scale=1., unconditional_conditioning=None, features_adapter=None, - append_to_context=None, cond_tau=0.4, style_cond_tau=1.0): + append_to_context=None, cond_tau=0.4, style_cond_tau=1.0,ld = None): device = self.model.betas.device b = shape[0] if x_T is None: @@ -148,6 +150,10 @@ def ddim_sampling(self, cond, shape, for i, step in enumerate(iterator): index = total_steps - i - 1 ts = torch.full((b,), step, device=device, dtype=torch.long) + if ld != None: + latent_ref = ld['x_inter'][index] + _, latents_cur =img.chunk(2) + img = torch.cat([latent_ref,latents_cur]) if mask is not None: assert x0 is not None diff --git a/masactrl_w_adapter/masactrl_w_adapter.py b/masactrl_w_adapter/masactrl_w_adapter.py index e4002d9..418eea9 100644 --- a/masactrl_w_adapter/masactrl_w_adapter.py +++ b/masactrl_w_adapter/masactrl_w_adapter.py @@ -191,6 +191,7 @@ def main(): verbose=False, unconditional_guidance_scale=opt.scale, unconditional_conditioning=uc, + ld = latents_dict, x_T=start_code, features_adapter=adapter_features, append_to_context=append_to_context,