diff --git a/examples/experimental/gaussian_splats.py b/examples/experimental/gaussian_splats.py index 513a151f1..38d911738 100644 --- a/examples/experimental/gaussian_splats.py +++ b/examples/experimental/gaussian_splats.py @@ -82,8 +82,30 @@ def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile: positions = onp.stack([v["x"], v["y"], v["z"]], axis=-1) scales = onp.exp(onp.stack([v["scale_0"], v["scale_1"], v["scale_2"]], axis=-1)) wxyzs = onp.stack([v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]], axis=1) - colors = 0.5 + SH_C0 * onp.stack([v["f_dc_0"], v["f_dc_1"], v["f_dc_2"]], axis=1) + colors = 0.5 + SH_C0 * onp.stack([v["f_dc_0"], v["f_dc_1"], v["f_dc_2"]], axis=1) + # print(v["f_dc_0"].shape) # prints (numGaussians) + # print(colors.shape) # prints (numGaussians, 3) opacities = 1.0 / (1.0 + onp.exp(-v["opacity"][:, None])) + + # Load all zero order SH coefficients + dc_terms = onp.stack([v["f_dc_0"], v["f_dc_1"], v["f_dc_2"]], axis=1) + + # Load higher order SH coefficients (f_rest_0, ... f_rest_44), which are either level 1 or higher + # Note: .ply file supports maximum SH degree of 3, R = f_rest_0, ... f_rest_14; G = f_rest_15, ... f_rest_29 + rest_terms = [] + i = 0 + #while f"f_rest_{i}" in v: + while i < 15: + rest_terms.append(v[f"f_rest_{i}"]) # has shape (numGaussians, ) + rest_terms.append(v[f"f_rest_{15 + i}"]) + rest_terms.append(v[f"f_rest_{30 + i}"]) + i += 1 + # while f"f_rest_{i}" in v: + # rest_terms.append(v[f"f_rest_{i}"]) + # i += 1 + if len(rest_terms) > 0: # if we do have higher than zero order SH, we will process them and add them here. + sh_coeffs = onp.stack([v["f_dc_0"], v["f_dc_1"], v["f_dc_2"]] + rest_terms, axis=1) + sh_degree = int(onp.sqrt(sh_coeffs.shape[1] // 3) - 1) Rs = tf.SO3(wxyzs).as_matrix() covariances = onp.einsum( @@ -93,19 +115,42 @@ def load_ply_file(ply_file_path: Path, center: bool = False) -> SplatFile: positions -= onp.mean(positions, axis=0, keepdims=True) num_gaussians = len(v) + + # print(sh_coeffs.shape) # prints (447703, 48) + # print(v["x"].shape) # prints (447703,) + # print(positions.shape) + # print(colors.shape) + # print(covariances.shape) + print( f"PLY file with {num_gaussians=} loaded in {time.time() - start_time} seconds" ) + print(onp.stack([v["f_dc_0"], v["f_dc_1"], v["f_dc_2"]], axis=1)[0 , :]) + print(sh_coeffs[0, :]) # first gaussian, all 48 coefficients + # print(sh_coeffs[0, :]) # next 3 SH coefficients that are the 1st order + + # return { + # "centers": positions[0:1, :], + # "rgbs": colors[0:1, :], + # "opacities": opacities[0:1, :], + # "covariances": 10000*covariances[0:1, :], + # "sh_degree": sh_degree, + # "sh_coeffs": sh_coeffs[0:1, :], + # } + return { "centers": positions, "rgbs": colors, "opacities": opacities, "covariances": covariances, + "sh_degree": sh_degree, + "sh_coeffs": sh_coeffs, } def main(splat_paths: tuple[Path, ...]) -> None: server = viser.ViserServer() + print(server.request_share_url()) server.gui.configure_theme(dark_mode=True) gui_reset_up = server.gui.add_button( "Reset up direction", @@ -135,6 +180,8 @@ def _(event: viser.GuiEvent) -> None: rgbs=splat_data["rgbs"], opacities=splat_data["opacities"], covariances=splat_data["covariances"], + sh_degree=splat_data["sh_degree"], + sh_coeffs=splat_data["sh_coeffs"], ) remove_button = server.gui.add_button(f"Remove splat object {i}") @@ -150,3 +197,7 @@ def _(_, gs_handle=gs_handle, remove_button=remove_button) -> None: if __name__ == "__main__": tyro.cli(main) + + + +print("yapyap") \ No newline at end of file diff --git a/src/viser/_messages.py b/src/viser/_messages.py index ed01f4c29..43ac1eecd 100644 --- a/src/viser/_messages.py +++ b/src/viser/_messages.py @@ -761,6 +761,14 @@ class GaussianSplatsMessage(Message): - cov5 (f16), cov6 (f16) - rgba (int32) Where cov1-6 are the upper triangular elements of the covariance matrix.""" + + sh_buffer: onpt.NDArray[onp.uint32] + """The spherical harmonic buffer contains: + - + - + - + """ + @dataclasses.dataclass diff --git a/src/viser/_scene_api.py b/src/viser/_scene_api.py index 649abae94..2c77448c8 100644 --- a/src/viser/_scene_api.py +++ b/src/viser/_scene_api.py @@ -942,6 +942,8 @@ def _add_gaussian_splats( wxyz: Tuple[float, float, float, float] | onp.ndarray = (1.0, 0.0, 0.0, 0.0), position: Tuple[float, float, float] | onp.ndarray = (0.0, 0.0, 0.0), visible: bool = True, + sh_degree: int = 0, + sh_coeffs: onp.ndarray = None, ) -> GaussianSplatHandle: """Add a model to render using Gaussian Splatting. @@ -990,12 +992,24 @@ def _add_gaussian_splats( ], axis=-1, ).view(onp.uint32) - assert buffer.shape == (num_gaussians, 8) + + assert buffer.shape == (num_gaussians, 8) + + # We have 48 float32 spherical coeffecients per gaussian. + # However, by converting them to float16 we now have 48 float16 values per gaussian + # This means, each cell of sh_buffer contains 2 spherical coefficients because each cell is 32bits + + # - (768 bits): spherical harmonics + print("sh_coeffs.shape", sh_coeffs.shape) + sh_buffer = (sh_coeffs.astype(onp.float16)).view(onp.uint32) + print("sh_buffer.shape", sh_buffer.shape) # has shape (num_gaussians, 24), each cell contains 2 spherical coeff. + # print(sh_buffer) self._websock_interface.queue_message( _messages.GaussianSplatsMessage( name=name, buffer=buffer, + sh_buffer=sh_buffer, ) ) node_handle = GaussianSplatHandle._make(self, name, wxyz, position, visible) diff --git a/src/viser/client/.yarn/install-state.gz b/src/viser/client/.yarn/install-state.gz new file mode 100644 index 000000000..ca619cdb9 Binary files /dev/null and b/src/viser/client/.yarn/install-state.gz differ diff --git a/src/viser/client/.yarnrc.yml b/src/viser/client/.yarnrc.yml new file mode 100644 index 000000000..3186f3f07 --- /dev/null +++ b/src/viser/client/.yarnrc.yml @@ -0,0 +1 @@ +nodeLinker: node-modules diff --git a/src/viser/client/src/MessageHandler.tsx b/src/viser/client/src/MessageHandler.tsx index 0c10c0fe4..67878c2f4 100644 --- a/src/viser/client/src/MessageHandler.tsx +++ b/src/viser/client/src/MessageHandler.tsx @@ -996,6 +996,14 @@ function useMessageHandler() { ), ) } + sh_buffer={ + new Uint32Array( + message.sh_buffer.buffer.slice( + message.sh_buffer.byteOffset, + message.sh_buffer.byteOffset + message.sh_buffer.byteLength, + ), + ) + } /> ); }), diff --git a/src/viser/client/src/Splatting/GaussianSplats.tsx b/src/viser/client/src/Splatting/GaussianSplats.tsx index 4d42313aa..d04e68f52 100644 --- a/src/viser/client/src/Splatting/GaussianSplats.tsx +++ b/src/viser/client/src/Splatting/GaussianSplats.tsx @@ -83,184 +83,295 @@ export function SplatRenderContext({ } const GaussianSplatMaterial = /* @__PURE__ */ shaderMaterial( - { - numGaussians: 0, - focal: 100.0, - viewport: [640, 480], - near: 1.0, - far: 100.0, - depthTest: true, - depthWrite: false, - transparent: true, - textureBuffer: null, - textureT_camera_groups: null, - transitionInState: 0.0, - }, - `precision highp usampler2D; // Most important: ints must be 32-bit. - precision mediump float; - - // Index from the splat sorter. - attribute uint sortedIndex; - - // Buffers for splat data; each Gaussian gets 4 floats and 4 int32s. We just - // copy quadjr for this. - uniform usampler2D textureBuffer; - - // We could also use a uniform to store transforms, but this would be more - // limiting in terms of the # of groups we can have. - uniform sampler2D textureT_camera_groups; - - // Various other uniforms... - uniform uint numGaussians; - uniform vec2 focal; - uniform vec2 viewport; - uniform float near; - uniform float far; - - // Fade in state between [0, 1]. - uniform float transitionInState; - - out vec4 vRgba; - out vec2 vPosition; - - // Function to fetch and construct the i-th transform matrix using texelFetch - mat4 getGroupTransform(uint i) { - // Calculate the base index for the i-th transform. - uint baseIndex = i * 3u; - - // Fetch the texels that represent the first 3 rows of the transform. We - // choose to use row-major here, since it lets us exclude the fourth row of - // the matrix. - vec4 row0 = texelFetch(textureT_camera_groups, ivec2(baseIndex + 0u, 0), 0); - vec4 row1 = texelFetch(textureT_camera_groups, ivec2(baseIndex + 1u, 0), 0); - vec4 row2 = texelFetch(textureT_camera_groups, ivec2(baseIndex + 2u, 0), 0); - - // Construct the mat4 with the fetched rows. - mat4 transform = mat4(row0, row1, row2, vec4(0.0, 0.0, 0.0, 1.0)); - return transpose(transform); - } - - void main () { - // Get position + scale from float buffer. - ivec2 texSize = textureSize(textureBuffer, 0); - uint texStart = sortedIndex << 1u; - ivec2 texPos0 = ivec2(texStart % uint(texSize.x), texStart / uint(texSize.x)); - - - // Fetch from textures. - uvec4 floatBufferData = texelFetch(textureBuffer, texPos0, 0); - mat4 T_camera_group = getGroupTransform(floatBufferData.w); - - // Any early return will discard the fragment. - gl_Position = vec4(0.0, 0.0, 2.0, 1.0); - - // Get center wrt camera. modelViewMatrix is T_cam_world. - vec3 center = uintBitsToFloat(floatBufferData.xyz); - vec4 c_cam = T_camera_group * vec4(center, 1); - if (-c_cam.z < near || -c_cam.z > far) - return; - vec4 pos2d = projectionMatrix * c_cam; - float clip = 1.1 * pos2d.w; - if (pos2d.x < -clip || pos2d.x > clip || pos2d.y < -clip || pos2d.y > clip) - return; - - // Read covariance terms. - ivec2 texPos1 = ivec2((texStart + 1u) % uint(texSize.x), (texStart + 1u) / uint(texSize.x)); - uvec4 intBufferData = texelFetch(textureBuffer, texPos1, 0); - - // Get covariance terms from int buffer. - uint rgbaUint32 = intBufferData.w; - vec2 chol01 = unpackHalf2x16(intBufferData.x); - vec2 chol23 = unpackHalf2x16(intBufferData.y); - vec2 chol45 = unpackHalf2x16(intBufferData.z); - - // Transition in. - float startTime = 0.8 * float(sortedIndex) / float(numGaussians); - float cov_scale = smoothstep(startTime, startTime + 0.2, transitionInState); - - // Do the actual splatting. - mat3 chol = mat3( - chol01.x, chol01.y, chol23.x, - 0., chol23.y, chol45.x, - 0., 0., chol45.y - ); - mat3 cov3d = chol * transpose(chol) * cov_scale; - mat3 J = mat3( - // Matrices are column-major. - focal.x / c_cam.z, 0., 0.0, - 0., focal.y / c_cam.z, 0.0, - -(focal.x * c_cam.x) / (c_cam.z * c_cam.z), -(focal.y * c_cam.y) / (c_cam.z * c_cam.z), 0. - ); - mat3 A = J * mat3(T_camera_group); - mat3 cov_proj = A * cov3d * transpose(A); - float diag1 = cov_proj[0][0] + 0.3; - float offDiag = cov_proj[0][1]; - float diag2 = cov_proj[1][1] + 0.3; - - // Eigendecomposition. - float mid = 0.5 * (diag1 + diag2); - float radius = length(vec2((diag1 - diag2) / 2.0, offDiag)); - float lambda1 = mid + radius; - float lambda2 = mid - radius; - if (lambda2 < 0.0) - return; - vec2 diagonalVector = normalize(vec2(offDiag, lambda1 - diag1)); - vec2 v1 = min(sqrt(2.0 * lambda1), 1024.0) * diagonalVector; - vec2 v2 = min(sqrt(2.0 * lambda2), 1024.0) * vec2(diagonalVector.y, -diagonalVector.x); - - vRgba = vec4( - float(rgbaUint32 & uint(0xFF)) / 255.0, - float((rgbaUint32 >> uint(8)) & uint(0xFF)) / 255.0, - float((rgbaUint32 >> uint(16)) & uint(0xFF)) / 255.0, - float(rgbaUint32 >> uint(24)) / 255.0 - ); - - // Throw the Gaussian off the screen if it's too close, too far, or too small. - float weightedDeterminant = vRgba.a * (diag1 * diag2 - offDiag * offDiag); - if (weightedDeterminant < 0.5) - return; - vPosition = position.xy; - - gl_Position = vec4( - vec2(pos2d) / pos2d.w - + position.x * v1 / viewport * 2.0 - + position.y * v2 / viewport * 2.0, pos2d.z / pos2d.w, 1.); - } -`, - `precision mediump float; - - uniform vec2 viewport; - uniform vec2 focal; - - in vec4 vRgba; - in vec2 vPosition; - - void main () { - float A = -dot(vPosition, vPosition); - if (A < -4.0) discard; - float B = exp(A) * vRgba.a; - if (B < 0.01) discard; // alphaTest. - gl_FragColor = vec4(vRgba.rgb, B); - }`, -); + { + numGaussians: 0, + focal: 100.0, + viewport: [640, 480], + near: 1.0, + far: 100.0, + depthTest: true, + depthWrite: false, + transparent: true, + sh_degree: 0, + textureBuffer: null, + shTextureBuffer: null, + textureT_camera_groups: null, + transitionInState: 0.0, + }, + `precision highp usampler2D; // Most important: ints must be 32-bit. + precision mediump float; + + // Index from the splat sorter. + attribute uint sortedIndex; + + // Buffers for splat data; each Gaussian gets 4 floats and 4 int32s. We just + // copy quadjr for this. + uniform usampler2D textureBuffer; + + // Buffer for spherical harmonics; each Gaussian also gets 24 int32s representing this information. + uniform usampler2D shTextureBuffer; + + // We could also use a uniform to store transforms, but this would be more + // limiting in terms of the # of groups we can have. + uniform sampler2D textureT_camera_groups; + + // Various other uniforms... + uniform uint numGaussians; + uniform vec2 focal; + uniform vec2 viewport; + uniform float near; + uniform float far; + uniform uint sh_degree; + + // Fade in state between [0, 1]. + uniform float transitionInState; + + out vec4 vRgba; + out vec2 vPosition; + + // Function to fetch and construct the i-th transform matrix using texelFetch + mat4 getGroupTransform(uint i) { + // Calculate the base index for the i-th transform. + uint baseIndex = i * 3u; + + // Fetch the texels that represent the first 3 rows of the transform. We + // choose to use row-major here, since it lets us exclude the fourth row of + // the matrix. + vec4 row0 = texelFetch(textureT_camera_groups, ivec2(baseIndex + 0u, 0), 0); + vec4 row1 = texelFetch(textureT_camera_groups, ivec2(baseIndex + 1u, 0), 0); + vec4 row2 = texelFetch(textureT_camera_groups, ivec2(baseIndex + 2u, 0), 0); + + // Construct the mat4 with the fetched rows. + mat4 transform = mat4(row0, row1, row2, vec4(0.0, 0.0, 0.0, 1.0)); + return transpose(transform); + } + + void main () { + // Get position + scale from float buffer. + ivec2 texSize = textureSize(textureBuffer, 0); + uint texStart = sortedIndex << 1u; + ivec2 texPos0 = ivec2(texStart % uint(texSize.x), texStart / uint(texSize.x)); + + + // Fetch from textures. + uvec4 floatBufferData = texelFetch(textureBuffer, texPos0, 0); + mat4 T_camera_group = getGroupTransform(floatBufferData.w); + + // Any early return will discard the fragment. + gl_Position = vec4(0.0, 0.0, 2.0, 1.0); + + // Get center wrt camera. modelViewMatrix is T_cam_world. + vec3 center = uintBitsToFloat(floatBufferData.xyz); + vec4 c_cam = T_camera_group * vec4(center, 1); + if (-c_cam.z < near || -c_cam.z > far) + return; + vec4 pos2d = projectionMatrix * c_cam; + float clip = 1.1 * pos2d.w; + if (pos2d.x < -clip || pos2d.x > clip || pos2d.y < -clip || pos2d.y > clip) + return; + + // Read covariance terms. + ivec2 texPos1 = ivec2((texStart + 1u) % uint(texSize.x), (texStart + 1u) / uint(texSize.x)); + uvec4 intBufferData = texelFetch(textureBuffer, texPos1, 0); + + // Get covariance terms from int buffer. + uint rgbaUint32 = intBufferData.w; + vec2 chol01 = unpackHalf2x16(intBufferData.x); + vec2 chol23 = unpackHalf2x16(intBufferData.y); + vec2 chol45 = unpackHalf2x16(intBufferData.z); + + // Get spherical harmonic terms from the buffer, there are 48 coeffecients per vertex + uint shTexStart = sortedIndex * 6u; + ivec2 shTexSize = textureSize(shTextureBuffer, 0); + float sh_coeffs_unpacked[48]; + for (int i = 0; i < 6; i++) { + ivec2 shTexPos = ivec2((shTexStart + uint(i)) % uint(shTexSize.x), + (shTexStart + uint(i)) / uint(shTexSize.x)); + uvec4 packedCoeffs = texelFetch(shTextureBuffer, shTexPos, 0); + + // unpack each uint32 directly into two float16 values, we read 4 at a time + vec2 unpacked; + unpacked = unpackHalf2x16(packedCoeffs.x); + sh_coeffs_unpacked[i*8] = unpacked.x; + sh_coeffs_unpacked[i*8+1] = unpacked.y; + + unpacked = unpackHalf2x16(packedCoeffs.y); + sh_coeffs_unpacked[i*8+2] = unpacked.x; + sh_coeffs_unpacked[i*8+3] = unpacked.y; + + unpacked = unpackHalf2x16(packedCoeffs.z); + sh_coeffs_unpacked[i*8+4] = unpacked.x; + sh_coeffs_unpacked[i*8+5] = unpacked.y; + + unpacked = unpackHalf2x16(packedCoeffs.w); + sh_coeffs_unpacked[i*8+6] = unpacked.x; + sh_coeffs_unpacked[i*8+7] = unpacked.y; + } + + // Transition in. + float startTime = 0.8 * float(sortedIndex) / float(numGaussians); + float cov_scale = smoothstep(startTime, startTime + 0.2, transitionInState); + + // Do the actual splatting. + mat3 chol = mat3( + chol01.x, chol01.y, chol23.x, + 0., chol23.y, chol45.x, + 0., 0., chol45.y + ); + mat3 cov3d = chol * transpose(chol) * cov_scale; + mat3 J = mat3( + // Matrices are column-major. + focal.x / c_cam.z, 0., 0.0, + 0., focal.y / c_cam.z, 0.0, + -(focal.x * c_cam.x) / (c_cam.z * c_cam.z), -(focal.y * c_cam.y) / (c_cam.z * c_cam.z), 0. + ); + mat3 A = J * mat3(T_camera_group); + mat3 cov_proj = A * cov3d * transpose(A); + float diag1 = cov_proj[0][0] + 0.3; + float offDiag = cov_proj[0][1]; + float diag2 = cov_proj[1][1] + 0.3; + + // Eigendecomposition. + float mid = 0.5 * (diag1 + diag2); + float radius = length(vec2((diag1 - diag2) / 2.0, offDiag)); + float lambda1 = mid + radius; + float lambda2 = mid - radius; + if (lambda2 < 0.0) + return; + vec2 diagonalVector = normalize(vec2(offDiag, lambda1 - diag1)); + vec2 v1 = min(sqrt(2.0 * lambda1), 1024.0) * diagonalVector; + vec2 v2 = min(sqrt(2.0 * lambda2), 1024.0) * vec2(diagonalVector.y, -diagonalVector.x); + + // Get the spherical harmonic view direction in world coordinates and calculate color + // vec3 viewDir = normalize(center - cameraPosition); + vec3 t_group_camera = -(transpose(mat3(T_camera_group)) * T_camera_group[3].xyz); + vec3 viewDir = normalize(center - t_group_camera); + const float C0 = 0.28209479177387814; + const float C1 = 0.4886025119029199; + const float C2[5] = float[5]( + 1.0925484305920792, + -1.0925484305920792, + 0.31539156525252005, + -1.0925484305920792, + 0.5462742152960396 + ); + const float C3[7] = float[7]( + -0.5900435899266435, + 2.890611442640554, + -0.4570457994644658, + 0.3731763325901154, + -0.4570457994644658, + 1.445305721320277, + -0.5900435899266435 + ); + + vec3 sh_coeffs[16]; + for (int i = 0; i < 16; i++) { + sh_coeffs[i] = vec3(sh_coeffs_unpacked[i*3], sh_coeffs_unpacked[i*3+1], sh_coeffs_unpacked[i*3+2]); + } + float x = viewDir.x; + float y = viewDir.y; + float z = viewDir.z; + float xx = viewDir.x * viewDir.x; + float yy = viewDir.y * viewDir.y; + float zz = viewDir.z * viewDir.z; + float xy = viewDir.x * viewDir.y; + float yz = viewDir.y * viewDir.z; + float xz = viewDir.x * viewDir.z; + + + vec3 rgb = C0 * sh_coeffs[0]; // line 74 of plenoxels + vec3 pointFive = vec3(0.5, 0.5, 0.5); + + // ----GSPLAT IMPLEMENTATION----- + + // degree 1 + rgb = rgb + C1 * (-y * sh_coeffs[1] + + z * sh_coeffs[2] - + x * sh_coeffs[3]); + // degree 2 + + float fTmp0B = -1.092548430592079 * z; + float fC1 = xx - yy; + float fS1 = 2.0 * xy; + float pSH6 = (0.9461746957575601 * zz - 0.3153915652525201); + float pSH7 = fTmp0B * x; + float pSH5 = fTmp0B * y; + float pSH8 = 0.5462742152960395 * fC1; + float pSH4 = 0.5462742152960395 * fS1; + rgb = rgb + pSH4 * sh_coeffs[4] + pSH5 * sh_coeffs[5] + + pSH6 * sh_coeffs[6] + pSH7 * sh_coeffs[7] + + pSH8 * sh_coeffs[8]; + + // degree 3 + float fTmp0C = -2.285228997322329f * zz + 0.4570457994644658f; + float fTmp1B = 1.445305721320277 * z; + float fC2 = x * fC1 - y * fS1; + float fS2 = x * fS1 + y * fC1; + float pSH12 = z * (1.865881662950577 * zz - 1.119528997770346f); + float pSH13 = fTmp0C * x; + float pSH11 = fTmp0C * y; + float pSH14 = fTmp1B * fC1; + float pSH10 = fTmp1B * fS1; + float pSH15 = -0.5900435899266435 * fC2; + float pSH9 = -0.5900435899266435 * fS2; + + rgb = rgb + pSH9 * sh_coeffs[9] + pSH10 * sh_coeffs[10] + + pSH11 * sh_coeffs[11] + pSH12 * sh_coeffs[12] + + pSH13 * sh_coeffs[13] + pSH14 * sh_coeffs[14]+ + pSH15 * sh_coeffs[15]; + + vRgba = vec4(rgb + pointFive, float(rgbaUint32 >> uint(24)) / 255.0); + + // Throw the Gaussian off the screen if it's too close, too far, or too small. + float weightedDeterminant = vRgba.a * (diag1 * diag2 - offDiag * offDiag); + if (weightedDeterminant < 0.5) + return; + vPosition = position.xy; + + gl_Position = vec4( + vec2(pos2d) / pos2d.w + + position.x * v1 / viewport * 2.0 + + position.y * v2 / viewport * 2.0, pos2d.z / pos2d.w, 1.); + } + `, + `precision mediump float; + + uniform vec2 viewport; + uniform vec2 focal; + + in vec4 vRgba; + in vec2 vPosition; + + void main () { + float A = -dot(vPosition, vPosition); + if (A < -4.0) discard; + float B = exp(A) * vRgba.a; + if (B < 0.01) discard; // alphaTest. + gl_FragColor = vec4(vRgba.rgb, B); + }`, + ); export const SplatObject = React.forwardRef< THREE.Group, { buffer: Uint32Array; + sh_buffer: Uint32Array; } ->(function SplatObject({ buffer }, ref) { +>(function SplatObject({ buffer, sh_buffer }, ref) { const splatContext = React.useContext(GaussianSplatsContext)!; const setBuffer = splatContext((state) => state.setBuffer); const removeBuffer = splatContext((state) => state.removeBuffer); const nodeRefFromId = splatContext((state) => state.nodeRefFromId); const name = React.useMemo(() => crypto.randomUUID(), [buffer]); + const sh_buffer_name = "sh_buffer_" + name const [obj, setRef] = React.useState(null); React.useEffect(() => { if (obj === null) return; setBuffer(name, buffer); + setBuffer(sh_buffer_name, sh_buffer) if (ref !== null) { if ("current" in ref) { ref.current = obj; @@ -269,8 +380,9 @@ export const SplatObject = React.forwardRef< } } nodeRefFromId.current[name] = obj; - return () => { + return () => { // this is the cleanup function that is triggered on every re-render (every time obj is changed) removeBuffer(name); + removeBuffer(sh_buffer_name) delete nodeRefFromId.current[name]; }; }, [obj]); @@ -288,6 +400,7 @@ function SplatRenderer() { const merged = mergeGaussianGroups(groupBufferFromId); const meshProps = useGaussianMeshProps( merged.gaussianBuffer, + merged.combinedSHBuffer, merged.numGroups, ); @@ -304,6 +417,7 @@ function SplatRenderer() { if (!initializedBufferTexture) { meshProps.material.uniforms.numGaussians.value = merged.numGaussians; meshProps.textureBuffer.needsUpdate = true; + meshProps.shTextureBuffer.needsUpdate = true; initializedBufferTexture = true; } }; @@ -320,6 +434,7 @@ function SplatRenderer() { React.useEffect(() => { return () => { meshProps.textureBuffer.dispose(); + meshProps.shTextureBuffer.dispose(); meshProps.geometry.dispose(); meshProps.material.dispose(); postToWorker({ close: true }); @@ -446,16 +561,21 @@ function mergeGaussianGroups(groupBufferFromName: { }) { // Create geometry. Each Gaussian will be rendered as a quad. let totalBufferLength = 0; - for (const buffer of Object.values(groupBufferFromName)) { + const groupBufferFromNameFiltered = Object.fromEntries( + Object.entries(groupBufferFromName).filter(([key]) => !key.startsWith("sh_buffer_")) + ); + // groupBufferFromNameFiltered only contains buffers for the gaussians (not sh buffers) + for (const buffer of Object.values(groupBufferFromNameFiltered)) { totalBufferLength += buffer.length; } const numGaussians = totalBufferLength / 8; + // console.log(numGaussians) # this is correct! it logged 447703 const gaussianBuffer = new Uint32Array(totalBufferLength); const groupIndices = new Uint32Array(numGaussians); let offset = 0; for (const [groupIndex, groupBuffer] of Object.values( - groupBufferFromName, + groupBufferFromNameFiltered, ).entries()) { groupIndices.fill( groupIndex, @@ -476,12 +596,28 @@ function mergeGaussianGroups(groupBufferFromName: { offset += groupBuffer.length; } - const numGroups = Object.keys(groupBufferFromName).length; - return { numGaussians, gaussianBuffer, numGroups, groupIndices }; + // Consolidate the spherical harmonic coefficient buffers + let totalSHBufferLength = 0; + const shGaussianBuffers = Object.fromEntries( + Object.entries(groupBufferFromName).filter(([key]) => key.startsWith("sh_buffer_")) + ); + for (const sh_buffer of Object.values(shGaussianBuffers)) { + totalSHBufferLength += sh_buffer.length; + } + const combinedSHBuffer = new Uint32Array(totalSHBufferLength); + let sh_offset = 0 + for (const sh_buffer of Object.values(shGaussianBuffers)) { + combinedSHBuffer.set(sh_buffer, sh_offset) + sh_offset += sh_buffer.length; + } + console.log(combinedSHBuffer) + + const numGroups = Object.keys(groupBufferFromNameFiltered).length; + return { numGaussians, gaussianBuffer, numGroups, groupIndices, combinedSHBuffer}; } /**Hook to generate properties for rendering Gaussians via a three.js mesh.*/ -function useGaussianMeshProps(gaussianBuffer: Uint32Array, numGroups: number) { +function useGaussianMeshProps(gaussianBuffer: Uint32Array, combinedSHBuffer: Uint32Array, numGroups: number) { const numGaussians = gaussianBuffer.length / 8; const maxTextureSize = useThree((state) => state.gl).capabilities .maxTextureSize; @@ -522,6 +658,10 @@ function useGaussianMeshProps(gaussianBuffer: Uint32Array, numGroups: number) { ); textureBuffer.internalFormat = "RGBA32UI"; textureBuffer.needsUpdate = true; + console.log("textureWidth ", textureWidth); // + console.log("textureHeight ", textureHeight); // + console.log("bufferPadded ", bufferPadded); + console.log("textureBuffer ", textureBuffer); const rowMajorT_camera_groups = new Float32Array(numGroups * 12); const textureT_camera_groups = new THREE.DataTexture( @@ -534,18 +674,40 @@ function useGaussianMeshProps(gaussianBuffer: Uint32Array, numGroups: number) { textureT_camera_groups.internalFormat = "RGBA32F"; textureT_camera_groups.needsUpdate = true; + // Spherical Harmonics Texture Buffer + const shTextureWidth = Math.min(numGaussians * 6, maxTextureSize); + const shTextureHeight = Math.ceil((numGaussians * 6) / shTextureWidth); + const shBufferPadded = new Uint32Array(shTextureWidth * shTextureHeight * 4); + shBufferPadded.set(combinedSHBuffer); + const shTextureBuffer = new THREE.DataTexture( + shBufferPadded, + shTextureWidth, + shTextureHeight, + THREE.RGBAIntegerFormat, + THREE.UnsignedIntType, + ); + shTextureBuffer.internalFormat = "RGBA32UI"; + shTextureBuffer.needsUpdate = true; + console.log("numGaussians ", numGaussians); // 1 + console.log("shTextureWidth ", shTextureWidth); // 6 + console.log("shTextureHeight ", shTextureHeight); // 1 + console.log("shBufferPadded ", shBufferPadded); + console.log("shTextureBuffer ", shTextureBuffer); + const material = new GaussianSplatMaterial({ // @ts-ignore textureBuffer: textureBuffer, + shTextureBuffer: shTextureBuffer, textureT_camera_groups: textureT_camera_groups, numGaussians: 0, transitionInState: 0.0, }); - + // console.log(gaussianBuffer) // long list of ints return { geometry, material, textureBuffer, + shTextureBuffer, sortedIndexAttribute, textureT_camera_groups, rowMajorT_camera_groups, diff --git a/src/viser/client/src/Splatting/spherical_harmonics_testing.ipynb b/src/viser/client/src/Splatting/spherical_harmonics_testing.ipynb new file mode 100644 index 000000000..067687b8a --- /dev/null +++ b/src/viser/client/src/Splatting/spherical_harmonics_testing.ipynb @@ -0,0 +1,355 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright 2021 The PlenOctree Authors.\n", + "# Redistribution and use in source and binary forms, with or without\n", + "# modification, are permitted provided that the following conditions are met:\n", + "#\n", + "# 1. Redistributions of source code must retain the above copyright notice,\n", + "# this list of conditions and the following disclaimer.\n", + "#\n", + "# 2. Redistributions in binary form must reproduce the above copyright notice,\n", + "# this list of conditions and the following disclaimer in the documentation\n", + "# and/or other materials provided with the distribution.\n", + "#\n", + "# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n", + "# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n", + "# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\n", + "# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE\n", + "# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n", + "# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\n", + "# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS\n", + "# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n", + "# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\n", + "# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE\n", + "# POSSIBILITY OF SUCH DAMAGE.\n", + "\n", + "import torch\n", + "\n", + "C0 = 0.28209479177387814\n", + "C1 = 0.4886025119029199\n", + "C2 = [\n", + " 1.0925484305920792,\n", + " -1.0925484305920792,\n", + " 0.31539156525252005,\n", + " -1.0925484305920792,\n", + " 0.5462742152960396\n", + "]\n", + "C3 = [\n", + " -0.5900435899266435,\n", + " 2.890611442640554,\n", + " -0.4570457994644658,\n", + " 0.3731763325901154,\n", + " -0.4570457994644658,\n", + " 1.445305721320277,\n", + " -0.5900435899266435\n", + "]\n", + "C4 = [\n", + " 2.5033429417967046,\n", + " -1.7701307697799304,\n", + " 0.9461746957575601,\n", + " -0.6690465435572892,\n", + " 0.10578554691520431,\n", + " -0.6690465435572892,\n", + " 0.47308734787878004,\n", + " -1.7701307697799304,\n", + " 0.6258357354491761,\n", + "] \n", + "\n", + "\n", + "def eval_sh(deg, sh, dirs):\n", + " \"\"\"\n", + " Evaluate spherical harmonics at unit directions\n", + " using hardcoded SH polynomials.\n", + " Works with torch/np/jnp.\n", + " ... Can be 0 or more batch dimensions.\n", + " Args:\n", + " deg: int SH deg. Currently, 0-3 supported\n", + " sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]\n", + " dirs: jnp.ndarray unit directions [..., 3]\n", + " Returns:\n", + " [..., C]\n", + " \"\"\"\n", + " assert deg <= 4 and deg >= 0\n", + " coeff = (deg + 1) ** 2\n", + " assert sh.shape[-1] >= coeff\n", + "\n", + " result = C0 * sh[..., 0]\n", + " if deg > 0:\n", + " x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]\n", + " result = (result -\n", + " C1 * y * sh[..., 1] +\n", + " C1 * z * sh[..., 2] -\n", + " C1 * x * sh[..., 3])\n", + "\n", + " if deg > 1:\n", + " xx, yy, zz = x * x, y * y, z * z\n", + " xy, yz, xz = x * y, y * z, x * z\n", + " result = (result +\n", + " C2[0] * xy * sh[..., 4] +\n", + " C2[1] * yz * sh[..., 5] +\n", + " C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +\n", + " C2[3] * xz * sh[..., 7] +\n", + " C2[4] * (xx - yy) * sh[..., 8])\n", + "\n", + " if deg > 2:\n", + " result = (result +\n", + " C3[0] * y * (3 * xx - yy) * sh[..., 9] +\n", + " C3[1] * xy * z * sh[..., 10] +\n", + " C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +\n", + " C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +\n", + " C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +\n", + " C3[5] * z * (xx - yy) * sh[..., 14] +\n", + " C3[6] * x * (xx - 3 * yy) * sh[..., 15])\n", + "\n", + " if deg > 3:\n", + " result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +\n", + " C4[1] * yz * (3 * xx - yy) * sh[..., 17] +\n", + " C4[2] * xy * (7 * zz - 1) * sh[..., 18] +\n", + " C4[3] * yz * (7 * zz - 3) * sh[..., 19] +\n", + " C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +\n", + " C4[5] * xz * (7 * zz - 3) * sh[..., 21] +\n", + " C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +\n", + " C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +\n", + " C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])\n", + " return result\n", + "\n", + "def RGB2SH(rgb):\n", + " return (rgb - 0.5) / C0\n", + "\n", + "def SH2RGB(sh):\n", + " return sh * C0 + 0.5" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0.34397008 0.1287549 0.09941738]\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "deg = 3\n", + "sh = np.random.rand(3, 16)\n", + "dirs = np.array([1, 0, 0])\n", + "print(eval_sh(deg, sh, dirs))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import struct\n", + "import numpy as np\n", + "\n", + "def unpackHalf2x16(value):\n", + " \"\"\"The first float is the least significant 16bits, the second is the most significant 16bits.\"\"\"\n", + " # Convert int32 to its binary representation\n", + " binary = format(value, '032b')\n", + " \n", + " # Split the binary string into two 16-bit parts\n", + " binary1 = binary[:16]\n", + " binary2 = binary[16:]\n", + " \n", + " # Convert each 16-bit binary string to an integer\n", + " int1 = int(binary1, 2)\n", + " int2 = int(binary2, 2)\n", + " \n", + " # Use numpy to convert uint16 to float16\n", + " float1 = np.frombuffer(struct.pack('H', int1), dtype=np.float16)[0]\n", + " float2 = np.frombuffer(struct.pack('H', int2), dtype=np.float16)[0]\n", + " \n", + " return float1, float2\n", + "\n", + "def int32_to_rgba(value):\n", + " # Ensure the input is a 32-bit integer\n", + " value = int(value) & 0xFFFFFFFF\n", + " \n", + " # Extract each 8-bit piece\n", + " r = (value >> 24) & 0xFF\n", + " g = (value >> 16) & 0xFF\n", + " b = (value >> 8) & 0xFF\n", + " a = value & 0xFF\n", + " \n", + " return r / 255.0, g / 255.0, b / 255.0, a / 255.0" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(0.803, 0.8213)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "unpackHalf2x16(980236946)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(-0.00319, 0.786)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "unpackHalf2x16(2592619082)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(0.3686274509803922,\n", + " 0.7215686274509804,\n", + " 0.7254901960784313,\n", + " 0.7294117647058823)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "int32_to_rgba(1589164474)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ae6ae4b686e9423895d3a8ac6156387c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(FloatSlider(value=0.8212358, description='Red:', max=1.0, step=0.01), FloatSlider(value=0.80300…" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import ipywidgets as widgets\n", + "from IPython.display import display, HTML\n", + "import matplotlib.colors as mcolors\n", + "\n", + "def color_picker(r=0.5, g=0.5, b=0.5):\n", + " def update_color(r, g, b):\n", + " color = mcolors.to_hex([r, g, b])\n", + " color_display.value = f'
'\n", + " rgb_display.value = f'RGB: ({r:.2f}, {g:.2f}, {b:.2f})'\n", + " hex_display.value = f'Hex: {color}'\n", + "\n", + " r_slider = widgets.FloatSlider(value=r, min=0, max=1, step=0.01, description='Red:')\n", + " g_slider = widgets.FloatSlider(value=g, min=0, max=1, step=0.01, description='Green:')\n", + " b_slider = widgets.FloatSlider(value=b, min=0, max=1, step=0.01, description='Blue:')\n", + "\n", + " color_display = widgets.HTML()\n", + " rgb_display = widgets.Label()\n", + " hex_display = widgets.Label()\n", + "\n", + " widgets.interactive(update_color, r=r_slider, g=g_slider, b=b_slider)\n", + "\n", + " display(widgets.VBox([r_slider, g_slider, b_slider, color_display, rgb_display, hex_display]))\n", + "\n", + "# Usage\n", + "color_picker(0.8212358, 0.8030037, 0.78623223) # You can change these initial values" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "df534cd29685413f8ee9b4bd3d94794d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(FloatSlider(value=0.3686274509803922, description='Red:', max=1.0, step=0.01), FloatSlider(valu…" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "color_picker(0.3686274509803922, 0.7215686274509804, 0.7254901960784313)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "nerfstudio", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.19" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/viser/client/src/WebsocketMessages.tsx b/src/viser/client/src/WebsocketMessages.tsx index a09df19e4..264fcbf02 100644 --- a/src/viser/client/src/WebsocketMessages.tsx +++ b/src/viser/client/src/WebsocketMessages.tsx @@ -851,6 +851,7 @@ export interface GaussianSplatsMessage { type: "GaussianSplatsMessage"; name: string; buffer: Uint8Array; + sh_buffer: Uint8Array; } /** Message from server->client requesting a render of the current viewport. *