diff --git a/README.md b/README.md index 265374a..2ca3d71 100644 --- a/README.md +++ b/README.md @@ -421,6 +421,63 @@ Make sure your docker container has sufficient CPU, swap and GPU memory. Please You will need instance segmentation (e.g. [SAM2](https://ai.meta.com/sam2/)) and motion planning (e.g. [cuRobo](https://curobo.org/)) to run this model. More details can be found in the experiments section of the paper. +### How do I fine-tune the discriminator on real-world robot data? + +GraspGen supports fine-tuning the discriminator on real grasp execution outcomes — no mesh required. + +**Step 1: Record rollouts on your robot** + +In your robot control loop, record each grasp attempt: + +```python +from scripts.record_grasp import GraspRecorder + +recorder = GraspRecorder( + object_id="my_object.obj", # unique name for this object + output_dir="/data/rollouts/my_object", +) + +# After each grasp attempt: +recorder.record( + point_cloud=pc, # (N, 3) numpy array, object-centered, from depth camera + grasp_pose=T, # (4, 4) SE(3) matrix — exactly as returned by GraspGenSampler + confidence=conf, # discriminator score from GraspGenSampler.run_inference() + success=True/False, # did the object end up in the gripper? + collided=False, # did the gripper collide before closing? +) + +recorder.save() # writes rollouts.h5 + grasps.json +``` + +> **Important:** `point_cloud` and `grasp_pose` must be in the same coordinate frame — the object-centered frame that `GraspGenSampler` uses (it subtracts the point cloud mean before inference and adds it back to the returned poses). Pass the pose exactly as returned by `run_inference()`. + +> **Success labels:** Use fingertip position checking or a force sensor to determine `success`. A grasp that collided with the scene before closing should be marked `collided=True` — it will be excluded from training (it never got a fair attempt). + +**Step 2: Prepare the cache for training** + +Once you have collected enough rollouts (recommended: 10+ non-colliding attempts, at least a few successes): + +```bash +python scripts/prepare_finetune.py \ + --rollouts /data/rollouts/my_object \ + --cache_dir /data/cache \ + --dataset_dir /data/dataset \ + --object_id my_object.obj \ + --gripper robotiq_2f_140 +``` + +This prints the exact `train_graspgen.py` command to run. Execute it to fine-tune. + +**What to expect:** +- Discriminator loss should decrease within the first epoch +- With ~10–50 rollouts the improvement is modest but measurable +- More rollouts from diverse viewpoints → better generalization +- The generator is not fine-tuned, only the discriminator + +**Implementation notes:** +- `rollouts.h5` stores a `gt_grasps` field that mirrors `pred_grasps`. Real-world rollouts don't have ground-truth grasp annotations, so the predicted poses are used as a placeholder — this is intentional, not a bug. +- The printed `train_graspgen.py` command uses a 7-element `discriminator_ratio`. Do not shorten it to 5 elements — `train_graspgen.py` always accesses the onpolicy slots (indices 5 and 6) when `load_discriminator_dataset=true`, and a 5-element ratio will crash inside the dataloader. + ### You did not include the gripper I have/want with your dataset! Sorry we missed your gripper! Please consider completing this quick [survey](https://docs.google.com/forms/d/e/1FAIpQLSdTCstEtaeZz5iSyjAhYFuJqSpMF671ftPylkS3ZJFhRIg3dg/viewform?usp=dialog) to describe your gripper. You can optionally leave a your URDF. diff --git a/grasp_gen/dataset/dataset.py b/grasp_gen/dataset/dataset.py index ffb517d..e661d80 100644 --- a/grasp_gen/dataset/dataset.py +++ b/grasp_gen/dataset/dataset.py @@ -1307,6 +1307,8 @@ def __getitem__(self, idx): pc, self.gripper_visual_mesh, ) + if type(outputs["points"]) == np.ndarray: + outputs["points"] = torch.from_numpy(outputs["points"]).float() if len(outputs["points"].shape) == 2: outputs["points"] = outputs["points"].unsqueeze(0) return outputs @@ -1501,9 +1503,13 @@ def load_discriminator_batch_with_stratified_sampling( obj_asset_path = scene_info["assets"][0] obj_pose = scene_info["poses"][0] - scene_mesh = trimesh.load(obj_asset_path) - scene_mesh.apply_scale(obj_scale) - scene_mesh.apply_transform(obj_pose) + if os.path.exists(obj_asset_path): + scene_mesh = trimesh.load(obj_asset_path) + scene_mesh.apply_scale(obj_scale) + scene_mesh.apply_transform(obj_pose) + else: + scene_mesh = None + num_neg_hncolliding = 0 # mesh-based hard negatives unavailable grasps, grasp_ids = None, None diff --git a/scripts/prepare_finetune.py b/scripts/prepare_finetune.py new file mode 100644 index 0000000..2a28d86 --- /dev/null +++ b/scripts/prepare_finetune.py @@ -0,0 +1,115 @@ +""" +Convert real-world rollout data into a training cache and print the fine-tuning command. + +Run this after collecting rollouts with GraspRecorder: + + python scripts/prepare_finetune.py \\ + --rollouts /data/rollouts/banana \\ + --cache_dir /data/cache \\ + --dataset_dir /data/dataset \\ + --object_id banana.obj +""" + +import argparse +import json +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument("--rollouts", required=True) + p.add_argument("--cache_dir", required=True) + p.add_argument("--dataset_dir", required=True) + p.add_argument("--object_id", required=True) + p.add_argument("--num_renderings", type=int, default=5) + p.add_argument("--gripper", default="robotiq_2f_140") + return p.parse_args() + + +def main(): + args = parse_args() + + h5_path = os.path.join(args.rollouts, "rollouts.h5") + json_path = os.path.join(args.rollouts, "grasps.json") + train_txt = os.path.join(args.dataset_dir, "train.txt") + + assert os.path.exists(h5_path), f"rollouts.h5 not found: {h5_path}" + assert os.path.exists(json_path), f"grasps.json not found: {json_path}" + + import h5py + import numpy as np + with h5py.File(h5_path, "r") as f: + grp = f["objects"][args.object_id] + n_total = grp["pred_grasps"].shape[0] + n_collided = int(grp["collision"][...].sum()) + with open(json_path) as jf: + labels = json.load(jf)["grasps"]["object_in_gripper"] + n_success = sum(labels) + n_noncolliding = n_total - n_collided + + print(f" {args.object_id}: {n_total} attempts, " + f"{n_success} success, {n_noncolliding - n_success} fail, " + f"{n_collided} collided") + + if n_success == 0: + print("No successful grasps — need at least 1 to build a useful cache.") + sys.exit(1) + + if n_noncolliding < 10: + print(f"Only {n_noncolliding} non-colliding attempts — signal may be weak. " + f"Recommend collecting at least 10.") + + # Validate that the ratio we're about to print is the right length. + # train_graspgen.py accesses onpolicy slots (indices 5, 6) unconditionally + # when load_discriminator_dataset=True; a 5-element ratio crashes there. + _ratio = [0.30, 0.10, 0.15, 0.05, 0.0, 0.20, 0.20] + assert len(_ratio) == 7, "discriminator_ratio must have 7 elements for onpolicy data" + + from scripts.record_grasp import GraspRecorder + from grasp_gen.dataset.dataset import get_cache_prefix + + dataset_name = os.path.basename(os.path.abspath(args.dataset_dir)) + cache_subdir = os.path.join(args.cache_dir, dataset_name) + os.makedirs(cache_subdir, exist_ok=True) + + prefix = get_cache_prefix(prob_point_cloud=0.0, load_discriminator_dataset=True) + cache_h5 = os.path.join(cache_subdir, f"cache_train_{prefix}.h5") + + recorder = GraspRecorder.from_h5(h5_path, args.object_id, args.rollouts) + recorder.write_cache(cache_h5, num_renderings=args.num_renderings) + + os.makedirs(args.dataset_dir, exist_ok=True) + existing = set() + if os.path.exists(train_txt): + with open(train_txt) as f: + existing = set(f.read().splitlines()) + if args.object_id not in existing: + with open(train_txt, "a") as f: + f.write(args.object_id + "\n") + + ckpt_root = os.path.join(os.path.dirname(__file__), "..", "checkpoints") + + print(f""" +Run this to fine-tune: + + python scripts/train_graspgen.py \\ + data.root_dir={args.dataset_dir} \\ + data.cache_dir={args.cache_dir} \\ + data.gripper_name={args.gripper} \\ + data.load_discriminator_dataset=true \\ + data.discriminator_ratio='[0.30,0.10,0.15,0.05,0.0,0.20,0.20]' \\ + train.model_name=discriminator \\ + train.num_epochs=5 \\ + train.checkpoint={ckpt_root}/graspgen_{args.gripper}_dis.pth \\ + discriminator.checkpoint={ckpt_root}/graspgen_{args.gripper}_dis.pth \\ + discriminator.gripper_name={args.gripper} + +Loss should drop within the first epoch. If it doesn't move, collect more data. +""") + + +if __name__ == "__main__": + main() diff --git a/scripts/record_grasp.py b/scripts/record_grasp.py new file mode 100644 index 0000000..6547a9f --- /dev/null +++ b/scripts/record_grasp.py @@ -0,0 +1,221 @@ +""" +Record grasp rollout data for discriminator finetuning on real-world data. + +Usage: + from scripts.record_grasp import GraspRecorder + + recorder = GraspRecorder("banana", "/data/rollouts/banana") + recorder.record(point_cloud, grasp_pose, confidence, success, collided) + recorder.save() + +Output: + rollouts.h5 - grasp poses, confidences, collision flags, point clouds + grasps.json - success labels for non-colliding grasps + +Note: pass a 7-element discriminator_ratio when fine-tuning: + discriminator_ratio=[0.35, 0.15, 0.20, 0.05, 0.0, 0.15, 0.10] +The default 5-element ratio will crash because the onpolicy slots (indices 5, 6) +are always accessed by load_discriminator_batch_with_stratified_sampling. +""" + +import json +import os +from dataclasses import dataclass +from typing import List + +import h5py +import numpy as np + + +@dataclass +class GraspAttempt: + point_cloud: np.ndarray # (N, 3) + grasp_pose: np.ndarray # (4, 4) + confidence: float + success: bool + collided: bool + + +class GraspRecorder: + """Records grasp attempts for a single object and saves them to disk.""" + + def __init__(self, object_id: str, output_dir: str): + self.object_id = object_id + self.output_dir = output_dir + self.attempts: List[GraspAttempt] = [] + os.makedirs(output_dir, exist_ok=True) + + def record( + self, + point_cloud: np.ndarray, + grasp_pose: np.ndarray, + confidence: float, + success: bool, + collided: bool = False, + ): + """ + Record a single grasp attempt. + + point_cloud: (N, 3) in object-centered frame (same frame as grasp_pose). + Pass exactly what GraspGenSampler received — don't re-center. + grasp_pose: (4, 4) SE(3), exactly as returned by GraspGenSampler.run_inference(). + collided: mark True if the gripper hit something before closing. + Collided attempts are excluded from training labels. + """ + assert grasp_pose.shape == (4, 4), "grasp_pose must be (4, 4)" + assert point_cloud.ndim == 2 and point_cloud.shape[1] == 3, \ + "point_cloud must be (N, 3)" + + self.attempts.append(GraspAttempt( + point_cloud=point_cloud.astype(np.float32), + grasp_pose=grasp_pose.astype(np.float64), + confidence=float(confidence), + success=bool(success), + collided=bool(collided), + )) + + def save(self): + """Write all attempts to rollouts.h5 and grasps.json.""" + assert len(self.attempts) > 0, "No attempts recorded yet" + + N = len(self.attempts) + pred_grasps = np.array([a.grasp_pose for a in self.attempts]) + confidences = np.array([a.confidence for a in self.attempts]) + collisions = np.array([a.collided for a in self.attempts]) + successes = np.array([a.success for a in self.attempts]) + + pc_sizes = [a.point_cloud.shape[0] for a in self.attempts] + M = max(pc_sizes) + point_clouds = np.zeros((N, M, 3), dtype=np.float32) + for i, a in enumerate(self.attempts): + point_clouds[i, :a.point_cloud.shape[0]] = a.point_cloud + + h5_path = os.path.join(self.output_dir, "rollouts.h5") + with h5py.File(h5_path, "w") as f: + grp = f.require_group(f"objects/{self.object_id}") + grp.create_dataset("pred_grasps", data=pred_grasps) + grp.create_dataset("gt_grasps", data=pred_grasps) # no GT annotations in real-world rollouts; mirrors pred_grasps + grp.create_dataset("confidence", data=confidences) + grp.create_dataset("collision", data=collisions) + grp.create_dataset("point_clouds", data=point_clouds) + grp.create_dataset("pc_sizes", data=np.array(pc_sizes)) + grp.create_dataset("asset_path", data=np.bytes_(self.object_id)) + + non_colliding = ~collisions + json_data = { + "grasps": { + "transforms": pred_grasps[non_colliding].tolist(), + "object_in_gripper": successes[non_colliding].astype(int).tolist(), + } + } + json_path = os.path.join(self.output_dir, "grasps.json") + with open(json_path, "w") as f: + json.dump(json_data, f) + + print(f"Saved {N} attempts for '{self.object_id}' to {self.output_dir}") + print(f" Collided: {collisions.sum()}, Successful: {successes.sum()}") + + def write_cache(self, cache_h5_path: str, num_renderings: int = 5) -> None: + """ + Write to GraspGenDatasetCache format so train_graspgen.py can consume + this data directly with preload_dataset=True — no mesh needed. + + Place the output at //cache_train_.h5 + (prepare_finetune.py handles this automatically). + """ + assert len(self.attempts) > 0, "No attempts recorded" + import trimesh.transformations as tra + from grasp_gen.dataset.eval_utils import write_info + + non_colliding = [a for a in self.attempts if not a.collided] + assert len(non_colliding) > 0, "All attempts collided" + + pos_attempts = [a for a in non_colliding if a.success] + neg_attempts = [a for a in non_colliding if not a.success] + + positive_grasps = ( + np.array([a.grasp_pose for a in pos_attempts], dtype=np.float64) + if pos_attempts else np.zeros((0, 4, 4), dtype=np.float64) + ) + negative_grasps = ( + np.array([a.grasp_pose for a in neg_attempts], dtype=np.float64) + if neg_attempts else np.zeros((0, 4, 4), dtype=np.float64) + ) + assert len(positive_grasps) > 0, "No successful grasps to write" + + n = min(num_renderings, len(self.attempts)) + idxs = np.random.choice(len(self.attempts), size=n, replace=False) + renderings = [] + for idx in idxs: + xyz = self.attempts[idx].point_cloud.astype(np.float32) + T = tra.translation_matrix(-xyz.mean(axis=0)) + xyz_centered = tra.transform_points(xyz, T).astype(np.float32) + renderings.append({ + "mesh_mode": np.bool_(False), + "load_contact_batch": np.bool_(False), + "invalid": np.bool_(False), + "points": xyz_centered, + "T_move_to_pc_mean": T.astype(np.float32), + "positive_grasps": positive_grasps, + }) + + grasp_data = { + "object_mesh": None, + "positive_grasps": positive_grasps, + "contacts": None, + "object_asset_path": self.object_id, + "object_scale": float(1.0), + "negative_grasps": negative_grasps, + "positive_grasps_onpolicy": positive_grasps, + "negative_grasps_onpolicy": negative_grasps, + } + + key_h5 = self.object_id.replace("/", "____") + with h5py.File(cache_h5_path, "a") as f: + if key_h5 in f: + del f[key_h5] + grp = f.create_group(key_h5) + write_info(grp.create_group("grasp_data"), grasp_data) + grp_r = grp.create_group("renderings") + for i, rendering in enumerate(renderings): + write_info(grp_r.create_group(str(i)), rendering) + + print(f"Wrote cache for '{self.object_id}' → {cache_h5_path}") + print(f" Positives: {len(positive_grasps)}, Negatives: {len(negative_grasps)}, Renderings: {n}") + + @classmethod + def from_h5(cls, h5_path: str, object_id: str, output_dir: str) -> "GraspRecorder": + """Reconstruct a GraspRecorder from a saved rollouts.h5 + grasps.json.""" + recorder = cls(object_id=object_id, output_dir=output_dir) + with h5py.File(h5_path, "r") as f: + grp = f["objects"][object_id] + pred_grasps = grp["pred_grasps"][...] + confidences = grp["confidence"][...] + collisions = grp["collision"][...] + point_clouds = grp["point_clouds"][...] + pc_sizes = grp["pc_sizes"][...] if "pc_sizes" in grp else None + + json_path = os.path.join(output_dir, "grasps.json") + with open(json_path) as jf: + data = json.load(jf) + success_labels = np.array(data["grasps"]["object_in_gripper"]) + + non_colliding_idx = np.where(~collisions)[0] + successes = np.zeros(len(pred_grasps), dtype=bool) + successes[non_colliding_idx] = success_labels.astype(bool) + + for i in range(len(pred_grasps)): + pc = point_clouds[i] + if pc_sizes is not None: + pc = pc[:pc_sizes[i]] + recorder.attempts.append(GraspAttempt( + point_cloud=pc.astype(np.float32), + grasp_pose=pred_grasps[i].astype(np.float64), + confidence=float(confidences[i]), + success=bool(successes[i]), + collided=bool(collisions[i]), + )) + return recorder + + def __len__(self): + return len(self.attempts) diff --git a/scripts/test_training_step.py b/scripts/test_training_step.py new file mode 100644 index 0000000..3bef2d6 --- /dev/null +++ b/scripts/test_training_step.py @@ -0,0 +1,344 @@ +""" +Verify one end-to-end discriminator training step with onpolicy data. + +No mocks. Constructs a minimal dataset, loads one batch, runs one +forward + backward pass through the discriminator, checks loss is finite. + +Run on the GCP VM inside Docker: + docker run --rm --gpus all \\ + -v /opt/GraspGen:/code \\ + graspgen:latest \\ + bash -c "pip install -q -e /code --no-deps 2>/dev/null && \\ + pip install -q viser 2>/dev/null && \\ + python /code/scripts/test_training_step.py" +""" + +import json +import os +import sys +import tempfile + +import h5py +import numpy as np + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + + +OBJECT_KEY = "banana.obj" +NUM_POINTS = 512 +NUM_GRASPS = 20 +GRIPPER_NAME = "robotiq_2f_140" +GRIPPER_DEPTH = 0.195 # from config/grippers/robotiq_2f_140.yaml +CKPT_PATH = "/code/checkpoints/graspgen_robotiq_2f_140_dis.pth" # mount checkpoints at /code/checkpoints inside Docker + + +def make_grasp(z_offset=0.0): + """Grasp whose tool tip lands near origin (inside the point cloud).""" + T = np.eye(4) + T[2, 3] = -GRIPPER_DEPTH + z_offset + return T.tolist() + + +def setup_dataset(root): + from scripts.record_grasp import GraspRecorder + + object_dir = os.path.join(root, "objects") + grasp_dir = os.path.join(root, "grasps") + rollout_dir = os.path.join(root, "onpolicy") + dataset_dir = os.path.join(root, "dataset") + cache_dir = os.path.join(root, "cache") + for d in [object_dir, grasp_dir, rollout_dir, dataset_dir, cache_dir]: + os.makedirs(d, exist_ok=True) + + # Object mesh + banana_mesh = os.path.join(os.path.dirname(__file__), "..", "assets", "objects", "banana.obj") + banana_mesh = os.path.abspath(banana_mesh) + assert os.path.exists(banana_mesh), f"Mesh not found: {banana_mesh}" + os.symlink(banana_mesh, os.path.join(object_dir, OBJECT_KEY)) + + # Grasp JSON + rng = np.random.default_rng(0) + grasps_data = { + "object": {"file": OBJECT_KEY, "scale": 1.0}, + "grasps": { + "transforms": [make_grasp(rng.uniform(-0.01, 0.01)) for _ in range(NUM_GRASPS)], + "object_in_gripper": [i % 2 for i in range(NUM_GRASPS)], + }, + } + with open(os.path.join(grasp_dir, "banana_grasps.json"), "w") as f: + json.dump(grasps_data, f) + with open(os.path.join(grasp_dir, "map_uuid_to_path.json"), "w") as f: + json.dump({OBJECT_KEY: "banana_grasps.json"}, f) + + # Onpolicy rollouts + recorder = GraspRecorder(object_id=OBJECT_KEY, output_dir=rollout_dir) + for i in range(8): + pc = rng.standard_normal((512, 3)).astype(np.float32) * 0.05 + T = np.eye(4); T[2, 3] = -GRIPPER_DEPTH + recorder.record(pc, T, confidence=0.8, success=(i % 2 == 0), collided=False) + recorder.save() + h5_path = os.path.join(rollout_dir, "rollouts.h5") + json_path = os.path.join(rollout_dir, "grasps.json") + + # Cache JSON so constructor skips UUID scan + real_cache_dir = os.path.join(cache_dir, os.path.basename(dataset_dir)) + os.makedirs(real_cache_dir, exist_ok=True) + cache_json = os.path.join(real_cache_dir, f"{os.path.basename(rollout_dir)}.json") + with open(cache_json, "w") as f: + json.dump({OBJECT_KEY: json_path}, f) + + # train.txt + with open(os.path.join(dataset_dir, "train.txt"), "w") as f: + f.write(OBJECT_KEY + "\n") + + return object_dir, grasp_dir, rollout_dir, dataset_dir, cache_dir, h5_path + + +def run(): + import torch + from grasp_gen.dataset.dataset import ObjectPickDataset + from grasp_gen.models.discriminator import GraspGenDiscriminator + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"[device] {device}") + + with tempfile.TemporaryDirectory() as root: + sys.path.insert(0, os.path.join(os.path.dirname(__file__))) + object_dir, grasp_dir, rollout_dir, dataset_dir, cache_dir, h5_path = setup_dataset(root) + + # ── Dataset ────────────────────────────────────────────────────────── + print("\n[1] Building dataset...") + dataset = ObjectPickDataset( + root_dir=dataset_dir, + cache_dir=cache_dir, + split="train", + tasks=[], + num_points=NUM_POINTS, + num_obj_points=NUM_POINTS, + cam_coord=False, + num_rotations=1, + grid_res=0.005, + jitter_scale=0.0, + contact_radius=0.005, + dist_above_table=0.0, + offset_bins=None, + robot_prob=0.0, + random_seed=42, + object_root_dir=object_dir, + grasp_root_dir=grasp_dir, + dataset_version="v2", + gripper_name=GRIPPER_NAME, + num_grasps_per_object=NUM_GRASPS, + load_discriminator_dataset=True, + prob_point_cloud=0.0, # mesh_mode=True avoids pyrender (no display needed) + onpolicy_dataset_h5_path=h5_path, + onpolicy_dataset_dir=rollout_dir, + discriminator_ratio=[0.30, 0.10, 0.15, 0.05, 0.0, 0.20, 0.20], + ) + print(f" {len(dataset)} items") + + # ── Load one batch ──────────────────────────────────────────────────── + print("\n[2] Loading one batch...") + item = dataset[0] + assert not item.get("invalid", False), "batch returned invalid" + print(f" batch keys: {sorted(item.keys())}") + + # ── Load discriminator model ────────────────────────────────────────── + print("\n[3] Loading discriminator...") + assert os.path.exists(CKPT_PATH), f"Checkpoint not found: {CKPT_PATH}" + + from omegaconf import OmegaConf + cfg_path = os.path.join(os.path.dirname(__file__), "config.yaml") + cfg = OmegaConf.load(cfg_path) + model = GraspGenDiscriminator.from_config(cfg.discriminator).to(device) + + ckpt = torch.load(CKPT_PATH, map_location=device) + state = ckpt.get("model_state_dict", ckpt) + model.load_state_dict(state, strict=False) + model.train() + print(f" loaded from {CKPT_PATH}") + + # ── Forward + backward pass ─────────────────────────────────────────── + print("\n[4] Forward + backward pass...") + from grasp_gen.utils.train_utils import to_gpu + optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + + # Mimic collate_batch_keys from dataset.py for a single-item batch: + # - "points", "inputs", "seg", "cam_pose" etc. → stacked tensor [1, ...] + # - "grasps", "labels", "grasp_ids" etc. → list of tensors [tensor] + STACKED_KEYS = {"inputs", "points", "seg", "object_inputs", "bottom_center", + "cam_pose", "ee_pose", "placement_masks", "placement_region"} + LIST_KEYS = {"grasps", "labels", "grasp_ids", "positive_grasps", + "negative_grasps", "grasps_ground_truth", "grasps_highres"} + + data = {} + for k, v in item.items(): + if not torch.is_tensor(v): + data[k] = v + elif k in STACKED_KEYS: + data[k] = v.unsqueeze(0) # [1, ...] + elif k in LIST_KEYS: + data[k] = [v] # list of 1 tensor, as collate produces + else: + data[k] = v.unsqueeze(0) # default: stack + to_gpu(data) + + optimizer.zero_grad() + outputs, losses, stats = model(data, None) + loss = sum(w * v for w, v in losses.values()) + loss.backward() + optimizer.step() + + assert torch.isfinite(loss), f"Loss is not finite: {loss.item()}" + + print(f"\n[PASS] One training step completed") + print(f" loss = {loss.item():.6f}") + print(f" grasp_ids = {item['grasp_ids'].unique().tolist()}") + + print("\n" + "=" * 60) + print("TRAINING STEP TEST PASSED") + print("=" * 60) + + +def setup_meshless_dataset(root): + """Minimal dataset with NO mesh — uses GraspRecorder.write_cache() directly.""" + from scripts.record_grasp import GraspRecorder + from grasp_gen.dataset.dataset import get_cache_prefix + + MESHLESS_KEY = "new_real_object.obj" + + dataset_dir = os.path.join(root, "dataset") + cache_dir = os.path.join(root, "cache") + rollout_dir = os.path.join(root, "onpolicy") + for d in [dataset_dir, cache_dir, rollout_dir]: + os.makedirs(d, exist_ok=True) + + # Record rollouts with stored point clouds + rng = np.random.default_rng(1) + recorder = GraspRecorder(object_id=MESHLESS_KEY, output_dir=rollout_dir) + for i in range(8): + pc = rng.standard_normal((512, 3)).astype(np.float32) * 0.05 + T = np.eye(4); T[2, 3] = -GRIPPER_DEPTH + recorder.record(pc, T, confidence=0.8, success=(i % 2 == 0), collided=False) + recorder.save() + + # Write directly to GraspGenDatasetCache format so preload_dataset=True works + real_cache_dir = os.path.join(cache_dir, os.path.basename(dataset_dir)) + os.makedirs(real_cache_dir, exist_ok=True) + prefix = get_cache_prefix(prob_point_cloud=0.0, load_discriminator_dataset=True) + cache_h5 = os.path.join(real_cache_dir, f"cache_train_{prefix}.h5") + recorder.write_cache(cache_h5, num_renderings=5) + + with open(os.path.join(dataset_dir, "train.txt"), "w") as f: + f.write(MESHLESS_KEY + "\n") + + return dataset_dir, cache_dir, cache_h5, MESHLESS_KEY + + +def run_meshless(): + """One training step for a meshless real-world object using preload_dataset=True.""" + import torch + from grasp_gen.dataset.dataset import ObjectPickDataset + from grasp_gen.models.discriminator import GraspGenDiscriminator + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"[device] {device}") + + with tempfile.TemporaryDirectory() as root: + sys.path.insert(0, os.path.join(os.path.dirname(__file__))) + dataset_dir, cache_dir, cache_h5, key = setup_meshless_dataset(root) + + # Dummy dirs — not used since preload_dataset=True reads from cache + dummy_obj_dir = os.path.join(root, "objects") + dummy_grasp_dir = os.path.join(root, "grasps") + os.makedirs(dummy_obj_dir, exist_ok=True) + os.makedirs(dummy_grasp_dir, exist_ok=True) + with open(os.path.join(dummy_grasp_dir, "map_uuid_to_path.json"), "w") as f: + json.dump({}, f) + + print("\n[1] Building meshless dataset (preload_dataset=True)...") + dataset = ObjectPickDataset( + root_dir=dataset_dir, + cache_dir=cache_dir, + split="train", + tasks=[], + num_points=NUM_POINTS, + num_obj_points=NUM_POINTS, + cam_coord=False, + num_rotations=1, + grid_res=0.005, + jitter_scale=0.0, + contact_radius=0.005, + dist_above_table=0.0, + offset_bins=None, + robot_prob=0.0, + random_seed=42, + object_root_dir=dummy_obj_dir, + grasp_root_dir=dummy_grasp_dir, + dataset_version="v2", + gripper_name=GRIPPER_NAME, + num_grasps_per_object=NUM_GRASPS, + load_discriminator_dataset=True, + prob_point_cloud=0.0, + preload_dataset=True, # production default + discriminator_ratio=[0.30, 0.10, 0.15, 0.05, 0.0, 0.20, 0.20], + ) + print(f" {len(dataset)} items in cache") + + print("\n[2] Loading one batch (preload_dataset=True, meshless)...") + item = dataset[0] + assert not item.get("invalid", False), "batch returned invalid" + print(f" batch keys: {sorted(item.keys())}") + + print("\n[3] Loading discriminator...") + assert os.path.exists(CKPT_PATH), f"Checkpoint not found: {CKPT_PATH}" + from omegaconf import OmegaConf + cfg_path = os.path.join(os.path.dirname(__file__), "config.yaml") + cfg = OmegaConf.load(cfg_path) + model = GraspGenDiscriminator.from_config(cfg.discriminator).to(device) + ckpt = torch.load(CKPT_PATH, map_location=device) + state = ckpt.get("model_state_dict", ckpt) + model.load_state_dict(state, strict=False) + model.train() + + print("\n[4] Forward + backward pass (meshless, preload_dataset=True)...") + from grasp_gen.utils.train_utils import to_gpu + optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + + STACKED_KEYS = {"inputs", "points", "seg", "object_inputs", "bottom_center", + "cam_pose", "ee_pose", "placement_masks", "placement_region"} + LIST_KEYS = {"grasps", "labels", "grasp_ids", "positive_grasps", + "negative_grasps", "grasps_ground_truth", "grasps_highres"} + + data = {} + for k, v in item.items(): + if not torch.is_tensor(v): + data[k] = v + elif k in STACKED_KEYS: + data[k] = v.unsqueeze(0) + elif k in LIST_KEYS: + data[k] = [v] + else: + data[k] = v.unsqueeze(0) + to_gpu(data) + + optimizer.zero_grad() + outputs, losses, stats = model(data, None) + loss = sum(w * v for w, v in losses.values()) + loss.backward() + optimizer.step() + + assert torch.isfinite(loss), f"Loss is not finite: {loss.item()}" + + print(f"\n[PASS] Meshless training step completed (preload_dataset=True)") + print(f" loss = {loss.item():.6f}") + print(f" grasp_ids = {item['grasp_ids'].unique().tolist()}") + + print("\n" + "=" * 60) + print("MESHLESS TRAINING STEP TEST PASSED") + print("=" * 60) + + +if __name__ == "__main__": + run() + run_meshless() diff --git a/tests/test_record_grasp.py b/tests/test_record_grasp.py new file mode 100644 index 0000000..222126d --- /dev/null +++ b/tests/test_record_grasp.py @@ -0,0 +1,343 @@ +"""Tests for GraspRecorder (scripts/record_grasp.py).""" + +import json +import os +import subprocess +import sys +import tempfile + +import h5py +import numpy as np +import pytest + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) +from scripts.record_grasp import GraspRecorder + + +def make_random_grasp(): + T = np.eye(4) + T[:3, 3] = np.random.randn(3) * 0.1 + Q, _ = np.linalg.qr(np.random.randn(3, 3)) + T[:3, :3] = Q + return T + + +def make_random_point_cloud(n=512): + return np.random.randn(n, 3).astype(np.float32) * 0.1 + + +def test_save_creates_files(): + with tempfile.TemporaryDirectory() as tmpdir: + rec = GraspRecorder("banana", tmpdir) + rec.record(make_random_point_cloud(), make_random_grasp(), 0.9, success=True) + rec.record(make_random_point_cloud(), make_random_grasp(), 0.3, success=False) + rec.save() + assert os.path.exists(os.path.join(tmpdir, "rollouts.h5")) + assert os.path.exists(os.path.join(tmpdir, "grasps.json")) + + +def test_h5_fields_exist(): + with tempfile.TemporaryDirectory() as tmpdir: + rec = GraspRecorder("banana", tmpdir) + rec.record(make_random_point_cloud(), make_random_grasp(), 0.9, success=True) + rec.save() + with h5py.File(os.path.join(tmpdir, "rollouts.h5"), "r") as f: + grp = f["objects/banana"] + for field in ["pred_grasps", "gt_grasps", "confidence", "collision", "point_clouds"]: + assert field in grp + + +def test_h5_shapes(): + N = 5 + with tempfile.TemporaryDirectory() as tmpdir: + rec = GraspRecorder("banana", tmpdir) + for _ in range(N): + rec.record(make_random_point_cloud(512), make_random_grasp(), 0.5, success=True) + rec.save() + with h5py.File(os.path.join(tmpdir, "rollouts.h5"), "r") as f: + grp = f["objects/banana"] + assert grp["pred_grasps"].shape == (N, 4, 4) + assert grp["confidence"].shape == (N,) + assert grp["point_clouds"].shape == (N, 512, 3) + + +def test_critical_constraint_no_collisions(): + """len(json transforms) == sum(~collision) — required by load_onpolicy_dataset.""" + with tempfile.TemporaryDirectory() as tmpdir: + rec = GraspRecorder("banana", tmpdir) + rec.record(make_random_point_cloud(), make_random_grasp(), 0.9, success=True, collided=False) + rec.record(make_random_point_cloud(), make_random_grasp(), 0.8, success=False, collided=False) + rec.record(make_random_point_cloud(), make_random_grasp(), 0.2, success=False, collided=True) + rec.save() + + with h5py.File(os.path.join(tmpdir, "rollouts.h5"), "r") as f: + n_noncolliding = int(np.logical_not(f["objects/banana"]["collision"][...]).sum()) + with open(os.path.join(tmpdir, "grasps.json")) as f: + n_transforms = len(json.load(f)["grasps"]["transforms"]) + + assert n_noncolliding == n_transforms + + +def test_success_labels_correct(): + with tempfile.TemporaryDirectory() as tmpdir: + rec = GraspRecorder("banana", tmpdir) + rec.record(make_random_point_cloud(), make_random_grasp(), 0.9, success=True, collided=False) + rec.record(make_random_point_cloud(), make_random_grasp(), 0.5, success=False, collided=False) + rec.record(make_random_point_cloud(), make_random_grasp(), 0.1, success=False, collided=True) + rec.save() + with open(os.path.join(tmpdir, "grasps.json")) as f: + labels = json.load(f)["grasps"]["object_in_gripper"] + assert labels == [1, 0] + + +def test_empty_recorder_raises(): + with tempfile.TemporaryDirectory() as tmpdir: + with pytest.raises(AssertionError): + GraspRecorder("banana", tmpdir).save() + + +def test_wrong_grasp_shape_raises(): + with tempfile.TemporaryDirectory() as tmpdir: + with pytest.raises(AssertionError): + GraspRecorder("banana", tmpdir).record( + make_random_point_cloud(), np.eye(3), 0.5, success=True) + + +def test_wrong_point_cloud_shape_raises(): + with tempfile.TemporaryDirectory() as tmpdir: + with pytest.raises(AssertionError): + GraspRecorder("banana", tmpdir).record( + np.random.randn(512), make_random_grasp(), 0.5, success=True) + + +def test_load_onpolicy_dataset_logic(): + """Round-trip: save then manually run load_onpolicy_dataset() logic.""" + with tempfile.TemporaryDirectory() as tmpdir: + rec = GraspRecorder("banana", tmpdir) + for _ in range(3): + rec.record(make_random_point_cloud(), make_random_grasp(), 0.9, success=True, collided=False) + for _ in range(2): + rec.record(make_random_point_cloud(), make_random_grasp(), 0.3, success=False, collided=False) + rec.record(make_random_point_cloud(), make_random_grasp(), 0.1, success=False, collided=True) + rec.save() + + h5 = h5py.File(os.path.join(tmpdir, "rollouts.h5"), "r") + h5_obj = h5["objects"]["banana"] + pred_grasps = h5_obj["pred_grasps"][...] + scores = h5_obj["confidence"][...] + collision = h5_obj["collision"][...] + mask_not_colliding = np.logical_not(collision) + + data = json.load(open(os.path.join(tmpdir, "grasps.json"), "rb")) + assert mask_not_colliding.sum() == len(data["grasps"]["transforms"]) + + mask_eval_success = np.array(data["grasps"]["object_in_gripper"]) + success_result = np.zeros(len(scores)) + success_result[np.where(mask_not_colliding)[0][np.where(mask_eval_success)[0]]] = 1.0 + + positive_grasps = pred_grasps[success_result.astype(np.bool_)] + negative_grasps = pred_grasps[~success_result.astype(np.bool_)] + + assert positive_grasps.shape == (3, 4, 4) + assert negative_grasps.shape == (3, 4, 4) + + +def test_write_cache_creates_file(): + with tempfile.TemporaryDirectory() as tmpdir: + rec = GraspRecorder("banana", tmpdir) + rec.record(make_random_point_cloud(), make_random_grasp(), 0.9, success=True) + rec.record(make_random_point_cloud(), make_random_grasp(), 0.3, success=False) + rec.save() + cache_path = os.path.join(tmpdir, "cache.h5") + rec.write_cache(cache_path, num_renderings=2) + assert os.path.exists(cache_path) + + +def test_write_cache_loadable(): + with tempfile.TemporaryDirectory() as tmpdir: + rec = GraspRecorder("banana", tmpdir) + for _ in range(4): + rec.record(make_random_point_cloud(), make_random_grasp(), 0.9, success=True) + rec.record(make_random_point_cloud(), make_random_grasp(), 0.3, success=False) + rec.save() + cache_path = os.path.join(tmpdir, "cache.h5") + rec.write_cache(cache_path, num_renderings=3) + + from grasp_gen.dataset.dataset_utils import GraspGenDatasetCache + cache = GraspGenDatasetCache.load_from_h5_file(cache_path) + assert "banana" in cache + _, renderings = cache["banana"] + assert len(renderings) == 3 + + +def test_write_cache_shapes(): + N_POINTS, N_POS, N_NEG, N_RENDERINGS = 256, 3, 2, 4 + with tempfile.TemporaryDirectory() as tmpdir: + rec = GraspRecorder("mug", tmpdir) + for _ in range(N_POS): + rec.record(make_random_point_cloud(N_POINTS), make_random_grasp(), 0.9, success=True) + for _ in range(N_NEG): + rec.record(make_random_point_cloud(N_POINTS), make_random_grasp(), 0.2, success=False) + rec.save() + cache_path = os.path.join(tmpdir, "cache.h5") + rec.write_cache(cache_path, num_renderings=N_RENDERINGS) + + from grasp_gen.dataset.dataset_utils import GraspGenDatasetCache + gd, renderings = GraspGenDatasetCache.load_from_h5_file(cache_path)["mug"] + assert gd.positive_grasps.shape == (N_POS, 4, 4) + assert gd.negative_grasps.shape == (N_NEG, 4, 4) + assert gd.positive_grasps_onpolicy.shape == (N_POS, 4, 4) + assert gd.negative_grasps_onpolicy.shape == (N_NEG, 4, 4) + assert len(renderings) == N_RENDERINGS + for r in renderings: + assert r["points"].shape == (N_POINTS, 3) + assert r["positive_grasps"].shape == (N_POS, 4, 4) + + +def test_write_cache_point_clouds_centered(): + with tempfile.TemporaryDirectory() as tmpdir: + rec = GraspRecorder("apple", tmpdir) + for _ in range(3): + rec.record(make_random_point_cloud(), make_random_grasp(), 0.8, success=True) + rec.save() + cache_path = os.path.join(tmpdir, "cache.h5") + rec.write_cache(cache_path, num_renderings=3) + + from grasp_gen.dataset.dataset_utils import GraspGenDatasetCache + _, renderings = GraspGenDatasetCache.load_from_h5_file(cache_path)["apple"] + for r in renderings: + np.testing.assert_allclose(r["points"].mean(axis=0), 0.0, atol=1e-4) + + +def test_from_h5_round_trip_attempts(): + with tempfile.TemporaryDirectory() as tmpdir: + rec = GraspRecorder("banana", tmpdir) + for i in range(6): + rec.record(make_random_point_cloud(), make_random_grasp(), + 0.9, success=(i % 2 == 0), collided=(i == 5)) + rec.save() + restored = GraspRecorder.from_h5(os.path.join(tmpdir, "rollouts.h5"), "banana", tmpdir) + assert len(restored) == len(rec) + + +def test_from_h5_round_trip_success_labels(): + with tempfile.TemporaryDirectory() as tmpdir: + rec = GraspRecorder("banana", tmpdir) + rec.record(make_random_point_cloud(), make_random_grasp(), 0.9, success=True, collided=False) + rec.record(make_random_point_cloud(), make_random_grasp(), 0.5, success=False, collided=False) + rec.record(make_random_point_cloud(), make_random_grasp(), 0.1, success=False, collided=True) + rec.save() + restored = GraspRecorder.from_h5(os.path.join(tmpdir, "rollouts.h5"), "banana", tmpdir) + assert restored.attempts[0].success is True + assert restored.attempts[0].collided is False + assert restored.attempts[1].success is False + assert restored.attempts[2].collided is True + + +def test_from_h5_write_cache_matches_direct(): + """from_h5 + write_cache should produce same shapes as direct write_cache.""" + from grasp_gen.dataset.dataset_utils import GraspGenDatasetCache + with tempfile.TemporaryDirectory() as tmpdir: + rec = GraspRecorder("banana", tmpdir) + for i in range(5): + rec.record(make_random_point_cloud(256), make_random_grasp(), + 0.9, success=(i < 3), collided=False) + rec.save() + + cache_direct = os.path.join(tmpdir, "cache_direct.h5") + rec.write_cache(cache_direct, num_renderings=3) + + restored = GraspRecorder.from_h5(os.path.join(tmpdir, "rollouts.h5"), "banana", tmpdir) + cache_restored = os.path.join(tmpdir, "cache_restored.h5") + restored.write_cache(cache_restored, num_renderings=3) + + gd_d, rd_d = GraspGenDatasetCache.load_from_h5_file(cache_direct)["banana"] + gd_r, rd_r = GraspGenDatasetCache.load_from_h5_file(cache_restored)["banana"] + + assert gd_d.positive_grasps.shape == gd_r.positive_grasps.shape + assert gd_d.negative_grasps.shape == gd_r.negative_grasps.shape + assert len(rd_d) == len(rd_r) + + +def test_write_cache_two_objects_same_file(): + """Two objects written to the same cache file should both be readable.""" + from grasp_gen.dataset.dataset_utils import GraspGenDatasetCache + with tempfile.TemporaryDirectory() as tmpdir: + cache_path = os.path.join(tmpdir, "cache.h5") + for obj_id, n_pos, n_neg in [("apple", 3, 2), ("mug", 4, 1)]: + obj_dir = os.path.join(tmpdir, obj_id) + rec = GraspRecorder(obj_id, obj_dir) + for _ in range(n_pos): + rec.record(make_random_point_cloud(), make_random_grasp(), 0.9, success=True) + for _ in range(n_neg): + rec.record(make_random_point_cloud(), make_random_grasp(), 0.2, success=False) + rec.save() + rec.write_cache(cache_path, num_renderings=2) + + cache = GraspGenDatasetCache.load_from_h5_file(cache_path) + assert "apple" in cache + assert "mug" in cache + assert cache["apple"][0].positive_grasps.shape[0] == 3 + assert cache["mug"][0].positive_grasps.shape[0] == 4 + + +def test_write_cache_multi_object_shapes_independent(): + """Point cloud shapes for different objects don't bleed into each other.""" + from grasp_gen.dataset.dataset_utils import GraspGenDatasetCache + with tempfile.TemporaryDirectory() as tmpdir: + cache_path = os.path.join(tmpdir, "cache.h5") + for obj_id, n_pts in [("small_obj", 128), ("large_obj", 512)]: + obj_dir = os.path.join(tmpdir, obj_id) + rec = GraspRecorder(obj_id, obj_dir) + rec.record(make_random_point_cloud(n_pts), make_random_grasp(), 0.9, success=True) + rec.record(make_random_point_cloud(n_pts), make_random_grasp(), 0.2, success=False) + rec.save() + rec.write_cache(cache_path, num_renderings=1) + + cache = GraspGenDatasetCache.load_from_h5_file(cache_path) + assert cache["small_obj"][1][0]["points"].shape == (128, 3) + assert cache["large_obj"][1][0]["points"].shape == (512, 3) + + +def test_prepare_finetune_prints_valid_config_keys(): + """prepare_finetune.py output must use config keys that actually exist.""" + with tempfile.TemporaryDirectory() as tmpdir: + rec = GraspRecorder("banana", tmpdir) + for i in range(6): + rec.record(make_random_point_cloud(), make_random_grasp(), + 0.9, success=(i < 3), collided=False) + rec.save() + + dataset_dir = os.path.join(tmpdir, "dataset") + os.makedirs(dataset_dir, exist_ok=True) + with open(os.path.join(dataset_dir, "train.txt"), "w") as f: + f.write("banana\n") + + result = subprocess.run( + [sys.executable, + os.path.join(os.path.dirname(__file__), "..", "scripts", "prepare_finetune.py"), + "--rollouts", tmpdir, + "--cache_dir", os.path.join(tmpdir, "cache"), + "--dataset_dir", dataset_dir, + "--object_id", "banana"], + capture_output=True, text=True + ) + assert result.returncode == 0, result.stderr + out = result.stdout + assert "train.model_name=discriminator" in out + assert "data.load_discriminator_dataset=true" in out + assert "data.discriminator_ratio=" in out + assert "discriminator.checkpoint=" in out + assert "train.checkpoint=" in out + + +def test_asset_path_written_to_h5(): + """asset_path must be in HDF5 — load_onpolicy_dataset() reads it to build UUID mapping.""" + with tempfile.TemporaryDirectory() as tmpdir: + rec = GraspRecorder("banana", tmpdir) + rec.record(make_random_point_cloud(), make_random_grasp(), 0.9, success=True) + rec.save() + with h5py.File(os.path.join(tmpdir, "rollouts.h5"), "r") as f: + grp = f["objects"]["banana"] + assert "asset_path" in grp + assert grp["asset_path"][...].item().decode("utf-8") == "banana"