Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 0b_preprocess_training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def worker(device_idx: int) -> None:

print(
f"Finished ~{total_count - task_queue.qsize()}/{total_count},",
f"{(total_count - task_queue.qsize())/total_count * 100:.2f}% in",
f"{(total_count - task_queue.qsize()) / total_count * 100:.2f}% in",
f"{time.time() - start_time} seconds",
)

Expand Down
2 changes: 1 addition & 1 deletion 1_train_motion_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def run_training(
mem_free, mem_total = torch.cuda.mem_get_info()
logger.info(
f"step: {step} ({loop_metrics.iterations_per_sec:.2f} it/sec)"
f" mem: {(mem_total-mem_free)/1024**3:.2f}/{mem_total/1024**3:.2f}G"
f" mem: {(mem_total - mem_free) / 1024**3:.2f}/{mem_total / 1024**3:.2f}G"
f" lr: {scheduler.get_last_lr()[0]:.7f}"
f" loss: {loss.item():.6f}"
)
Expand Down
6 changes: 3 additions & 3 deletions 4_visualize_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ def load_and_visualize(
"frame_nums",
"timestamps_ns",
]
assert all(
key in outputs for key in expected_keys
), f"Missing keys in NPZ file. Expected: {expected_keys}, Found: {list(outputs.keys())}"
assert all(key in outputs for key in expected_keys), (
f"Missing keys in NPZ file. Expected: {expected_keys}, Found: {list(outputs.keys())}"
)
(num_samples, timesteps, _, _) = outputs["body_quats"].shape

# We assume the directory structure is:
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ EgoAllo requires Python 3.12 or newer.

```bash
# Also see: https://jax.readthedocs.io/en/latest/installation.html
pip install -U "jax[cuda12]"
pip install "jax[cuda12]==0.6.1"
```

You'll also need [jaxls](https://github.com/brentyi/jaxls):
Expand Down
32 changes: 16 additions & 16 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@ classifiers = [
"Operating System :: OS Independent"
]
dependencies = [
"torch>2.2",
"torch==2.7.1",
"viser>=0.2.11",
"typeguard",
"jaxtyping>=0.2.29",
"einops",
"rotary-embedding-torch",
"h5py",
"tensorboard",
"projectaria_tools",
"accelerate",
"tensorboardX",
"loguru",
"projectaria-tools[all]",
"opencv-python",
"gdown",
"scikit-learn", # Only needed for preprocessing
"smplx", # Only needed for preprocessing
"typeguard==4.4.3",
"jaxtyping==0.3.2",
"einops==0.8.1",
"rotary-embedding-torch==0.8.6",
"h5py==3.13.0",
"tensorboard==2.19.0",
"projectaria_tools==1.6.0",
"accelerate==1.7.0",
"tensorboardX==2.6.2.2",
"loguru==0.7.3",
"projectaria-tools[all]==1.6.0",
"opencv-python==4.11.0.86",
"gdown==5.2.0",
"scikit-learn==1.6.1", # Only needed for preprocessing
"smplx==0.1.28", # Only needed for preprocessing
]

[tool.setuptools.package-data]
Expand Down
4 changes: 3 additions & 1 deletion src/egoallo/fncsmpl_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,9 @@ def lbs(self) -> SmplhMesh:
assert (
self.local_quats.shape[0]
== self.shaped_model.body_model.parent_indices.shape[0]
), "It looks like only a partial set of joint rotations was passed into `with_pose()`. We need all of them for LBS."
), (
"It looks like only a partial set of joint rotations was passed into `with_pose()`. We need all of them for LBS."
)

# Linear blend skinning with a pose blend shape.
verts_with_blend = self.shaped_model.verts_zero + einsum(
Expand Down
2 changes: 1 addition & 1 deletion src/egoallo/guidance_optimizer_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ class JaxGuidanceParams:

# Optimization parameters.
lambda_initial: float = 0.1
max_iters: int = 20
max_iters: jdc.Static[int] = 20

@staticmethod
def defaults(
Expand Down
4 changes: 2 additions & 2 deletions src/egoallo/hand_detection_structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,11 @@ class OneSide(Protocol):
):
continue

if wp_pose.left_hand.confidence > 0.7:
if wp_pose.left_hand is not None and wp_pose.left_hand.confidence > 0.7:
indices_left.append(i)
detections_left.append(wp_pose.left_hand)

if wp_pose.right_hand.confidence > 0.7:
if wp_pose.right_hand is not None and wp_pose.right_hand.confidence > 0.7:
indices_right.append(i)
detections_right.append(wp_pose.right_hand)

Expand Down
6 changes: 3 additions & 3 deletions src/egoallo/inference_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def find(traj_root: Path) -> InferenceTrajectoryPaths:
if len(wrist_and_palm_poses_csv) == 0:
wrist_and_palm_poses_csv = None
else:
assert (
len(wrist_and_palm_poses_csv) == 1
), "Found multiple wrist and palm poses files!"
assert len(wrist_and_palm_poses_csv) == 1, (
"Found multiple wrist and palm poses files!"
)

splat_path = traj_root / "splat.ply"
if not splat_path.exists():
Expand Down