Skip to content

Approx guidance and log twist variables not updated after resampling #1

@chinmay8bit

Description

@chinmay8bit

Code snippet from pipeline_using_SMC.py

                    # Update variables based on resampling
                    latents = latents.detach().view(-1, num_particles, *latents.shape[1:])[torch.arange(latents.size(0)//num_particles).unsqueeze(1), resample_indices].view(-1, *latents.shape[1:])
                    noise_pred = noise_pred.view(-1, num_particles, *noise_pred.shape[1:])[torch.arange(noise_pred.size(0)//num_particles).unsqueeze(1), resample_indices].view(-1, *noise_pred.shape[1:])
                    pred_original_sample = pred_original_sample.view(-1, num_particles, *pred_original_sample.shape[1:])[torch.arange(pred_original_sample.size(0)//num_particles).unsqueeze(1), resample_indices].view(-1, *pred_original_sample.shape[1:])
                    manifold_deviation_trace = manifold_deviation_trace.view(-1, num_particles, *manifold_deviation_trace.shape[1:])[torch.arange(manifold_deviation_trace.size(0)//num_particles).unsqueeze(1), resample_indices].view(-1, *manifold_deviation_trace.shape[1:])
                    log_prob_diffusion_trace = log_prob_diffusion_trace.view(-1, num_particles, *log_prob_diffusion_trace.shape[1:])[torch.arange(log_prob_diffusion_trace.size(0)//num_particles).unsqueeze(1), resample_indices].view(-1, *log_prob_diffusion_trace.shape[1:])
                
                all_latents.append(latents.cpu())

                ################### Propose Particles ###################    
                # Sample from proposal distribution
                prev_sample, prev_sample_mean = ddim_step_with_mean(
                    self.scheduler, noise_pred, t, latents, **extra_step_kwargs
                )

                variance = get_variance(self.scheduler, t, prev_timestep)
                variance = eta**2 * _left_broadcast(variance, prev_sample.shape).to(device)
                std_dev_t = variance.sqrt()

                prop_latents = prev_sample + variance * approx_guidance
                manifold_deviation_trace = torch.cat([manifold_deviation_trace, ((variance * approx_guidance * (-noise_pred)).view(num_particles, -1).sum(dim=1).abs() / (noise_pred**2).view(num_particles, -1).sum(dim=1).sqrt()).unsqueeze(1)], dim=1)

In this bit of code, after the resampling step you have updated the particle latents and noise_pred variables. But the approx_guidance term is not updated after resampling. This means we are using the approx_guidance of a different particle with noise_pred and latents of a different particle. Similarly, log_twist_func is also not updated after resampling.

Is this a bug or am I missing something?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions