Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,63 @@ Make sure your docker container has sufficient CPU, swap and GPU memory. Please

You will need instance segmentation (e.g. [SAM2](https://ai.meta.com/sam2/)) and motion planning (e.g. [cuRobo](https://curobo.org/)) to run this model. More details can be found in the experiments section of the paper.

### How do I fine-tune the discriminator on real-world robot data?

GraspGen supports fine-tuning the discriminator on real grasp execution outcomes — no mesh required.

**Step 1: Record rollouts on your robot**

In your robot control loop, record each grasp attempt:

```python
from scripts.record_grasp import GraspRecorder

recorder = GraspRecorder(
object_id="my_object.obj", # unique name for this object
output_dir="/data/rollouts/my_object",
)

# After each grasp attempt:
recorder.record(
point_cloud=pc, # (N, 3) numpy array, object-centered, from depth camera
grasp_pose=T, # (4, 4) SE(3) matrix — exactly as returned by GraspGenSampler
confidence=conf, # discriminator score from GraspGenSampler.run_inference()
success=True/False, # did the object end up in the gripper?
collided=False, # did the gripper collide before closing?
)

recorder.save() # writes rollouts.h5 + grasps.json
```

> **Important:** `point_cloud` and `grasp_pose` must be in the same coordinate frame — the object-centered frame that `GraspGenSampler` uses (it subtracts the point cloud mean before inference and adds it back to the returned poses). Pass the pose exactly as returned by `run_inference()`.

> **Success labels:** Use fingertip position checking or a force sensor to determine `success`. A grasp that collided with the scene before closing should be marked `collided=True` — it will be excluded from training (it never got a fair attempt).

**Step 2: Prepare the cache for training**

Once you have collected enough rollouts (recommended: 10+ non-colliding attempts, at least a few successes):

```bash
python scripts/prepare_finetune.py \
--rollouts /data/rollouts/my_object \
--cache_dir /data/cache \
--dataset_dir /data/dataset \
--object_id my_object.obj \
--gripper robotiq_2f_140
```

This prints the exact `train_graspgen.py` command to run. Execute it to fine-tune.

**What to expect:**
- Discriminator loss should decrease within the first epoch
- With ~10–50 rollouts the improvement is modest but measurable
- More rollouts from diverse viewpoints → better generalization
- The generator is not fine-tuned, only the discriminator

**Implementation notes:**
- `rollouts.h5` stores a `gt_grasps` field that mirrors `pred_grasps`. Real-world rollouts don't have ground-truth grasp annotations, so the predicted poses are used as a placeholder — this is intentional, not a bug.
- The printed `train_graspgen.py` command uses a 7-element `discriminator_ratio`. Do not shorten it to 5 elements — `train_graspgen.py` always accesses the onpolicy slots (indices 5 and 6) when `load_discriminator_dataset=true`, and a 5-element ratio will crash inside the dataloader.

### You did not include the gripper I have/want with your dataset!
Sorry we missed your gripper! Please consider completing this quick [survey](https://docs.google.com/forms/d/e/1FAIpQLSdTCstEtaeZz5iSyjAhYFuJqSpMF671ftPylkS3ZJFhRIg3dg/viewform?usp=dialog) to describe your gripper. You can optionally leave a your URDF.

Expand Down
12 changes: 9 additions & 3 deletions grasp_gen/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,6 +1307,8 @@ def __getitem__(self, idx):
pc,
self.gripper_visual_mesh,
)
if type(outputs["points"]) == np.ndarray:
outputs["points"] = torch.from_numpy(outputs["points"]).float()
if len(outputs["points"].shape) == 2:
outputs["points"] = outputs["points"].unsqueeze(0)
return outputs
Expand Down Expand Up @@ -1501,9 +1503,13 @@ def load_discriminator_batch_with_stratified_sampling(
obj_asset_path = scene_info["assets"][0]
obj_pose = scene_info["poses"][0]

scene_mesh = trimesh.load(obj_asset_path)
scene_mesh.apply_scale(obj_scale)
scene_mesh.apply_transform(obj_pose)
if os.path.exists(obj_asset_path):
scene_mesh = trimesh.load(obj_asset_path)
scene_mesh.apply_scale(obj_scale)
scene_mesh.apply_transform(obj_pose)
else:
scene_mesh = None
num_neg_hncolliding = 0 # mesh-based hard negatives unavailable

grasps, grasp_ids = None, None

Expand Down
115 changes: 115 additions & 0 deletions scripts/prepare_finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""
Convert real-world rollout data into a training cache and print the fine-tuning command.

Run this after collecting rollouts with GraspRecorder:

python scripts/prepare_finetune.py \\
--rollouts /data/rollouts/banana \\
--cache_dir /data/cache \\
--dataset_dir /data/dataset \\
--object_id banana.obj
"""

import argparse
import json
import os
import sys

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))


def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--rollouts", required=True)
p.add_argument("--cache_dir", required=True)
p.add_argument("--dataset_dir", required=True)
p.add_argument("--object_id", required=True)
p.add_argument("--num_renderings", type=int, default=5)
p.add_argument("--gripper", default="robotiq_2f_140")
return p.parse_args()


def main():
args = parse_args()

h5_path = os.path.join(args.rollouts, "rollouts.h5")
json_path = os.path.join(args.rollouts, "grasps.json")
train_txt = os.path.join(args.dataset_dir, "train.txt")

assert os.path.exists(h5_path), f"rollouts.h5 not found: {h5_path}"
assert os.path.exists(json_path), f"grasps.json not found: {json_path}"

import h5py
import numpy as np
with h5py.File(h5_path, "r") as f:
grp = f["objects"][args.object_id]
n_total = grp["pred_grasps"].shape[0]
n_collided = int(grp["collision"][...].sum())
with open(json_path) as jf:
labels = json.load(jf)["grasps"]["object_in_gripper"]
n_success = sum(labels)
n_noncolliding = n_total - n_collided

print(f" {args.object_id}: {n_total} attempts, "
f"{n_success} success, {n_noncolliding - n_success} fail, "
f"{n_collided} collided")

if n_success == 0:
print("No successful grasps — need at least 1 to build a useful cache.")
sys.exit(1)

if n_noncolliding < 10:
print(f"Only {n_noncolliding} non-colliding attempts — signal may be weak. "
f"Recommend collecting at least 10.")

# Validate that the ratio we're about to print is the right length.
# train_graspgen.py accesses onpolicy slots (indices 5, 6) unconditionally
# when load_discriminator_dataset=True; a 5-element ratio crashes there.
_ratio = [0.30, 0.10, 0.15, 0.05, 0.0, 0.20, 0.20]
assert len(_ratio) == 7, "discriminator_ratio must have 7 elements for onpolicy data"

from scripts.record_grasp import GraspRecorder
from grasp_gen.dataset.dataset import get_cache_prefix

dataset_name = os.path.basename(os.path.abspath(args.dataset_dir))
cache_subdir = os.path.join(args.cache_dir, dataset_name)
os.makedirs(cache_subdir, exist_ok=True)

prefix = get_cache_prefix(prob_point_cloud=0.0, load_discriminator_dataset=True)
cache_h5 = os.path.join(cache_subdir, f"cache_train_{prefix}.h5")

recorder = GraspRecorder.from_h5(h5_path, args.object_id, args.rollouts)
recorder.write_cache(cache_h5, num_renderings=args.num_renderings)

os.makedirs(args.dataset_dir, exist_ok=True)
existing = set()
if os.path.exists(train_txt):
with open(train_txt) as f:
existing = set(f.read().splitlines())
if args.object_id not in existing:
with open(train_txt, "a") as f:
f.write(args.object_id + "\n")

ckpt_root = os.path.join(os.path.dirname(__file__), "..", "checkpoints")

print(f"""
Run this to fine-tune:

python scripts/train_graspgen.py \\
data.root_dir={args.dataset_dir} \\
data.cache_dir={args.cache_dir} \\
data.gripper_name={args.gripper} \\
data.load_discriminator_dataset=true \\
data.discriminator_ratio='[0.30,0.10,0.15,0.05,0.0,0.20,0.20]' \\
train.model_name=discriminator \\
train.num_epochs=5 \\
train.checkpoint={ckpt_root}/graspgen_{args.gripper}_dis.pth \\
discriminator.checkpoint={ckpt_root}/graspgen_{args.gripper}_dis.pth \\
discriminator.gripper_name={args.gripper}

Loss should drop within the first epoch. If it doesn't move, collect more data.
""")


if __name__ == "__main__":
main()
Loading