Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
51d1c6f
initial branch commit
Crezle Feb 12, 2025
ed71564
format a bigger message for SH coeffs
Crezle Feb 12, 2025
680e78f
add new gsplat message properties definitions
Crezle Feb 12, 2025
6713f5f
add new gsplat message properties definitions 2.0
Crezle Feb 12, 2025
3be0cd3
partial progress SH and norm addition on client end
Crezle Feb 12, 2025
cad878a
update progress on SH implementation
Crezle Feb 12, 2025
4f35d20
add temporary SH rendering equations
Crezle Feb 13, 2025
7babd95
minor tweaks
Crezle Feb 13, 2025
a6f43f0
add SH functioning, but not correct
Crezle Feb 13, 2025
7dc5433
change example script to include SH and normals
Crezle Feb 13, 2025
2751dac
make SH code more readable according to theory
Crezle Feb 17, 2025
aa43e8d
re-ordered f_rest_* are interpreted for correct SH representation
Crezle Feb 17, 2025
0a75ea1
remove surface normals code and clean up some comments
Crezle Feb 18, 2025
81fbb44
minor tweaks
Crezle Feb 18, 2025
464c639
Merge branch 'main' into extended_gsplat_properties
Crezle Feb 18, 2025
ead3811
Merge branch 'main' into extended_gsplat_properties
Crezle Feb 26, 2025
4572b36
Merge branch 'main' into extended_gsplat_properties
brentyi Mar 22, 2025
71243bc
Merge branch 'main' into extended_gsplat_properties
Crezle Mar 24, 2025
e765f96
Merge branch 'main' into extended_gsplat_properties
Crezle Apr 3, 2025
89991c4
Merge branch 'main' into extended_gsplat_properties
Crezle Apr 6, 2025
fdbaf6c
Merge branch 'main' into extended_gsplat_properties
Crezle Apr 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions examples/experimental/gaussian_splats.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class SplatFile(TypedDict):
"""(N, 1). Range [0, 1]."""
covariances: npt.NDArray[np.floating]
"""(N, 3, 3)."""
sh_coeffs: npt.NDArray[np.floating]


def load_splat_file(splat_path: Path, center: bool = False) -> SplatFile:
Expand Down Expand Up @@ -69,6 +70,8 @@ def load_splat_file(splat_path: Path, center: bool = False) -> SplatFile:
"opacities": splat_uint8[:, 27:28] / 255.0,
# Covariances should have shape (N, 3, 3).
"covariances": covariances,
# No SH coefficients in the splat file.
"sh_coeffs": np.zeros((num_gaussians, 45), dtype=np.float32),
}


Expand All @@ -85,6 +88,17 @@ def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile:
wxyzs = np.stack([v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]], axis=1)
colors = 0.5 + SH_C0 * np.stack([v["f_dc_0"], v["f_dc_1"], v["f_dc_2"]], axis=1)
opacities = 1.0 / (1.0 + np.exp(-v["opacity"][:, None]))
dc_coeffs = np.stack([v["f_dc_0"], v["f_dc_1"], v["f_dc_2"]], axis=1)
# Rest coefficients 0-14 belongs to RED channel, 15-29 to GREEN, 30-44 to BLUE
# Due to spherical harmonic calculations calculating a triplet at a time
# we need to stack them by (0,15,30), (1,16,31), ..., (14,29,44)
rest_coeffs = []
for i in range(15):
rest_coeffs.append(v[f"f_rest_{i}"])
rest_coeffs.append(v[f"f_rest_{i + 15}"])
rest_coeffs.append(v[f"f_rest_{i + 30}"])
rest_coeffs = np.stack(rest_coeffs, axis=1)
sh_coeffs = np.concatenate([dc_coeffs, rest_coeffs], axis=1)

Rs = tf.SO3(wxyzs).as_matrix()
covariances = np.einsum(
Expand All @@ -102,6 +116,7 @@ def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile:
"rgbs": colors,
"opacities": opacities,
"covariances": covariances,
"sh_coeffs": sh_coeffs,
}


Expand Down Expand Up @@ -136,6 +151,7 @@ def _(event: viser.GuiEvent) -> None:
rgbs=splat_data["rgbs"],
opacities=splat_data["opacities"],
covariances=splat_data["covariances"],
sh_coeffs=splat_data["sh_coeffs"],
)

