Skip to content

THE INFERENCE PROCESS DOES NOT PERFORM DIFFUSION SAMPLING #85

@lixirui142

Description

@lixirui142

After carefully checking and debugging the inference process (i.e., forward_test() for TrajectoryHead), I found that it is entirely incorrect, or at least it is not a diffusion sampling process. There is no iterative denoising in the inference process. The model just sends 20 noised anchor trajs to the diff decoder and returns its output traj predictions as the final result.

Explanation:

At first, the inference timestep number is set to 1000 in forward_test().

self.diffusion_scheduler.set_timesteps(1000, device)
So the scheduler only denoises 1 timestep for each scheduler.step(), i.e. 999->998->997...->1->0.

roll_timesteps = (np.arange(0, step_num) * step_ratio).round()[::-1].copy().astype(np.int64)

Then, the roll_timesteps is defined as [10, 0] (step_num=2, step_ratio=10), which means there are 2 denoising steps, the first step's noise level is 10, and the second step's noise level is 0. The intended timestep interval is 10, where the scheduler can denoise the sample from ts=10 to ts=0. So we need to set inference timestep number to 100: self.diffusion_scheduler.set_timesteps(100, device).

Let's first look at what happens in the first iteration (k=10)

img = self.diffusion_scheduler.step(
model_output=x_start,
timestep=k,
sample=img
).prev_sample

Here, the scheduler denoises the original noisy trajectories (img), based on the predicted original sample (x_start) and current timestep (k=10). However, the code uses 1000 inference steps, which means the scheduler only denoises the sample from ts=10 to ts=9. It is a single step within a total of 1000 steps.

Image Image

The above two figures show the 20 trajectories before and after the scheduler step. Since the scheduler only updated 1/1000 step, input and output trajs are almost identical. The updated trajectories are at timestep 9, instead of the expected 0.

trunc_timesteps = torch.ones((bs,), device=device, dtype=torch.long) * 8

Actually, there is another bug in the code that sets the initial noise level to 8, which mismatches the timestep (k=10) sent to the model.

Then in the second iteration (k=0)

As the updated noisy trajs are almost identical to the original noisy trajs (ts=9 vs ts=10), the diff decoder input in the second iteration is almost identical to the first iteration, which is the raw noisy anchor trajectories.

mode_idx = poses_cls.argmax(dim=-1)
mode_idx = mode_idx[...,None,None,None].repeat(1,1,self._num_poses,3)
best_reg = torch.gather(poses_reg, 1, mode_idx).squeeze(1)
return {"trajectory": best_reg}

Meanwhile, as shown in the above code, the method directly returns the second (final) iteration's model output (poses_cls, poses_reg) as the final result. Therefore, we come to our initial conclusion: The model just sent 20 noised anchor trajs to the diff decoder and returns its output traj predictions as the final result.

Why it works?

The diff decoder is trained to take noisy trajectories and predict the clean trajectories. Therefore, "taking 20 noised anchor trajs to the diff decoder and returns its output traj predictions" is still a valid inference strategy. Here is the visualized final result:

Image

However, the actual inference is a one-step prediction instead of iterative denoising. It does not perform a standard diffusion sampling process and is not aligned with the paper description. So, it is still a bug and should be fixed.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions