diff --git a/examples/datasets/nerf_synthetic.py b/examples/datasets/nerf_synthetic.py index 11542924..0115efec 100644 --- a/examples/datasets/nerf_synthetic.py +++ b/examples/datasets/nerf_synthetic.py @@ -5,6 +5,7 @@ import collections import json import os +import math import imageio.v2 as imageio import numpy as np @@ -13,16 +14,15 @@ from .utils import Rays +radii_factor = 2 / math.sqrt(12) + def _load_renderings(root_fp: str, subject_id: str, split: str): """Load images from disk.""" if not root_fp.startswith("/"): # allow relative path. e.g., "./data/nerf_synthetic/" root_fp = os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "..", - "..", - root_fp, + os.path.dirname(os.path.abspath(__file__)), "..", "..", root_fp, ) data_dir = os.path.join(root_fp, subject_id) @@ -79,6 +79,7 @@ def __init__( near: float = None, far: float = None, batch_over_images: bool = True, + get_radii: bool = False, ): super().__init__() assert split in self.SPLITS, "%s" % split @@ -93,6 +94,7 @@ def __init__( ) self.color_bkgd_aug = color_bkgd_aug self.batch_over_images = batch_over_images + self.get_radii = get_radii if split == "trainval": _images_train, _camtoworlds_train, _focal_train = _load_renderings( root_fp, subject_id, "train" @@ -211,18 +213,72 @@ def fetch_data(self, index): directions, dim=-1, keepdims=True ) + if self.get_radii: + camera_dirs_cornor = F.pad( + torch.stack( + [ + (x - self.K[0, 2]) / self.K[0, 0], + (y - self.K[1, 2]) + / self.K[1, 1] + * (-1.0 if self.OPENGL_CAMERA else 1.0), + ], + dim=-1, + ), + (0, 1), + value=(-1.0 if self.OPENGL_CAMERA else 1.0), + ) # [num_rays, 3] + directions_cornor = ( + camera_dirs_cornor[:, None, :] * c2w[:, :3, :3] + ).sum(dim=-1) + dx = torch.sqrt( + torch.sum((directions_cornor - directions) ** 2, -1) + ) + radii = dx[:, None] * radii_factor + else: + radii_value = ( + math.sqrt((0.5 / self.K[0, 0]) ** 2 + (0.5 / self.K[1, 1]) ** 2) + * radii_factor + ) + radii = ( + torch.ones(origins.shape[0], 1, device=self.images.device) + * radii_value + ) + if self.training: origins = torch.reshape(origins, (num_rays, 3)) viewdirs = torch.reshape(viewdirs, (num_rays, 3)) rgba = torch.reshape(rgba, (num_rays, 4)) + radii = torch.reshape(radii, (num_rays, 1)) else: origins = torch.reshape(origins, (self.HEIGHT, self.WIDTH, 3)) viewdirs = torch.reshape(viewdirs, (self.HEIGHT, self.WIDTH, 3)) rgba = torch.reshape(rgba, (self.HEIGHT, self.WIDTH, 4)) + radii = torch.reshape(radii, (self.HEIGHT, self.WIDTH, 1)) - rays = Rays(origins=origins, viewdirs=viewdirs) + rays = Rays(origins=origins, viewdirs=viewdirs, radii=radii) return { "rgba": rgba, # [h, w, 4] or [num_rays, 4] "rays": rays, # [h, w, 3] or [num_rays, 3] } + + def fetch_data_for_x(self, camera_ids, x): + """Fetch the data for a loc and camera (it maybe cached for multiple batches).""" + c2w = self.camtoworlds[camera_ids] # (num_rays, 3, 4) + + origins = torch.broadcast_to(c2w[:, :3, -1], x.shape) + directions = x - origins + viewdirs = directions / torch.linalg.norm( + directions, dim=-1, keepdims=True + ) + + # get fix value to simpliy the calculation + radii_value = ( + math.sqrt((0.5 / self.K[0, 0]) ** 2 + (0.5 / self.K[1, 1]) ** 2) + * radii_factor + ) + radii = ( + torch.ones(origins.shape[0], 1, device=self.images.device) + * radii_value + ) + return Rays(origins=origins, viewdirs=viewdirs, radii=radii) diff --git a/examples/datasets/utils.py b/examples/datasets/utils.py index 04d223f5..7803e6b6 100644 --- a/examples/datasets/utils.py +++ b/examples/datasets/utils.py @@ -4,7 +4,7 @@ import collections -Rays = collections.namedtuple("Rays", ("origins", "viewdirs")) +Rays = collections.namedtuple("Rays", ("origins", "viewdirs", "radii")) def namedtuple_map(fn, tup): diff --git a/examples/mip_utils.py b/examples/mip_utils.py new file mode 100644 index 00000000..37622792 --- /dev/null +++ b/examples/mip_utils.py @@ -0,0 +1,161 @@ +""" +Copyright (c) 2022 Ruilong Li, UC Berkeley. +""" + +import random +from typing import Optional + +import numpy as np +import torch +from datasets.utils import Rays, namedtuple_map + +from nerfacc import OccupancyGrid, ray_marching, rendering + + +def set_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + +# gaussion computation from Nerf-Factory +def lift_gaussian(d, t_mean, t_var, r_var): + mean = d * t_mean + + d_mag_sq = torch.sum(d**2, dim=-1, keepdim=True) + thresholds = torch.ones_like(d_mag_sq) * 1e-10 + d_mag_sq = torch.fmax(d_mag_sq, thresholds) + + d_outer_diag = d**2 + null_outer_diag = 1 - d_outer_diag / d_mag_sq + t_cov_diag = t_var * d_outer_diag + xy_cov_diag = r_var * null_outer_diag + cov_diag = t_cov_diag + xy_cov_diag + + return mean, cov_diag + + +def conical_frustum_to_gaussian(d, t0, t1, radius): + + mu = (t0 + t1) / 2 + hw = (t1 - t0) / 2 + t_mean = mu + (2 * mu * hw**2) / (3 * mu**2 + hw**2) + t_var = (hw**2) / 3 - (4 / 15) * ( + (hw**4 * (12 * mu**2 - hw**2)) / (3 * mu**2 + hw**2) ** 2 + ) + r_var = radius**2 * ( + (mu**2) / 4 + + (5 / 12) * hw**2 + - 4 / 15 * (hw**4) / (3 * mu**2 + hw**2) + ) + + return lift_gaussian(d, t_mean, t_var, r_var) + + +def cylinder_to_gaussian(d, t0, t1, radius): + + t_mean = (t0 + t1) / 2 + r_var = radius**2 / 4 + t_var = (t1 - t0) ** 2 / 12 + + return lift_gaussian(d, t_mean, t_var, r_var) + + +def cast_rays(t_starts, t_ends, origins, directions, radii, ray_shape): + if ray_shape == "cone": + gaussian_fn = conical_frustum_to_gaussian + elif ray_shape == "cylinder": + gaussian_fn = cylinder_to_gaussian + else: + assert False + means, covs = gaussian_fn(directions, t_starts, t_ends, radii) + means = means + origins + return means, covs + + +def render_image( + # scene + radiance_field: torch.nn.Module, + occupancy_grid: OccupancyGrid, + rays: Rays, + scene_aabb: torch.Tensor, + # rendering options + near_plane: Optional[float] = None, + far_plane: Optional[float] = None, + render_step_size: float = 1e-3, + render_bkgd: Optional[torch.Tensor] = None, + cone_angle: float = 0.0, + alpha_thre: float = 0.0, + # test options + test_chunk_size: int = 8192, + # only useful for dnerf + ray_shape: str = "cylinder", +): + """Render the pixels of an image.""" + rays_shape = rays.origins.shape + if len(rays_shape) == 3: + height, width, _ = rays_shape + num_rays = height * width + rays = namedtuple_map( + lambda r: r.reshape([num_rays] + list(r.shape[2:])), rays + ) + else: + num_rays, _ = rays_shape + + def sigma_fn(t_starts, t_ends, ray_indices): + ray_indices = ray_indices.long() + t_origins = chunk_rays.origins[ray_indices] # (n_samples, 3) + t_dirs = chunk_rays.viewdirs[ray_indices] # (n_samples, 3) + t_radii = chunk_rays.radii[ray_indices] # (n_samples,) + mean, cov = cast_rays(t_starts, t_ends, t_origins, t_dirs, t_radii, ray_shape) + return radiance_field.query_density(mean, cov) + + def rgb_sigma_fn(t_starts, t_ends, ray_indices): + ray_indices = ray_indices.long() + t_origins = chunk_rays.origins[ray_indices] # (n_samples, 3) + t_dirs = chunk_rays.viewdirs[ray_indices] # (n_samples, 3) + t_radii = chunk_rays.radii[ray_indices] # (n_samples,) + mean, cov = cast_rays(t_starts, t_ends, t_origins, t_dirs, t_radii, ray_shape) + return radiance_field(mean, cov, t_dirs) + + results = [] + chunk = ( + torch.iinfo(torch.int32).max + if radiance_field.training + else test_chunk_size + ) + for i in range(0, num_rays, chunk): + chunk_rays = namedtuple_map(lambda r: r[i : i + chunk], rays) + ray_indices, t_starts, t_ends = ray_marching( + chunk_rays.origins, + chunk_rays.viewdirs, + scene_aabb=scene_aabb, + grid=occupancy_grid, + sigma_fn=sigma_fn, + near_plane=near_plane, + far_plane=far_plane, + render_step_size=render_step_size, + stratified=radiance_field.training, + cone_angle=cone_angle, + alpha_thre=alpha_thre, + ) + rgb, opacity, depth = rendering( + t_starts, + t_ends, + ray_indices, + n_rays=chunk_rays.origins.shape[0], + rgb_sigma_fn=rgb_sigma_fn, + render_bkgd=render_bkgd, + ) + chunk_results = [rgb, opacity, depth, len(t_starts)] + results.append(chunk_results) + colors, opacities, depths, n_rendering_samples = [ + torch.cat(r, dim=0) if isinstance(r[0], torch.Tensor) else r + for r in zip(*results) + ] + return ( + colors.view((*rays_shape[:-1], -1)), + opacities.view((*rays_shape[:-1], -1)), + depths.view((*rays_shape[:-1], -1)), + sum(n_rendering_samples), + ) diff --git a/examples/radiance_fields/mlp.py b/examples/radiance_fields/mlp.py index ff8c7c42..d36dac93 100644 --- a/examples/radiance_fields/mlp.py +++ b/examples/radiance_fields/mlp.py @@ -203,6 +203,53 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return latent +class IntegrateSinusoidalEncoder(nn.Module): + """Integrate Sinusoidal Positional Encoder used in Nerf.""" + + def __init__(self, x_dim, min_deg, max_deg, use_identity: bool = True): + super().__init__() + self.x_dim = x_dim + self.min_deg = min_deg + self.max_deg = max_deg + self.use_identity = use_identity + self.register_buffer( + "scales", torch.tensor([2**i for i in range(min_deg, max_deg)]) + ) + + @property + def latent_dim(self) -> int: + return ( + int(self.use_identity) + (self.max_deg - self.min_deg) * 2 + ) * self.x_dim + + def forward(self, x: torch.Tensor, x_cov: torch.Tensor) -> torch.Tensor: + """ + Args: + x: [..., x_dim] + x_cov: [..., x_dim] + Returns: + latent: [..., latent_dim] + """ + if self.max_deg == self.min_deg: + return x + shape = list(x.shape[:-1]) + [ + (self.max_deg - self.min_deg) * self.x_dim + ] + xb = torch.reshape( + (x[Ellipsis, None, :] * self.scales[:, None]), + shape, + ) + xvar = torch.reshape( + x_cov[..., None, :] * self.scales[:, None] ** 2, shape + ) + latent = torch.exp(-0.5 * torch.cat([xvar] * 2, dim=-1)) * torch.sin( + torch.cat([xb, xb + 0.5 * math.pi], dim=-1) + ) + if self.use_identity: + latent = torch.cat([x] + [latent], dim=-1) + return latent + + class VanillaNeRFRadianceField(nn.Module): def __init__( self, @@ -245,6 +292,41 @@ def forward(self, x, condition=None): return torch.sigmoid(rgb), F.relu(sigma) +class MipNeRFRadianceField(nn.Module): + def __init__( + self, + net_depth: int = 8, # The depth of the MLP. + net_width: int = 256, # The width of the MLP. + skip_layer: int = 4, # The layer to add skip layers to. + net_depth_condition: int = 1, # The depth of the second part of MLP. + net_width_condition: int = 128, # The width of the second part of MLP. + ) -> None: + super().__init__() + self.posi_encoder = IntegrateSinusoidalEncoder(3, 0, 10, True) + self.view_encoder = SinusoidalEncoder(3, 0, 4, True) + self.mlp = NerfMLP( + input_dim=self.posi_encoder.latent_dim, + condition_dim=self.view_encoder.latent_dim, + net_depth=net_depth, + net_width=net_width, + skip_layer=skip_layer, + net_depth_condition=net_depth_condition, + net_width_condition=net_width_condition, + ) + + def query_density(self, x, x_conv): + x = self.posi_encoder(x, x_conv) + sigma = self.mlp.query_density(x) + return F.relu(sigma) + + def forward(self, x, x_conv, condition=None): + x = self.posi_encoder(x, x_conv) + if condition is not None: + condition = self.view_encoder(condition) + rgb, sigma = self.mlp(x, condition=condition) + return torch.sigmoid(rgb), F.relu(sigma) + + class DNeRFRadianceField(nn.Module): def __init__(self) -> None: super().__init__() diff --git a/examples/train_mip_nerf.py b/examples/train_mip_nerf.py new file mode 100644 index 00000000..a248fef9 --- /dev/null +++ b/examples/train_mip_nerf.py @@ -0,0 +1,319 @@ +""" +Copyright (c) 2022 Ruilong Li, UC Berkeley. +""" + +import argparse +import math +import os +import time +import pathlib + +import imageio +import numpy as np +import torch +import torch.nn.functional as F +import tqdm +from radiance_fields.mlp import MipNeRFRadianceField +from mip_utils import render_image, set_random_seed, cast_rays + +from nerfacc import ContractionType, OccupancyGrid + +if __name__ == "__main__": + + device = "cuda:0" + set_random_seed(42) + + parser = argparse.ArgumentParser() + parser.add_argument( + "--data_root", + type=str, + default=str(pathlib.Path.home() / "data"), + help="the root dir of the dataset", + ) + parser.add_argument( + "--train_split", + type=str, + default="trainval", + choices=["train", "trainval"], + help="which train split to use", + ) + parser.add_argument( + "--scene", + type=str, + default="lego", + choices=[ + # nerf synthetic + "chair", + "drums", + "ficus", + "hotdog", + "lego", + "materials", + "mic", + "ship", + # mipnerf360 unbounded + "garden", + ], + help="which scene to use", + ) + parser.add_argument( + "--ray_shape", + type=str, + default="cone", + choices=["cone", "cylinder"], + help="the shape of the ray", + ) + parser.add_argument( + "--aabb", + type=lambda s: [float(item) for item in s.split(",")], + default="-1.5,-1.5,-1.5,1.5,1.5,1.5", + help="delimited list input", + ) + parser.add_argument( + "--test_chunk_size", + type=int, + default=8192, + ) + parser.add_argument( + "--unbounded", + action="store_true", + help="whether to use unbounded rendering", + ) + parser.add_argument("--cone_angle", type=float, default=0.0) + args = parser.parse_args() + + render_n_samples = 1024 + + # setup the scene bounding box. + if args.unbounded: + print("Using unbounded rendering") + contraction_type = ContractionType.UN_BOUNDED_SPHERE + # contraction_type = ContractionType.UN_BOUNDED_TANH + scene_aabb = None + near_plane = 0.2 + far_plane = 1e4 + render_step_size = 1e-2 + else: + contraction_type = ContractionType.AABB + scene_aabb = torch.tensor(args.aabb, dtype=torch.float32, device=device) + near_plane = None + far_plane = None + render_step_size = ( + (scene_aabb[3:] - scene_aabb[:3]).max() + * math.sqrt(3) + / render_n_samples + ).item() + + # setup the radiance field we want to train. + max_steps = 50000 + grad_scaler = torch.cuda.amp.GradScaler(1) + radiance_field = MipNeRFRadianceField().to(device) + optimizer = torch.optim.Adam(radiance_field.parameters(), lr=5e-4) + scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=[ + max_steps // 2, + max_steps * 3 // 4, + max_steps * 5 // 6, + max_steps * 9 // 10, + ], + gamma=0.33, + ) + + # setup the dataset + train_dataset_kwargs = {} + test_dataset_kwargs = {} + if args.scene == "garden": + from datasets.nerf_360_v2 import SubjectLoader + + data_root_fp = str(pathlib.Path(args.data_root) / "360_v2") + target_sample_batch_size = 1 << 16 + train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4} + test_dataset_kwargs = {"factor": 4} + grid_resolution = 128 + else: + from datasets.nerf_synthetic import SubjectLoader + + data_root_fp = str(pathlib.Path(args.data_root) / "nerf_synthetic") + target_sample_batch_size = 1 << 16 + grid_resolution = 128 + + train_dataset = SubjectLoader( + subject_id=args.scene, + root_fp=data_root_fp, + split=args.train_split, + num_rays=target_sample_batch_size // render_n_samples, + get_radii=True, + **train_dataset_kwargs, + ) + + train_dataset.images = train_dataset.images.to(device) + train_dataset.camtoworlds = train_dataset.camtoworlds.to(device) + train_dataset.K = train_dataset.K.to(device) + + test_dataset = SubjectLoader( + subject_id=args.scene, + root_fp=data_root_fp, + split="test", + num_rays=None, + get_radii=True, + **test_dataset_kwargs, + ) + test_dataset.images = test_dataset.images.to(device) + test_dataset.camtoworlds = test_dataset.camtoworlds.to(device) + test_dataset.K = test_dataset.K.to(device) + + occupancy_grid = OccupancyGrid( + roi_aabb=args.aabb, + resolution=grid_resolution, + contraction_type=contraction_type, + ).to(device) + + # training + step = 0 + tic = time.time() + for epoch in range(10000000): + for i in range(len(train_dataset)): + radiance_field.train() + data = train_dataset[i] + + render_bkgd = data["color_bkgd"] + rays = data["rays"] + pixels = data["pixels"] + + def occ_eval_fn(x): + # randomly sample a camera for computing step size. + camera_ids = torch.randint( + 0, len(train_dataset), (x.shape[0],), device=device + ) + occ_rays = train_dataset.fetch_data_for_x(camera_ids, x) + + # calculate distance between x and origins + origins = train_dataset.camtoworlds[camera_ids, :3, -1] + t = (origins - x).norm(dim=-1, keepdim=True) + # compute actual step size used in marching, based on the distance to the camera. + step_size = torch.clamp( + t * args.cone_angle, min=render_step_size + ) + # filter out the points that are not in the near far plane. + if (near_plane is not None) and (far_plane is not None): + step_size = torch.where( + (t > near_plane) & (t < far_plane), + step_size, + torch.zeros_like(step_size), + ) + + # calculate t_starts, t_ends + t_starts, t_ends = t - step_size / 2, t + step_size / 2 + + # compute mean, cov of the samples. + x, xcov = cast_rays( + t_starts, + t_ends, + occ_rays.origins, + occ_rays.viewdirs, + occ_rays.radii, + args.ray_shape, + ) + + # compute occupancy + density = radiance_field.query_density(x, xcov) + return density * step_size + + # update occupancy grid + occupancy_grid.every_n_step(step=step, occ_eval_fn=occ_eval_fn) + + # render + rgb, acc, depth, n_rendering_samples = render_image( + radiance_field, + occupancy_grid, + rays, + scene_aabb, + # rendering options + near_plane=near_plane, + far_plane=far_plane, + render_step_size=render_step_size, + render_bkgd=render_bkgd, + cone_angle=args.cone_angle, + ray_shape=args.ray_shape, + ) + if n_rendering_samples == 0: + continue + + # dynamic batch size for rays to keep sample batch size constant. + num_rays = len(pixels) + num_rays = int( + num_rays + * (target_sample_batch_size / float(n_rendering_samples)) + ) + train_dataset.update_num_rays(num_rays) + alive_ray_mask = acc.squeeze(-1) > 0 + + # compute loss + loss = F.smooth_l1_loss(rgb[alive_ray_mask], pixels[alive_ray_mask]) + + optimizer.zero_grad() + # do not unscale it because we are using Adam. + grad_scaler.scale(loss).backward() + optimizer.step() + scheduler.step() + + if step % 5000 == 0: + elapsed_time = time.time() - tic + loss = F.mse_loss(rgb[alive_ray_mask], pixels[alive_ray_mask]) + print( + f"elapsed_time={elapsed_time:.2f}s | step={step} | " + f"loss={loss:.5f} | " + f"alive_ray_mask={alive_ray_mask.long().sum():d} | " + f"n_rendering_samples={n_rendering_samples:d} | num_rays={len(pixels):d} |" + ) + + if step >= 0 and step % max_steps == 0 and step > 0: + # evaluation + radiance_field.eval() + + psnrs = [] + with torch.no_grad(): + for i in tqdm.tqdm(range(len(test_dataset))): + data = test_dataset[i] + render_bkgd = data["color_bkgd"] + rays = data["rays"] + pixels = data["pixels"] + + # rendering + rgb, acc, depth, _ = render_image( + radiance_field, + occupancy_grid, + rays, + scene_aabb, + # rendering options + near_plane=None, + far_plane=None, + render_step_size=render_step_size, + render_bkgd=render_bkgd, + cone_angle=args.cone_angle, + # test options + test_chunk_size=args.test_chunk_size, + ray_shape=args.ray_shape, + ) + mse = F.mse_loss(rgb, pixels) + psnr = -10.0 * torch.log(mse) / np.log(10.0) + psnrs.append(psnr.item()) + # imageio.imwrite( + # "acc_binary_test.png", + # ((acc > 0).float().cpu().numpy() * 255).astype(np.uint8), + # ) + # imageio.imwrite( + # "rgb_test.png", + # (rgb.cpu().numpy() * 255).astype(np.uint8), + # ) + # break + psnr_avg = sum(psnrs) / len(psnrs) + print(f"evaluation: psnr_avg={psnr_avg}") + train_dataset.training = True + + if step == max_steps: + print("training stops") + exit() + + step += 1 diff --git a/examples/train_mlp_dnerf.py b/examples/train_mlp_dnerf.py index 9702ee35..9a43c4bb 100644 --- a/examples/train_mlp_dnerf.py +++ b/examples/train_mlp_dnerf.py @@ -24,6 +24,12 @@ set_random_seed(42) parser = argparse.ArgumentParser() + parser.add_argument( + "--data_root", + type=str, + default=str(pathlib.Path.home() / "data"), + help="the root dir of the dataset", + ) parser.add_argument( "--train_split", type=str, @@ -91,7 +97,7 @@ gamma=0.33, ) # setup the dataset - data_root_fp = "/home/ruilongli/data/dnerf/" + data_root_fp = str(pathlib.Path(args.data_root) / "dnerf") target_sample_batch_size = 1 << 16 grid_resolution = 128 diff --git a/examples/train_mlp_nerf.py b/examples/train_mlp_nerf.py index aa114a44..d13ef62a 100644 --- a/examples/train_mlp_nerf.py +++ b/examples/train_mlp_nerf.py @@ -23,6 +23,12 @@ set_random_seed(42) parser = argparse.ArgumentParser() + parser.add_argument( + "--data_root", + type=str, + default=str(pathlib.Path.home() / "data"), + help="the root dir of the dataset", + ) parser.add_argument( "--train_split", type=str, @@ -112,7 +118,7 @@ if args.scene == "garden": from datasets.nerf_360_v2 import SubjectLoader - data_root_fp = "/home/ruilongli/data/360_v2/" + data_root_fp = str(pathlib.Path(args.data_root) / "360_v2") target_sample_batch_size = 1 << 16 train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4} test_dataset_kwargs = {"factor": 4} @@ -120,7 +126,7 @@ else: from datasets.nerf_synthetic import SubjectLoader - data_root_fp = "/home/ruilongli/data/nerf_synthetic/" + data_root_fp = str(pathlib.Path(args.data_root) / "nerf_synthetic") target_sample_batch_size = 1 << 16 grid_resolution = 128 diff --git a/examples/train_ngp_nerf.py b/examples/train_ngp_nerf.py index 584cc13a..295455fa 100644 --- a/examples/train_ngp_nerf.py +++ b/examples/train_ngp_nerf.py @@ -23,6 +23,12 @@ set_random_seed(42) parser = argparse.ArgumentParser() + parser.add_argument( + "--data_root", + type=str, + default=str(pathlib.Path.home() / "data"), + help="the root dir of the dataset", + ) parser.add_argument( "--train_split", type=str, @@ -87,7 +93,7 @@ if args.unbounded: from datasets.nerf_360_v2 import SubjectLoader - data_root_fp = "/home/ruilongli/data/360_v2/" + data_root_fp = str(pathlib.Path(args.data_root) / "360_v2") target_sample_batch_size = 1 << 20 train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4} test_dataset_kwargs = {"factor": 4} @@ -95,7 +101,7 @@ else: from datasets.nerf_synthetic import SubjectLoader - data_root_fp = "/home/ruilongli/data/nerf_synthetic/" + data_root_fp = str(pathlib.Path(args.data_root) / "nerf_synthetic") target_sample_batch_size = 1 << 18 grid_resolution = 128 diff --git a/examples/train_ngp_nerf_proposal.py b/examples/train_ngp_nerf_proposal.py index 3b2f2d92..592f7f5d 100644 --- a/examples/train_ngp_nerf_proposal.py +++ b/examples/train_ngp_nerf_proposal.py @@ -138,6 +138,12 @@ def rgb_sigma_fn(t_starts, t_ends, ray_indices): set_random_seed(42) parser = argparse.ArgumentParser() + parser.add_argument( + "--data_root", + type=str, + default=str(pathlib.Path.home() / "data"), + help="the root dir of the dataset", + ) parser.add_argument( "--train_split", type=str, @@ -202,14 +208,14 @@ def rgb_sigma_fn(t_starts, t_ends, ray_indices): if args.unbounded: from datasets.nerf_360_v2 import SubjectLoader - data_root_fp = "/home/ruilongli/data/360_v2/" + data_root_fp = str(pathlib.Path(args.data_root) / "360_v2") target_sample_batch_size = 1 << 20 train_dataset_kwargs = {"color_bkgd_aug": "random", "factor": 4} test_dataset_kwargs = {"factor": 4} else: from datasets.nerf_synthetic import SubjectLoader - data_root_fp = "/home/ruilongli/data/nerf_synthetic/" + data_root_fp = str(pathlib.Path(args.data_root) / "nerf_synthetic") target_sample_batch_size = 1 << 18 train_dataset = SubjectLoader(