remove_button = server.gui.add_button(f"Remove splat object {i}")
Expand Down
1 change: 1 addition & 0 deletions src/viser/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,6 +1400,7 @@ class GaussianSplatsProps:
# Memory layout is borrowed from:
# https://github.com/antimatter15/splat
buffer: npt.NDArray[np.uint32]
sh_buffer: npt.NDArray[np.uint32]
"""Our buffer will contain:
- x as f32
- y as f32
Expand Down
25 changes: 25 additions & 0 deletions src/viser/_scene_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1508,6 +1508,7 @@ def add_gaussian_splats(
covariances: np.ndarray,
rgbs: np.ndarray,
opacities: np.ndarray,
sh_coeffs: np.ndarray | None = None,
wxyz: Tuple[float, float, float, float] | np.ndarray = (1.0, 0.0, 0.0, 0.0),
position: Tuple[float, float, float] | np.ndarray = (0.0, 0.0, 0.0),
visible: bool = True,
Expand All @@ -1523,6 +1524,7 @@ def add_gaussian_splats(
covariances: Second moment for each Gaussian. (N, 3, 3).
rgbs: Color for each Gaussian. (N, 3).
opacities: Opacity for each Gaussian. (N, 1).
sh_coeffs: Spherical harmonics coefficients for each Gaussian. (N, 48).
wxyz: R_parent_local transformation.
position: t_parent_local transformation.
visible: Initial visibility of scene node.
Expand Down Expand Up @@ -1556,10 +1558,33 @@ def add_gaussian_splats(
).view(np.uint32)
assert buffer.shape == (num_gaussians, 8)

if sh_coeffs is not None:
assert sh_coeffs.shape == (num_gaussians, 48)
sh_buffer = np.concatenate(
[
sh_coeffs.astype(np.float16).copy().view(np.uint8)
],
).view(np.uint32)
else:
# To ensure backwards compatibility, we'll compute SH coefficients from
# the RGB values.
# However, this is not efficient as packets sent to client
# will be larger. TODO: Permanently incorporate colors with SH coefficients.
SH_C0 = 0.28209479177387814
dc_coeffs = (rgbs - 0.5) / SH_C0
sh_buffer = np.concatenate(
[
dc_coeffs, np.zeros((num_gaussians, 45))
],
axis=1,
dtype=np.float16,
).view(np.uint32)

message = _messages.GaussianSplatsMessage(
name=name,
props=_messages.GaussianSplatsProps(
buffer=buffer,
sh_buffer=sh_buffer,
),
)
node_handle = GaussianSplatHandle._make(
Expand Down
9 changes: 9 additions & 0 deletions src/viser/client/src/SceneTree.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,15 @@ function useObjectFactory(message: SceneNodeMessage | undefined): {
),
)
}
sh_buffer={
new Uint32Array(
message.props.sh_buffer.buffer.slice(
message.props.sh_buffer.byteOffset,
message.props.sh_buffer.byteOffset +
message.props.sh_buffer.byteLength,
),
)
}
/>
),
};
Expand Down
40 changes: 35 additions & 5 deletions src/viser/client/src/Splatting/GaussianSplats.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ export const SplatObject = React.forwardRef<
THREE.Group,
{
buffer: Uint32Array;
sh_buffer: Uint32Array;
}
>(function SplatObject({ buffer }, ref) {
>(function SplatObject({ buffer, sh_buffer }, ref) {
const splatContext = React.useContext(GaussianSplatsContext)!;
const setBuffer = splatContext.useGaussianSplatStore(
(state) => state.setBuffer,
Expand All @@ -74,11 +75,14 @@ export const SplatObject = React.forwardRef<
(state) => state.nodeRefFromId,
);
const name = React.useMemo(() => uuidv4(), [buffer]);
const sh_buffer_name = `sh_buffer_${name}`;

React.useEffect(() => {
setBuffer(name, buffer);
setBuffer(sh_buffer_name, sh_buffer);
return () => {
removeBuffer(name);
removeBuffer(sh_buffer_name);
delete nodeRefFromId.current[name];
};
}, [buffer]);
Expand Down Expand Up @@ -129,6 +133,7 @@ function SplatRendererImpl() {
const merged = mergeGaussianGroups(groupBufferFromId);
const meshProps = useGaussianMeshProps(
merged.gaussianBuffer,
merged.combinedSHBuffer,
merged.numGroups,
);
splatContext.meshPropsRef.current = meshProps;
Expand All @@ -146,6 +151,7 @@ function SplatRendererImpl() {
if (!initializedBufferTexture) {
meshProps.material.uniforms.numGaussians.value = merged.numGaussians;
meshProps.textureBuffer.needsUpdate = true;
meshProps.shTextureBuffer.needsUpdate = true;
initializedBufferTexture = true;
}
};
Expand All @@ -161,6 +167,7 @@ function SplatRendererImpl() {
React.useEffect(() => {
return () => {
meshProps.textureBuffer.dispose();
meshProps.shTextureBuffer.dispose();
meshProps.geometry.dispose();
meshProps.material.dispose();
postToWorker({ close: true });
Expand Down Expand Up @@ -336,7 +343,12 @@ function mergeGaussianGroups(groupBufferFromName: {
}) {
// Create geometry. Each Gaussian will be rendered as a quad.
let totalBufferLength = 0;
for (const buffer of Object.values(groupBufferFromName)) {
const groupBufferFromNameFiltered = Object.fromEntries(
Object.entries(groupBufferFromName).filter(
([key]) => !key.startsWith("sh_buffer_")
)
);
for (const buffer of Object.values(groupBufferFromNameFiltered)) {
totalBufferLength += buffer.length;
}
const numGaussians = totalBufferLength / 8;
Expand All @@ -345,7 +357,7 @@ function mergeGaussianGroups(groupBufferFromName: {

let offset = 0;
for (const [groupIndex, groupBuffer] of Object.values(
groupBufferFromName,
groupBufferFromNameFiltered,
).entries()) {
groupIndices.fill(
groupIndex,
Expand All @@ -366,6 +378,24 @@ function mergeGaussianGroups(groupBufferFromName: {
offset += groupBuffer.length;
}

const numGroups = Object.keys(groupBufferFromName).length;
return { numGaussians, gaussianBuffer, numGroups, groupIndices };
let totalSHBufferLength = 0;

const shGaussianBuffers = Object.fromEntries(
Object.entries(groupBufferFromName).filter(([key]) => key.startsWith("sh_buffer_"))
);

for (const sh_buffer of Object.values(shGaussianBuffers)) {
totalSHBufferLength += sh_buffer.length;
}

const combinedSHBuffer = new Uint32Array(totalSHBufferLength);
let sh_offset = 0;
for (const sh_buffer of Object.values(shGaussianBuffers)) {
combinedSHBuffer.set(sh_buffer, sh_offset);
sh_offset += sh_buffer.length;
}

const numGroups = Object.keys(groupBufferFromNameFiltered).length;

return { numGaussians, gaussianBuffer, numGroups, groupIndices, combinedSHBuffer };
}
Loading