diff --git a/examples/experimental/gaussian_splats.py b/examples/experimental/gaussian_splats.py index 0fff1085a..dbd28c876 100644 --- a/examples/experimental/gaussian_splats.py +++ b/examples/experimental/gaussian_splats.py @@ -26,6 +26,7 @@ class SplatFile(TypedDict): """(N, 1). Range [0, 1].""" covariances: npt.NDArray[np.floating] """(N, 3, 3).""" + sh_coeffs: npt.NDArray[np.floating] def load_splat_file(splat_path: Path, center: bool = False) -> SplatFile: @@ -69,6 +70,8 @@ def load_splat_file(splat_path: Path, center: bool = False) -> SplatFile: "opacities": splat_uint8[:, 27:28] / 255.0, # Covariances should have shape (N, 3, 3). "covariances": covariances, + # No SH coefficients in the splat file. + "sh_coeffs": np.zeros((num_gaussians, 45), dtype=np.float32), } @@ -85,6 +88,17 @@ def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile: wxyzs = np.stack([v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]], axis=1) colors = 0.5 + SH_C0 * np.stack([v["f_dc_0"], v["f_dc_1"], v["f_dc_2"]], axis=1) opacities = 1.0 / (1.0 + np.exp(-v["opacity"][:, None])) + dc_coeffs = np.stack([v["f_dc_0"], v["f_dc_1"], v["f_dc_2"]], axis=1) + # Rest coefficients 0-14 belongs to RED channel, 15-29 to GREEN, 30-44 to BLUE + # Due to spherical harmonic calculations calculating a triplet at a time + # we need to stack them by (0,15,30), (1,16,31), ..., (14,29,44) + rest_coeffs = [] + for i in range(15): + rest_coeffs.append(v[f"f_rest_{i}"]) + rest_coeffs.append(v[f"f_rest_{i + 15}"]) + rest_coeffs.append(v[f"f_rest_{i + 30}"]) + rest_coeffs = np.stack(rest_coeffs, axis=1) + sh_coeffs = np.concatenate([dc_coeffs, rest_coeffs], axis=1) Rs = tf.SO3(wxyzs).as_matrix() covariances = np.einsum( @@ -102,6 +116,7 @@ def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile: "rgbs": colors, "opacities": opacities, "covariances": covariances, + "sh_coeffs": sh_coeffs, } @@ -136,6 +151,7 @@ def _(event: viser.GuiEvent) -> None: rgbs=splat_data["rgbs"], opacities=splat_data["opacities"], covariances=splat_data["covariances"], + sh_coeffs=splat_data["sh_coeffs"], ) remove_button = server.gui.add_button(f"Remove splat object {i}") diff --git a/src/viser/_messages.py b/src/viser/_messages.py index 045791998..761957a3f 100644 --- a/src/viser/_messages.py +++ b/src/viser/_messages.py @@ -1400,6 +1400,7 @@ class GaussianSplatsProps: # Memory layout is borrowed from: # https://github.com/antimatter15/splat buffer: npt.NDArray[np.uint32] + sh_buffer: npt.NDArray[np.uint32] """Our buffer will contain: - x as f32 - y as f32 diff --git a/src/viser/_scene_api.py b/src/viser/_scene_api.py index 7c2495058..6943553a8 100644 --- a/src/viser/_scene_api.py +++ b/src/viser/_scene_api.py @@ -1508,6 +1508,7 @@ def add_gaussian_splats( covariances: np.ndarray, rgbs: np.ndarray, opacities: np.ndarray, + sh_coeffs: np.ndarray | None = None, wxyz: Tuple[float, float, float, float] | np.ndarray = (1.0, 0.0, 0.0, 0.0), position: Tuple[float, float, float] | np.ndarray = (0.0, 0.0, 0.0), visible: bool = True, @@ -1523,6 +1524,7 @@ def add_gaussian_splats( covariances: Second moment for each Gaussian. (N, 3, 3). rgbs: Color for each Gaussian. (N, 3). opacities: Opacity for each Gaussian. (N, 1). + sh_coeffs: Spherical harmonics coefficients for each Gaussian. (N, 48). wxyz: R_parent_local transformation. position: t_parent_local transformation. visible: Initial visibility of scene node. @@ -1556,10 +1558,33 @@ def add_gaussian_splats( ).view(np.uint32) assert buffer.shape == (num_gaussians, 8) + if sh_coeffs is not None: + assert sh_coeffs.shape == (num_gaussians, 48) + sh_buffer = np.concatenate( + [ + sh_coeffs.astype(np.float16).copy().view(np.uint8) + ], + ).view(np.uint32) + else: + # To ensure backwards compatibility, we'll compute SH coefficients from + # the RGB values. + # However, this is not efficient as packets sent to client + # will be larger. TODO: Permanently incorporate colors with SH coefficients. + SH_C0 = 0.28209479177387814 + dc_coeffs = (rgbs - 0.5) / SH_C0 + sh_buffer = np.concatenate( + [ + dc_coeffs, np.zeros((num_gaussians, 45)) + ], + axis=1, + dtype=np.float16, + ).view(np.uint32) + message = _messages.GaussianSplatsMessage( name=name, props=_messages.GaussianSplatsProps( buffer=buffer, + sh_buffer=sh_buffer, ), ) node_handle = GaussianSplatHandle._make( diff --git a/src/viser/client/src/SceneTree.tsx b/src/viser/client/src/SceneTree.tsx index df5f0cf00..7b5ed5586 100644 --- a/src/viser/client/src/SceneTree.tsx +++ b/src/viser/client/src/SceneTree.tsx @@ -503,6 +503,15 @@ function useObjectFactory(message: SceneNodeMessage | undefined): { ), ) } + sh_buffer={ + new Uint32Array( + message.props.sh_buffer.buffer.slice( + message.props.sh_buffer.byteOffset, + message.props.sh_buffer.byteOffset + + message.props.sh_buffer.byteLength, + ), + ) + } /> ), }; diff --git a/src/viser/client/src/Splatting/GaussianSplats.tsx b/src/viser/client/src/Splatting/GaussianSplats.tsx index a5387d545..6ab8462e9 100644 --- a/src/viser/client/src/Splatting/GaussianSplats.tsx +++ b/src/viser/client/src/Splatting/GaussianSplats.tsx @@ -61,8 +61,9 @@ export const SplatObject = React.forwardRef< THREE.Group, { buffer: Uint32Array; + sh_buffer: Uint32Array; } ->(function SplatObject({ buffer }, ref) { +>(function SplatObject({ buffer, sh_buffer }, ref) { const splatContext = React.useContext(GaussianSplatsContext)!; const setBuffer = splatContext.useGaussianSplatStore( (state) => state.setBuffer, @@ -74,11 +75,14 @@ export const SplatObject = React.forwardRef< (state) => state.nodeRefFromId, ); const name = React.useMemo(() => uuidv4(), [buffer]); + const sh_buffer_name = `sh_buffer_${name}`; React.useEffect(() => { setBuffer(name, buffer); + setBuffer(sh_buffer_name, sh_buffer); return () => { removeBuffer(name); + removeBuffer(sh_buffer_name); delete nodeRefFromId.current[name]; }; }, [buffer]); @@ -129,6 +133,7 @@ function SplatRendererImpl() { const merged = mergeGaussianGroups(groupBufferFromId); const meshProps = useGaussianMeshProps( merged.gaussianBuffer, + merged.combinedSHBuffer, merged.numGroups, ); splatContext.meshPropsRef.current = meshProps; @@ -146,6 +151,7 @@ function SplatRendererImpl() { if (!initializedBufferTexture) { meshProps.material.uniforms.numGaussians.value = merged.numGaussians; meshProps.textureBuffer.needsUpdate = true; + meshProps.shTextureBuffer.needsUpdate = true; initializedBufferTexture = true; } }; @@ -161,6 +167,7 @@ function SplatRendererImpl() { React.useEffect(() => { return () => { meshProps.textureBuffer.dispose(); + meshProps.shTextureBuffer.dispose(); meshProps.geometry.dispose(); meshProps.material.dispose(); postToWorker({ close: true }); @@ -336,7 +343,12 @@ function mergeGaussianGroups(groupBufferFromName: { }) { // Create geometry. Each Gaussian will be rendered as a quad. let totalBufferLength = 0; - for (const buffer of Object.values(groupBufferFromName)) { + const groupBufferFromNameFiltered = Object.fromEntries( + Object.entries(groupBufferFromName).filter( + ([key]) => !key.startsWith("sh_buffer_") + ) + ); + for (const buffer of Object.values(groupBufferFromNameFiltered)) { totalBufferLength += buffer.length; } const numGaussians = totalBufferLength / 8; @@ -345,7 +357,7 @@ function mergeGaussianGroups(groupBufferFromName: { let offset = 0; for (const [groupIndex, groupBuffer] of Object.values( - groupBufferFromName, + groupBufferFromNameFiltered, ).entries()) { groupIndices.fill( groupIndex, @@ -366,6 +378,24 @@ function mergeGaussianGroups(groupBufferFromName: { offset += groupBuffer.length; } - const numGroups = Object.keys(groupBufferFromName).length; - return { numGaussians, gaussianBuffer, numGroups, groupIndices }; + let totalSHBufferLength = 0; + + const shGaussianBuffers = Object.fromEntries( + Object.entries(groupBufferFromName).filter(([key]) => key.startsWith("sh_buffer_")) + ); + + for (const sh_buffer of Object.values(shGaussianBuffers)) { + totalSHBufferLength += sh_buffer.length; + } + + const combinedSHBuffer = new Uint32Array(totalSHBufferLength); + let sh_offset = 0; + for (const sh_buffer of Object.values(shGaussianBuffers)) { + combinedSHBuffer.set(sh_buffer, sh_offset); + sh_offset += sh_buffer.length; + } + + const numGroups = Object.keys(groupBufferFromNameFiltered).length; + + return { numGaussians, gaussianBuffer, numGroups, groupIndices, combinedSHBuffer }; } diff --git a/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts b/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts index c3ed1f9af..e36869dd9 100644 --- a/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts +++ b/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts @@ -29,6 +29,10 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( // copy quadjr for this. uniform usampler2D textureBuffer; + // Buffer for spherical harmonics; Each Gaussian gets 24 int32s representing + // this information (Each coefficient is 16 bits, corr. to 48 coeffs.). + uniform usampler2D shTextureBuffer; + // We could also use a uniform to store transforms, but this would be more // limiting in terms of the # of groups we can have. uniform sampler2D textureT_camera_groups; @@ -97,6 +101,33 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( vec2 triu23 = unpackHalf2x16(intBufferData.y); vec2 triu45 = unpackHalf2x16(intBufferData.z); + // Get spherical harmonics terms from int buffer. 48 coefficents per vertex. + uint shTexStart = sortedIndex * 6u; + ivec2 shTexSize = textureSize(shTextureBuffer, 0); + float sh_coeffs_unpacked[48]; + for (int i = 0; i < 6; i++) { + ivec2 shTexPos = ivec2((shTexStart + uint(i)) % uint(shTexSize.x), (shTexStart + uint(i)) / uint(shTexSize.x)); + uvec4 packedCoeffs = texelFetch(shTextureBuffer, shTexPos, 0); + + // unpack each uint32 directly into two float16 values, we read 4 at a time + vec2 unpacked; + unpacked = unpackHalf2x16(packedCoeffs.x); + sh_coeffs_unpacked[i*8] = unpacked.x; + sh_coeffs_unpacked[i*8+1] = unpacked.y; + + unpacked = unpackHalf2x16(packedCoeffs.y); + sh_coeffs_unpacked[i*8+2] = unpacked.x; + sh_coeffs_unpacked[i*8+3] = unpacked.y; + + unpacked = unpackHalf2x16(packedCoeffs.z); + sh_coeffs_unpacked[i*8+4] = unpacked.x; + sh_coeffs_unpacked[i*8+5] = unpacked.y; + + unpacked = unpackHalf2x16(packedCoeffs.w); + sh_coeffs_unpacked[i*8+6] = unpacked.x; + sh_coeffs_unpacked[i*8+7] = unpacked.y; + } + // Transition in. float startTime = 0.8 * float(sortedIndex) / float(numGaussians); float cov_scale = smoothstep(startTime, startTime + 0.2, transitionInState); @@ -130,12 +161,112 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( vec2 v1 = min(sqrt(2.0 * lambda1), 1024.0) * diagonalVector; vec2 v2 = min(sqrt(2.0 * lambda2), 1024.0) * vec2(diagonalVector.y, -diagonalVector.x); - vRgba = vec4( - float(rgbaUint32 & uint(0xFF)) / 255.0, - float((rgbaUint32 >> uint(8)) & uint(0xFF)) / 255.0, - float((rgbaUint32 >> uint(16)) & uint(0xFF)) / 255.0, - float(rgbaUint32 >> uint(24)) / 255.0 + // Calculate the spherical harmonics. + // According to gsplat implementation, seems that "x" and "y" have opposite direction + // of conventional SH directions, so square brackets contains the sign of resulting variable + // multiplications. + // A comprehensible table of Real SH constants: + // https://en.wikipedia.org/wiki/Table_of_spherical_harmonics + vec3 viewDir = normalize(center - cameraPosition); + // C0 = 0.5 * sqrt(1.0 / pi) + const float C0 = 0.28209479177387814; + // C1[0] = sqrt(3.0 / (4.0 * pi)) * [-1] + // C1[1] = sqrt(3.0 / (4.0 * pi)) * [1] + // C1[2] = sqrt(3.0 / (4.0 * pi)) * [-1] + const float C1[3] = float[3]( + -0.4886025119029199, + 0.4886025119029199, + -0.4886025119029199 + ); + // C2[0] = 0.5 * sqrt(15/pi) * [1] + // C2[1] = 0.5 * sqrt(15/pi) * [-1] + // C2[2] = 0.25 * sqrt(5/pi) * [1] + // C2[3] = 0.5 * sqrt(15/pi) * [-1] + // C2[4] = 0.25 * sqrt(15/pi) * [1] + const float C2[5] = float[5]( + 1.0925484305920792, + -1.0925484305920792, + 0.31539156525252005, + -1.0925484305920792, + 0.5462742152960396 ); + // C3[0] = 0.25 * sqrt(35/(2pi)) * [-1] + // C3[1] = 0.5 * sqrt(105/pi) * [1] + // C3[2] = 0.25 * sqrt(21/(2pi)) * [-1] + // C3[3] = 0.25 * sqrt(7/pi) * [1] + // C3[4] = 0.25 * sqrt(21/(2pi)) * [-1] + // C3[5] = 0.25 * sqrt(105/(pi)) * [1] + // C3[6] = 0.25 * sqrt(35/(2pi)) * [-1] + const float C3[7] = float[7]( + -0.5900435899266435, + 2.890611442640554, + -0.4570457994644658, + 0.3731763325901154, + -0.4570457994644658, + 1.445305721320277, + -0.5900435899266435 + ); + + vec3 sh_coeffs[16]; + for (int i = 0; i < 16; i++) { + sh_coeffs[i] = vec3(sh_coeffs_unpacked[i*3], sh_coeffs_unpacked[i*3+1], sh_coeffs_unpacked[i*3+2]); + } + + // View-dependent variables + + float x = viewDir.x; + float y = viewDir.y; + float z = viewDir.z; + float xx = viewDir.x * viewDir.x; + float yy = viewDir.y * viewDir.y; + float zz = viewDir.z * viewDir.z; + float xy = viewDir.x * viewDir.y; + float yz = viewDir.y * viewDir.z; + float xz = viewDir.x * viewDir.z; + + // 0th degree + vec3 rgb = C0 * sh_coeffs[0]; + vec3 pointFive = vec3(0.5, 0.5, 0.5); + + // 1st degree + // From here, variables are included in multiplication with constants + float pSH1 = C1[0] * y; + float pSH2 = C1[1] * z; + float pSH3 = C1[2] * x; + rgb = rgb + pSH1 * sh_coeffs[1] + + pSH2 * sh_coeffs[2] + + pSH3 * sh_coeffs[3]; + + // 2nd degree + float pSH4 = C2[0] * xy; + float pSH5 = C2[1] * yz; + float pSH6 = C2[2] * (3.0 * zz - 1.0); + float pSH7 = C2[3] * xz; + float pSH8 = C2[4] * (xx - yy); + rgb = rgb + pSH4 * sh_coeffs[4] + + pSH5 * sh_coeffs[5] + + pSH6 * sh_coeffs[6] + + pSH7 * sh_coeffs[7] + + pSH8 * sh_coeffs[8]; + + // 3rd degree + float pSH9 = C3[0] * y * (3.0 * xx - yy); + float pSH10 = C3[1] * x * y * z; + float pSH11 = C3[2] * y * (5.0 * zz - 1.0); + float pSH12 = C3[3] * z * (5.0 * zz - 3.0); + float pSH13 = C3[4] * x * (5.0 * zz - 1.0); + float pSH14 = C3[5] * (xx - yy) * z; + float pSH15 = C3[6] * x * (xx - 3.0 * yy); + rgb = rgb + pSH9 * sh_coeffs[9] + + pSH10 * sh_coeffs[10] + + pSH11 * sh_coeffs[11] + + pSH12 * sh_coeffs[12] + + pSH13 * sh_coeffs[13] + + pSH14 * sh_coeffs[14] + + pSH15 * sh_coeffs[15]; + + // Finalize the color + vRgba = vec4(rgb + pointFive, float(rgbaUint32 >> uint(24)) / 255.0); // Throw the Gaussian off the screen if it's too close, too far, or too small. float weightedDeterminant = vRgba.a * (diag1 * diag2 - offDiag * offDiag); @@ -169,6 +300,7 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( /**Hook to generate properties for rendering Gaussians via a three.js mesh.*/ export function useGaussianMeshProps( gaussianBuffer: Uint32Array, + combinedSHBuffer: Uint32Array, numGroups: number, ) { const numGaussians = gaussianBuffer.length / 8; @@ -198,6 +330,8 @@ export function useGaussianMeshProps( geometry.setAttribute("sortedIndex", sortedIndexAttribute); // Create texture buffers. + // We store 4 floats and 4 int32s per Gaussian. + // One "numGaussians" corresponds to 4 32-bit values. const textureWidth = Math.min(numGaussians * 2, maxTextureSize); const textureHeight = Math.ceil((numGaussians * 2) / textureWidth); const bufferPadded = new Uint32Array(textureWidth * textureHeight * 4); @@ -223,6 +357,24 @@ export function useGaussianMeshProps( textureT_camera_groups.internalFormat = "RGBA32F"; textureT_camera_groups.needsUpdate = true; + // Values taken from PR https://github.com/nerfstudio-project/viser/pull/286/files + // WIDTH AND HEIGHT ARE MEASURED IN TEXELS + // As 48 x float16 = 96 bytes and each texel is 4 uint32s = 16 bytes + // We can fit 6 spherical harmonics coefficients in a single texel + const shTextureWidth = Math.min(numGaussians * 6, maxTextureSize); + const shTextureHeight = Math.ceil((numGaussians * 6) / shTextureWidth); + const shBufferPadded = new Uint32Array(shTextureWidth * shTextureHeight * 4); + shBufferPadded.set(combinedSHBuffer); + const shTextureBuffer = new THREE.DataTexture( + shBufferPadded, + shTextureWidth, + shTextureHeight, + THREE.RGBAIntegerFormat, + THREE.UnsignedIntType, + ); + shTextureBuffer.internalFormat = "RGBA32UI"; + shTextureBuffer.needsUpdate = true; + const material = new GaussianSplatMaterial(); material.textureBuffer = textureBuffer; material.textureT_camera_groups = textureT_camera_groups; @@ -233,6 +385,7 @@ export function useGaussianMeshProps( geometry, material, textureBuffer, + shTextureBuffer, sortedIndexAttribute, textureT_camera_groups, rowMajorT_camera_groups, diff --git a/src/viser/client/src/WebsocketMessages.ts b/src/viser/client/src/WebsocketMessages.ts index 62272d01a..c81ee77be 100644 --- a/src/viser/client/src/WebsocketMessages.ts +++ b/src/viser/client/src/WebsocketMessages.ts @@ -379,7 +379,10 @@ export interface CubicBezierSplineMessage { export interface GaussianSplatsMessage { type: "GaussianSplatsMessage"; name: string; - props: { buffer: Uint8Array }; + props: { + buffer: Uint8Array; + sh_buffer: Uint8Array; + }; } /** Remove a particular node from the scene. *