From 51d1c6fbc0359a601e07f766b184621885764989 Mon Sep 17 00:00:00 2001 From: Christian Le Date: Wed, 12 Feb 2025 12:57:32 +0100 Subject: [PATCH 01/14] initial branch commit --- src/viser/_scene_api.py | 62 ++++++++++++++++++++++++++++++----------- 1 file changed, 45 insertions(+), 17 deletions(-) diff --git a/src/viser/_scene_api.py b/src/viser/_scene_api.py index ffdf660b6..1eee60965 100644 --- a/src/viser/_scene_api.py +++ b/src/viser/_scene_api.py @@ -1294,6 +1294,8 @@ def add_gaussian_splats( covariances: np.ndarray, rgbs: np.ndarray, opacities: np.ndarray, + normals: np.ndarray | None = None, + sh_coeffs: np.ndarray | None = None, wxyz: Tuple[float, float, float, float] | np.ndarray = (1.0, 0.0, 0.0, 0.0), position: Tuple[float, float, float] | np.ndarray = (0.0, 0.0, 0.0), visible: bool = True, @@ -1309,6 +1311,8 @@ def add_gaussian_splats( covariances: Second moment for each Gaussian. (N, 3, 3). rgbs: Color for each Gaussian. (N, 3). opacities: Opacity for each Gaussian. (N, 1). + normals: Normals for each Gaussian. (N, 3). + sh_coeffs: Spherical harmonics coefficients for each Gaussian. (N, 48). wxyz: R_parent_local transformation. position: t_parent_local transformation. visible: Initial visibility of scene node. @@ -1321,26 +1325,50 @@ def add_gaussian_splats( assert rgbs.shape == (num_gaussians, 3) assert opacities.shape == (num_gaussians, 1) assert covariances.shape == (num_gaussians, 3, 3) + assert normals is None or normals.shape == (num_gaussians, 3) + assert sh_coeffs is None or sh_coeffs.shape == (num_gaussians, 48) # Get upper-triangular terms of covariance matrix. cov_triu = covariances.reshape((-1, 9))[:, np.array([0, 1, 2, 4, 5, 8])] - buffer = np.concatenate( - [ - # First texelFetch. - # - xyz (96 bits): centers. - centers.astype(np.float32).view(np.uint8), - # - w (32 bits): this is reserved for use by the renderer. - np.zeros((num_gaussians, 4), dtype=np.uint8), - # Second texelFetch. - # - xyz (96 bits): upper-triangular terms of covariance. - cov_triu.astype(np.float16).copy().view(np.uint8), - # - w (32 bits): rgba. - colors_to_uint8(rgbs), - colors_to_uint8(opacities), - ], - axis=-1, - ).view(np.uint32) - assert buffer.shape == (num_gaussians, 8) + + if sh_coeffs and normals: + buffer = np.concatenate( + [ + # First texelFetch. + # - xyz (96 bits): centers. + centers.astype(np.float32).view(np.uint8), + # - w (32 bits): this is reserved for use by the renderer. + np.zeros((num_gaussians, 4), dtype=np.uint8), + # Second texelFetch. + # - xyz (96 bits): upper-triangular terms of covariance. + cov_triu.astype(np.float16).copy().view(np.uint8), + # - w (32 bits): rgba. + np.zeros((num_gaussians, 1), dtype=np.uint8), + colors_to_uint8(opacities), + # - w (56-bit padding). + + ], + axis=-1, + ).view(np.uint32) + assert buffer.shape == (num_gaussians, 8) # TODO: Change assertion criteria. + else: + buffer = np.concatenate( + [ + # First texelFetch. + # - xyz (96 bits): centers. + centers.astype(np.float32).view(np.uint8), + # - w (32 bits): this is reserved for use by the renderer. + np.zeros((num_gaussians, 4), dtype=np.uint8), + # Second texelFetch. + # - xyz (96 bits): upper-triangular terms of covariance. + cov_triu.astype(np.float16).copy().view(np.uint8), + # - w (32 bits): rgba. + colors_to_uint8(rgbs), + colors_to_uint8(opacities), + ], + axis=-1, + ).view(np.uint32) + assert buffer.shape == (num_gaussians, 8) message = _messages.GaussianSplatsMessage( name=name, From ed7156434305b0eb2b34d7d5e047bf58de63e5e1 Mon Sep 17 00:00:00 2001 From: Christian Le Date: Wed, 12 Feb 2025 13:04:10 +0100 Subject: [PATCH 02/14] format a bigger message for SH coeffs --- src/viser/_scene_api.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/viser/_scene_api.py b/src/viser/_scene_api.py index 1eee60965..32db83e87 100644 --- a/src/viser/_scene_api.py +++ b/src/viser/_scene_api.py @@ -1342,15 +1342,15 @@ def add_gaussian_splats( # Second texelFetch. # - xyz (96 bits): upper-triangular terms of covariance. cov_triu.astype(np.float16).copy().view(np.uint8), - # - w (32 bits): rgba. - np.zeros((num_gaussians, 1), dtype=np.uint8), + # - w (32 bits): normals + alphas. + colors_to_uint8(normals), colors_to_uint8(opacities), - # - w (56-bit padding). - + # - SH (768 bits). + sh_coeffs.astype(np.float16).copy().view(np.uint8) ], axis=-1, ).view(np.uint32) - assert buffer.shape == (num_gaussians, 8) # TODO: Change assertion criteria. + assert buffer.shape == (num_gaussians, 32) # TODO: Change assertion criteria. else: buffer = np.concatenate( [ From 680e78fd870027997f7aab802918e758cdc66fa1 Mon Sep 17 00:00:00 2001 From: Christian Le Date: Wed, 12 Feb 2025 15:00:09 +0100 Subject: [PATCH 03/14] add new gsplat message properties definitions --- src/viser/_messages.py | 2 ++ src/viser/client/src/SceneTree.tsx | 18 ++++++++++++++++++ src/viser/client/src/WebsocketMessages.ts | 2 +- 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/viser/_messages.py b/src/viser/_messages.py index db5a68722..8c34646bc 100644 --- a/src/viser/_messages.py +++ b/src/viser/_messages.py @@ -1317,6 +1317,8 @@ class GaussianSplatsProps: # Memory layout is borrowed from: # https://github.com/antimatter15/splat buffer: npt.NDArray[np.uint32] + sh_buffer: npt.NDArray[np.uint32] + norm_buffer: npt.NDArray[np.uint32] """Our buffer will contain: - x as f32 - y as f32 diff --git a/src/viser/client/src/SceneTree.tsx b/src/viser/client/src/SceneTree.tsx index 69a1d1008..e2c45261b 100644 --- a/src/viser/client/src/SceneTree.tsx +++ b/src/viser/client/src/SceneTree.tsx @@ -491,6 +491,24 @@ function useObjectFactory(message: SceneNodeMessage | undefined): { ), ) } + sh_buffer={ + new Uint32Array( + message.props.sh_buffer.buffer.slice( + message.props.sh_buffer.byteOffset, + message.props.sh_buffer.byteOffset + + message.props.sh_buffer.byteLength, + ), + ) + } + norm_buffer={ + new Uint32Array( + message.props.norm_buffer.buffer.slice( + message.props.norm_buffer.byteOffset, + message.props.norm_buffer.byteOffset + + message.props.norm_buffer.byteLength, + ), + ) + } /> ), }; diff --git a/src/viser/client/src/WebsocketMessages.ts b/src/viser/client/src/WebsocketMessages.ts index 6bb26f1d3..c89519b31 100644 --- a/src/viser/client/src/WebsocketMessages.ts +++ b/src/viser/client/src/WebsocketMessages.ts @@ -319,7 +319,7 @@ export interface CubicBezierSplineMessage { export interface GaussianSplatsMessage { type: "GaussianSplatsMessage"; name: string; - props: { buffer: Uint8Array }; + props: { buffer: Uint8Array, sh_buffer: Uint8Array, norm_buffer: Uint8Array }; } /** Remove a particular node from the scene. * From 6713f5f58f1c76bed3044b8ef2590bab52bb6e3a Mon Sep 17 00:00:00 2001 From: Christian Le Date: Wed, 12 Feb 2025 15:00:24 +0100 Subject: [PATCH 04/14] add new gsplat message properties definitions 2.0 --- src/viser/_scene_api.py | 64 ++++++++++++++++++++--------------------- 1 file changed, 31 insertions(+), 33 deletions(-) diff --git a/src/viser/_scene_api.py b/src/viser/_scene_api.py index 32db83e87..02ce7c3b3 100644 --- a/src/viser/_scene_api.py +++ b/src/viser/_scene_api.py @@ -1325,55 +1325,53 @@ def add_gaussian_splats( assert rgbs.shape == (num_gaussians, 3) assert opacities.shape == (num_gaussians, 1) assert covariances.shape == (num_gaussians, 3, 3) - assert normals is None or normals.shape == (num_gaussians, 3) - assert sh_coeffs is None or sh_coeffs.shape == (num_gaussians, 48) + # Get upper-triangular terms of covariance matrix. cov_triu = covariances.reshape((-1, 9))[:, np.array([0, 1, 2, 4, 5, 8])] - if sh_coeffs and normals: - buffer = np.concatenate( + buffer = np.concatenate( + [ + # First texelFetch. + # - xyz (96 bits): centers. + centers.astype(np.float32).view(np.uint8), + # - w (32 bits): this is reserved for use by the renderer. + np.zeros((num_gaussians, 4), dtype=np.uint8), + # Second texelFetch. + # - xyz (96 bits): upper-triangular terms of covariance. + cov_triu.astype(np.float16).copy().view(np.uint8), + # - w (32 bits): rgba. + colors_to_uint8(rgbs), + colors_to_uint8(opacities), + ], + axis=-1, + ).view(np.uint32) + assert buffer.shape == (num_gaussians, 8) + + if sh_coeffs is not None and normals is not None: + sh_buffer = np.concatenate( [ - # First texelFetch. - # - xyz (96 bits): centers. - centers.astype(np.float32).view(np.uint8), - # - w (32 bits): this is reserved for use by the renderer. - np.zeros((num_gaussians, 4), dtype=np.uint8), - # Second texelFetch. - # - xyz (96 bits): upper-triangular terms of covariance. - cov_triu.astype(np.float16).copy().view(np.uint8), - # - w (32 bits): normals + alphas. - colors_to_uint8(normals), - colors_to_uint8(opacities), - # - SH (768 bits). - sh_coeffs.astype(np.float16).copy().view(np.uint8) + sh_coeffs.astype(np.float16).copy().view(np.uint8) ], - axis=-1, ).view(np.uint32) - assert buffer.shape == (num_gaussians, 32) # TODO: Change assertion criteria. else: - buffer = np.concatenate( + sh_buffer = np.empty((0,), dtype=np.uint32) + + if normals is not None: + norm_buffer = np.concatenate( [ - # First texelFetch. - # - xyz (96 bits): centers. - centers.astype(np.float32).view(np.uint8), - # - w (32 bits): this is reserved for use by the renderer. - np.zeros((num_gaussians, 4), dtype=np.uint8), - # Second texelFetch. - # - xyz (96 bits): upper-triangular terms of covariance. - cov_triu.astype(np.float16).copy().view(np.uint8), - # - w (32 bits): rgba. - colors_to_uint8(rgbs), - colors_to_uint8(opacities), + normals.astype(np.float32).view(np.uint8) ], - axis=-1, ).view(np.uint32) - assert buffer.shape == (num_gaussians, 8) + else: + norm_buffer = np.empty((0,), dtype=np.uint32) message = _messages.GaussianSplatsMessage( name=name, props=_messages.GaussianSplatsProps( buffer=buffer, + sh_buffer=sh_buffer, + norm_buffer=norm_buffer, ), ) node_handle = GaussianSplatHandle._make( From 3be0cd3f3afd5a7145d38c441120dd19c3e93a9b Mon Sep 17 00:00:00 2001 From: Christian Le Date: Wed, 12 Feb 2025 15:01:27 +0100 Subject: [PATCH 05/14] partial progress SH and norm addition on client end --- .../client/src/Splatting/GaussianSplats.tsx | 2 + .../src/Splatting/GaussianSplatsHelpers.ts | 72 +++++++++++++++++++ 2 files changed, 74 insertions(+) diff --git a/src/viser/client/src/Splatting/GaussianSplats.tsx b/src/viser/client/src/Splatting/GaussianSplats.tsx index a5387d545..b6f7a22f9 100644 --- a/src/viser/client/src/Splatting/GaussianSplats.tsx +++ b/src/viser/client/src/Splatting/GaussianSplats.tsx @@ -61,6 +61,8 @@ export const SplatObject = React.forwardRef< THREE.Group, { buffer: Uint32Array; + sh_buffer: Uint32Array; + norm_buffer: Uint32Array; } >(function SplatObject({ buffer }, ref) { const splatContext = React.useContext(GaussianSplatsContext)!; diff --git a/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts b/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts index 81b6e315f..86dbba7b9 100644 --- a/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts +++ b/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts @@ -15,7 +15,9 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( depthTest: true, depthWrite: false, transparent: true, + sh_degree: 0, textureBuffer: null, + shTextureBuffer: null, textureT_camera_groups: null, transitionInState: 0.0, }, @@ -29,6 +31,12 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( // copy quadjr for this. uniform usampler2D textureBuffer; + // NEW ADDITION******************************************************* + // Buffer for spherical harmonics; Each Gaussian gets 24 int32s representing + // this information. + uniform usampler2D shTextureBuffer; + // END NEW ADDITION*************************************************** + // We could also use a uniform to store transforms, but this would be more // limiting in terms of the # of groups we can have. uniform sampler2D textureT_camera_groups; @@ -39,6 +47,7 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( uniform vec2 viewport; uniform float near; uniform float far; + uniform uint sh_degree; // Fade in state between [0, 1]. uniform float transitionInState; @@ -97,6 +106,37 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( vec2 triu23 = unpackHalf2x16(intBufferData.y); vec2 triu45 = unpackHalf2x16(intBufferData.z); + // NEW ADDITION******************************************************* + // Get spherical harmonics terms from int buffer. 48 coefficents per vertex. + uint shTexStart = sortedIndex * 6u + ivec2 shTexSize = textureSize(shTextureBuffer, 0); + float sh_coeffs_unpacked[48]; + for (int i = 0; i < 6; i++) { + ivec2 shTexPos = ivec2((shTexStart + uint(i)) % uint(shTexSize.x), + (shTexStart + uint(i)) / uint(shTexSize.x)); + uvec4 packedCoeffs = texelFetch(shTextureBuffer, shTexPos, 0); + + // unpack each uint32 directly into two float16 values, we read 4 at a time + vec2 unpacked; + unpacked = unpackHalf2x16(packedCoeffs.x); + sh_coeffs_unpacked[i*8] = unpacked.x; + sh_coeffs_unpacked[i*8+1] = unpacked.y; + + unpacked = unpackHalf2x16(packedCoeffs.y); + sh_coeffs_unpacked[i*8+2] = unpacked.x; + sh_coeffs_unpacked[i*8+3] = unpacked.y; + + unpacked = unpackHalf2x16(packedCoeffs.z); + sh_coeffs_unpacked[i*8+4] = unpacked.x; + sh_coeffs_unpacked[i*8+5] = unpacked.y; + + unpacked = unpackHalf2x16(packedCoeffs.w); + sh_coeffs_unpacked[i*8+6] = unpacked.x; + sh_coeffs_unpacked[i*8+7] = unpacked.y; + } + // TODO: CONTINUE, NOT YET FINISHED + // END NEW ADDITION*************************************************** + // Transition in. float startTime = 0.8 * float(sortedIndex) / float(numGaussians); float cov_scale = smoothstep(startTime, startTime + 0.2, transitionInState); @@ -212,6 +252,36 @@ export function useGaussianMeshProps( textureBuffer.internalFormat = "RGBA32UI"; textureBuffer.needsUpdate = true; + // NEW ADDITION******************************************************* + // Create texture buffers for spherical harmonics. + const shBufferPadded = new Uint32Array(textureWidth * textureHeight * 4); + shBufferPadded.set(shBuffer); + const shTextureBuffer = new THREE.DataTexture( + shBufferPadded, + textureWidth, + textureHeight, + THREE.RGBAIntegerFormat, + THREE.UnsignedIntType, + ); + shTextureBuffer.internalFormat = "RGBA32UI"; + shTextureBuffer.needsUpdate = true; + // END NEW ADDITION*************************************************** + + // NEW ADDITION******************************************************* + // Create texture buffers for normals. + const normBufferPadded = new Uint32Array(textureWidth * textureHeight * 4); + normBufferPadded.set(normBuffer); + const normTextureBuffer = new THREE.DataTexture( + normBufferPadded, + textureWidth, + textureHeight, + THREE.RGBAIntegerFormat, + THREE.UnsignedIntType, + ); + normTextureBuffer.internalFormat = "RGBA32UI"; + normTextureBuffer.needsUpdate = true; + // END NEW ADDITION*************************************************** + const rowMajorT_camera_groups = new Float32Array(numGroups * 12); const textureT_camera_groups = new THREE.DataTexture( rowMajorT_camera_groups, @@ -235,6 +305,8 @@ export function useGaussianMeshProps( geometry, material, textureBuffer, + shTextureBuffer, // NEW ADDITION******************************************************* + normTextureBuffer, // NEW ADDITION*************************************************** sortedIndexAttribute, textureT_camera_groups, rowMajorT_camera_groups, From cad878a4d0379314b220b5efa8ae41d5470172ba Mon Sep 17 00:00:00 2001 From: Christian Le Date: Wed, 12 Feb 2025 21:45:12 +0100 Subject: [PATCH 06/14] update progress on SH implementation --- src/viser/_scene_api.py | 7 +- .../client/src/Splatting/GaussianSplats.tsx | 62 ++++++++++++-- .../src/Splatting/GaussianSplatsHelpers.ts | 80 +++++++++++++------ src/viser/client/src/WebsocketMessages.ts | 6 +- 4 files changed, 123 insertions(+), 32 deletions(-) diff --git a/src/viser/_scene_api.py b/src/viser/_scene_api.py index 02ce7c3b3..74c99dc13 100644 --- a/src/viser/_scene_api.py +++ b/src/viser/_scene_api.py @@ -1325,7 +1325,6 @@ def add_gaussian_splats( assert rgbs.shape == (num_gaussians, 3) assert opacities.shape == (num_gaussians, 1) assert covariances.shape == (num_gaussians, 3, 3) - # Get upper-triangular terms of covariance matrix. cov_triu = covariances.reshape((-1, 9))[:, np.array([0, 1, 2, 4, 5, 8])] @@ -1349,22 +1348,24 @@ def add_gaussian_splats( assert buffer.shape == (num_gaussians, 8) if sh_coeffs is not None and normals is not None: + assert sh_coeffs.shape == (num_gaussians, 48) sh_buffer = np.concatenate( [ sh_coeffs.astype(np.float16).copy().view(np.uint8) ], ).view(np.uint32) else: - sh_buffer = np.empty((0,), dtype=np.uint32) + sh_buffer = np.zeros((num_gaussians, 48), dtype=np.uint32) if normals is not None: + assert normals.shape == (num_gaussians, 3) norm_buffer = np.concatenate( [ normals.astype(np.float32).view(np.uint8) ], ).view(np.uint32) else: - norm_buffer = np.empty((0,), dtype=np.uint32) + norm_buffer = np.zeros((num_gaussians, 3), dtype=np.uint32) message = _messages.GaussianSplatsMessage( name=name, diff --git a/src/viser/client/src/Splatting/GaussianSplats.tsx b/src/viser/client/src/Splatting/GaussianSplats.tsx index b6f7a22f9..ba45242d2 100644 --- a/src/viser/client/src/Splatting/GaussianSplats.tsx +++ b/src/viser/client/src/Splatting/GaussianSplats.tsx @@ -64,7 +64,7 @@ export const SplatObject = React.forwardRef< sh_buffer: Uint32Array; norm_buffer: Uint32Array; } ->(function SplatObject({ buffer }, ref) { +>(function SplatObject({ buffer, sh_buffer, norm_buffer }, ref) { const splatContext = React.useContext(GaussianSplatsContext)!; const setBuffer = splatContext.useGaussianSplatStore( (state) => state.setBuffer, @@ -76,11 +76,18 @@ export const SplatObject = React.forwardRef< (state) => state.nodeRefFromId, ); const name = React.useMemo(() => uuidv4(), [buffer]); + // Inspired from PR, maybe should look more similar to the original code above + const sh_buffer_name = "sh_buffer_" + name; + const norm_buffer_name = "norm_buffer_" + name; React.useEffect(() => { setBuffer(name, buffer); + setBuffer(sh_buffer_name, sh_buffer); + setBuffer(norm_buffer_name, norm_buffer); return () => { removeBuffer(name); + removeBuffer(sh_buffer_name); + removeBuffer(norm_buffer_name); delete nodeRefFromId.current[name]; }; }, [buffer]); @@ -131,6 +138,8 @@ function SplatRendererImpl() { const merged = mergeGaussianGroups(groupBufferFromId); const meshProps = useGaussianMeshProps( merged.gaussianBuffer, + merged.combinedSHBuffer, + merged.combinedNormBuffer, merged.numGroups, ); splatContext.meshPropsRef.current = meshProps; @@ -148,6 +157,8 @@ function SplatRendererImpl() { if (!initializedBufferTexture) { meshProps.material.uniforms.numGaussians.value = merged.numGaussians; meshProps.textureBuffer.needsUpdate = true; + meshProps.shTextureBuffer.needsUpdate = true; + meshProps.normTextureBuffer.needsUpdate = true; initializedBufferTexture = true; } }; @@ -163,6 +174,8 @@ function SplatRendererImpl() { React.useEffect(() => { return () => { meshProps.textureBuffer.dispose(); + meshProps.shTextureBuffer.dispose(); + meshProps.normTextureBuffer.dispose(); meshProps.geometry.dispose(); meshProps.material.dispose(); postToWorker({ close: true }); @@ -338,7 +351,13 @@ function mergeGaussianGroups(groupBufferFromName: { }) { // Create geometry. Each Gaussian will be rendered as a quad. let totalBufferLength = 0; - for (const buffer of Object.values(groupBufferFromName)) { + const groupBufferFromNameFiltered = Object.fromEntries( + Object.entries(groupBufferFromName).filter( + ([key]) => !key.startsWith("sh_buffer_") && !key.startsWith("norm_buffer_") + ) + ); + + for (const buffer of Object.values(groupBufferFromNameFiltered)) { totalBufferLength += buffer.length; } const numGaussians = totalBufferLength / 8; @@ -347,7 +366,7 @@ function mergeGaussianGroups(groupBufferFromName: { let offset = 0; for (const [groupIndex, groupBuffer] of Object.values( - groupBufferFromName, + groupBufferFromNameFiltered, ).entries()) { groupIndices.fill( groupIndex, @@ -368,6 +387,39 @@ function mergeGaussianGroups(groupBufferFromName: { offset += groupBuffer.length; } - const numGroups = Object.keys(groupBufferFromName).length; - return { numGaussians, gaussianBuffer, numGroups, groupIndices }; + let totalSHBufferLength = 0; + + const shGaussianBuffers = Object.fromEntries( + Object.entries(groupBufferFromName).filter(([key]) => key.startsWith("sh_buffer_")) + ); + + for (const sh_buffer of Object.values(shGaussianBuffers)) { + totalSHBufferLength += sh_buffer.length; + } + + const combinedSHBuffer = new Uint32Array(totalSHBufferLength); + let sh_offset = 0; + for (const sh_buffer of Object.values(shGaussianBuffers)) { + combinedSHBuffer.set(sh_buffer, sh_offset); + sh_offset += sh_buffer.length; + } + + const normGaussianBuffers = Object.fromEntries( + Object.entries(groupBufferFromName).filter(([key]) => key.startsWith("norm_buffer_")) + ); + + for (const norm_buffer of Object.values(normGaussianBuffers)) { + totalSHBufferLength += norm_buffer.length; + } + + const combinedNormBuffer = new Uint32Array(totalSHBufferLength); + let norm_offset = 0; + for (const norm_buffer of Object.values(normGaussianBuffers)) { + combinedNormBuffer.set(norm_buffer, norm_offset); + norm_offset += norm_buffer.length; + } + + const numGroups = Object.keys(groupBufferFromNameFiltered).length; + + return { numGaussians, gaussianBuffer, numGroups, groupIndices, combinedSHBuffer, combinedNormBuffer }; } diff --git a/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts b/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts index 86dbba7b9..6b7d7191c 100644 --- a/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts +++ b/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts @@ -15,9 +15,9 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( depthTest: true, depthWrite: false, transparent: true, - sh_degree: 0, textureBuffer: null, shTextureBuffer: null, + normTextureBuffer: null, textureT_camera_groups: null, transitionInState: 0.0, }, @@ -170,6 +170,33 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( vec2 v1 = min(sqrt(2.0 * lambda1), 1024.0) * diagonalVector; vec2 v2 = min(sqrt(2.0 * lambda2), 1024.0) * vec2(diagonalVector.y, -diagonalVector.x); + // NEW ADDITION******************************************************* + // Calculate the spherical harmonics. + vec3 viewDir = normalize(center - cameraPosition); + const float C0 = 0.28209479177387814; + const float C1 = 0.4886025119029199; + const float C2[5] = float[5]( + 1.0925484305920792, + -1.0925484305920792, + 0.31539156525252005, + -1.0925484305920792, + 0.5462742152960396 + ); + const float C3[7] = float[7]( + -0.5900435899266435, + 2.890611442640554, + -0.4570457994644658, + 0.3731763325901154, + -0.4570457994644658, + 1.445305721320277, + -0.5900435899266435 + ); + + vec3 sh_coeffs[16]; + for (int i = 0; i < 16; i++) { + sh_coeffs[i] = vec3(0.0); + } + vRgba = vec4( float(rgbaUint32 & uint(0xFF)) / 255.0, float((rgbaUint32 >> uint(8)) & uint(0xFF)) / 255.0, @@ -209,6 +236,8 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( /**Hook to generate properties for rendering Gaussians via a three.js mesh.*/ export function useGaussianMeshProps( gaussianBuffer: Uint32Array, + combinedSHBuffer: Uint32Array, // NEW ADDITION******************************************************* + combinedNormBuffer: Uint32Array, // NEW ADDITION*************************************************** numGroups: number, ) { const numGaussians = gaussianBuffer.length / 8; @@ -252,50 +281,55 @@ export function useGaussianMeshProps( textureBuffer.internalFormat = "RGBA32UI"; textureBuffer.needsUpdate = true; + const rowMajorT_camera_groups = new Float32Array(numGroups * 12); + const textureT_camera_groups = new THREE.DataTexture( + rowMajorT_camera_groups, + (numGroups * 12) / 4, + 1, + THREE.RGBAFormat, + THREE.FloatType, + ); + textureT_camera_groups.internalFormat = "RGBA32F"; + textureT_camera_groups.needsUpdate = true; + // NEW ADDITION******************************************************* - // Create texture buffers for spherical harmonics. - const shBufferPadded = new Uint32Array(textureWidth * textureHeight * 4); - shBufferPadded.set(shBuffer); + // Values taken from PR https://github.com/nerfstudio-project/viser/pull/286/files + const shTextureWidth = Math.min(numGaussians * 6, maxTextureSize); + const shTextureHeight = Math.ceil((numGaussians * 6) / shTextureWidth); + const shBufferPadded = new Uint32Array(shTextureWidth * shTextureHeight * 4); + shBufferPadded.set(combinedSHBuffer); const shTextureBuffer = new THREE.DataTexture( shBufferPadded, - textureWidth, - textureHeight, + shTextureWidth, + shTextureHeight, THREE.RGBAIntegerFormat, THREE.UnsignedIntType, ); shTextureBuffer.internalFormat = "RGBA32UI"; shTextureBuffer.needsUpdate = true; - // END NEW ADDITION*************************************************** - // NEW ADDITION******************************************************* - // Create texture buffers for normals. - const normBufferPadded = new Uint32Array(textureWidth * textureHeight * 4); - normBufferPadded.set(normBuffer); + const normTexturwWidth = Math.min(numGaussians * 6, maxTextureSize); + const normTextureHeight = Math.ceil((numGaussians * 6) / normTexturwWidth); + const normBufferPadded = new Uint32Array(normTexturwWidth * normTextureHeight * 4); + normBufferPadded.set(combinedNormBuffer); const normTextureBuffer = new THREE.DataTexture( normBufferPadded, - textureWidth, - textureHeight, + normTexturwWidth, + normTextureHeight, THREE.RGBAIntegerFormat, THREE.UnsignedIntType, ); normTextureBuffer.internalFormat = "RGBA32UI"; normTextureBuffer.needsUpdate = true; + // END NEW ADDITION*************************************************** - const rowMajorT_camera_groups = new Float32Array(numGroups * 12); - const textureT_camera_groups = new THREE.DataTexture( - rowMajorT_camera_groups, - (numGroups * 12) / 4, - 1, - THREE.RGBAFormat, - THREE.FloatType, - ); - textureT_camera_groups.internalFormat = "RGBA32F"; - textureT_camera_groups.needsUpdate = true; const material = new GaussianSplatMaterial({ // @ts-ignore textureBuffer: textureBuffer, + shTextureBuffer: shTextureBuffer, // NEW ADDITION******************************************************* + normTextureBuffer: normTextureBuffer, // NEW ADDITION*************************************************** textureT_camera_groups: textureT_camera_groups, numGaussians: 0, transitionInState: 0.0, diff --git a/src/viser/client/src/WebsocketMessages.ts b/src/viser/client/src/WebsocketMessages.ts index c89519b31..28b96d0b5 100644 --- a/src/viser/client/src/WebsocketMessages.ts +++ b/src/viser/client/src/WebsocketMessages.ts @@ -319,7 +319,11 @@ export interface CubicBezierSplineMessage { export interface GaussianSplatsMessage { type: "GaussianSplatsMessage"; name: string; - props: { buffer: Uint8Array, sh_buffer: Uint8Array, norm_buffer: Uint8Array }; + props: { + buffer: Uint8Array; + sh_buffer: Uint8Array; + norm_buffer: Uint8Array; + }; } /** Remove a particular node from the scene. * From 4f35d207f1d87cbd1104637b3e988f99f425ecf4 Mon Sep 17 00:00:00 2001 From: Christian Le Date: Thu, 13 Feb 2025 01:36:59 +0100 Subject: [PATCH 07/14] add temporary SH rendering equations --- .../src/Splatting/GaussianSplatsHelpers.ts | 60 ++++++++++++++++--- 1 file changed, 52 insertions(+), 8 deletions(-) diff --git a/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts b/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts index 6b7d7191c..367e65e67 100644 --- a/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts +++ b/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts @@ -194,15 +194,59 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( vec3 sh_coeffs[16]; for (int i = 0; i < 16; i++) { - sh_coeffs[i] = vec3(0.0); + sh_coeffs[i] = vec3(sh_coeffs_unpacked[i*3], sh_coeffs_unpacked[i*3+1], sh_coeffs_unpacked[i*3+2]); } - - vRgba = vec4( - float(rgbaUint32 & uint(0xFF)) / 255.0, - float((rgbaUint32 >> uint(8)) & uint(0xFF)) / 255.0, - float((rgbaUint32 >> uint(16)) & uint(0xFF)) / 255.0, - float(rgbaUint32 >> uint(24)) / 255.0 - ); + + float x = viewDir.x; + float y = viewDir.y; + float z = viewDir.z; + float xx = viewDir.x * viewDir.x; + float yy = viewDir.y * viewDir.y; + float zz = viewDir.z * viewDir.z; + float xy = viewDir.x * viewDir.y; + float yz = viewDir.y * viewDir.z; + float xz = viewDir.x * viewDir.z; + + // 0th degree, standard RGB + vec3 rgb = C0 * sh_coeffs[0]; + vec3 pointFive = vec3(0.5, 0.5, 0.5); + + // This is taken from gsplat + // 1st degree + rgb = rgb + C1 * (-y * sh_coeffs[1] + z * sh_coeffs[2] - x * sh_coeffs[3]); + + // 2nd degree + float fTmp0B = -1.0925484305920792 * z; // Reuse the constants + float fC1 = xx - yy; + float fS1 = 2.0 * xy; + float pSH5 = fTmp0B * y; + float pSH6 = (0.9461746957575601 * zz - 0.3153915652525201); + float pSH7 = fTmp0B * x; + float pSH8 = 0.5462742152960395 * fC1; + + rgb = rgb + pSH4 * sh_coeffs[4] + pSH5 * sh_coeffs[5] + pSH6 * sh_coeffs[6] + pSH7 * sh_coeffs[7] + pSH8 * sh_coeffs[8]; + + // 3rd degree + float fTmp0C = -2.285228997322329f * zz + 0.4570457994644658f; + float fTmp1B = 1.445305721320277 * z; + float fC2 = x * fC1 - y * fS1; + float fS2 = x * fS1 + y * fC1; + float pSH12 = z * (1.865881662950577 * zz - 1.119528997770346f); + float pSH13 = fTmp0C * x; + float pSH11 = fTmp0C * y; + float pSH14 = fTmp1B * fC1; + float pSH10 = fTmp1B * fS1; + float pSH15 = -0.5900435899266435 * fC2; + float pSH9 = -0.5900435899266435 * fS2; + + rgb = rgb + pSH9 * sh_coeffs[9] + pSH10 * sh_coeffs[10] + + pSH11 * sh_coeffs[11] + pSH12 * sh_coeffs[12] + + pSH13 * sh_coeffs[13] + pSH14 * sh_coeffs[14]+ + pSH15 * sh_coeffs[15]; + + vRgba = vec4(rgb + pointFive, float(rgbaUint32 >> 24) / 255.0); + // INRIA IMPLEMENTATION NOT INCLUDED + // END NEW ADDITION*************************************************** // Throw the Gaussian off the screen if it's too close, too far, or too small. float weightedDeterminant = vRgba.a * (diag1 * diag2 - offDiag * offDiag); From 7babd950ed670b26688ae85a233f5d11a0aaba44 Mon Sep 17 00:00:00 2001 From: Christian Le Date: Thu, 13 Feb 2025 12:41:41 +0100 Subject: [PATCH 08/14] minor tweaks --- .../client/src/Splatting/GaussianSplats.tsx | 14 +++-- .../src/Splatting/GaussianSplatsHelpers.ts | 62 ++++++++++++------- 2 files changed, 48 insertions(+), 28 deletions(-) diff --git a/src/viser/client/src/Splatting/GaussianSplats.tsx b/src/viser/client/src/Splatting/GaussianSplats.tsx index ba45242d2..f54f78afa 100644 --- a/src/viser/client/src/Splatting/GaussianSplats.tsx +++ b/src/viser/client/src/Splatting/GaussianSplats.tsx @@ -77,8 +77,8 @@ export const SplatObject = React.forwardRef< ); const name = React.useMemo(() => uuidv4(), [buffer]); // Inspired from PR, maybe should look more similar to the original code above - const sh_buffer_name = "sh_buffer_" + name; - const norm_buffer_name = "norm_buffer_" + name; + const sh_buffer_name = `sh_buffer_${name}`; + const norm_buffer_name = `norm_buffer_${name}`; React.useEffect(() => { setBuffer(name, buffer); @@ -357,7 +357,9 @@ function mergeGaussianGroups(groupBufferFromName: { ) ); - for (const buffer of Object.values(groupBufferFromNameFiltered)) { + // Temporarily using groupBufferFromName instead of groupBufferFromNameFiltered + // because it causes an error in the shader + for (const buffer of Object.values(groupBufferFromName)) { totalBufferLength += buffer.length; } const numGaussians = totalBufferLength / 8; @@ -404,15 +406,17 @@ function mergeGaussianGroups(groupBufferFromName: { sh_offset += sh_buffer.length; } + let totalNormBufferLength = 0; + const normGaussianBuffers = Object.fromEntries( Object.entries(groupBufferFromName).filter(([key]) => key.startsWith("norm_buffer_")) ); for (const norm_buffer of Object.values(normGaussianBuffers)) { - totalSHBufferLength += norm_buffer.length; + totalNormBufferLength += norm_buffer.length; } - const combinedNormBuffer = new Uint32Array(totalSHBufferLength); + const combinedNormBuffer = new Uint32Array(totalNormBufferLength); let norm_offset = 0; for (const norm_buffer of Object.values(normGaussianBuffers)) { combinedNormBuffer.set(norm_buffer, norm_offset); diff --git a/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts b/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts index 367e65e67..b4af533a0 100644 --- a/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts +++ b/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts @@ -108,12 +108,11 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( // NEW ADDITION******************************************************* // Get spherical harmonics terms from int buffer. 48 coefficents per vertex. - uint shTexStart = sortedIndex * 6u + uint shTexStart = sortedIndex * 6u; ivec2 shTexSize = textureSize(shTextureBuffer, 0); float sh_coeffs_unpacked[48]; for (int i = 0; i < 6; i++) { - ivec2 shTexPos = ivec2((shTexStart + uint(i)) % uint(shTexSize.x), - (shTexStart + uint(i)) / uint(shTexSize.x)); + ivec2 shTexPos = ivec2((shTexStart + uint(i)) % uint(shTexSize.x), (shTexStart + uint(i)) / uint(shTexSize.x)); uvec4 packedCoeffs = texelFetch(shTextureBuffer, shTexPos, 0); // unpack each uint32 directly into two float16 values, we read 4 at a time @@ -134,19 +133,22 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( sh_coeffs_unpacked[i*8+6] = unpacked.x; sh_coeffs_unpacked[i*8+7] = unpacked.y; } - // TODO: CONTINUE, NOT YET FINISHED // END NEW ADDITION*************************************************** // Transition in. float startTime = 0.8 * float(sortedIndex) / float(numGaussians); float cov_scale = smoothstep(startTime, startTime + 0.2, transitionInState); + // NEW ADDITION******************************************************* // Do the actual splatting. - mat3 cov3d = mat3( + mat3 triu = mat3( triu01.x, triu01.y, triu23.x, - triu01.y, triu23.y, triu45.x, - triu23.x, triu45.x, triu45.y + 0., triu23.y, triu45.x, + 0., 0., triu45.y ); + mat3 cov3d = triu * transpose(triu) * cov_scale; + // END NEW ADDITION*************************************************** + mat3 J = mat3( // Matrices are column-major. focal.x / c_cam.z, 0., 0.0, @@ -196,7 +198,6 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( for (int i = 0; i < 16; i++) { sh_coeffs[i] = vec3(sh_coeffs_unpacked[i*3], sh_coeffs_unpacked[i*3+1], sh_coeffs_unpacked[i*3+2]); } - float x = viewDir.x; float y = viewDir.y; float z = viewDir.z; @@ -206,26 +207,41 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( float xy = viewDir.x * viewDir.y; float yz = viewDir.y * viewDir.z; float xz = viewDir.x * viewDir.z; - - // 0th degree, standard RGB - vec3 rgb = C0 * sh_coeffs[0]; + + // 0th degree + vec3 rgb = C0 * sh_coeffs[0]; // line 74 of plenoxels vec3 pointFive = vec3(0.5, 0.5, 0.5); - // This is taken from gsplat - // 1st degree - rgb = rgb + C1 * (-y * sh_coeffs[1] + z * sh_coeffs[2] - x * sh_coeffs[3]); + //vRgba = vec4(rgb + pointFive, float(rgbaUint32 >> uint(24)) / 255.0); + + // could be useful code for debugging (make sure to set 1 gaussian to use this) + // if (sh_coeffs[0].x > 1.0 && sh_coeffs[0].x < 1.6) { + // vRgba = vec4(1.0, 0.0, 0.0, 1.0); + // } else { + // vRgba = vec4(1.0, 1.0, 1.0, 1.0); + // } + + + // ----GSPLAT IMPLEMENTATION----- + // 1st degree + rgb = rgb + C1 * (-y * sh_coeffs[1] + + z * sh_coeffs[2] - + x * sh_coeffs[3]); // 2nd degree - float fTmp0B = -1.0925484305920792 * z; // Reuse the constants + + float fTmp0B = -1.092548430592079 * z; float fC1 = xx - yy; float fS1 = 2.0 * xy; - float pSH5 = fTmp0B * y; float pSH6 = (0.9461746957575601 * zz - 0.3153915652525201); float pSH7 = fTmp0B * x; + float pSH5 = fTmp0B * y; float pSH8 = 0.5462742152960395 * fC1; - - rgb = rgb + pSH4 * sh_coeffs[4] + pSH5 * sh_coeffs[5] + pSH6 * sh_coeffs[6] + pSH7 * sh_coeffs[7] + pSH8 * sh_coeffs[8]; - + float pSH4 = 0.5462742152960395 * fS1; + rgb = rgb + pSH4 * sh_coeffs[4] + pSH5 * sh_coeffs[5] + + pSH6 * sh_coeffs[6] + pSH7 * sh_coeffs[7] + + pSH8 * sh_coeffs[8]; + // 3rd degree float fTmp0C = -2.285228997322329f * zz + 0.4570457994644658f; float fTmp1B = 1.445305721320277 * z; @@ -243,14 +259,14 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( pSH11 * sh_coeffs[11] + pSH12 * sh_coeffs[12] + pSH13 * sh_coeffs[13] + pSH14 * sh_coeffs[14]+ pSH15 * sh_coeffs[15]; + + vRgba = vec4(rgb + pointFive, float(rgbaUint32 >> uint(24)) / 255.0); - vRgba = vec4(rgb + pointFive, float(rgbaUint32 >> 24) / 255.0); - // INRIA IMPLEMENTATION NOT INCLUDED - // END NEW ADDITION*************************************************** + // ----END GSPLAT IMPLEMENTATION----- // Throw the Gaussian off the screen if it's too close, too far, or too small. float weightedDeterminant = vRgba.a * (diag1 * diag2 - offDiag * offDiag); - if (weightedDeterminant < 0.25) + if (weightedDeterminant < 0.5) return; vPosition = position.xy; From a6f43f08f08904745cb199d7483f42c478241285 Mon Sep 17 00:00:00 2001 From: Christian Le Date: Thu, 13 Feb 2025 16:41:24 +0100 Subject: [PATCH 09/14] add SH functioning, but not correct --- src/viser/_scene_api.py | 12 ++++-------- src/viser/client/src/Splatting/GaussianSplats.tsx | 6 ++++-- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/viser/_scene_api.py b/src/viser/_scene_api.py index 74c99dc13..60ae66909 100644 --- a/src/viser/_scene_api.py +++ b/src/viser/_scene_api.py @@ -1347,7 +1347,7 @@ def add_gaussian_splats( ).view(np.uint32) assert buffer.shape == (num_gaussians, 8) - if sh_coeffs is not None and normals is not None: + if sh_coeffs is not None: assert sh_coeffs.shape == (num_gaussians, 48) sh_buffer = np.concatenate( [ @@ -1355,17 +1355,13 @@ def add_gaussian_splats( ], ).view(np.uint32) else: - sh_buffer = np.zeros((num_gaussians, 48), dtype=np.uint32) + sh_buffer = np.zeros((num_gaussians, 48), dtype=np.float16).view(np.uint32) if normals is not None: assert normals.shape == (num_gaussians, 3) - norm_buffer = np.concatenate( - [ - normals.astype(np.float32).view(np.uint8) - ], - ).view(np.uint32) + norm_buffer = normals.astype(np.float32).view(np.uint8).view(np.uint32) else: - norm_buffer = np.zeros((num_gaussians, 3), dtype=np.uint32) + norm_buffer = np.zeros((num_gaussians, 3), dtype=np.float32).view(np.uint8).view(np.uint32) message = _messages.GaussianSplatsMessage( name=name, diff --git a/src/viser/client/src/Splatting/GaussianSplats.tsx b/src/viser/client/src/Splatting/GaussianSplats.tsx index f54f78afa..90abdeb83 100644 --- a/src/viser/client/src/Splatting/GaussianSplats.tsx +++ b/src/viser/client/src/Splatting/GaussianSplats.tsx @@ -353,13 +353,15 @@ function mergeGaussianGroups(groupBufferFromName: { let totalBufferLength = 0; const groupBufferFromNameFiltered = Object.fromEntries( Object.entries(groupBufferFromName).filter( - ([key]) => !key.startsWith("sh_buffer_") && !key.startsWith("norm_buffer_") + ([key]) => !key.startsWith("sh_buffer_") ) ); // Temporarily using groupBufferFromName instead of groupBufferFromNameFiltered // because it causes an error in the shader - for (const buffer of Object.values(groupBufferFromName)) { + console.log(groupBufferFromName); + console.log(groupBufferFromNameFiltered); + for (const buffer of Object.values(groupBufferFromNameFiltered)) { totalBufferLength += buffer.length; } const numGaussians = totalBufferLength / 8; From 7dc54338ec0b1017656858dda53dee029d4f9dcb Mon Sep 17 00:00:00 2001 From: Christian Le Date: Thu, 13 Feb 2025 16:41:43 +0100 Subject: [PATCH 10/14] change example script to include SH and normals --- examples/experimental/gaussian_splats.py | 61 +++++------------------- 1 file changed, 13 insertions(+), 48 deletions(-) diff --git a/examples/experimental/gaussian_splats.py b/examples/experimental/gaussian_splats.py index 0fff1085a..e5f12d254 100644 --- a/examples/experimental/gaussian_splats.py +++ b/examples/experimental/gaussian_splats.py @@ -26,50 +26,10 @@ class SplatFile(TypedDict): """(N, 1). Range [0, 1].""" covariances: npt.NDArray[np.floating] """(N, 3, 3).""" - - -def load_splat_file(splat_path: Path, center: bool = False) -> SplatFile: - """Load an antimatter15-style splat file.""" - start_time = time.time() - splat_buffer = splat_path.read_bytes() - bytes_per_gaussian = ( - # Each Gaussian is serialized as: - # - position (vec3, float32) - 3 * 4 - # - xyz (vec3, float32) - + 3 * 4 - # - rgba (vec4, uint8) - + 4 - # - ijkl (vec4, uint8), where 0 => -1, 255 => 1. - + 4 - ) - assert len(splat_buffer) % bytes_per_gaussian == 0 - num_gaussians = len(splat_buffer) // bytes_per_gaussian - - # Reinterpret cast to dtypes that we want to extract. - splat_uint8 = np.frombuffer(splat_buffer, dtype=np.uint8).reshape( - (num_gaussians, bytes_per_gaussian) - ) - scales = splat_uint8[:, 12:24].copy().view(np.float32) - wxyzs = splat_uint8[:, 28:32] / 255.0 * 2.0 - 1.0 - Rs = tf.SO3(wxyzs).as_matrix() - covariances = np.einsum( - "nij,njk,nlk->nil", Rs, np.eye(3)[None, :, :] * scales[:, None, :] ** 2, Rs - ) - centers = splat_uint8[:, 0:12].copy().view(np.float32) - if center: - centers -= np.mean(centers, axis=0, keepdims=True) - print( - f"Splat file with {num_gaussians=} loaded in {time.time() - start_time} seconds" - ) - return { - "centers": centers, - # Colors should have shape (N, 3). - "rgbs": splat_uint8[:, 24:27] / 255.0, - "opacities": splat_uint8[:, 27:28] / 255.0, - # Covariances should have shape (N, 3, 3). - "covariances": covariances, - } + sh_coeffs: npt.NDArray[np.floating] + """(N, 48).""" + normals: npt.NDArray[np.floating] + """(N, 3).""" def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile: @@ -85,6 +45,10 @@ def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile: wxyzs = np.stack([v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]], axis=1) colors = 0.5 + SH_C0 * np.stack([v["f_dc_0"], v["f_dc_1"], v["f_dc_2"]], axis=1) opacities = 1.0 / (1.0 + np.exp(-v["opacity"][:, None])) + dc_coeffs = np.stack([v["f_dc_0"], v["f_dc_1"], v["f_dc_2"]], axis=1) + rest_coeffs = np.stack([v[f"f_rest_{i}"] for i in range(45)], axis=1) + sh_coeffs = np.concatenate([dc_coeffs, rest_coeffs], axis=1) + normals = np.stack([v["nx"], v["ny"], v["nz"]], axis=-1) Rs = tf.SO3(wxyzs).as_matrix() covariances = np.einsum( @@ -102,11 +66,13 @@ def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile: "rgbs": colors, "opacities": opacities, "covariances": covariances, + "sh_coeffs": sh_coeffs, + "normals": normals, } def main(splat_paths: tuple[Path, ...]) -> None: - server = viser.ViserServer() + server = viser.ViserServer(port=8001) server.gui.configure_theme(dark_mode=True) gui_reset_up = server.gui.add_button( "Reset up direction", @@ -122,9 +88,7 @@ def _(event: viser.GuiEvent) -> None: ) for i, splat_path in enumerate(splat_paths): - if splat_path.suffix == ".splat": - splat_data = load_splat_file(splat_path, center=True) - elif splat_path.suffix == ".ply": + if splat_path.suffix == ".ply": splat_data = load_ply_file(splat_path, center=True) else: raise SystemExit("Please provide a filepath to a .splat or .ply file.") @@ -136,6 +100,7 @@ def _(event: viser.GuiEvent) -> None: rgbs=splat_data["rgbs"], opacities=splat_data["opacities"], covariances=splat_data["covariances"], + sh_coeffs=splat_data["sh_coeffs"], ) remove_button = server.gui.add_button(f"Remove splat object {i}") From 2751dac33b3b395f0adfa411a01f9d0d0e003260 Mon Sep 17 00:00:00 2001 From: Christian Le Date: Mon, 17 Feb 2025 11:39:45 +0100 Subject: [PATCH 11/14] make SH code more readable according to theory --- .../client/src/Splatting/GaussianSplats.tsx | 4 +- .../src/Splatting/GaussianSplatsHelpers.ts | 167 ++++++++++-------- 2 files changed, 99 insertions(+), 72 deletions(-) diff --git a/src/viser/client/src/Splatting/GaussianSplats.tsx b/src/viser/client/src/Splatting/GaussianSplats.tsx index 90abdeb83..ad389e7ce 100644 --- a/src/viser/client/src/Splatting/GaussianSplats.tsx +++ b/src/viser/client/src/Splatting/GaussianSplats.tsx @@ -353,14 +353,12 @@ function mergeGaussianGroups(groupBufferFromName: { let totalBufferLength = 0; const groupBufferFromNameFiltered = Object.fromEntries( Object.entries(groupBufferFromName).filter( - ([key]) => !key.startsWith("sh_buffer_") + ([key]) => !key.startsWith("sh_buffer_") && !key.startsWith("norm_buffer_") ) ); // Temporarily using groupBufferFromName instead of groupBufferFromNameFiltered // because it causes an error in the shader - console.log(groupBufferFromName); - console.log(groupBufferFromNameFiltered); for (const buffer of Object.values(groupBufferFromNameFiltered)) { totalBufferLength += buffer.length; } diff --git a/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts b/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts index b4af533a0..7ed7fdb34 100644 --- a/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts +++ b/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts @@ -47,7 +47,6 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( uniform vec2 viewport; uniform float near; uniform float far; - uniform uint sh_degree; // Fade in state between [0, 1]. uniform float transitionInState; @@ -139,15 +138,12 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( float startTime = 0.8 * float(sortedIndex) / float(numGaussians); float cov_scale = smoothstep(startTime, startTime + 0.2, transitionInState); - // NEW ADDITION******************************************************* // Do the actual splatting. - mat3 triu = mat3( + mat3 cov3d = mat3( triu01.x, triu01.y, triu23.x, - 0., triu23.y, triu45.x, - 0., 0., triu45.y + triu01.y, triu23.y, triu45.x, + triu23.x, triu45.x, triu45.y ); - mat3 cov3d = triu * transpose(triu) * cov_scale; - // END NEW ADDITION*************************************************** mat3 J = mat3( // Matrices are column-major. @@ -172,32 +168,61 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( vec2 v1 = min(sqrt(2.0 * lambda1), 1024.0) * diagonalVector; vec2 v2 = min(sqrt(2.0 * lambda2), 1024.0) * vec2(diagonalVector.y, -diagonalVector.x); - // NEW ADDITION******************************************************* + // NEW ADDITION: SPHERICAL HARMONICS IMPLEMENTATION----- + // Calculate the spherical harmonics. vec3 viewDir = normalize(center - cameraPosition); - const float C0 = 0.28209479177387814; - const float C1 = 0.4886025119029199; + // // 0.5 * sqrt(1.0 / pi) + const float C0 = 0.28209479177387814; + // sqrt(3.0 / (4.0 * pi)) * [y] + // sqrt(3.0 / (4.0 * pi)) * [z] + // sqrt(3.0 / (4.0 * pi)) * [x] + const float C1[3] = float[3]( + -0.4886025119029199, + 0.4886025119029199, + -0.4886025119029199 + ); + // 0.5 * sqrt(15/pi) * [x*y] + // 0.5 * sqrt(15/pi) * [y*z] + // 0.25 * sqrt(5/pi) * [3z^2 - 1] + // 0.5 * sqrt(15/pi) * [z*x] + // 0.25 * sqrt(15/pi) * [x^2 - y^2] const float C2[5] = float[5]( - 1.0925484305920792, - -1.0925484305920792, - 0.31539156525252005, - -1.0925484305920792, - 0.5462742152960396 + 1.0925484305920792, + -1.0925484305920792, + 0.31539156525252005, + -1.0925484305920792, + 0.5462742152960396 ); + // 0.25 * sqrt(35/(2pi)) * [y*(3x^2-y^2)] + // 0.5 * sqrt(105/pi) * [x*y*z] + // 0.25 * sqrt(21/(2pi)) * [y*(5z^2-1)] + // 0.25 * sqrt(7/pi) * [z*(5z^2-3)] + // 0.25 * sqrt(21/(2pi)) * [x*(5z^2-1)] + // 0.25 * sqrt(105/(pi)) * [(x^2-y^2)*z] + // 0.25 * sqrt(35/(2pi)) * [x*(x^2-3y^2)] const float C3[7] = float[7]( - -0.5900435899266435, - 2.890611442640554, - -0.4570457994644658, - 0.3731763325901154, - -0.4570457994644658, - 1.445305721320277, - -0.5900435899266435 + -0.5900435899266435, + 2.890611442640554, + -0.4570457994644658, + 0.3731763325901154, + -0.4570457994644658, + 1.445305721320277, + -0.5900435899266435 ); vec3 sh_coeffs[16]; for (int i = 0; i < 16; i++) { sh_coeffs[i] = vec3(sh_coeffs_unpacked[i*3], sh_coeffs_unpacked[i*3+1], sh_coeffs_unpacked[i*3+2]); } + + // View-dependent variables + // Along with SH coefficients, we represent view-dependent colors + // From Wikipedia definition: + // x = sin(theta) * cos(phi) + // y = sin(theta) * sin(phi) + // z = cos(theta) + float x = viewDir.x; float y = viewDir.y; float z = viewDir.z; @@ -212,57 +237,59 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( vec3 rgb = C0 * sh_coeffs[0]; // line 74 of plenoxels vec3 pointFive = vec3(0.5, 0.5, 0.5); - //vRgba = vec4(rgb + pointFive, float(rgbaUint32 >> uint(24)) / 255.0); - - // could be useful code for debugging (make sure to set 1 gaussian to use this) - // if (sh_coeffs[0].x > 1.0 && sh_coeffs[0].x < 1.6) { - // vRgba = vec4(1.0, 0.0, 0.0, 1.0); - // } else { - // vRgba = vec4(1.0, 1.0, 1.0, 1.0); - // } + // 1st degree - // ----GSPLAT IMPLEMENTATION----- + // float pSH1 = C1[0] * y; + // float pSH2 = C1[1] * z; + // float pSH3 = C1[2] * x; + + // rgb = rgb + pSH1 * sh_coeffs[1] + + // pSH2 * sh_coeffs[2] + + // pSH3 * sh_coeffs[3]; + + // // 2nd degree + + // float pSH4 = C2[0] * xy; + // float pSH5 = C2[1] * yz; + // float pSH6 = C2[2] * (3.0 * zz - 1.0); + // float pSH7 = C2[3] * xz; + // float pSH8 = C2[4] * (xx - yy); + + // rgb = rgb + pSH4 * sh_coeffs[4] + + // pSH5 * sh_coeffs[5] + + // pSH6 * sh_coeffs[6] + + // pSH7 * sh_coeffs[7] + + // pSH8 * sh_coeffs[8]; + + // // 3rd degree + + // float pSH9 = C3[0] * y * (3.0 * xx - yy); + // float pSH10 = C3[1] * x * y * z; + // float pSH11 = C3[2] * y * (5.0 * zz - 1.0); + // float pSH12 = C3[3] * z * (5.0 * zz - 3.0); + // float pSH13 = C3[4] * x * (5.0 * zz - 1.0); + // float pSH14 = C3[5] * (xx - yy) * z; + // float pSH15 = C3[6] * x * (xx - 3.0 * yy); + + // rgb = rgb + pSH9 * sh_coeffs[9] + + // pSH10 * sh_coeffs[10] + + // pSH11 * sh_coeffs[11] + + // pSH12 * sh_coeffs[12] + + // pSH13 * sh_coeffs[13] + + // pSH14 * sh_coeffs[14] + + // pSH15 * sh_coeffs[15]; - // 1st degree - rgb = rgb + C1 * (-y * sh_coeffs[1] + - z * sh_coeffs[2] - - x * sh_coeffs[3]); - // 2nd degree - - float fTmp0B = -1.092548430592079 * z; - float fC1 = xx - yy; - float fS1 = 2.0 * xy; - float pSH6 = (0.9461746957575601 * zz - 0.3153915652525201); - float pSH7 = fTmp0B * x; - float pSH5 = fTmp0B * y; - float pSH8 = 0.5462742152960395 * fC1; - float pSH4 = 0.5462742152960395 * fS1; - rgb = rgb + pSH4 * sh_coeffs[4] + pSH5 * sh_coeffs[5] + - pSH6 * sh_coeffs[6] + pSH7 * sh_coeffs[7] + - pSH8 * sh_coeffs[8]; - - // 3rd degree - float fTmp0C = -2.285228997322329f * zz + 0.4570457994644658f; - float fTmp1B = 1.445305721320277 * z; - float fC2 = x * fC1 - y * fS1; - float fS2 = x * fS1 + y * fC1; - float pSH12 = z * (1.865881662950577 * zz - 1.119528997770346f); - float pSH13 = fTmp0C * x; - float pSH11 = fTmp0C * y; - float pSH14 = fTmp1B * fC1; - float pSH10 = fTmp1B * fS1; - float pSH15 = -0.5900435899266435 * fC2; - float pSH9 = -0.5900435899266435 * fS2; - - rgb = rgb + pSH9 * sh_coeffs[9] + pSH10 * sh_coeffs[10] + - pSH11 * sh_coeffs[11] + pSH12 * sh_coeffs[12] + - pSH13 * sh_coeffs[13] + pSH14 * sh_coeffs[14]+ - pSH15 * sh_coeffs[15]; - vRgba = vec4(rgb + pointFive, float(rgbaUint32 >> uint(24)) / 255.0); - - // ----END GSPLAT IMPLEMENTATION----- + + // // DEBUGGING USING INPUT RGB INSTEAD OF SH + // vRgba = vec4( + // float(rgbaUint32 & uint(0xFF)) / 255.0, + // float((rgbaUint32 >> uint(8)) & uint(0xFF)) / 255.0, + // float((rgbaUint32 >> uint(16)) & uint(0xFF)) / 255.0, + // float(rgbaUint32 >> uint(24)) / 255.0 + // ); + // NEW SPHERICAL HARMONICS IMPLEMENTATION END----- // Throw the Gaussian off the screen if it's too close, too far, or too small. float weightedDeterminant = vRgba.a * (diag1 * diag2 - offDiag * offDiag); @@ -327,6 +354,8 @@ export function useGaussianMeshProps( geometry.setAttribute("sortedIndex", sortedIndexAttribute); // Create texture buffers. + // We store 4 floats and 4 int32s per Gaussian. + // One "numGaussians" corresponds to 4 32-bit values. const textureWidth = Math.min(numGaussians * 2, maxTextureSize); const textureHeight = Math.ceil((numGaussians * 2) / textureWidth); const bufferPadded = new Uint32Array(textureWidth * textureHeight * 4); From aa43e8d4950be33c7356f4b01b5e72181062cac8 Mon Sep 17 00:00:00 2001 From: Christian Le Date: Mon, 17 Feb 2025 15:32:21 +0100 Subject: [PATCH 12/14] re-ordered f_rest_* are interpreted for correct SH representation --- examples/experimental/gaussian_splats.py | 12 ++- .../src/Splatting/GaussianSplatsHelpers.ts | 87 ++++++++++--------- 2 files changed, 55 insertions(+), 44 deletions(-) diff --git a/examples/experimental/gaussian_splats.py b/examples/experimental/gaussian_splats.py index e5f12d254..bb0f30433 100644 --- a/examples/experimental/gaussian_splats.py +++ b/examples/experimental/gaussian_splats.py @@ -46,7 +46,15 @@ def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile: colors = 0.5 + SH_C0 * np.stack([v["f_dc_0"], v["f_dc_1"], v["f_dc_2"]], axis=1) opacities = 1.0 / (1.0 + np.exp(-v["opacity"][:, None])) dc_coeffs = np.stack([v["f_dc_0"], v["f_dc_1"], v["f_dc_2"]], axis=1) - rest_coeffs = np.stack([v[f"f_rest_{i}"] for i in range(45)], axis=1) + # Rest coefficients 0-14 belongs to RED channel, 15-29 to GREEN, 30-44 to BLUE + # Due to spherical harmonic calculations calculating a triplet at a time + # we need to stack them by (0,15,30), (1,16,31), ..., (14,29,44) + rest_coeffs = [] + for i in range(15): + rest_coeffs.append(v[f"f_rest_{i}"]) + rest_coeffs.append(v[f"f_rest_{i + 15}"]) + rest_coeffs.append(v[f"f_rest_{i + 30}"]) + rest_coeffs = np.stack(rest_coeffs, axis=1) sh_coeffs = np.concatenate([dc_coeffs, rest_coeffs], axis=1) normals = np.stack([v["nx"], v["ny"], v["nz"]], axis=-1) @@ -72,7 +80,7 @@ def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile: def main(splat_paths: tuple[Path, ...]) -> None: - server = viser.ViserServer(port=8001) + server = viser.ViserServer(port=8014) server.gui.configure_theme(dark_mode=True) gui_reset_up = server.gui.add_button( "Reset up direction", diff --git a/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts b/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts index 7ed7fdb34..a198cf6c3 100644 --- a/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts +++ b/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts @@ -172,7 +172,7 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( // Calculate the spherical harmonics. vec3 viewDir = normalize(center - cameraPosition); - // // 0.5 * sqrt(1.0 / pi) + // // 0.5 * sqrt(1.0 / pi) const float C0 = 0.28209479177387814; // sqrt(3.0 / (4.0 * pi)) * [y] // sqrt(3.0 / (4.0 * pi)) * [z] @@ -240,45 +240,45 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( // 1st degree - // float pSH1 = C1[0] * y; - // float pSH2 = C1[1] * z; - // float pSH3 = C1[2] * x; - - // rgb = rgb + pSH1 * sh_coeffs[1] + - // pSH2 * sh_coeffs[2] + - // pSH3 * sh_coeffs[3]; - - // // 2nd degree - - // float pSH4 = C2[0] * xy; - // float pSH5 = C2[1] * yz; - // float pSH6 = C2[2] * (3.0 * zz - 1.0); - // float pSH7 = C2[3] * xz; - // float pSH8 = C2[4] * (xx - yy); - - // rgb = rgb + pSH4 * sh_coeffs[4] + - // pSH5 * sh_coeffs[5] + - // pSH6 * sh_coeffs[6] + - // pSH7 * sh_coeffs[7] + - // pSH8 * sh_coeffs[8]; - - // // 3rd degree - - // float pSH9 = C3[0] * y * (3.0 * xx - yy); - // float pSH10 = C3[1] * x * y * z; - // float pSH11 = C3[2] * y * (5.0 * zz - 1.0); - // float pSH12 = C3[3] * z * (5.0 * zz - 3.0); - // float pSH13 = C3[4] * x * (5.0 * zz - 1.0); - // float pSH14 = C3[5] * (xx - yy) * z; - // float pSH15 = C3[6] * x * (xx - 3.0 * yy); - - // rgb = rgb + pSH9 * sh_coeffs[9] + - // pSH10 * sh_coeffs[10] + - // pSH11 * sh_coeffs[11] + - // pSH12 * sh_coeffs[12] + - // pSH13 * sh_coeffs[13] + - // pSH14 * sh_coeffs[14] + - // pSH15 * sh_coeffs[15]; + float pSH1 = C1[0] * y; + float pSH2 = C1[1] * z; + float pSH3 = C1[2] * x; + + rgb = rgb + pSH1 * sh_coeffs[1] + + pSH2 * sh_coeffs[2] + + pSH3 * sh_coeffs[3]; + + // 2nd degree + + float pSH4 = C2[0] * xy; + float pSH5 = C2[1] * yz; + float pSH6 = C2[2] * (3.0 * zz - 1.0); + float pSH7 = C2[3] * xz; + float pSH8 = C2[4] * (xx - yy); + + rgb = rgb + pSH4 * sh_coeffs[4] + + pSH5 * sh_coeffs[5] + + pSH6 * sh_coeffs[6] + + pSH7 * sh_coeffs[7] + + pSH8 * sh_coeffs[8]; + + // 3rd degree + + float pSH9 = C3[0] * y * (3.0 * xx - yy); + float pSH10 = C3[1] * x * y * z; + float pSH11 = C3[2] * y * (5.0 * zz - 1.0); + float pSH12 = C3[3] * z * (5.0 * zz - 3.0); + float pSH13 = C3[4] * x * (5.0 * zz - 1.0); + float pSH14 = C3[5] * (xx - yy) * z; + float pSH15 = C3[6] * x * (xx - 3.0 * yy); + + rgb = rgb + pSH9 * sh_coeffs[9] + + pSH10 * sh_coeffs[10] + + pSH11 * sh_coeffs[11] + + pSH12 * sh_coeffs[12] + + pSH13 * sh_coeffs[13] + + pSH14 * sh_coeffs[14] + + pSH15 * sh_coeffs[15]; vRgba = vec4(rgb + pointFive, float(rgbaUint32 >> uint(24)) / 255.0); @@ -383,6 +383,9 @@ export function useGaussianMeshProps( // NEW ADDITION******************************************************* // Values taken from PR https://github.com/nerfstudio-project/viser/pull/286/files + // WIDTH AND HEIGHT ARE MEASURED IN TEXELS + // As 48 x float16 = 96 bytes and each texel is 4 uint32s = 16 bytes + // We can fit 6 spherical harmonics coefficients in a single texel const shTextureWidth = Math.min(numGaussians * 6, maxTextureSize); const shTextureHeight = Math.ceil((numGaussians * 6) / shTextureWidth); const shBufferPadded = new Uint32Array(shTextureWidth * shTextureHeight * 4); @@ -397,8 +400,8 @@ export function useGaussianMeshProps( shTextureBuffer.internalFormat = "RGBA32UI"; shTextureBuffer.needsUpdate = true; - const normTexturwWidth = Math.min(numGaussians * 6, maxTextureSize); - const normTextureHeight = Math.ceil((numGaussians * 6) / normTexturwWidth); + const normTexturwWidth = Math.min(numGaussians, maxTextureSize); + const normTextureHeight = Math.ceil((numGaussians) / normTexturwWidth); const normBufferPadded = new Uint32Array(normTexturwWidth * normTextureHeight * 4); normBufferPadded.set(combinedNormBuffer); const normTextureBuffer = new THREE.DataTexture( From 0a75ea161d957eb066fd7ad35e525ea22ce8f4c2 Mon Sep 17 00:00:00 2001 From: Christian Le Date: Tue, 18 Feb 2025 14:02:12 +0100 Subject: [PATCH 13/14] remove surface normals code and clean up some comments --- examples/experimental/gaussian_splats.py | 57 +++++++++-- src/viser/_messages.py | 1 - src/viser/_scene_api.py | 18 ++-- src/viser/client/src/SceneTree.tsx | 9 -- .../client/src/Splatting/GaussianSplats.tsx | 30 +----- .../src/Splatting/GaussianSplatsHelpers.ts | 94 ++++++------------- src/viser/client/src/WebsocketMessages.ts | 3 +- 7 files changed, 92 insertions(+), 120 deletions(-) diff --git a/examples/experimental/gaussian_splats.py b/examples/experimental/gaussian_splats.py index bb0f30433..dbd28c876 100644 --- a/examples/experimental/gaussian_splats.py +++ b/examples/experimental/gaussian_splats.py @@ -27,9 +27,52 @@ class SplatFile(TypedDict): covariances: npt.NDArray[np.floating] """(N, 3, 3).""" sh_coeffs: npt.NDArray[np.floating] - """(N, 48).""" - normals: npt.NDArray[np.floating] - """(N, 3).""" + + +def load_splat_file(splat_path: Path, center: bool = False) -> SplatFile: + """Load an antimatter15-style splat file.""" + start_time = time.time() + splat_buffer = splat_path.read_bytes() + bytes_per_gaussian = ( + # Each Gaussian is serialized as: + # - position (vec3, float32) + 3 * 4 + # - xyz (vec3, float32) + + 3 * 4 + # - rgba (vec4, uint8) + + 4 + # - ijkl (vec4, uint8), where 0 => -1, 255 => 1. + + 4 + ) + assert len(splat_buffer) % bytes_per_gaussian == 0 + num_gaussians = len(splat_buffer) // bytes_per_gaussian + + # Reinterpret cast to dtypes that we want to extract. + splat_uint8 = np.frombuffer(splat_buffer, dtype=np.uint8).reshape( + (num_gaussians, bytes_per_gaussian) + ) + scales = splat_uint8[:, 12:24].copy().view(np.float32) + wxyzs = splat_uint8[:, 28:32] / 255.0 * 2.0 - 1.0 + Rs = tf.SO3(wxyzs).as_matrix() + covariances = np.einsum( + "nij,njk,nlk->nil", Rs, np.eye(3)[None, :, :] * scales[:, None, :] ** 2, Rs + ) + centers = splat_uint8[:, 0:12].copy().view(np.float32) + if center: + centers -= np.mean(centers, axis=0, keepdims=True) + print( + f"Splat file with {num_gaussians=} loaded in {time.time() - start_time} seconds" + ) + return { + "centers": centers, + # Colors should have shape (N, 3). + "rgbs": splat_uint8[:, 24:27] / 255.0, + "opacities": splat_uint8[:, 27:28] / 255.0, + # Covariances should have shape (N, 3, 3). + "covariances": covariances, + # No SH coefficients in the splat file. + "sh_coeffs": np.zeros((num_gaussians, 45), dtype=np.float32), + } def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile: @@ -56,7 +99,6 @@ def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile: rest_coeffs.append(v[f"f_rest_{i + 30}"]) rest_coeffs = np.stack(rest_coeffs, axis=1) sh_coeffs = np.concatenate([dc_coeffs, rest_coeffs], axis=1) - normals = np.stack([v["nx"], v["ny"], v["nz"]], axis=-1) Rs = tf.SO3(wxyzs).as_matrix() covariances = np.einsum( @@ -75,12 +117,11 @@ def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile: "opacities": opacities, "covariances": covariances, "sh_coeffs": sh_coeffs, - "normals": normals, } def main(splat_paths: tuple[Path, ...]) -> None: - server = viser.ViserServer(port=8014) + server = viser.ViserServer() server.gui.configure_theme(dark_mode=True) gui_reset_up = server.gui.add_button( "Reset up direction", @@ -96,7 +137,9 @@ def _(event: viser.GuiEvent) -> None: ) for i, splat_path in enumerate(splat_paths): - if splat_path.suffix == ".ply": + if splat_path.suffix == ".splat": + splat_data = load_splat_file(splat_path, center=True) + elif splat_path.suffix == ".ply": splat_data = load_ply_file(splat_path, center=True) else: raise SystemExit("Please provide a filepath to a .splat or .ply file.") diff --git a/src/viser/_messages.py b/src/viser/_messages.py index 8c34646bc..8b45c7581 100644 --- a/src/viser/_messages.py +++ b/src/viser/_messages.py @@ -1318,7 +1318,6 @@ class GaussianSplatsProps: # https://github.com/antimatter15/splat buffer: npt.NDArray[np.uint32] sh_buffer: npt.NDArray[np.uint32] - norm_buffer: npt.NDArray[np.uint32] """Our buffer will contain: - x as f32 - y as f32 diff --git a/src/viser/_scene_api.py b/src/viser/_scene_api.py index 60ae66909..93ccdef14 100644 --- a/src/viser/_scene_api.py +++ b/src/viser/_scene_api.py @@ -1294,7 +1294,6 @@ def add_gaussian_splats( covariances: np.ndarray, rgbs: np.ndarray, opacities: np.ndarray, - normals: np.ndarray | None = None, sh_coeffs: np.ndarray | None = None, wxyz: Tuple[float, float, float, float] | np.ndarray = (1.0, 0.0, 0.0, 0.0), position: Tuple[float, float, float] | np.ndarray = (0.0, 0.0, 0.0), @@ -1355,20 +1354,23 @@ def add_gaussian_splats( ], ).view(np.uint32) else: - sh_buffer = np.zeros((num_gaussians, 48), dtype=np.float16).view(np.uint32) + SH_C0 = 0.28209479177387814 + dc_coeffs = (rgbs - 0.5) / SH_C0 + sh_buffer = np.concatenate( + [ + dc_coeffs, np.zeros((num_gaussians, 45)) + ], + axis=1, + dtype=np.float16, + ).view(np.uint32) + # sh_buffer = np.zeros((num_gaussians, 48), dtype=np.float16).view(np.uint32) - if normals is not None: - assert normals.shape == (num_gaussians, 3) - norm_buffer = normals.astype(np.float32).view(np.uint8).view(np.uint32) - else: - norm_buffer = np.zeros((num_gaussians, 3), dtype=np.float32).view(np.uint8).view(np.uint32) message = _messages.GaussianSplatsMessage( name=name, props=_messages.GaussianSplatsProps( buffer=buffer, sh_buffer=sh_buffer, - norm_buffer=norm_buffer, ), ) node_handle = GaussianSplatHandle._make( diff --git a/src/viser/client/src/SceneTree.tsx b/src/viser/client/src/SceneTree.tsx index e2c45261b..4203115a7 100644 --- a/src/viser/client/src/SceneTree.tsx +++ b/src/viser/client/src/SceneTree.tsx @@ -500,15 +500,6 @@ function useObjectFactory(message: SceneNodeMessage | undefined): { ), ) } - norm_buffer={ - new Uint32Array( - message.props.norm_buffer.buffer.slice( - message.props.norm_buffer.byteOffset, - message.props.norm_buffer.byteOffset + - message.props.norm_buffer.byteLength, - ), - ) - } /> ), }; diff --git a/src/viser/client/src/Splatting/GaussianSplats.tsx b/src/viser/client/src/Splatting/GaussianSplats.tsx index ad389e7ce..75529bd2f 100644 --- a/src/viser/client/src/Splatting/GaussianSplats.tsx +++ b/src/viser/client/src/Splatting/GaussianSplats.tsx @@ -62,9 +62,8 @@ export const SplatObject = React.forwardRef< { buffer: Uint32Array; sh_buffer: Uint32Array; - norm_buffer: Uint32Array; } ->(function SplatObject({ buffer, sh_buffer, norm_buffer }, ref) { +>(function SplatObject({ buffer, sh_buffer }, ref) { const splatContext = React.useContext(GaussianSplatsContext)!; const setBuffer = splatContext.useGaussianSplatStore( (state) => state.setBuffer, @@ -78,16 +77,13 @@ export const SplatObject = React.forwardRef< const name = React.useMemo(() => uuidv4(), [buffer]); // Inspired from PR, maybe should look more similar to the original code above const sh_buffer_name = `sh_buffer_${name}`; - const norm_buffer_name = `norm_buffer_${name}`; React.useEffect(() => { setBuffer(name, buffer); setBuffer(sh_buffer_name, sh_buffer); - setBuffer(norm_buffer_name, norm_buffer); return () => { removeBuffer(name); removeBuffer(sh_buffer_name); - removeBuffer(norm_buffer_name); delete nodeRefFromId.current[name]; }; }, [buffer]); @@ -139,7 +135,6 @@ function SplatRendererImpl() { const meshProps = useGaussianMeshProps( merged.gaussianBuffer, merged.combinedSHBuffer, - merged.combinedNormBuffer, merged.numGroups, ); splatContext.meshPropsRef.current = meshProps; @@ -158,7 +153,6 @@ function SplatRendererImpl() { meshProps.material.uniforms.numGaussians.value = merged.numGaussians; meshProps.textureBuffer.needsUpdate = true; meshProps.shTextureBuffer.needsUpdate = true; - meshProps.normTextureBuffer.needsUpdate = true; initializedBufferTexture = true; } }; @@ -175,7 +169,6 @@ function SplatRendererImpl() { return () => { meshProps.textureBuffer.dispose(); meshProps.shTextureBuffer.dispose(); - meshProps.normTextureBuffer.dispose(); meshProps.geometry.dispose(); meshProps.material.dispose(); postToWorker({ close: true }); @@ -353,7 +346,7 @@ function mergeGaussianGroups(groupBufferFromName: { let totalBufferLength = 0; const groupBufferFromNameFiltered = Object.fromEntries( Object.entries(groupBufferFromName).filter( - ([key]) => !key.startsWith("sh_buffer_") && !key.startsWith("norm_buffer_") + ([key]) => !key.startsWith("sh_buffer_") ) ); @@ -406,24 +399,7 @@ function mergeGaussianGroups(groupBufferFromName: { sh_offset += sh_buffer.length; } - let totalNormBufferLength = 0; - - const normGaussianBuffers = Object.fromEntries( - Object.entries(groupBufferFromName).filter(([key]) => key.startsWith("norm_buffer_")) - ); - - for (const norm_buffer of Object.values(normGaussianBuffers)) { - totalNormBufferLength += norm_buffer.length; - } - - const combinedNormBuffer = new Uint32Array(totalNormBufferLength); - let norm_offset = 0; - for (const norm_buffer of Object.values(normGaussianBuffers)) { - combinedNormBuffer.set(norm_buffer, norm_offset); - norm_offset += norm_buffer.length; - } - const numGroups = Object.keys(groupBufferFromNameFiltered).length; - return { numGaussians, gaussianBuffer, numGroups, groupIndices, combinedSHBuffer, combinedNormBuffer }; + return { numGaussians, gaussianBuffer, numGroups, groupIndices, combinedSHBuffer }; } diff --git a/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts b/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts index a198cf6c3..2d1e672ab 100644 --- a/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts +++ b/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts @@ -17,7 +17,6 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( transparent: true, textureBuffer: null, shTextureBuffer: null, - normTextureBuffer: null, textureT_camera_groups: null, transitionInState: 0.0, }, @@ -31,11 +30,9 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( // copy quadjr for this. uniform usampler2D textureBuffer; - // NEW ADDITION******************************************************* // Buffer for spherical harmonics; Each Gaussian gets 24 int32s representing - // this information. + // this information (Each coefficient is 16 bits, corr. to 48 coeffs.). uniform usampler2D shTextureBuffer; - // END NEW ADDITION*************************************************** // We could also use a uniform to store transforms, but this would be more // limiting in terms of the # of groups we can have. @@ -105,7 +102,6 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( vec2 triu23 = unpackHalf2x16(intBufferData.y); vec2 triu45 = unpackHalf2x16(intBufferData.z); - // NEW ADDITION******************************************************* // Get spherical harmonics terms from int buffer. 48 coefficents per vertex. uint shTexStart = sortedIndex * 6u; ivec2 shTexSize = textureSize(shTextureBuffer, 0); @@ -132,7 +128,6 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( sh_coeffs_unpacked[i*8+6] = unpacked.x; sh_coeffs_unpacked[i*8+7] = unpacked.y; } - // END NEW ADDITION*************************************************** // Transition in. float startTime = 0.8 * float(sortedIndex) / float(numGaussians); @@ -168,25 +163,28 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( vec2 v1 = min(sqrt(2.0 * lambda1), 1024.0) * diagonalVector; vec2 v2 = min(sqrt(2.0 * lambda2), 1024.0) * vec2(diagonalVector.y, -diagonalVector.x); - // NEW ADDITION: SPHERICAL HARMONICS IMPLEMENTATION----- - // Calculate the spherical harmonics. + // According to gsplat implementation, seems that "x" and "y" have opposite direction + // of conventional SH directions, so square brackets contains the sign of resulting variable + // multiplications. + // A comprehensible table of Real SH constants: + // https://en.wikipedia.org/wiki/Table_of_spherical_harmonics vec3 viewDir = normalize(center - cameraPosition); - // // 0.5 * sqrt(1.0 / pi) + // C0 = 0.5 * sqrt(1.0 / pi) const float C0 = 0.28209479177387814; - // sqrt(3.0 / (4.0 * pi)) * [y] - // sqrt(3.0 / (4.0 * pi)) * [z] - // sqrt(3.0 / (4.0 * pi)) * [x] + // C1[0] = sqrt(3.0 / (4.0 * pi)) * [-1] + // C1[1] = sqrt(3.0 / (4.0 * pi)) * [1] + // C1[2] = sqrt(3.0 / (4.0 * pi)) * [-1] const float C1[3] = float[3]( -0.4886025119029199, 0.4886025119029199, -0.4886025119029199 ); - // 0.5 * sqrt(15/pi) * [x*y] - // 0.5 * sqrt(15/pi) * [y*z] - // 0.25 * sqrt(5/pi) * [3z^2 - 1] - // 0.5 * sqrt(15/pi) * [z*x] - // 0.25 * sqrt(15/pi) * [x^2 - y^2] + // C2[0] = 0.5 * sqrt(15/pi) * [1] + // C2[1] = 0.5 * sqrt(15/pi) * [-1] + // C2[2] = 0.25 * sqrt(5/pi) * [1] + // C2[3] = 0.5 * sqrt(15/pi) * [-1] + // C2[4] = 0.25 * sqrt(15/pi) * [1] const float C2[5] = float[5]( 1.0925484305920792, -1.0925484305920792, @@ -194,13 +192,13 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( -1.0925484305920792, 0.5462742152960396 ); - // 0.25 * sqrt(35/(2pi)) * [y*(3x^2-y^2)] - // 0.5 * sqrt(105/pi) * [x*y*z] - // 0.25 * sqrt(21/(2pi)) * [y*(5z^2-1)] - // 0.25 * sqrt(7/pi) * [z*(5z^2-3)] - // 0.25 * sqrt(21/(2pi)) * [x*(5z^2-1)] - // 0.25 * sqrt(105/(pi)) * [(x^2-y^2)*z] - // 0.25 * sqrt(35/(2pi)) * [x*(x^2-3y^2)] + // C3[0] = 0.25 * sqrt(35/(2pi)) * [-1] + // C3[1] = 0.5 * sqrt(105/pi) * [1] + // C3[2] = 0.25 * sqrt(21/(2pi)) * [-1] + // C3[3] = 0.25 * sqrt(7/pi) * [1] + // C3[4] = 0.25 * sqrt(21/(2pi)) * [-1] + // C3[5] = 0.25 * sqrt(105/(pi)) * [1] + // C3[6] = 0.25 * sqrt(35/(2pi)) * [-1] const float C3[7] = float[7]( -0.5900435899266435, 2.890611442640554, @@ -217,11 +215,6 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( } // View-dependent variables - // Along with SH coefficients, we represent view-dependent colors - // From Wikipedia definition: - // x = sin(theta) * cos(phi) - // y = sin(theta) * sin(phi) - // z = cos(theta) float x = viewDir.x; float y = viewDir.y; @@ -234,28 +227,24 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( float xz = viewDir.x * viewDir.z; // 0th degree - vec3 rgb = C0 * sh_coeffs[0]; // line 74 of plenoxels + vec3 rgb = C0 * sh_coeffs[0]; vec3 pointFive = vec3(0.5, 0.5, 0.5); - // 1st degree - + // From here, variables are included in multiplication with constants float pSH1 = C1[0] * y; float pSH2 = C1[1] * z; float pSH3 = C1[2] * x; - rgb = rgb + pSH1 * sh_coeffs[1] + pSH2 * sh_coeffs[2] + pSH3 * sh_coeffs[3]; // 2nd degree - float pSH4 = C2[0] * xy; float pSH5 = C2[1] * yz; float pSH6 = C2[2] * (3.0 * zz - 1.0); float pSH7 = C2[3] * xz; float pSH8 = C2[4] * (xx - yy); - rgb = rgb + pSH4 * sh_coeffs[4] + pSH5 * sh_coeffs[5] + pSH6 * sh_coeffs[6] + @@ -263,7 +252,6 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( pSH8 * sh_coeffs[8]; // 3rd degree - float pSH9 = C3[0] * y * (3.0 * xx - yy); float pSH10 = C3[1] * x * y * z; float pSH11 = C3[2] * y * (5.0 * zz - 1.0); @@ -271,7 +259,6 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( float pSH13 = C3[4] * x * (5.0 * zz - 1.0); float pSH14 = C3[5] * (xx - yy) * z; float pSH15 = C3[6] * x * (xx - 3.0 * yy); - rgb = rgb + pSH9 * sh_coeffs[9] + pSH10 * sh_coeffs[10] + pSH11 * sh_coeffs[11] + @@ -280,16 +267,8 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( pSH14 * sh_coeffs[14] + pSH15 * sh_coeffs[15]; + // Finalize the color vRgba = vec4(rgb + pointFive, float(rgbaUint32 >> uint(24)) / 255.0); - - // // DEBUGGING USING INPUT RGB INSTEAD OF SH - // vRgba = vec4( - // float(rgbaUint32 & uint(0xFF)) / 255.0, - // float((rgbaUint32 >> uint(8)) & uint(0xFF)) / 255.0, - // float((rgbaUint32 >> uint(16)) & uint(0xFF)) / 255.0, - // float(rgbaUint32 >> uint(24)) / 255.0 - // ); - // NEW SPHERICAL HARMONICS IMPLEMENTATION END----- // Throw the Gaussian off the screen if it's too close, too far, or too small. float weightedDeterminant = vRgba.a * (diag1 * diag2 - offDiag * offDiag); @@ -323,8 +302,7 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( /**Hook to generate properties for rendering Gaussians via a three.js mesh.*/ export function useGaussianMeshProps( gaussianBuffer: Uint32Array, - combinedSHBuffer: Uint32Array, // NEW ADDITION******************************************************* - combinedNormBuffer: Uint32Array, // NEW ADDITION*************************************************** + combinedSHBuffer: Uint32Array, numGroups: number, ) { const numGaussians = gaussianBuffer.length / 8; @@ -400,28 +378,13 @@ export function useGaussianMeshProps( shTextureBuffer.internalFormat = "RGBA32UI"; shTextureBuffer.needsUpdate = true; - const normTexturwWidth = Math.min(numGaussians, maxTextureSize); - const normTextureHeight = Math.ceil((numGaussians) / normTexturwWidth); - const normBufferPadded = new Uint32Array(normTexturwWidth * normTextureHeight * 4); - normBufferPadded.set(combinedNormBuffer); - const normTextureBuffer = new THREE.DataTexture( - normBufferPadded, - normTexturwWidth, - normTextureHeight, - THREE.RGBAIntegerFormat, - THREE.UnsignedIntType, - ); - normTextureBuffer.internalFormat = "RGBA32UI"; - normTextureBuffer.needsUpdate = true; - // END NEW ADDITION*************************************************** const material = new GaussianSplatMaterial({ // @ts-ignore textureBuffer: textureBuffer, - shTextureBuffer: shTextureBuffer, // NEW ADDITION******************************************************* - normTextureBuffer: normTextureBuffer, // NEW ADDITION*************************************************** + shTextureBuffer: shTextureBuffer, textureT_camera_groups: textureT_camera_groups, numGaussians: 0, transitionInState: 0.0, @@ -431,8 +394,7 @@ export function useGaussianMeshProps( geometry, material, textureBuffer, - shTextureBuffer, // NEW ADDITION******************************************************* - normTextureBuffer, // NEW ADDITION*************************************************** + shTextureBuffer, sortedIndexAttribute, textureT_camera_groups, rowMajorT_camera_groups, diff --git a/src/viser/client/src/WebsocketMessages.ts b/src/viser/client/src/WebsocketMessages.ts index 28b96d0b5..b7c23496a 100644 --- a/src/viser/client/src/WebsocketMessages.ts +++ b/src/viser/client/src/WebsocketMessages.ts @@ -321,8 +321,7 @@ export interface GaussianSplatsMessage { name: string; props: { buffer: Uint8Array; - sh_buffer: Uint8Array; - norm_buffer: Uint8Array; + sh_buffer: Uint8Array; }; } /** Remove a particular node from the scene. From 81fbb44d7aaa8a4de2bf5dc9c45de8a502688ea1 Mon Sep 17 00:00:00 2001 From: Christian Le Date: Tue, 18 Feb 2025 14:11:15 +0100 Subject: [PATCH 14/14] minor tweaks --- src/viser/_scene_api.py | 8 ++++---- src/viser/client/src/Splatting/GaussianSplats.tsx | 4 ---- src/viser/client/src/Splatting/GaussianSplatsHelpers.ts | 7 +------ 3 files changed, 5 insertions(+), 14 deletions(-) diff --git a/src/viser/_scene_api.py b/src/viser/_scene_api.py index 93ccdef14..1de5a6990 100644 --- a/src/viser/_scene_api.py +++ b/src/viser/_scene_api.py @@ -1310,7 +1310,6 @@ def add_gaussian_splats( covariances: Second moment for each Gaussian. (N, 3, 3). rgbs: Color for each Gaussian. (N, 3). opacities: Opacity for each Gaussian. (N, 1). - normals: Normals for each Gaussian. (N, 3). sh_coeffs: Spherical harmonics coefficients for each Gaussian. (N, 48). wxyz: R_parent_local transformation. position: t_parent_local transformation. @@ -1327,7 +1326,6 @@ def add_gaussian_splats( # Get upper-triangular terms of covariance matrix. cov_triu = covariances.reshape((-1, 9))[:, np.array([0, 1, 2, 4, 5, 8])] - buffer = np.concatenate( [ # First texelFetch. @@ -1354,6 +1352,10 @@ def add_gaussian_splats( ], ).view(np.uint32) else: + # To ensure backwards compatibility, we'll compute SH coefficients from + # the RGB values. + # However, this is not efficient as packets sent to client + # will be larger. TODO: Permanently incorporate colors with SH coefficients. SH_C0 = 0.28209479177387814 dc_coeffs = (rgbs - 0.5) / SH_C0 sh_buffer = np.concatenate( @@ -1363,8 +1365,6 @@ def add_gaussian_splats( axis=1, dtype=np.float16, ).view(np.uint32) - # sh_buffer = np.zeros((num_gaussians, 48), dtype=np.float16).view(np.uint32) - message = _messages.GaussianSplatsMessage( name=name, diff --git a/src/viser/client/src/Splatting/GaussianSplats.tsx b/src/viser/client/src/Splatting/GaussianSplats.tsx index 75529bd2f..6ab8462e9 100644 --- a/src/viser/client/src/Splatting/GaussianSplats.tsx +++ b/src/viser/client/src/Splatting/GaussianSplats.tsx @@ -75,7 +75,6 @@ export const SplatObject = React.forwardRef< (state) => state.nodeRefFromId, ); const name = React.useMemo(() => uuidv4(), [buffer]); - // Inspired from PR, maybe should look more similar to the original code above const sh_buffer_name = `sh_buffer_${name}`; React.useEffect(() => { @@ -349,9 +348,6 @@ function mergeGaussianGroups(groupBufferFromName: { ([key]) => !key.startsWith("sh_buffer_") ) ); - - // Temporarily using groupBufferFromName instead of groupBufferFromNameFiltered - // because it causes an error in the shader for (const buffer of Object.values(groupBufferFromNameFiltered)) { totalBufferLength += buffer.length; } diff --git a/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts b/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts index 2d1e672ab..b6fb216a4 100644 --- a/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts +++ b/src/viser/client/src/Splatting/GaussianSplatsHelpers.ts @@ -139,7 +139,6 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( triu01.y, triu23.y, triu45.x, triu23.x, triu45.x, triu45.y ); - mat3 J = mat3( // Matrices are column-major. focal.x / c_cam.z, 0., 0.0, @@ -272,7 +271,7 @@ const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( // Throw the Gaussian off the screen if it's too close, too far, or too small. float weightedDeterminant = vRgba.a * (diag1 * diag2 - offDiag * offDiag); - if (weightedDeterminant < 0.5) + if (weightedDeterminant < 0.25) return; vPosition = position.xy; @@ -359,7 +358,6 @@ export function useGaussianMeshProps( textureT_camera_groups.internalFormat = "RGBA32F"; textureT_camera_groups.needsUpdate = true; - // NEW ADDITION******************************************************* // Values taken from PR https://github.com/nerfstudio-project/viser/pull/286/files // WIDTH AND HEIGHT ARE MEASURED IN TEXELS // As 48 x float16 = 96 bytes and each texel is 4 uint32s = 16 bytes @@ -378,9 +376,6 @@ export function useGaussianMeshProps( shTextureBuffer.internalFormat = "RGBA32UI"; shTextureBuffer.needsUpdate = true; - // END NEW ADDITION*************************************************** - - const material = new GaussianSplatMaterial({ // @ts-ignore textureBuffer: textureBuffer,