Skip to content

Redundant warping in local-correlation computation. #45

@ChenYutongTHU

Description

@ChenYutongTHU

Dear authors,

Thanks so much for contributing this great work!

I have a question when I examined the refiner.py

Before computing the local correlation, f_B is first warped to f_A, yielding f_BA. f_BA is then passed to local_correlation.

with torch.no_grad():
f_BA = bhwc_grid_sample(
f_B, prev_warp, mode=self.cfg.grid_sample_mode, align_corners=False
)
im_A_coords = get_normalized_grid(B, H_A, W_A)
in_displacement = prev_warp - im_A_coords
in_displacement_bdhw = in_displacement.permute(0, 3, 1, 2)
emb_in_displacement = self.disp_emb(
scale_factor[None, :, None, None] * in_displacement_bdhw
)
# Corr in other means take a kxk grid around the predicted coordinate in other image
f_A_bdhw = f_A.permute(0, 3, 1, 2)
f_B_bdhw = f_BA.permute(0, 3, 1, 2)
d = torch.cat((f_A_bdhw, f_B_bdhw, emb_in_displacement), dim=1)
if self.cfg.local_corr_radius is not None:
local_corr = local_correlation(
f_A_bdhw,
f_B_bdhw,
local_radius=self.cfg.local_corr_radius,
warp=prev_warp,
scale_factor=scale_factor,
)

However, when computing the local correlation, f_A's pixel still querys warp[_, :, :, None] + local_window[:, None, None] rather than normalized-grid + local_window[:, None, None]

def native_torch_local_corr(
feature0,
feature1,
warp,
local_window,
B,
K,
c,
r,
h,
w,
device,
padding_mode="zeros",
sample_mode="bilinear",
dtype=torch.float32,
):
corr = torch.empty((B, K, h, w), device=device, dtype=dtype)
for _ in range(B):
with torch.no_grad():
local_window_coords = (
warp[_, :, :, None] + local_window[:, None, None]
).reshape(1, h, w * K, 2)
window_feature = F.grid_sample(
feature1[_ : _ + 1],
local_window_coords,
padding_mode=padding_mode,
align_corners=False,
mode=sample_mode, #
)
window_feature = window_feature.reshape(c, h, w, K)
corr[_] = (
(feature0[_, ..., None] / (c**0.5) * window_feature)
.sum(dim=0)
.permute(2, 0, 1)
)
return corr

The particular line is

warp[_, :, :, None] + local_window[:, None, None]

This is a bit strange as f_BA is already spatially aligned with f_A.

The logic is different in RoMaV1, where the original f_B, instead of f_BA, is used to compute the local correlation. See https://github.com/Parskatt/RoMa/blob/77f8d68803526dcddfd9b7a46bc76125bdc25f15/romatch/models/matcher.py#L132-L159

Would greatly appreciate it if you could help explain it.

Great thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions