From 99125d6a8a36d6deebb891efe05c67a866729def Mon Sep 17 00:00:00 2001 From: Jonas Kulhanek Date: Thu, 9 Nov 2023 10:28:23 +0100 Subject: [PATCH 1/2] Add support for multi-camera datasets --- internal/camera_utils.py | 77 +++++++++++++----- internal/datasets.py | 164 ++++++++++++++++++++++++++------------- 2 files changed, 164 insertions(+), 77 deletions(-) diff --git a/internal/camera_utils.py b/internal/camera_utils.py index 1811238..a452cd3 100644 --- a/internal/camera_utils.py +++ b/internal/camera_utils.py @@ -573,25 +573,51 @@ def pix_to_dir(x, y): # Apply inverse intrinsic matrices. camera_dirs_stacked = mat_vec_mul(pixtocams, pixel_dirs_stacked) - if distortion_params is not None: + mask = camtype > 0 + if xnp.any(mask): + is_uniform = xnp.all(mask) + if is_uniform: + ldistortion_params = distortion_params + dl = camera_dirs_stacked + else: + ldistortion_params = distortion_params[mask, :] + dl = camera_dirs_stacked[:, mask, :] + # Correct for distortion. + dist_dict = dict(zip( + ["k1", "k2", "k3", "k4", "p1", "p2"], + xnp.moveaxis(ldistortion_params, -1, 0))) x, y = _radial_and_tangential_undistort( - camera_dirs_stacked[..., 0], - camera_dirs_stacked[..., 1], - **distortion_params, - xnp=xnp) - camera_dirs_stacked = xnp.stack([x, y, xnp.ones_like(x)], -1) - - if camtype == ProjectionType.FISHEYE: - theta = xnp.sqrt(xnp.sum(xnp.square(camera_dirs_stacked[..., :2]), axis=-1)) - theta = xnp.minimum(xnp.pi, theta) - - sin_theta_over_theta = xnp.sin(theta) / theta - camera_dirs_stacked = xnp.stack([ - camera_dirs_stacked[..., 0] * sin_theta_over_theta, - camera_dirs_stacked[..., 1] * sin_theta_over_theta, - xnp.cos(theta), - ], axis=-1) + dl[..., 0], + dl[..., 1], + **dist_dict, + xnp=xnp) + dl = xnp.stack([x, y, xnp.ones_like(x)], -1) + dcamera_types = camtype[mask] + + fisheye_mask = dcamera_types == 2 + if fisheye_mask.any(): + is_all_fisheye = xnp.all(fisheye_mask) + if is_all_fisheye: + dll = dl + else: + dll = dl[:, mask, :2] + theta = xnp.sqrt(xnp.sum(xnp.square(dll[..., :2]), axis=-1)) + theta = xnp.minimum(xnp.pi, theta) + sin_theta_over_theta = xnp.sin(theta) / theta + + if is_all_fisheye: + dl[..., :2] *= sin_theta_over_theta + dl[..., 2:] *= xnp.cos(theta) + else: + dl[:, mask, :2] *= sin_theta_over_theta + dl[:, mask, 2:] *= xnp.cos(theta) + + if mask.any(): + if is_uniform: + camera_dirs_stacked = dl + else: + camera_dirs_stacked[:, mask, :] = dl # Flip from OpenCV to OpenGL coordinate system. camera_dirs_stacked = matmul(camera_dirs_stacked, @@ -655,21 +681,30 @@ def cast_ray_batch( Returns: rays: Rays dataclass with computed 3D world space ray data. """ - pixtocams, camtoworlds, distortion_params, pixtocam_ndc = cameras + del camtype + pixtocams, camtoworlds, distortion_params, pixtocam_ndc, camtype = cameras # pixels.cam_idx has shape [..., 1], remove this hanging dimension. cam_idx = pixels.cam_idx[..., 0] batch_index = lambda arr: arr if arr.ndim == 2 else arr[cam_idx] + bs = pixels.pix_x_int.shape + dtype = pixtocams.dtype + origins = xnp.zeros((*bs, 3), dtype=dtype) + directions = xnp.zeros((*bs, 3), dtype=dtype) + viewdirs = xnp.zeros((*bs, 3), dtype=dtype) + radii = xnp.zeros((*bs, 1), dtype=dtype) + imageplane = xnp.zeros((*bs, 1), dtype=dtype) + # Compute rays from pixel coordinates. origins, directions, viewdirs, radii, imageplane = pixels_to_rays( pixels.pix_x_int, pixels.pix_y_int, batch_index(pixtocams), batch_index(camtoworlds), - distortion_params=distortion_params, - pixtocam_ndc=pixtocam_ndc, - camtype=camtype, + distortion_params=distortion_params[cam_idx], + pixtocam_ndc=pixtocam_ndc[cam_idx] if pixtocam_ndc is not None else None, + camtype=camtype[cam_idx], xnp=xnp) # Create Rays data structure. diff --git a/internal/datasets.py b/internal/datasets.py index b1622a3..94aef1c 100644 --- a/internal/datasets.py +++ b/internal/datasets.py @@ -79,15 +79,58 @@ def process( # self.load_points3D() # For now, we do not need the point cloud data. # Assume shared intrinsics between all cameras. - cam = self.cameras[1] + cams = {} + for cam_id, cam in self.cameras.items(): + # Extract focal lengths and principal point parameters. + fx, fy, cx, cy = cam.fx, cam.fy, cam.cx, cam.cy + pixtocam = np.linalg.inv(camera_utils.intrinsic_matrix(fx, fy, cx, cy)) + + # Get distortion parameters. + type_ = cam.camera_type + + if type_ == 0 or type_ == 'SIMPLE_PINHOLE': + params = None + camtype = camera_utils.ProjectionType.PERSPECTIVE + + elif type_ == 1 or type_ == 'PINHOLE': + params = None + camtype = camera_utils.ProjectionType.PERSPECTIVE + + if type_ == 2 or type_ == 'SIMPLE_RADIAL': + params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']} + params['k1'] = cam.k1 + camtype = camera_utils.ProjectionType.PERSPECTIVE + + elif type_ == 3 or type_ == 'RADIAL': + params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']} + params['k1'] = cam.k1 + params['k2'] = cam.k2 + camtype = camera_utils.ProjectionType.PERSPECTIVE + + elif type_ == 4 or type_ == 'OPENCV': + params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']} + params['k1'] = cam.k1 + params['k2'] = cam.k2 + params['p1'] = cam.p1 + params['p2'] = cam.p2 + camtype = camera_utils.ProjectionType.PERSPECTIVE + + elif type_ == 5 or type_ == 'OPENCV_FISHEYE': + params = {k: 0. for k in ['k1', 'k2', 'k3', 'k4']} + params['k1'] = cam.k1 + params['k2'] = cam.k2 + params['k3'] = cam.k3 + params['k4'] = cam.k4 + camtype = camera_utils.ProjectionType.FISHEYE + cams[cam_id] = (cam, pixtocam, params, camtype) - # Extract focal lengths and principal point parameters. - fx, fy, cx, cy = cam.fx, cam.fy, cam.cx, cam.cy - pixtocam = np.linalg.inv(camera_utils.intrinsic_matrix(fx, fy, cx, cy)) # Extract extrinsic matrices in world-to-camera format. imdata = self.images w2c_mats = [] + pixtocams = [] + all_params = [] + all_camtypes = [] bottom = np.array([0, 0, 0, 1]).reshape(1, 4) for k in imdata: im = imdata[k] @@ -95,6 +138,10 @@ def process( trans = im.tvec.reshape(3, 1) w2c = np.concatenate([np.concatenate([rot, trans], 1), bottom], axis=0) w2c_mats.append(w2c) + cam, pixtocam, params, camtype = cams[im.camera_id] + all_params.append(params) + all_camtypes.append(camtype) + pixtocams.append(pixtocam) w2c_mats = np.stack(w2c_mats, axis=0) # Convert extrinsics to camera-to-world. @@ -108,45 +155,9 @@ def process( # Switch from COLMAP (right, down, fwd) to NeRF (right, up, back) frame. poses = poses @ np.diag([1, -1, -1, 1]) - # Get distortion parameters. - type_ = cam.camera_type - - if type_ == 0 or type_ == 'SIMPLE_PINHOLE': - params = None - camtype = camera_utils.ProjectionType.PERSPECTIVE - - elif type_ == 1 or type_ == 'PINHOLE': - params = None - camtype = camera_utils.ProjectionType.PERSPECTIVE - - if type_ == 2 or type_ == 'SIMPLE_RADIAL': - params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']} - params['k1'] = cam.k1 - camtype = camera_utils.ProjectionType.PERSPECTIVE - - elif type_ == 3 or type_ == 'RADIAL': - params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']} - params['k1'] = cam.k1 - params['k2'] = cam.k2 - camtype = camera_utils.ProjectionType.PERSPECTIVE - - elif type_ == 4 or type_ == 'OPENCV': - params = {k: 0. for k in ['k1', 'k2', 'k3', 'p1', 'p2']} - params['k1'] = cam.k1 - params['k2'] = cam.k2 - params['p1'] = cam.p1 - params['p2'] = cam.p2 - camtype = camera_utils.ProjectionType.PERSPECTIVE - - elif type_ == 5 or type_ == 'OPENCV_FISHEYE': - params = {k: 0. for k in ['k1', 'k2', 'k3', 'k4']} - params['k1'] = cam.k1 - params['k2'] = cam.k2 - params['k3'] = cam.k3 - params['k4'] = cam.k4 - camtype = camera_utils.ProjectionType.FISHEYE + pixtocams = np.stack(pixtocams) - return names, poses, pixtocam, params, camtype + return names, poses, pixtocams, all_params, all_camtypes def load_blender_posedata(data_dir, split=None): @@ -288,8 +299,9 @@ def __init__(self, self.images: np.ndarray = None self.camtoworlds: np.ndarray = None self.pixtocams: np.ndarray = None - self.height: int = None - self.width: int = None + self.height: np.ndarray = None + self.width: np.ndarray = None + # Load data from disk using provided config parameters. self._load_renderings(config) @@ -314,11 +326,41 @@ def __init__(self, self.height) self._n_examples = self.camtoworlds.shape[0] + if len(self.pixtocams.shape) == 2: + self.pixtocams = np.repeat(self.pixtocams[None], self._n_examples, 0) + if not isinstance(self.focal, np.ndarray): + self.focal = np.full((self._n_examples,), self.focal, dtype=np.int32) + if not isinstance(self.width, np.ndarray): + self.width = np.full((self._n_examples,), self.width, dtype=np.int32) + if not isinstance(self.height, np.ndarray): + self.height = np.full((self._n_examples,), self.height, dtype=np.int32) + if not isinstance(self.camtype, list): + self.camtype = [self.camtype] * self._n_examples + map_camera = { + camera_utils.ProjectionType.PERSPECTIVE.value: 1, + camera_utils.ProjectionType.FISHEYE.value: 2 + } + self.camtype = np.array([map_camera[x.value] for x in self.camtype], dtype=np.int32) + distortion_params = np.zeros((self._n_examples, 6), dtype=np.float32) + for i in range(self._n_examples): + k = self.distortion_params + if isinstance(self.distortion_params, list): + try: + k = self.distortion_params[i] + except Exception as e: + breakpoint() + print(e) + if k is None: + self.camtype[i] = 0 + continue + distortion_params[i] = np.array([k[m] for m in ["k1", "k2", "k3", "k4", "p1", "p2"]], dtype=np.float32) + self.distortion_params = distortion_params self.cameras = (self.pixtocams, self.camtoworlds, self.distortion_params, - self.pixtocam_ndc) + self.pixtocam_ndc, + self.camtype) # Seed the queue with one batch to avoid race condition. if self.split == utils.DataSplit.TRAIN: @@ -456,11 +498,18 @@ def _next_train(self) -> utils.Batch: num_patches = self._batch_size // self._patch_size ** 2 lower_border = self._num_border_pixels_to_mask upper_border = self._num_border_pixels_to_mask + self._patch_size - 1 + + # Random camera indices. + if self._batching == utils.BatchingMethod.ALL_IMAGES: + cam_idx = np.random.randint(0, self._n_examples, (num_patches, 1, 1)) + else: + cam_idx = np.random.randint(0, self._n_examples, (1,)) + # Random pixel patch x-coordinates. - pix_x_int = np.random.randint(lower_border, self.width - upper_border, + pix_x_int = np.random.randint(lower_border, self.width[cam_idx] - upper_border, (num_patches, 1, 1)) # Random pixel patch y-coordinates. - pix_y_int = np.random.randint(lower_border, self.height - upper_border, + pix_y_int = np.random.randint(lower_border, self.height[cam_idx] - upper_border, (num_patches, 1, 1)) # Add patch coordinate offsets. # Shape will broadcast to (num_patches, _patch_size, _patch_size). @@ -468,11 +517,6 @@ def _next_train(self) -> utils.Batch: self._patch_size, self._patch_size) pix_x_int = pix_x_int + patch_dx_int pix_y_int = pix_y_int + patch_dy_int - # Random camera indices. - if self._batching == utils.BatchingMethod.ALL_IMAGES: - cam_idx = np.random.randint(0, self._n_examples, (num_patches, 1, 1)) - else: - cam_idx = np.random.randint(0, self._n_examples, (1,)) if self._apply_bayer_mask: # Compute the Bayer mosaic mask for each pixel in the batch. @@ -488,12 +532,12 @@ def generate_ray_batch(self, cam_idx: int) -> utils.Batch: if self._render_spherical: camtoworld = self.camtoworlds[cam_idx] rays = camera_utils.cast_spherical_rays( - camtoworld, self.height, self.width, self.near, self.far, xnp=np) + camtoworld, self.height[cam_idx], self.width[cam_idx], self.near, self.far, xnp=np) return utils.Batch(rays=rays) else: # Generate rays for all pixels in the image. pix_x_int, pix_y_int = camera_utils.pixel_coordinates( - self.width, self.height) + self.width[cam_idx], self.height[cam_idx]) return self._make_ray_batch(pix_x_int, pix_y_int, cam_idx) def _next_test(self) -> utils.Batch: @@ -593,13 +637,15 @@ def _load_renderings(self, config): inds = np.argsort(image_names) image_names = [image_names[i] for i in inds] poses = poses[inds] + pixtocam = pixtocam[inds] + distortion_params = [distortion_params[i] for i in inds] + camtype = [camtype[i] for i in inds] # Scale the inverse intrinsics matrix by the image downsampling factor. pixtocam = pixtocam @ np.diag([factor, factor, 1.]) self.pixtocams = pixtocam.astype(np.float32) - self.focal = 1. / self.pixtocams[0, 0] + self.focal = 1. / self.pixtocams[..., 0, 0] self.distortion_params = distortion_params - self.camtype = camtype raw_testscene = False if config.rawnerf_mode: @@ -706,6 +752,12 @@ def _load_renderings(self, config): # All per-image quantities must be re-indexed using the split indices. images = images[indices] poses = poses[indices] + self.pixtocams = self.pixtocams[indices] + self.focal = self.focal[indices] + list_indices = np.where(indices)[0] if indices.dtype == np.bool_ else indices + self.distortion_params = [self.distortion_params[i] for i in list_indices] + self.camtype = [camtype[i] for i in list_indices] + assert len(self.camtype) == len(self.pixtocams) if self.exposures is not None: self.exposures = self.exposures[indices] if config.rawnerf_mode: From 0e6699cc01eb3f0e77e0f7c15057a3ee29ad74ba Mon Sep 17 00:00:00 2001 From: Jonas Kulhanek Date: Thu, 9 Nov 2023 11:18:01 +0100 Subject: [PATCH 2/2] Fix focal lenght dtype + add support for different znear,zfar --- internal/camera_utils.py | 2 +- internal/datasets.py | 17 +++++++++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/internal/camera_utils.py b/internal/camera_utils.py index a452cd3..89618ad 100644 --- a/internal/camera_utils.py +++ b/internal/camera_utils.py @@ -602,7 +602,7 @@ def pix_to_dir(x, y): dll = dl else: dll = dl[:, mask, :2] - theta = xnp.sqrt(xnp.sum(xnp.square(dll[..., :2]), axis=-1)) + theta = xnp.sqrt(xnp.sum(xnp.square(dll[..., :2]), axis=-1, keepdims=True)) theta = xnp.minimum(xnp.pi, theta) sin_theta_over_theta = xnp.sin(theta) / theta diff --git a/internal/datasets.py b/internal/datasets.py index 94aef1c..98f297e 100644 --- a/internal/datasets.py +++ b/internal/datasets.py @@ -329,7 +329,11 @@ def __init__(self, if len(self.pixtocams.shape) == 2: self.pixtocams = np.repeat(self.pixtocams[None], self._n_examples, 0) if not isinstance(self.focal, np.ndarray): - self.focal = np.full((self._n_examples,), self.focal, dtype=np.int32) + self.focal = np.full((self._n_examples,), self.focal, dtype=np.float32) + if not isinstance(self.near, np.ndarray): + self.near = np.full((self._n_examples,), self.near, dtype=np.float32) + if not isinstance(self.height, np.ndarray): + self.far = np.full((self._n_examples,), self.far, dtype=np.float32) if not isinstance(self.width, np.ndarray): self.width = np.full((self._n_examples,), self.width, dtype=np.int32) if not isinstance(self.height, np.ndarray): @@ -450,10 +454,11 @@ def _make_ray_batch(self, """ broadcast_scalar = lambda x: np.broadcast_to(x, pix_x_int.shape)[..., None] + idx = 0 if self.render_path else cam_idx ray_kwargs = { 'lossmult': broadcast_scalar(1.) if lossmult is None else lossmult, - 'near': broadcast_scalar(self.near), - 'far': broadcast_scalar(self.far), + 'near': broadcast_scalar(self.near[idx]), + 'far': broadcast_scalar(self.far[idx]), 'cam_idx': broadcast_scalar(cam_idx), } # Collect per-camera information needed for each ray. @@ -532,7 +537,11 @@ def generate_ray_batch(self, cam_idx: int) -> utils.Batch: if self._render_spherical: camtoworld = self.camtoworlds[cam_idx] rays = camera_utils.cast_spherical_rays( - camtoworld, self.height[cam_idx], self.width[cam_idx], self.near, self.far, xnp=np) + camtoworld, + self.height[cam_idx], self.width[cam_idx], + self.near[cam_idx], + self.far[cam_idx], + xnp=np) return utils.Batch(rays=rays) else: # Generate rays for all pixels in the image.