# Update variables based on resampling
latents = latents.detach().view(-1, num_particles, *latents.shape[1:])[torch.arange(latents.size(0)//num_particles).unsqueeze(1), resample_indices].view(-1, *latents.shape[1:])
noise_pred = noise_pred.view(-1, num_particles, *noise_pred.shape[1:])[torch.arange(noise_pred.size(0)//num_particles).unsqueeze(1), resample_indices].view(-1, *noise_pred.shape[1:])
pred_original_sample = pred_original_sample.view(-1, num_particles, *pred_original_sample.shape[1:])[torch.arange(pred_original_sample.size(0)//num_particles).unsqueeze(1), resample_indices].view(-1, *pred_original_sample.shape[1:])
manifold_deviation_trace = manifold_deviation_trace.view(-1, num_particles, *manifold_deviation_trace.shape[1:])[torch.arange(manifold_deviation_trace.size(0)//num_particles).unsqueeze(1), resample_indices].view(-1, *manifold_deviation_trace.shape[1:])
log_prob_diffusion_trace = log_prob_diffusion_trace.view(-1, num_particles, *log_prob_diffusion_trace.shape[1:])[torch.arange(log_prob_diffusion_trace.size(0)//num_particles).unsqueeze(1), resample_indices].view(-1, *log_prob_diffusion_trace.shape[1:])
all_latents.append(latents.cpu())
################### Propose Particles ###################
# Sample from proposal distribution
prev_sample, prev_sample_mean = ddim_step_with_mean(
self.scheduler, noise_pred, t, latents, **extra_step_kwargs
)
variance = get_variance(self.scheduler, t, prev_timestep)
variance = eta**2 * _left_broadcast(variance, prev_sample.shape).to(device)
std_dev_t = variance.sqrt()
prop_latents = prev_sample + variance * approx_guidance
manifold_deviation_trace = torch.cat([manifold_deviation_trace, ((variance * approx_guidance * (-noise_pred)).view(num_particles, -1).sum(dim=1).abs() / (noise_pred**2).view(num_particles, -1).sum(dim=1).sqrt()).unsqueeze(1)], dim=1)
Code snippet from pipeline_using_SMC.py
In this bit of code, after the resampling step you have updated the particle
latentsandnoise_predvariables. But theapprox_guidanceterm is not updated after resampling. This means we are using theapprox_guidanceof a different particle withnoise_predandlatentsof a different particle. Similarly,log_twist_funcis also not updated after resampling.Is this a bug or am I missing something?