-
Notifications
You must be signed in to change notification settings - Fork 13
Open
Description
作者您好,我在阅读和修改代码过程中对F.grid_sample()的使用方式有一些问题想向您请教一下,代码中有两处使用到了F.grid_sample():
- 3D voxel投影到2D图像并采样特征:
def unproject_image_to_mem(self, rgb_camB, pixB_T_camA, camB_T_camA, Z, Y, X, assert_cube=False):
# rgb_camB is B x C x H x W
# pixB_T_camA is B x 4 x 4
# rgb lives in B pixel coords
# we want everything in A memory coords
# this puts each C-dim pixel in the rgb_camB
# along a ray in the voxelgrid
B, C, H, W = list(rgb_camB.shape)
xyz_memA = basic.gridcloud3d(B, Z, Y, X, norm=False, device=pixB_T_camA.device)
# y is height here
xyz_camA = self.Mem2Ref(xyz_memA, Z, Y, X, assert_cube=assert_cube) #
xyz_camB = geom.apply_4x4(camB_T_camA, xyz_camA)
z = xyz_camB[:,:,2]
xyz_pixB = geom.apply_4x4(pixB_T_camA, xyz_camA)
normalizer = torch.unsqueeze(xyz_pixB[:,:,2], 2)
EPS=1e-6
# z = xyz_pixB[:,:,2]
xy_pixB = xyz_pixB[:,:,:2]/torch.clamp(normalizer, min=EPS)
# this is B x N x 2
# this is the (floating point) pixel coordinate of each voxel
x, y = xy_pixB[:,:,0], xy_pixB[:,:,1]
# these are B x N
x_valid = (x>-0.5).bool() & (x<float(W-0.5)).bool()
y_valid = (y>-0.5).bool() & (y<float(H-0.5)).bool()
z_valid = (z>0.0).bool()
valid_mem = (x_valid & y_valid & z_valid).reshape(B, 1, Z, Y, X).float()
# native pytorch version
y_pixB, x_pixB = basic.normalize_grid2d(y, x, H, W)
# since we want a 3d output, we need 5d tensors
z_pixB = torch.zeros_like(x)
xyz_pixB = torch.stack([x_pixB, y_pixB, z_pixB], axis=2)
rgb_camB = rgb_camB.unsqueeze(2)
xyz_pixB = torch.reshape(xyz_pixB, [B, Z, Y, X, 3])
values = F.grid_sample(rgb_camB, xyz_pixB, align_corners=False) # (N, H, W, Z, 3)
values = torch.reshape(values, (B, C, Z, Y, X))
# 16, 256, 256
values = values * valid_mem
return values
- NeRF的射线中的采样点从3D voxel中采样特征:
def grid_sampler(self, xyz, *grids, align_corners=True):
'''Wrapper for the interp operation'''
# pdb.set_trace()
shape = xyz.shape[:-1]
xyz = xyz.reshape(1, 1, 1, -1, 3)
ind_norm = ((xyz - self.xyz_min) / (self.xyz_max - self.xyz_min)).flip((-1,)) * 2 - 1 # XYZ
grid = grids[0] # BCXYZ # torch.Size([1, 1, 256, 256, 16])
ret_lst = F.grid_sample(grid, ind_norm, mode='bilinear', align_corners=align_corners).reshape(grid.shape[1], -1).T.reshape(*shape, grid.shape[1]).squeeze()
return ret_lst
我的问题是为什么在第一个场景中不需要对输入坐标进行flip操作,而在第二个场景中需要对输入坐标进行flip((-1))操作?期待您的解答!
Metadata
Metadata
Assignees
Labels
No labels