From 7e2f5a98b8ba9b9f60c84b10ba554dd7074a3794 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Wed, 4 Jun 2025 18:46:50 -0700 Subject: [PATCH] Fix runtime error, pin dependencies --- 0b_preprocess_training_data.py | 2 +- 1_train_motion_prior.py | 2 +- 4_visualize_outputs.py | 6 ++--- README.md | 2 +- pyproject.toml | 32 +++++++++++++-------------- src/egoallo/fncsmpl_jax.py | 4 +++- src/egoallo/guidance_optimizer_jax.py | 2 +- src/egoallo/hand_detection_structs.py | 4 ++-- src/egoallo/inference_utils.py | 6 ++--- 9 files changed, 31 insertions(+), 29 deletions(-) diff --git a/0b_preprocess_training_data.py b/0b_preprocess_training_data.py index 581f866..ec8555d 100644 --- a/0b_preprocess_training_data.py +++ b/0b_preprocess_training_data.py @@ -76,7 +76,7 @@ def worker(device_idx: int) -> None: print( f"Finished ~{total_count - task_queue.qsize()}/{total_count},", - f"{(total_count - task_queue.qsize())/total_count * 100:.2f}% in", + f"{(total_count - task_queue.qsize()) / total_count * 100:.2f}% in", f"{time.time() - start_time} seconds", ) diff --git a/1_train_motion_prior.py b/1_train_motion_prior.py index 1492722..e1f6b48 100644 --- a/1_train_motion_prior.py +++ b/1_train_motion_prior.py @@ -200,7 +200,7 @@ def run_training( mem_free, mem_total = torch.cuda.mem_get_info() logger.info( f"step: {step} ({loop_metrics.iterations_per_sec:.2f} it/sec)" - f" mem: {(mem_total-mem_free)/1024**3:.2f}/{mem_total/1024**3:.2f}G" + f" mem: {(mem_total - mem_free) / 1024**3:.2f}/{mem_total / 1024**3:.2f}G" f" lr: {scheduler.get_last_lr()[0]:.7f}" f" loss: {loss.item():.6f}" ) diff --git a/4_visualize_outputs.py b/4_visualize_outputs.py index 2480e8c..91bfdb6 100644 --- a/4_visualize_outputs.py +++ b/4_visualize_outputs.py @@ -127,9 +127,9 @@ def load_and_visualize( "frame_nums", "timestamps_ns", ] - assert all( - key in outputs for key in expected_keys - ), f"Missing keys in NPZ file. Expected: {expected_keys}, Found: {list(outputs.keys())}" + assert all(key in outputs for key in expected_keys), ( + f"Missing keys in NPZ file. Expected: {expected_keys}, Found: {list(outputs.keys())}" + ) (num_samples, timesteps, _, _) = outputs["body_quats"].shape # We assume the directory structure is: diff --git a/README.md b/README.md index 575de0f..f6b699d 100644 --- a/README.md +++ b/README.md @@ -96,7 +96,7 @@ EgoAllo requires Python 3.12 or newer. ```bash # Also see: https://jax.readthedocs.io/en/latest/installation.html - pip install -U "jax[cuda12]" + pip install "jax[cuda12]==0.6.1" ``` You'll also need [jaxls](https://github.com/brentyi/jaxls): diff --git a/pyproject.toml b/pyproject.toml index a5e9cf4..55dda1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,23 +15,23 @@ classifiers = [ "Operating System :: OS Independent" ] dependencies = [ - "torch>2.2", + "torch==2.7.1", "viser>=0.2.11", - "typeguard", - "jaxtyping>=0.2.29", - "einops", - "rotary-embedding-torch", - "h5py", - "tensorboard", - "projectaria_tools", - "accelerate", - "tensorboardX", - "loguru", - "projectaria-tools[all]", - "opencv-python", - "gdown", - "scikit-learn", # Only needed for preprocessing - "smplx", # Only needed for preprocessing + "typeguard==4.4.3", + "jaxtyping==0.3.2", + "einops==0.8.1", + "rotary-embedding-torch==0.8.6", + "h5py==3.13.0", + "tensorboard==2.19.0", + "projectaria_tools==1.6.0", + "accelerate==1.7.0", + "tensorboardX==2.6.2.2", + "loguru==0.7.3", + "projectaria-tools[all]==1.6.0", + "opencv-python==4.11.0.86", + "gdown==5.2.0", + "scikit-learn==1.6.1", # Only needed for preprocessing + "smplx==0.1.28", # Only needed for preprocessing ] [tool.setuptools.package-data] diff --git a/src/egoallo/fncsmpl_jax.py b/src/egoallo/fncsmpl_jax.py index 802e46a..37747ba 100644 --- a/src/egoallo/fncsmpl_jax.py +++ b/src/egoallo/fncsmpl_jax.py @@ -215,7 +215,9 @@ def lbs(self) -> SmplhMesh: assert ( self.local_quats.shape[0] == self.shaped_model.body_model.parent_indices.shape[0] - ), "It looks like only a partial set of joint rotations was passed into `with_pose()`. We need all of them for LBS." + ), ( + "It looks like only a partial set of joint rotations was passed into `with_pose()`. We need all of them for LBS." + ) # Linear blend skinning with a pose blend shape. verts_with_blend = self.shaped_model.verts_zero + einsum( diff --git a/src/egoallo/guidance_optimizer_jax.py b/src/egoallo/guidance_optimizer_jax.py index 58acdac..e324aac 100644 --- a/src/egoallo/guidance_optimizer_jax.py +++ b/src/egoallo/guidance_optimizer_jax.py @@ -190,7 +190,7 @@ class JaxGuidanceParams: # Optimization parameters. lambda_initial: float = 0.1 - max_iters: int = 20 + max_iters: jdc.Static[int] = 20 @staticmethod def defaults( diff --git a/src/egoallo/hand_detection_structs.py b/src/egoallo/hand_detection_structs.py index 7553927..2d42eb7 100644 --- a/src/egoallo/hand_detection_structs.py +++ b/src/egoallo/hand_detection_structs.py @@ -97,11 +97,11 @@ class OneSide(Protocol): ): continue - if wp_pose.left_hand.confidence > 0.7: + if wp_pose.left_hand is not None and wp_pose.left_hand.confidence > 0.7: indices_left.append(i) detections_left.append(wp_pose.left_hand) - if wp_pose.right_hand.confidence > 0.7: + if wp_pose.right_hand is not None and wp_pose.right_hand.confidence > 0.7: indices_right.append(i) detections_right.append(wp_pose.right_hand) diff --git a/src/egoallo/inference_utils.py b/src/egoallo/inference_utils.py index 994a0a2..f004c0e 100644 --- a/src/egoallo/inference_utils.py +++ b/src/egoallo/inference_utils.py @@ -80,9 +80,9 @@ def find(traj_root: Path) -> InferenceTrajectoryPaths: if len(wrist_and_palm_poses_csv) == 0: wrist_and_palm_poses_csv = None else: - assert ( - len(wrist_and_palm_poses_csv) == 1 - ), "Found multiple wrist and palm poses files!" + assert len(wrist_and_palm_poses_csv) == 1, ( + "Found multiple wrist and palm poses files!" + ) splat_path = traj_root / "splat.ply" if not splat_path.exists():