-
Notifications
You must be signed in to change notification settings - Fork 55
Description
Hi, I’m experimenting with some basic tensor indexing operations (e.g., slicing) in Dr.Jit and noticed that integer indexing behaves a bit differently from NumPy/PyTorch.
In NumPy/PyTorch, a pair of integer index tensors is interpreted as a coordinate list (with broadcasting if needed). In Dr.Jit, the same code seems to broadcast the index tensors in a way that gives a different result shape.
Here is a minimal example:
import drjit as dr
from drjit.cuda.ad import TensorXf, TensorXi
t = dr.zeros(TensorXf, shape=(100, 200))
b = dr.arange(TensorXi, 100)
d = dr.arange(TensorXi, 100)
print(t[b, d].shape) # Dr.Jit: (100, 100)
# PyTorch equivalents
print(t.torch()[b.torch(), d.torch()].shape) # torch.Size([100])
print(t.torch()[b.torch()[:, None], d.torch()].shape) # torch.Size([100, 100])So in this example:
t[b, d] in Dr.Jit produces a (100, 100) tensor.
In PyTorch, the grid-like behavior corresponds to t[b[:, None], d], while t[b, d] itself uses the “coordinate list” interpretation and returns a (100,) tensor.
Right now, to mimic the NumPy/PyTorch “coordinate list” behavior in Dr.Jit, I’m flattening the tensor and manually computing a 1D index.
Is there a recommended way in Dr.Jit to perform NumPy/PyTorch-style “coordinate list” integer indexing without manually flattening and computing the indices?
Thanks!