Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "examples/8_nerfacc/pycolmap"]
path = examples/8_nerfacc/pycolmap
url = https://github.com/rmbrualla/pycolmap.git
10,489 changes: 10,489 additions & 0 deletions client/yarn.lock

Large diffs are not rendered by default.

20 changes: 20 additions & 0 deletions examples/8_nerfacc/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
## NerfAcc visualization example

The majority of code is copy-pasted from the [nerfacc](https://github.com/KAIR-BAIR/nerfacc) repo

```
# Setup.
cd <PATH/TO/8_nerfacc>
pip install -r requirements.txt
git submodule update --init

# Train models with train_ngp_nerf*.py...
# python train_ngp_nerf.py --data_root <PATH/TO/BLENDER> --scene lego
# python train_ngp_nerf_prop.py --data_root <PATH/TO/360> --scene garden
# ... Or use provided checkpoints.
bash ../assets/download_nerfacc_checkpoints.sh

# Visualize using viser!
python visualize_ngp_nerf.py --data_root <PATH/TO/BLENDER> --scene lego
python visualize_ngp_nerf_prop.py --data_root <PATH/TO/360> --scene garden
```
Empty file.
236 changes: 236 additions & 0 deletions examples/8_nerfacc/datasets/dnerf_synthetic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
"""

import json
import os

import imageio.v2 as imageio
import numpy as np
import torch
import torch.nn.functional as F

from .utils import Rays


def _load_renderings(root_fp: str, subject_id: str, split: str):
"""Load images from disk."""
if not root_fp.startswith("/"):
# allow relative path. e.g., "./data/dnerf_synthetic/"
root_fp = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"..",
"..",
root_fp,
)

data_dir = os.path.join(root_fp, subject_id)
with open(
os.path.join(data_dir, "transforms_{}.json".format(split)), "r"
) as fp:
meta = json.load(fp)
images = []
camtoworlds = []
timestamps = []

for i in range(len(meta["frames"])):
frame = meta["frames"][i]
fname = os.path.join(data_dir, frame["file_path"] + ".png")
rgba = imageio.imread(fname)
timestamp = (
frame["time"]
if "time" in frame
else float(i) / (len(meta["frames"]) - 1)
)
timestamps.append(timestamp)
camtoworlds.append(frame["transform_matrix"])
images.append(rgba)

images = np.stack(images, axis=0)
camtoworlds = np.stack(camtoworlds, axis=0)
timestamps = np.stack(timestamps, axis=0)

h, w = images.shape[1:3]
camera_angle_x = float(meta["camera_angle_x"])
focal = 0.5 * w / np.tan(0.5 * camera_angle_x)

return images, camtoworlds, focal, timestamps


class SubjectLoader(torch.utils.data.Dataset):
"""Single subject data loader for training and evaluation."""

SPLITS = ["train", "val", "test"]
SUBJECT_IDS = [
"bouncingballs",
"hellwarrior",
"hook",
"jumpingjacks",
"lego",
"mutant",
"standup",
"trex",
]

WIDTH, HEIGHT = 800, 800
NEAR, FAR = 2.0, 6.0
OPENGL_CAMERA = True

def __init__(
self,
subject_id: str,
root_fp: str,
split: str,
color_bkgd_aug: str = "white",
num_rays: int = None,
near: float = None,
far: float = None,
batch_over_images: bool = True,
device: str = "cpu",
):
super().__init__()
assert split in self.SPLITS, "%s" % split
assert subject_id in self.SUBJECT_IDS, "%s" % subject_id
assert color_bkgd_aug in ["white", "black", "random"]
self.split = split
self.num_rays = num_rays
self.near = self.NEAR if near is None else near
self.far = self.FAR if far is None else far
self.training = (num_rays is not None) and (
split in ["train", "trainval"]
)
self.color_bkgd_aug = color_bkgd_aug
self.batch_over_images = batch_over_images
(
self.images,
self.camtoworlds,
self.focal,
self.timestamps,
) = _load_renderings(root_fp, subject_id, split)
self.images = torch.from_numpy(self.images).to(device).to(torch.uint8)
self.camtoworlds = (
torch.from_numpy(self.camtoworlds).to(device).to(torch.float32)
)
self.timestamps = (
torch.from_numpy(self.timestamps)
.to(device)
.to(torch.float32)[:, None]
)
self.K = torch.tensor(
[
[self.focal, 0, self.WIDTH / 2.0],
[0, self.focal, self.HEIGHT / 2.0],
[0, 0, 1],
],
dtype=torch.float32,
device=device,
) # (3, 3)
assert self.images.shape[1:3] == (self.HEIGHT, self.WIDTH)

def __len__(self):
return len(self.images)

@torch.no_grad()
def __getitem__(self, index):
data = self.fetch_data(index)
data = self.preprocess(data)
return data

def preprocess(self, data):
"""Process the fetched / cached data with randomness."""
rgba, rays = data["rgba"], data["rays"]
pixels, alpha = torch.split(rgba, [3, 1], dim=-1)

if self.training:
if self.color_bkgd_aug == "random":
color_bkgd = torch.rand(3, device=self.images.device)
elif self.color_bkgd_aug == "white":
color_bkgd = torch.ones(3, device=self.images.device)
elif self.color_bkgd_aug == "black":
color_bkgd = torch.zeros(3, device=self.images.device)
else:
# just use white during inference
color_bkgd = torch.ones(3, device=self.images.device)

pixels = pixels * alpha + color_bkgd * (1.0 - alpha)
return {
"pixels": pixels, # [n_rays, 3] or [h, w, 3]
"rays": rays, # [n_rays,] or [h, w]
"color_bkgd": color_bkgd, # [3,]
**{k: v for k, v in data.items() if k not in ["rgba", "rays"]},
}

def update_num_rays(self, num_rays):
self.num_rays = num_rays

def fetch_data(self, index):
"""Fetch the data (it maybe cached for multiple batches)."""
num_rays = self.num_rays

if self.training:
if self.batch_over_images:
image_id = torch.randint(
0,
len(self.images),
size=(num_rays,),
device=self.images.device,
)
else:
image_id = [index]
x = torch.randint(
0, self.WIDTH, size=(num_rays,), device=self.images.device
)
y = torch.randint(
0, self.HEIGHT, size=(num_rays,), device=self.images.device
)
else:
image_id = [index]
x, y = torch.meshgrid(
torch.arange(self.WIDTH, device=self.images.device),
torch.arange(self.HEIGHT, device=self.images.device),
indexing="xy",
)
x = x.flatten()
y = y.flatten()

# generate rays
rgba = self.images[image_id, y, x] / 255.0 # (num_rays, 4)
c2w = self.camtoworlds[image_id] # (num_rays, 3, 4)
camera_dirs = F.pad(
torch.stack(
[
(x - self.K[0, 2] + 0.5) / self.K[0, 0],
(y - self.K[1, 2] + 0.5)
/ self.K[1, 1]
* (-1.0 if self.OPENGL_CAMERA else 1.0),
],
dim=-1,
),
(0, 1),
value=(-1.0 if self.OPENGL_CAMERA else 1.0),
) # [num_rays, 3]

# [n_cams, height, width, 3]
directions = (camera_dirs[:, None, :] * c2w[:, :3, :3]).sum(dim=-1)
origins = torch.broadcast_to(c2w[:, :3, -1], directions.shape)
viewdirs = directions / torch.linalg.norm(
directions, dim=-1, keepdims=True
)

if self.training:
origins = torch.reshape(origins, (num_rays, 3))
viewdirs = torch.reshape(viewdirs, (num_rays, 3))
rgba = torch.reshape(rgba, (num_rays, 4))
else:
origins = torch.reshape(origins, (self.HEIGHT, self.WIDTH, 3))
viewdirs = torch.reshape(viewdirs, (self.HEIGHT, self.WIDTH, 3))
rgba = torch.reshape(rgba, (self.HEIGHT, self.WIDTH, 4))

rays = Rays(origins=origins, viewdirs=viewdirs)
timestamps = self.timestamps[image_id]

return {
"rgba": rgba, # [h, w, 4] or [num_rays, 4]
"rays": rays, # [h, w, 3] or [num_rays, 3]
"timestamps": timestamps, # [num_rays, 1]
}
Loading