diff --git a/2_run_hamer_on_vrs.py b/2_run_hamer_on_vrs.py index fc0b2e3..f03e690 100644 --- a/2_run_hamer_on_vrs.py +++ b/2_run_hamer_on_vrs.py @@ -23,23 +23,34 @@ from egoallo.inference_utils import InferenceTrajectoryPaths -def main(traj_root: Path, overwrite: bool = False) -> None: +def main(traj_root: Path, detector: str = "hamer",overwrite: bool = False, wilor_home: str = "/secondary/home/annie/code/WiLoR") -> None: """Run HaMeR for on trajectory. We'll save outputs to `traj_root/hamer_outputs.pkl` and `traj_root/hamer_outputs_render". Arguments: traj_root: The root directory of the trajectory. We assume that there's a VRS file in this directory. + detector: The detector to use. Can be "wilor" or "hamer". overwrite: If True, overwrite any existing HaMeR outputs. + wilor_home: The path to the WiLoR home directory. Only used if `detector` + is "wilor". """ paths = InferenceTrajectoryPaths.find(traj_root) vrs_path = paths.vrs_file assert vrs_path.exists() - pickle_out = traj_root / "hamer_outputs.pkl" - hamer_render_out = traj_root / "hamer_outputs_render" # This is just for debugging. - run_hamer_and_save(vrs_path, pickle_out, hamer_render_out, overwrite) + + if detector == "wilor": + pickle_out = traj_root / "wilor_outputs.pkl" + render_out_path = traj_root / "wilor_outputs_unrot_render" # This is just for debugging. + run_wilor_and_save(vrs_path, pickle_out, render_out_path, overwrite, wilor_home) + elif detector == "hamer": + pickle_out = traj_root / "hamer_outputs.pkl" + render_out_path = traj_root / "hamer_outputs_render" # This is just for debugging. + run_hamer_and_save(vrs_path, pickle_out, render_out_path, overwrite) + else: + raise ValueError(f"Unknown detector: {detector}") def run_hamer_and_save( @@ -199,6 +210,163 @@ def run_hamer_and_save( with open(pickle_out, "wb") as f: pickle.dump(outputs, f) +def run_wilor_and_save( + vrs_path: Path, pickle_out: Path, render_out_path: Path, overwrite: bool, wilor_home: str +) -> None: + from _wilor_helper import WiLoRHelper + if not overwrite: + assert not pickle_out.exists() + assert not render_out_path.exists() + else: + pickle_out.unlink(missing_ok=True) + shutil.rmtree(render_out_path, ignore_errors=True) + render_out_path.mkdir(exist_ok=True) + wilor_helper = WiLoRHelper(wilor_home) + hamer_helper = HamerHelper(other_detector=True) + + # VRS data provider setup. + provider = create_vrs_data_provider(str(vrs_path.absolute())) + assert isinstance(provider, VrsDataProvider) + rgb_stream_id = provider.get_stream_id_from_label("camera-rgb") + assert rgb_stream_id is not None + + num_images = provider.get_num_data(rgb_stream_id) + print(f"Found {num_images=}") + + # Get calibrations. + device_calib = provider.get_device_calibration() + assert device_calib is not None + camera_calib = device_calib.get_camera_calib("camera-rgb") + assert camera_calib is not None + pinhole = calibration.get_linear_camera_calibration(1408, 1408, 450) + + # Compute camera extrinsics! + sophus_T_device_camera = device_calib.get_transform_device_sensor("camera-rgb") + sophus_T_cpf_camera = device_calib.get_transform_cpf_sensor("camera-rgb") + assert sophus_T_device_camera is not None + assert sophus_T_cpf_camera is not None + T_device_cam = np.concatenate( + [ + sophus_T_device_camera.rotation().to_quat().squeeze(axis=0), + sophus_T_device_camera.translation().squeeze(axis=0), + ] + ) + T_cpf_cam = np.concatenate( + [ + sophus_T_cpf_camera.rotation().to_quat().squeeze(axis=0), + sophus_T_cpf_camera.translation().squeeze(axis=0), + ] + ) + assert T_device_cam.shape == T_cpf_cam.shape == (7,) + + # Dict from capture timestamp in nanoseconds to fields we care about. + detections_left_wrt_cam: dict[int, SingleHandHamerOutputWrtCamera | None] = {} + detections_right_wrt_cam: dict[int, SingleHandHamerOutputWrtCamera | None] = {} + + pbar = tqdm(range(num_images)) + for i in pbar: + image_data, image_data_record = provider.get_image_data_by_index( + rgb_stream_id, i + ) + undistorted_image = calibration.distort_by_calibration( + image_data.to_numpy_array(), pinhole, camera_calib + ) + + hamer_out_left, hamer_out_right = wilor_helper.look_for_hands( + undistorted_image, + focal_length=450, + ) + timestamp_ns = image_data_record.capture_timestamp_ns + + if hamer_out_left is None: + detections_left_wrt_cam[timestamp_ns] = None + else: + detections_left_wrt_cam[timestamp_ns] = { + "verts": hamer_out_left["verts"], + "keypoints_3d": hamer_out_left["keypoints_3d"], + "mano_hand_pose": hamer_out_left["mano_hand_pose"], + "mano_hand_betas": hamer_out_left["mano_hand_betas"], + "mano_hand_global_orient": hamer_out_left["mano_hand_global_orient"], + } + + if hamer_out_right is None: + detections_right_wrt_cam[timestamp_ns] = None + else: + detections_right_wrt_cam[timestamp_ns] = { + "verts": hamer_out_right["verts"], + "keypoints_3d": hamer_out_right["keypoints_3d"], + "mano_hand_pose": hamer_out_right["mano_hand_pose"], + "mano_hand_betas": hamer_out_right["mano_hand_betas"], + "mano_hand_global_orient": hamer_out_right["mano_hand_global_orient"], + } + + composited = undistorted_image + composited = hamer_helper.composite_detections( + composited, + hamer_out_left, + border_color=(255, 100, 100), + focal_length=450, + ) + composited = hamer_helper.composite_detections( + composited, + hamer_out_right, + border_color=(100, 100, 255), + focal_length=450, + ) + composited = put_text( + composited, + "L detections: " + + ( + "0" if hamer_out_left is None else str(hamer_out_left["verts"].shape[0]) + ), + 0, + color=(255, 100, 100), + font_scale=10.0 / 2880.0 * undistorted_image.shape[0], + ) + composited = put_text( + composited, + "R detections: " + + ( + "0" + if hamer_out_right is None + else str(hamer_out_right["verts"].shape[0]) + ), + 1, + color=(100, 100, 255), + font_scale=10.0 / 2880.0 * undistorted_image.shape[0], + ) + composited = put_text( + composited, + f"ns={timestamp_ns}", + 2, + color=(255, 255, 255), + font_scale=10.0 / 2880.0 * undistorted_image.shape[0], + ) + + print(f"Saving image {i:06d} to {render_out_path / f'{i:06d}.jpeg'}") + iio.imwrite( + str(render_out_path / f"{i:06d}.jpeg"), + np.concatenate( + [ + # Darken input image, just for contrast... + (undistorted_image * 0.6).astype(np.uint8), + composited, + ], + axis=1, + ), + quality=90, + ) + + outputs = SavedHamerOutputs( + mano_faces_right=wilor_helper.get_mano_faces("right"), + mano_faces_left=wilor_helper.get_mano_faces("left"), + detections_right_wrt_cam=detections_right_wrt_cam, + detections_left_wrt_cam=detections_left_wrt_cam, + T_device_cam=T_device_cam, + T_cpf_cam=T_cpf_cam, + ) + with open(pickle_out, "wb") as f: + pickle.dump(outputs, f) def put_text( image: np.ndarray, diff --git a/3_aria_inference.py b/3_aria_inference.py index c489011..578314e 100644 --- a/3_aria_inference.py +++ b/3_aria_inference.py @@ -38,6 +38,8 @@ class Args: ... ... """ + detector: str = "hamer" + """for choosing pkl file. choose from ["hamer", "wilor"]""" checkpoint_dir: Path = Path("./egoallo_checkpoint_april13/checkpoints_3000000/") smplh_npz_path: Path = Path("./data/smplh/neutral/model.npz") @@ -66,7 +68,7 @@ class Args: def main(args: Args) -> None: device = torch.device("cuda") - traj_paths = InferenceTrajectoryPaths.find(args.traj_root) + traj_paths = InferenceTrajectoryPaths.find(args.traj_root, detector=args.detector) if traj_paths.splat_path is not None: print("Found splat at", traj_paths.splat_path) else: @@ -151,10 +153,14 @@ def main(args: Args) -> None: time.strftime("%Y%m%d-%H%M%S") + f"_{args.start_index}-{args.start_index + args.traj_length}" ) - out_path = args.traj_root / "egoallo_outputs" / (save_name + ".npz") + if args.detector == "hamer": + egoallo_dir = "egoallo_outputs" + elif args.detector == "wilor": + egoallo_dir = "wilor_egoallo_outputs" + out_path = args.traj_root / egoallo_dir / (save_name + ".npz") out_path.parent.mkdir(parents=True, exist_ok=True) assert not out_path.exists() - (args.traj_root / "egoallo_outputs" / (save_name + "_args.yaml")).write_text( + (args.traj_root / egoallo_dir / (save_name + "_args.yaml")).write_text( yaml.dump(dataclasses.asdict(args)) ) diff --git a/4_visualize_outputs.py b/4_visualize_outputs.py index 2480e8c..d76e4da 100644 --- a/4_visualize_outputs.py +++ b/4_visualize_outputs.py @@ -31,6 +31,7 @@ def main( search_root_dir: Path, + detector: str = "hamer", smplh_npz_path: Path = Path("./data/smplh/neutral/model.npz"), ) -> None: """Visualization script for outputs from EgoAllo. @@ -48,9 +49,13 @@ def main( server.gui.configure_theme(dark_mode=True) def get_file_list(): + if detector == "hamer": + egoallo_outputs_dir = search_root_dir.glob("**/egoallo_outputs/*.npz") + elif detector == "wilor": + egoallo_outputs_dir = search_root_dir.glob("**/wilor_egoallo_outputs/*.npz") return ["None"] + sorted( str(p.relative_to(search_root_dir)) - for p in search_root_dir.glob("**/egoallo_outputs/*.npz") + for p in egoallo_outputs_dir ) options = get_file_list() @@ -86,6 +91,7 @@ def _(_) -> None: loop_cb = load_and_visualize( server, npz_path, + detector, body_model, device=device, ) @@ -100,6 +106,7 @@ def _(_) -> None: def load_and_visualize( server: viser.ViserServer, npz_path: Path, + detector: str, body_model: fncsmpl.SmplhModel, device: torch.device, ) -> Callable[[], int]: @@ -137,7 +144,7 @@ def load_and_visualize( # - outputs # - the npz file traj_dir = npz_path.resolve().parent.parent - paths = InferenceTrajectoryPaths.find(traj_dir) + paths = InferenceTrajectoryPaths.find(traj_dir,detector=detector) provider = create_vrs_data_provider(str(paths.vrs_file)) device_calib = provider.get_device_calibration() diff --git a/_wilor_helper.py b/_wilor_helper.py new file mode 100644 index 0000000..f0dab6c --- /dev/null +++ b/_wilor_helper.py @@ -0,0 +1,359 @@ +from pathlib import Path +import torch +import argparse +import os +import cv2 +import numpy as np +from dataclasses import dataclass +from torch import Tensor +from typing import Literal, TypedDict +from ultralytics import YOLO +from jaxtyping import Float, Int + +import sys + +LIGHT_PURPLE=(0.25098039, 0.274117647, 0.65882353) + + +# same in _hamer_helper.py +class HandOutputsWrtCamera(TypedDict): + """Hand outputs with respect to the camera frame.""" + + verts: Float[np.ndarray, "num_hands 778 3"] + keypoints_3d: Float[np.ndarray, "num_hands 21 3"] + mano_hand_pose: Float[np.ndarray, "num_hands 15 3 3"] + mano_hand_betas: Float[np.ndarray, "num_hands 10"] + mano_hand_global_orient: Float[np.ndarray, "num_hands 1 3 3"] + faces: Int[np.ndarray, "mesh_faces 3"] + +# same in _hamer_helper.py +@dataclass(frozen=True) +class _RawHamerOutputs: + """A typed wrapper for outputs from HaMeR.""" + + # Comments here are what I got when printing out the shapes of different + # HaMeR outputs. + + # pred_cam torch.Size([1, 3]) + pred_cam: Float[Tensor, "num_hands 3"] + # pred_mano_params global_orient torch.Size([1, 1, 3, 3]) + pred_mano_global_orient: Float[Tensor, "num_hands 1 3 3"] + # pred_mano_params hand_pose torch.Size([1, 15, 3, 3]) + pred_mano_hand_pose: Float[Tensor, "num_hands 15 3 3"] + # pred_mano_params betas torch.Size([1, 10]) + pred_mano_hand_betas: Float[Tensor, "num_hands 10"] + # pred_cam_t torch.Size([1, 3]) + pred_cam_t: Float[Tensor, "num_hands 3"] + + # focal length from model is ignored + # focal_length torch.Size([1, 2]) + # focal_length: Float[Tensor, "num_hands 2"] + + # pred_keypoints_3d torch.Size([1, 21, 3]) + pred_keypoints_3d: Float[Tensor, "num_hands 21 3"] + # pred_vertices torch.Size([1, 778, 3]) + pred_vertices: Float[Tensor, "num_hands 778 3"] + # pred_keypoints_2d torch.Size([1, 21, 2]) + pred_keypoints_2d: Float[Tensor, "num_hands 21 3"] + + pred_right: Float[Tensor, "num_hands"] + """A given hand is a right hand if this value is >0.5.""" + + # These aren't technically HaMeR outputs, but putting them here for convenience. + mano_faces_right: Tensor + mano_faces_left: Tensor + +class WiLoRHelper: + def __init__(self, wilor_home: str="./"): + sys.path.append(wilor_home) + global load_wilor + global recursive_to + global ViTDetDataset + global Renderer + global cam_crop_to_full + + from wilor.models import load_wilor + from wilor.utils import recursive_to + from wilor.datasets.vitdet_dataset import ViTDetDataset + from wilor.utils.renderer import Renderer, cam_crop_to_full + checkpoint_path = os.path.join(wilor_home, 'pretrained_models/wilor_final.ckpt') + cfg_path = os.path.join(wilor_home, 'pretrained_models/model_config.yaml') + original_dir = os.getcwd() + os.chdir(wilor_home) + model, model_cfg = load_wilor(checkpoint_path = checkpoint_path , + cfg_path= cfg_path) + os.chdir(original_dir) + detector = YOLO(os.path.join(wilor_home, 'pretrained_models/detector.pt')) + # Setup the renderer + self.renderer = Renderer(model_cfg, faces=model.mano.faces) + + self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + self._model = model.to(self.device) + self._model_cfg = model_cfg + self._detector = detector.to(self.device) + + def look_for_hands( + self, + image: np.ndarray, + focal_length: float | None = None, + rescale_factor: float = 2.0, + ) -> tuple[HandOutputsWrtCamera | None, HandOutputsWrtCamera | None]: + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + detections = self._detector(image, conf = 0.3, verbose=False)[0] + # breakpoint() + bboxes = [] + is_right = [] + for det in detections: + Bbox = det.boxes.data.cpu().detach().squeeze().numpy() + is_right.append(det.boxes.cls.cpu().detach().squeeze().item()) + bboxes.append(Bbox[:4].tolist()) + + if len(bboxes) == 0: + return None, None + boxes = np.stack(bboxes) + right = np.stack(is_right) + dataset = ViTDetDataset(self._model_cfg, image, boxes, right, rescale_factor=rescale_factor) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=False, num_workers=0) + + outputs: list[_RawHamerOutputs] = [] + + for batch in dataloader: + batch = recursive_to(batch, self.device) + + with torch.no_grad(): + out = self._model(batch) + + multiplier = (2*batch['right']-1) + pred_cam = out['pred_cam'] + pred_cam[:,1] = multiplier*pred_cam[:,1] + box_center = batch["box_center"].float() + box_size = batch["box_size"].float() + img_size = batch["img_size"].float() + if focal_length is None: + # All of the img_size rows should be the same. I think. + focal_length = self._model_cfg.EXTRA.FOCAL_LENGTH / self._model_cfg.MODEL.IMAGE_SIZE * img_size.max() + focal_length = float(focal_length) + scaled_focal_length = focal_length + pred_cam_t_full = cam_crop_to_full(pred_cam, box_center, box_size, img_size, scaled_focal_length).detach().cpu().numpy() + pred_cam_t_full = torch.tensor(pred_cam_t_full, device=self.device) + hamer_out = _RawHamerOutputs( + mano_faces_left=torch.from_numpy( + self._model.mano.faces[:, [0, 2, 1]].astype(np.int64) + ).to(device=self.device), + mano_faces_right=torch.from_numpy( + self._model.mano.faces.astype(np.int64) + ).to(device=self.device), + pred_cam=out["pred_cam"], + pred_mano_global_orient=out["pred_mano_params"]["global_orient"], + pred_mano_hand_pose=out["pred_mano_params"]["hand_pose"], + pred_mano_hand_betas=out["pred_mano_params"]["betas"], + pred_cam_t=pred_cam_t_full, + pred_keypoints_3d=out["pred_keypoints_3d"], + pred_vertices=out["pred_vertices"], + pred_keypoints_2d=out["pred_keypoints_2d"], + pred_right=batch["right"], + ) + + outputs.append(hamer_out) + + assert len(outputs) > 0 + + stacked_outputs = _RawHamerOutputs( + **{ + field_name: torch.cat([getattr(x, field_name) for x in outputs], dim=0) + for field_name in vars(outputs[0]).keys() + }, + ) + # begin new brent stuff + verts = stacked_outputs.pred_vertices.numpy(force=True) + keypoints_3d = stacked_outputs.pred_keypoints_3d.numpy(force=True) + pred_cam_t = stacked_outputs.pred_cam_t.numpy(force=True) + mano_hand_pose = stacked_outputs.pred_mano_hand_pose.numpy(force=True) + mano_hand_betas = stacked_outputs.pred_mano_hand_betas.numpy(force=True) + R_camera_hand = stacked_outputs.pred_mano_global_orient.squeeze(dim=1).numpy( + force=True + ) + + is_right = (stacked_outputs.pred_right > 0.5).numpy(force=True) + is_left = ~is_right + + detections_right_wrt_cam: HandOutputsWrtCamera | None + if np.sum(is_right) == 0: + detections_right_wrt_cam = None + else: + detections_right_wrt_cam = { + "verts": verts[is_right] + pred_cam_t[is_right, None, :], + "keypoints_3d": keypoints_3d[is_right] + pred_cam_t[is_right, None, :], + "mano_hand_pose": mano_hand_pose[is_right], + "mano_hand_betas": mano_hand_betas[is_right], + "mano_hand_global_orient": R_camera_hand[is_right], + "faces": self.get_mano_faces("right"), + } + + detections_left_wrt_cam: HandOutputsWrtCamera | None + if np.sum(is_left) == 0: + detections_left_wrt_cam = None + else: + + def flip_rotmats(rotmats: np.ndarray) -> np.ndarray: + assert rotmats.shape[-2:] == (3, 3) + from viser import transforms + + logspace = transforms.SO3.from_matrix(rotmats).log() + logspace[..., 1] *= -1 + logspace[..., 2] *= -1 + return transforms.SO3.exp(logspace).as_matrix() + + detections_left_wrt_cam = { + "verts": verts[is_left] * np.array([-1, 1, 1]) + + pred_cam_t[is_left, None, :], + "keypoints_3d": keypoints_3d[is_left] * np.array([-1, 1, 1]) + + pred_cam_t[is_left, None, :], + "mano_hand_pose": flip_rotmats(mano_hand_pose[is_left]), + "mano_hand_betas": mano_hand_betas[is_left], + "mano_hand_global_orient": flip_rotmats(R_camera_hand[is_left]), + "faces": self.get_mano_faces("left"), + } + # end new brent stuff + return detections_left_wrt_cam, detections_right_wrt_cam + + def get_mano_faces(self, side: Literal["left", "right"]) -> np.ndarray: + if side == "left": + return self._model.mano.faces[:, [0, 2, 1]].copy() + else: + return self._model.mano.faces.copy() + +def example_demo(): + parser = argparse.ArgumentParser(description='WiLoR demo code') + parser.add_argument('--img_folder', type=str, default='images', help='Folder with input images') + parser.add_argument('--out_folder', type=str, default='out_demo', help='Output folder to save rendered results') + parser.add_argument('--save_mesh', dest='save_mesh', action='store_true', default=False, help='If set, save meshes to disk also') + parser.add_argument('--rescale_factor', type=float, default=2.0, help='Factor for padding the bbox') + parser.add_argument('--file_type', nargs='+', default=['*.jpg', '*.png', '*.jpeg'], help='List of file extensions to consider') + + args = parser.parse_args() + + # Download and load checkpoints + model, model_cfg = load_wilor(checkpoint_path = './pretrained_models/wilor_final.ckpt' , cfg_path= './pretrained_models/model_config.yaml') + detector = YOLO('./pretrained_models/detector.pt') + # Setup the renderer + renderer = Renderer(model_cfg, faces=model.mano.faces) + renderer_side = Renderer(model_cfg, faces=model.mano.faces) + + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + model = model.to(device) + detector = detector.to(device) + model.eval() + + # Make output directory if it does not exist + os.makedirs(args.out_folder, exist_ok=True) + + # Get all demo images ends with .jpg or .png + img_paths = [img for end in args.file_type for img in Path(args.img_folder).glob(end)] + # Iterate over all images in folder + for img_path in img_paths: + img_cv2 = cv2.imread(str(img_path)) + detections = detector(img_cv2, conf = 0.3, verbose=False)[0] + bboxes = [] + is_right = [] + for det in detections: + Bbox = det.boxes.data.cpu().detach().squeeze().numpy() + is_right.append(det.boxes.cls.cpu().detach().squeeze().item()) + bboxes.append(Bbox[:4].tolist()) + + if len(bboxes) == 0: + # basename = os.path.basename(img_path).split('.')[0] + # cv2.imwrite(os.path.join(args.out_folder, f'{basename}.jpg'), img_cv2) + # print(os.path.join(args.out_folder, f'{basename}.jpg')) + continue + boxes = np.stack(bboxes) + right = np.stack(is_right) + dataset = ViTDetDataset(model_cfg, img_cv2, boxes, right, rescale_factor=args.rescale_factor) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=False, num_workers=0) + + all_verts = [] + all_cam_t = [] + all_right = [] + all_joints= [] + all_kpts = [] + + for batch in dataloader: + batch = recursive_to(batch, device) + + with torch.no_grad(): + out = model(batch) + + multiplier = (2*batch['right']-1) + pred_cam = out['pred_cam'] + pred_cam[:,1] = multiplier*pred_cam[:,1] + box_center = batch["box_center"].float() + box_size = batch["box_size"].float() + img_size = batch["img_size"].float() + scaled_focal_length = model_cfg.EXTRA.FOCAL_LENGTH / model_cfg.MODEL.IMAGE_SIZE * img_size.max() + pred_cam_t_full = cam_crop_to_full(pred_cam, box_center, box_size, img_size, scaled_focal_length).detach().cpu().numpy() + + + # Render the result + batch_size = batch['img'].shape[0] + for n in range(batch_size): + # Get filename from path img_path + img_fn, _ = os.path.splitext(os.path.basename(img_path)) + + verts = out['pred_vertices'][n].detach().cpu().numpy() + joints = out['pred_keypoints_3d'][n].detach().cpu().numpy() + + is_right = batch['right'][n].cpu().numpy() + verts[:,0] = (2*is_right-1)*verts[:,0] + joints[:,0] = (2*is_right-1)*joints[:,0] + cam_t = pred_cam_t_full[n] + kpts_2d = project_full_img(verts, cam_t, scaled_focal_length, img_size[n]) + + all_verts.append(verts) + all_cam_t.append(cam_t) + all_right.append(is_right) + all_joints.append(joints) + all_kpts.append(kpts_2d) + + + # Save all meshes to disk + if args.save_mesh: + camera_translation = cam_t.copy() + tmesh = renderer.vertices_to_trimesh(verts, camera_translation, LIGHT_PURPLE, is_right=is_right) + tmesh.export(os.path.join(args.out_folder, f'{img_fn}_{n}.obj')) + + # Render front view + if len(all_verts) > 0: + misc_args = dict( + mesh_base_color=LIGHT_PURPLE, + scene_bg_color=(1, 1, 1), + focal_length=scaled_focal_length, + ) + cam_view = renderer.render_rgba_multiple(all_verts, cam_t=all_cam_t, render_res=img_size[n], is_right=all_right, **misc_args) + + # Overlay image + input_img = img_cv2.astype(np.float32)[:,:,::-1]/255.0 + input_img = np.concatenate([input_img, np.ones_like(input_img[:,:,:1])], axis=2) # Add alpha channel + input_img_overlay = input_img[:,:,:3] * (1-cam_view[:,:,3:]) + cam_view[:,:,:3] * cam_view[:,:,3:] + final_img = 255*input_img_overlay[:, :, ::-1] + else: + final_img = img_cv2 + for i in range(len(bboxes)): + if right[i] == 0: + final_img = cv2.rectangle(final_img, (int(bboxes[i][0]), int(bboxes[i][1])), (int(bboxes[i][2]), int(bboxes[i][3])), (255, 100, 100), 2) + else: + final_img = cv2.rectangle(final_img, (int(bboxes[i][0]), int(bboxes[i][1])), (int(bboxes[i][2]), int(bboxes[i][3])), (100, 100, 255), 2) + cv2.imwrite(os.path.join(args.out_folder, f'{img_fn}.jpg'), final_img) + +def project_full_img(points, cam_trans, focal_length, img_res): + camera_center = [img_res[0] / 2., img_res[1] / 2.] + K = torch.eye(3) + K[0,0] = focal_length + K[1,1] = focal_length + K[0,2] = camera_center[0] + K[1,2] = camera_center[1] + points = points + cam_trans + points = points / points[..., -1:] + + V_2d = (K @ points.T).T + return V_2d[..., :-1] + diff --git a/src/egoallo/inference_utils.py b/src/egoallo/inference_utils.py index 994a0a2..e3b69aa 100644 --- a/src/egoallo/inference_utils.py +++ b/src/egoallo/inference_utils.py @@ -62,7 +62,7 @@ class InferenceTrajectoryPaths: splat_path: Path | None @staticmethod - def find(traj_root: Path) -> InferenceTrajectoryPaths: + def find(traj_root: Path, detector: str="hamer") -> InferenceTrajectoryPaths: vrs_files = tuple(traj_root.glob("**/*.vrs")) assert len(vrs_files) == 1, f"Found {len(vrs_files)} VRS files!" @@ -71,8 +71,8 @@ def find(traj_root: Path) -> InferenceTrajectoryPaths: if len(points_paths) == 0: points_paths = tuple(traj_root.glob("**/global_points.csv.gz")) assert len(points_paths) == 1, f"Found {len(points_paths)} files!" - - hamer_outputs = traj_root / "hamer_outputs.pkl" + pickle_file_name = detector + "_outputs.pkl" + hamer_outputs = traj_root / pickle_file_name if not hamer_outputs.exists(): hamer_outputs = None