-
Notifications
You must be signed in to change notification settings - Fork 33
Expand file tree
/
Copy pathsampler.py
More file actions
40 lines (31 loc) · 969 Bytes
/
sampler.py
File metadata and controls
40 lines (31 loc) · 969 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import math
from typing import List
import torch
from ray_utils import RayBundle
from pytorch3d.renderer.cameras import CamerasBase
# Sampler which implements stratified (uniform) point sampling along rays
class StratifiedRaysampler(torch.nn.Module):
def __init__(
self,
cfg
):
super().__init__()
self.n_pts_per_ray = cfg.n_pts_per_ray
self.min_depth = cfg.min_depth
self.max_depth = cfg.max_depth
def forward(
self,
ray_bundle,
):
# TODO (Q1.4): Compute z values for self.n_pts_per_ray points uniformly sampled between [near, far]
z_vals = None
# TODO (Q1.4): Sample points from z values
sample_points = None
# Return
return ray_bundle._replace(
sample_points=sample_points,
sample_lengths=z_vals * torch.ones_like(sample_points[..., :1]),
)
sampler_dict = {
'stratified': StratifiedRaysampler
}