-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathextract_data.py
More file actions
executable file
·96 lines (76 loc) · 3.02 KB
/
extract_data.py
File metadata and controls
executable file
·96 lines (76 loc) · 3.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import gc
import joblib
import torch
import numpy as np
from pathlib import Path
from tqdm import tqdm
from typing import Dict, Any, Optional
from decompress import PngCompression
class SceneProcessor:
def __init__(self, input_dir: str, output_dir: str, label: int = 0):
self.input_dir = Path(input_dir)
self.output_dir = Path(output_dir)
self.label = label
self.decompressor = PngCompression()
self.output_dir.mkdir(parents=True, exist_ok=True)
@staticmethod
def crop_splats(data: Dict[str, np.ndarray], n_crop: int) -> Dict[str, np.ndarray]:
opacities = data["opacity"].squeeze()
keep_idx = np.argsort(opacities)[n_crop:]
return {k: v[keep_idx] for k, v in data.items() if isinstance(v, np.ndarray)}
@staticmethod
def process_splats(splats: Dict[str, torch.Tensor]) -> Dict[str, np.ndarray]:
means = splats["means"]
scales = torch.exp(splats["scales"])
quats = splats["quats"] / splats["quats"].norm(dim=-1, keepdim=True)
opacities = torch.sigmoid(splats["opacities"])
sh0 = splats["sh0"]
shN = splats["shN"]
sh_feat = torch.cat([sh0, shN], dim=1).view(-1, 16 * 3)
return {
"means": means.cpu().numpy(),
"scales": scales.cpu().numpy(),
"quats": quats.cpu().numpy(),
"opacity": opacities.cpu().numpy(),
"sh": sh_feat.cpu().numpy(),
}
def load_scene(self, scene_path: Path, crop_to: Optional[int] = None) -> None:
try:
splats = self.decompressor.decompress(str(scene_path))
data = self.process_splats(splats)
if crop_to is not None and data["means"].shape[0] > crop_to:
n_crop = data["means"].shape[0] - crop_to
data = self.crop_splats(data, n_crop)
data.update({
"label": self.label,
"name": scene_path.name
})
save_path = self.output_dir / f"{scene_path.name}.pkl"
joblib.dump(data, save_path)
except Exception as e:
print(f"[ERROR] {scene_path}: {e}")
finally:
del splats
gc.collect()
torch.cuda.empty_cache()
def process_all(self, start_idx: int = 0, crop_to: Optional[int] = None):
scenes = sorted(self.input_dir.iterdir())[start_idx:]
for scene_path in tqdm(scenes, desc="Processing scenes"):
self.load_scene(scene_path, crop_to=crop_to)
if __name__ == "__main__":
CONFIG = {
"input_dir": "/PATH/TO/original_compressed", #"/PATH/TO/fake_compressed"
"output_dir": "./gaussian_pickles/real", #"./gaussian_pickles/fake"
"label": 0, #0 for Real, 1 for Fake
"crop_to": None, # es: 65536
}
processor = SceneProcessor(
input_dir=CONFIG["input_dir"],
output_dir=CONFIG["output_dir"],
label=CONFIG["label"]
)
processor.process_all(
start_idx=CONFIG["start_index"],
crop_to=CONFIG["crop_to"]
)