-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsort.py
More file actions
executable file
·46 lines (38 loc) · 1.58 KB
/
sort.py
File metadata and controls
executable file
·46 lines (38 loc) · 1.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from typing import Dict
import torch
from torch import Tensor
def sort_splats(splats: Dict[str, Tensor], verbose: bool = True) -> Dict[str, Tensor]:
"""Sort splats with Parallel Linear Assignment Sorting from the paper `Compact 3D Scene Representation via
Self-Organizing Gaussian Grids <https://arxiv.org/pdf/2312.13299>`_.
.. warning::
PLAS must installed to use sorting.
Args:
splats (Dict[str, Tensor]): splats
verbose (bool, optional): Whether to print verbose information. Default to True.
Returns:
Dict[str, Tensor]: sorted splats
"""
try:
from plas import sort_with_plas
except:
raise ImportError(
"Please install PLAS with 'pip install git+https://github.com/fraunhoferhhi/PLAS.git' to use sorting"
)
n_gs = len(splats["means"])
n_sidelen = int(n_gs**0.5)
assert n_sidelen**2 == n_gs, "Must be a perfect square"
sort_keys = [k for k in splats if k != "shN"]
params_to_sort = torch.cat([splats[k].reshape(n_gs, -1) for k in sort_keys], dim=-1)
shuffled_indices = torch.randperm(
params_to_sort.shape[0], device=params_to_sort.device
)
params_to_sort = params_to_sort[shuffled_indices]
grid = params_to_sort.reshape((n_sidelen, n_sidelen, -1))
_, sorted_indices = sort_with_plas(
grid.permute(2, 0, 1), improvement_break=1e-4, verbose=verbose
)
sorted_indices = sorted_indices.squeeze().flatten()
sorted_indices = shuffled_indices[sorted_indices]
for k, v in splats.items():
splats[k] = v[sorted_indices]
return splats