Skip to content

Questions about F.grid_sample() #18

@Lazyangel

Description

@Lazyangel

作者您好,我在阅读和修改代码过程中对F.grid_sample()的使用方式有一些问题想向您请教一下,代码中有两处使用到了F.grid_sample():

  1. 3D voxel投影到2D图像并采样特征:
def unproject_image_to_mem(self, rgb_camB, pixB_T_camA, camB_T_camA, Z, Y, X, assert_cube=False):
        # rgb_camB is B x C x H x W
        # pixB_T_camA is B x 4 x 4

        # rgb lives in B pixel coords
        # we want everything in A memory coords

        # this puts each C-dim pixel in the rgb_camB
        # along a ray in the voxelgrid
        B, C, H, W = list(rgb_camB.shape)

        xyz_memA = basic.gridcloud3d(B, Z, Y, X, norm=False, device=pixB_T_camA.device)
        # y is height here

        xyz_camA = self.Mem2Ref(xyz_memA, Z, Y, X, assert_cube=assert_cube) #

        xyz_camB = geom.apply_4x4(camB_T_camA, xyz_camA)

        z = xyz_camB[:,:,2]

        xyz_pixB = geom.apply_4x4(pixB_T_camA, xyz_camA)
        normalizer = torch.unsqueeze(xyz_pixB[:,:,2], 2)
        EPS=1e-6
        # z = xyz_pixB[:,:,2]
        xy_pixB = xyz_pixB[:,:,:2]/torch.clamp(normalizer, min=EPS)
        # this is B x N x 2
        # this is the (floating point) pixel coordinate of each voxel
        x, y = xy_pixB[:,:,0], xy_pixB[:,:,1]
        # these are B x N

        x_valid = (x>-0.5).bool() & (x<float(W-0.5)).bool()
        y_valid = (y>-0.5).bool() & (y<float(H-0.5)).bool()
        z_valid = (z>0.0).bool()
        valid_mem = (x_valid & y_valid & z_valid).reshape(B, 1, Z, Y, X).float()


        # native pytorch version
        y_pixB, x_pixB = basic.normalize_grid2d(y, x, H, W)
        # since we want a 3d output, we need 5d tensors
        z_pixB = torch.zeros_like(x)
        xyz_pixB = torch.stack([x_pixB, y_pixB, z_pixB], axis=2)

        rgb_camB = rgb_camB.unsqueeze(2)
        xyz_pixB = torch.reshape(xyz_pixB, [B, Z, Y, X, 3])
        values = F.grid_sample(rgb_camB, xyz_pixB, align_corners=False) # (N, H, W, Z, 3)

        values = torch.reshape(values, (B, C, Z, Y, X))
        # 16, 256, 256
        values = values * valid_mem

        return values
  1. NeRF的射线中的采样点从3D voxel中采样特征:
def grid_sampler(self, xyz, *grids, align_corners=True):
        '''Wrapper for the interp operation'''
        # pdb.set_trace()
        shape = xyz.shape[:-1]
        xyz = xyz.reshape(1, 1, 1, -1, 3)
        ind_norm = ((xyz - self.xyz_min) / (self.xyz_max - self.xyz_min)).flip((-1,)) * 2 - 1 # XYZ
        grid = grids[0] # BCXYZ # torch.Size([1, 1, 256, 256, 16])
        ret_lst = F.grid_sample(grid, ind_norm, mode='bilinear', align_corners=align_corners).reshape(grid.shape[1], -1).T.reshape(*shape, grid.shape[1]).squeeze()

        return ret_lst

我的问题是为什么在第一个场景中不需要对输入坐标进行flip操作,而在第二个场景中需要对输入坐标进行flip((-1))操作?期待您的解答!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions