Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 172 additions & 4 deletions 2_run_hamer_on_vrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,34 @@
from egoallo.inference_utils import InferenceTrajectoryPaths


def main(traj_root: Path, overwrite: bool = False) -> None:
def main(traj_root: Path, detector: str = "hamer",overwrite: bool = False, wilor_home: str = "/secondary/home/annie/code/WiLoR") -> None:
"""Run HaMeR for on trajectory. We'll save outputs to
`traj_root/hamer_outputs.pkl` and `traj_root/hamer_outputs_render".

Arguments:
traj_root: The root directory of the trajectory. We assume that there's
a VRS file in this directory.
detector: The detector to use. Can be "wilor" or "hamer".
overwrite: If True, overwrite any existing HaMeR outputs.
wilor_home: The path to the WiLoR home directory. Only used if `detector`
is "wilor".
"""

paths = InferenceTrajectoryPaths.find(traj_root)

vrs_path = paths.vrs_file
assert vrs_path.exists()
pickle_out = traj_root / "hamer_outputs.pkl"
hamer_render_out = traj_root / "hamer_outputs_render" # This is just for debugging.
run_hamer_and_save(vrs_path, pickle_out, hamer_render_out, overwrite)

if detector == "wilor":
pickle_out = traj_root / "wilor_outputs.pkl"
render_out_path = traj_root / "wilor_outputs_unrot_render" # This is just for debugging.
run_wilor_and_save(vrs_path, pickle_out, render_out_path, overwrite, wilor_home)
elif detector == "hamer":
pickle_out = traj_root / "hamer_outputs.pkl"
render_out_path = traj_root / "hamer_outputs_render" # This is just for debugging.
run_hamer_and_save(vrs_path, pickle_out, render_out_path, overwrite)
else:
raise ValueError(f"Unknown detector: {detector}")


def run_hamer_and_save(
Expand Down Expand Up @@ -199,6 +210,163 @@ def run_hamer_and_save(
with open(pickle_out, "wb") as f:
pickle.dump(outputs, f)

def run_wilor_and_save(
vrs_path: Path, pickle_out: Path, render_out_path: Path, overwrite: bool, wilor_home: str
) -> None:
from _wilor_helper import WiLoRHelper
if not overwrite:
assert not pickle_out.exists()
assert not render_out_path.exists()
else:
pickle_out.unlink(missing_ok=True)
shutil.rmtree(render_out_path, ignore_errors=True)
render_out_path.mkdir(exist_ok=True)
wilor_helper = WiLoRHelper(wilor_home)
hamer_helper = HamerHelper(other_detector=True)

# VRS data provider setup.
provider = create_vrs_data_provider(str(vrs_path.absolute()))
assert isinstance(provider, VrsDataProvider)
rgb_stream_id = provider.get_stream_id_from_label("camera-rgb")
assert rgb_stream_id is not None

num_images = provider.get_num_data(rgb_stream_id)
print(f"Found {num_images=}")

# Get calibrations.
device_calib = provider.get_device_calibration()
assert device_calib is not None
camera_calib = device_calib.get_camera_calib("camera-rgb")
assert camera_calib is not None
pinhole = calibration.get_linear_camera_calibration(1408, 1408, 450)

# Compute camera extrinsics!
sophus_T_device_camera = device_calib.get_transform_device_sensor("camera-rgb")
sophus_T_cpf_camera = device_calib.get_transform_cpf_sensor("camera-rgb")
assert sophus_T_device_camera is not None
assert sophus_T_cpf_camera is not None
T_device_cam = np.concatenate(
[
sophus_T_device_camera.rotation().to_quat().squeeze(axis=0),
sophus_T_device_camera.translation().squeeze(axis=0),
]
)
T_cpf_cam = np.concatenate(
[
sophus_T_cpf_camera.rotation().to_quat().squeeze(axis=0),
sophus_T_cpf_camera.translation().squeeze(axis=0),
]
)
assert T_device_cam.shape == T_cpf_cam.shape == (7,)

# Dict from capture timestamp in nanoseconds to fields we care about.
detections_left_wrt_cam: dict[int, SingleHandHamerOutputWrtCamera | None] = {}
detections_right_wrt_cam: dict[int, SingleHandHamerOutputWrtCamera | None] = {}

pbar = tqdm(range(num_images))
for i in pbar:
image_data, image_data_record = provider.get_image_data_by_index(
rgb_stream_id, i
)
undistorted_image = calibration.distort_by_calibration(
image_data.to_numpy_array(), pinhole, camera_calib
)

hamer_out_left, hamer_out_right = wilor_helper.look_for_hands(
undistorted_image,
focal_length=450,
)
timestamp_ns = image_data_record.capture_timestamp_ns

if hamer_out_left is None:
detections_left_wrt_cam[timestamp_ns] = None
else:
detections_left_wrt_cam[timestamp_ns] = {
"verts": hamer_out_left["verts"],
"keypoints_3d": hamer_out_left["keypoints_3d"],
"mano_hand_pose": hamer_out_left["mano_hand_pose"],
"mano_hand_betas": hamer_out_left["mano_hand_betas"],
"mano_hand_global_orient": hamer_out_left["mano_hand_global_orient"],
}

if hamer_out_right is None:
detections_right_wrt_cam[timestamp_ns] = None
else:
detections_right_wrt_cam[timestamp_ns] = {
"verts": hamer_out_right["verts"],
"keypoints_3d": hamer_out_right["keypoints_3d"],
"mano_hand_pose": hamer_out_right["mano_hand_pose"],
"mano_hand_betas": hamer_out_right["mano_hand_betas"],
"mano_hand_global_orient": hamer_out_right["mano_hand_global_orient"],
}

composited = undistorted_image
composited = hamer_helper.composite_detections(
composited,
hamer_out_left,
border_color=(255, 100, 100),
focal_length=450,
)
composited = hamer_helper.composite_detections(
composited,
hamer_out_right,
border_color=(100, 100, 255),
focal_length=450,
)
composited = put_text(
composited,
"L detections: "
+ (
"0" if hamer_out_left is None else str(hamer_out_left["verts"].shape[0])
),
0,
color=(255, 100, 100),
font_scale=10.0 / 2880.0 * undistorted_image.shape[0],
)
composited = put_text(
composited,
"R detections: "
+ (
"0"
if hamer_out_right is None
else str(hamer_out_right["verts"].shape[0])
),
1,
color=(100, 100, 255),
font_scale=10.0 / 2880.0 * undistorted_image.shape[0],
)
composited = put_text(
composited,
f"ns={timestamp_ns}",
2,
color=(255, 255, 255),
font_scale=10.0 / 2880.0 * undistorted_image.shape[0],
)

print(f"Saving image {i:06d} to {render_out_path / f'{i:06d}.jpeg'}")
iio.imwrite(
str(render_out_path / f"{i:06d}.jpeg"),
np.concatenate(
[
# Darken input image, just for contrast...
(undistorted_image * 0.6).astype(np.uint8),
composited,
],
axis=1,
),
quality=90,
)

outputs = SavedHamerOutputs(
mano_faces_right=wilor_helper.get_mano_faces("right"),
mano_faces_left=wilor_helper.get_mano_faces("left"),
detections_right_wrt_cam=detections_right_wrt_cam,
detections_left_wrt_cam=detections_left_wrt_cam,
T_device_cam=T_device_cam,
T_cpf_cam=T_cpf_cam,
)
with open(pickle_out, "wb") as f:
pickle.dump(outputs, f)

def put_text(
image: np.ndarray,
Expand Down
12 changes: 9 additions & 3 deletions 3_aria_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class Args:
...
...
"""
detector: str = "hamer"
"""for choosing pkl file. choose from ["hamer", "wilor"]"""
checkpoint_dir: Path = Path("./egoallo_checkpoint_april13/checkpoints_3000000/")
smplh_npz_path: Path = Path("./data/smplh/neutral/model.npz")

Expand Down Expand Up @@ -66,7 +68,7 @@ class Args:
def main(args: Args) -> None:
device = torch.device("cuda")

traj_paths = InferenceTrajectoryPaths.find(args.traj_root)
traj_paths = InferenceTrajectoryPaths.find(args.traj_root, detector=args.detector)
if traj_paths.splat_path is not None:
print("Found splat at", traj_paths.splat_path)
else:
Expand Down Expand Up @@ -151,10 +153,14 @@ def main(args: Args) -> None:
time.strftime("%Y%m%d-%H%M%S")
+ f"_{args.start_index}-{args.start_index + args.traj_length}"
)
out_path = args.traj_root / "egoallo_outputs" / (save_name + ".npz")
if args.detector == "hamer":
egoallo_dir = "egoallo_outputs"
elif args.detector == "wilor":
egoallo_dir = "wilor_egoallo_outputs"
out_path = args.traj_root / egoallo_dir / (save_name + ".npz")
out_path.parent.mkdir(parents=True, exist_ok=True)
assert not out_path.exists()
(args.traj_root / "egoallo_outputs" / (save_name + "_args.yaml")).write_text(
(args.traj_root / egoallo_dir / (save_name + "_args.yaml")).write_text(
yaml.dump(dataclasses.asdict(args))
)

Expand Down
11 changes: 9 additions & 2 deletions 4_visualize_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

def main(
search_root_dir: Path,
detector: str = "hamer",
smplh_npz_path: Path = Path("./data/smplh/neutral/model.npz"),
) -> None:
"""Visualization script for outputs from EgoAllo.
Expand All @@ -48,9 +49,13 @@ def main(
server.gui.configure_theme(dark_mode=True)

def get_file_list():
if detector == "hamer":
egoallo_outputs_dir = search_root_dir.glob("**/egoallo_outputs/*.npz")
elif detector == "wilor":
egoallo_outputs_dir = search_root_dir.glob("**/wilor_egoallo_outputs/*.npz")
return ["None"] + sorted(
str(p.relative_to(search_root_dir))
for p in search_root_dir.glob("**/egoallo_outputs/*.npz")
for p in egoallo_outputs_dir
)

options = get_file_list()
Expand Down Expand Up @@ -86,6 +91,7 @@ def _(_) -> None:
loop_cb = load_and_visualize(
server,
npz_path,
detector,
body_model,
device=device,
)
Expand All @@ -100,6 +106,7 @@ def _(_) -> None:
def load_and_visualize(
server: viser.ViserServer,
npz_path: Path,
detector: str,
body_model: fncsmpl.SmplhModel,
device: torch.device,
) -> Callable[[], int]:
Expand Down Expand Up @@ -137,7 +144,7 @@ def load_and_visualize(
# - outputs
# - the npz file
traj_dir = npz_path.resolve().parent.parent
paths = InferenceTrajectoryPaths.find(traj_dir)
paths = InferenceTrajectoryPaths.find(traj_dir,detector=detector)

provider = create_vrs_data_provider(str(paths.vrs_file))
device_calib = provider.get_device_calibration()
Expand Down
Loading