Skip to content

Commit 8175292

Browse files
committed
update ioai grasp env to use vision model interface
1 parent e73ea24 commit 8175292

1 file changed

Lines changed: 281 additions & 12 deletions

File tree

examples/ioai_examples/ioai_grasp_env.py

Lines changed: 281 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,25 +48,134 @@
4848
from physics_simulator.utils.data_types import JointTrajectory
4949
import time
5050
import os
51+
from typing import Dict, List, Tuple, Optional, Any
52+
from dataclasses import dataclass
5153

5254
from physics_simulator.utils.state_machine import SimpleStateMachine
5355

56+
@dataclass
57+
class DetectedObject:
58+
"""Data class for detected object information"""
59+
class_name: str
60+
position: np.ndarray # [x, y, z] in camera frame
61+
orientation: np.ndarray # [qx, qy, qz, qw] in camera frame
62+
confidence: float
63+
bbox: Optional[np.ndarray] = None # [x1, y1, x2, y2] if available
64+
65+
class VisionModelInterface:
66+
"""Interface for vision model that detects objects and returns their poses"""
67+
68+
def __init__(self):
69+
"""Initialize the vision model interface"""
70+
pass
71+
72+
def detect_objects(self, rgb_image: np.ndarray, depth_image: Optional[np.ndarray] = None) -> List[DetectedObject]:
73+
"""
74+
Detect objects in the image and return their poses in camera frame
75+
76+
Args:
77+
rgb_image: RGB image from camera
78+
depth_image: Depth image from camera (optional)
79+
80+
Returns:
81+
List of detected objects with their poses in camera frame
82+
"""
83+
# This is a placeholder implementation
84+
# Replace this with your actual vision model
85+
raise NotImplementedError("Subclass must implement detect_objects method")
86+
87+
class DummyYoloSegmentationModel(VisionModelInterface):
88+
"""Dummy YOLO segmentation model that uses ground truth from simulator"""
89+
90+
def __init__(self, simulator, robot):
91+
super().__init__()
92+
self.simulator = simulator
93+
self.robot = robot
94+
self.object_classes = ["cube", "bin"] # Supported object classes
95+
96+
def detect_objects(self, rgb_image: np.ndarray, depth_image: Optional[np.ndarray] = None) -> List[DetectedObject]:
97+
"""
98+
Dummy YOLO segmentation detection using ground truth
99+
"""
100+
detected_objects = []
101+
102+
# Get ground truth poses for supported objects
103+
for obj_class in self.object_classes:
104+
# Get object state from simulator
105+
obj_state = self.simulator.get_object_state(f"/World/{obj_class.capitalize()}")
106+
world_position = obj_state["position"]
107+
world_orientation = obj_state["orientation"]
108+
109+
# Transform from world frame to camera frame
110+
camera_position, camera_orientation = self._world_to_camera_frame(
111+
world_position, world_orientation
112+
)
113+
114+
# Create detected object
115+
detected_obj = DetectedObject(
116+
class_name=obj_class,
117+
position=camera_position,
118+
orientation=camera_orientation,
119+
confidence=0.95, # High confidence for ground truth
120+
bbox=np.array([100, 100, 200, 200]) # Dummy bbox
121+
)
122+
detected_objects.append(detected_obj)
123+
124+
return detected_objects
125+
126+
def _world_to_camera_frame(self, world_position, world_orientation):
127+
"""Transform pose from world frame to camera frame"""
128+
from scipy.spatial.transform import Rotation
129+
130+
# Get camera pose in world frame
131+
camera_prim_path = "/World/Galbot/head_link2/head_end_effector_mount_link/front_head_rgb_camera"
132+
camera_state = self.simulator.get_sensor_state(camera_prim_path)
133+
camera_world_position = camera_state["transform_to_base_link"]["position"]
134+
camera_world_orientation = camera_state["transform_to_base_link"]["orientation"]
135+
136+
# Create transformation matrices
137+
camera_world_rot = Rotation.from_quat(camera_world_orientation)
138+
world_rot = Rotation.from_quat(world_orientation)
139+
140+
# Transform position: subtract camera position and rotate
141+
relative_position = world_position - camera_world_position
142+
camera_position = camera_world_rot.inv().apply(relative_position)
143+
144+
# Transform orientation: compose rotations
145+
camera_orientation = (camera_world_rot.inv() * world_rot).as_quat()
146+
147+
return camera_position, camera_orientation
148+
54149
def interpolate_joint_positions(start_positions, end_positions, steps):
55150
return np.linspace(start_positions, end_positions, steps).tolist()
56151

57152
class IoaiGraspEnv:
58-
def __init__(self, headless=False):
153+
def __init__(self, headless=False, vision_model: Optional[VisionModelInterface] = None):
59154
"""
60155
Initialize the Olympic environment.
61156
62157
Args:
63158
headless: Whether to run in headless mode (without visualization)
159+
vision_model: Vision model for object detection (optional)
64160
"""
65161
self.simulator = None
66162
self.robot = None
163+
164+
# Initialize vision model
165+
self.vision_model = vision_model if vision_model is not None else None
166+
167+
# Vision-related variables
168+
self.detected_objects = []
169+
self.last_detection_time = 0
170+
self.detection_interval = 0.1 # Detection frequency in seconds
67171

68172
# Setup the simulator
69173
self._setup_simulator(headless=headless)
174+
175+
# Initialize vision model after simulator setup
176+
if self.vision_model is None:
177+
self.vision_model = DummyYoloSegmentationModel(self.simulator, self.robot)
178+
70179
# Setup the interface
71180
self._setup_interface()
72181
self._init_pose()
@@ -181,7 +290,7 @@ def _setup_simulator(self, headless=False):
181290

182291
# Add bin
183292
bin_config = MeshConfig(
184-
prim_path="/World/bin",
293+
prim_path="/World/Bin",
185294
mjcf_path=Path()
186295
.joinpath(self.simulator.synthnova_assets_directory)
187296
.joinpath("synthnova_assets")
@@ -206,7 +315,7 @@ def _setup_simulator(self, headless=False):
206315
# Initialize the simulator
207316
self.simulator.initialize()
208317

209-
bin_state = self.simulator.get_object_state("/World/bin")
318+
bin_state = self.simulator.get_object_state("/World/Bin")
210319
self.bin_position = bin_state["position"]
211320
self.bin_orientation = bin_state["orientation"]
212321

@@ -392,6 +501,135 @@ def robot_to_world_frame(self, robot_position, robot_orientation):
392501

393502
return world_position, world_orientation
394503

504+
def camera_to_world_frame(self, camera_position, camera_orientation):
505+
"""Transform pose from camera frame to world frame.
506+
507+
Args:
508+
camera_position: Position in camera frame [x, y, z]
509+
camera_orientation: Orientation in camera frame [qx, qy, qz, qw]
510+
511+
Returns:
512+
Tuple of (world_position, world_orientation) in world frame
513+
"""
514+
from scipy.spatial.transform import Rotation
515+
516+
# Get camera pose in world frame
517+
camera_prim_path = self.front_head_rgb_camera_path
518+
camera_state = self.simulator.get_sensor_state(camera_prim_path)
519+
camera_world_position = camera_state["transform_to_base_link"]["position"]
520+
camera_world_orientation = camera_state["transform_to_base_link"]["orientation"]
521+
522+
# Create transformation matrices
523+
camera_world_rot = Rotation.from_quat(camera_world_orientation)
524+
camera_local_rot = Rotation.from_quat(camera_orientation)
525+
526+
# Transform position: rotate and add camera world position
527+
world_position = camera_world_rot.apply(camera_position) + camera_world_position
528+
529+
# Transform orientation: compose rotations
530+
world_orientation = (camera_world_rot * camera_local_rot).as_quat()
531+
532+
return world_position, world_orientation
533+
534+
def world_to_camera_frame(self, world_position, world_orientation):
535+
"""Transform pose from world frame to camera frame.
536+
537+
Args:
538+
world_position: Position in world frame [x, y, z]
539+
world_orientation: Orientation in world frame [qx, qy, qz, qw]
540+
541+
Returns:
542+
Tuple of (camera_position, camera_orientation) in camera frame
543+
"""
544+
from scipy.spatial.transform import Rotation
545+
546+
# Get camera pose in world frame
547+
camera_prim_path = self.front_head_rgb_camera_path
548+
camera_state = self.simulator.get_sensor_state(camera_prim_path)
549+
camera_world_position = camera_state["position"]
550+
camera_world_orientation = camera_state["orientation"]
551+
552+
# Create transformation matrices
553+
camera_world_rot = Rotation.from_quat(camera_world_orientation)
554+
world_rot = Rotation.from_quat(world_orientation)
555+
556+
# Transform position: subtract camera position and rotate
557+
relative_position = world_position - camera_world_position
558+
camera_position = camera_world_rot.inv().apply(relative_position)
559+
560+
# Transform orientation: compose rotations
561+
camera_orientation = (camera_world_rot.inv() * world_rot).as_quat()
562+
563+
return camera_position, camera_orientation
564+
565+
def get_camera_images(self):
566+
"""Get RGB and depth images from the front head camera.
567+
568+
Returns:
569+
Tuple of (rgb_image, depth_image) or (rgb_image, None) if depth not available
570+
"""
571+
try:
572+
# Get RGB image
573+
rgb_image = self.interface.front_head_camera.get_rgb()
574+
575+
# Get depth image if available
576+
depth_image = None
577+
try:
578+
depth_image = self.interface.front_head_camera.get_depth()
579+
except:
580+
pass # Depth image not available
581+
582+
return rgb_image, depth_image
583+
except Exception as e:
584+
print(f"Error getting camera images: {e}")
585+
return None, None
586+
587+
def detect_objects_vision(self) -> List[DetectedObject]:
588+
"""Detect objects using vision model"""
589+
current_time = time.time()
590+
591+
# Check detection frequency
592+
if current_time - self.last_detection_time < self.detection_interval:
593+
return self.detected_objects
594+
595+
# Get camera images
596+
rgb_image, depth_image = self.get_camera_images()
597+
598+
if rgb_image is None:
599+
return self.detected_objects
600+
601+
# Run vision model detection
602+
detected_objects = self.vision_model.detect_objects(rgb_image, depth_image)
603+
604+
# Update detection results
605+
self.detected_objects = detected_objects
606+
self.last_detection_time = current_time
607+
608+
return detected_objects
609+
610+
def get_object_pose_from_vision(self, target_class: str = "cube") -> Optional[Tuple[np.ndarray, np.ndarray]]:
611+
"""Get object pose from vision detection"""
612+
# Detect objects using vision
613+
detected_objects = self.detect_objects_vision()
614+
615+
# Find target object
616+
target_object = None
617+
for obj in detected_objects:
618+
if obj.class_name.lower() == target_class.lower():
619+
target_object = obj
620+
break
621+
622+
if target_object is None:
623+
print(f"Target object '{target_class}' not detected")
624+
return None
625+
626+
# Transform from camera frame to world frame
627+
world_position, world_orientation = self.camera_to_world_frame(
628+
target_object.position, target_object.orientation
629+
)
630+
631+
return world_position, world_orientation
632+
395633
def compute_simple_ik(self, start_joint, target_pose, arm_id="left_arm"):
396634
"""Compute inverse kinematics using Mink.
397635
@@ -669,22 +907,42 @@ def init_state():
669907
def move_to_pre_pick_state():
670908
"""Move to pre-pick position"""
671909
if self.state_first_entry:
672-
cube_state = self.simulator.get_object_state("/World/Cube")
673-
self.cube_position = cube_state["position"].copy()
910+
# Use vision model to detect object pose instead of ground truth
911+
vision_result = self.get_object_pose_from_vision("cube")
912+
if vision_result is not None:
913+
world_pos, world_ori = vision_result
914+
self.cube_position = world_pos.copy()
915+
self.cube_orientation = world_ori.copy()
916+
print(f"Vision detected cube at position: {world_pos}")
917+
else:
918+
# Fallback to ground truth if vision fails
919+
cube_state = self.simulator.get_object_state("/World/Cube")
920+
self.cube_position = cube_state["position"].copy()
921+
self.cube_orientation = cube_state["orientation"].copy()
922+
print("Using ground truth fallback for cube position")
674923
self.state_first_entry = False
675924

676925
# Convert world frame pose to robot frame
677926
world_pos = self.cube_position + np.array([0, 0, 0.15])
678-
world_ori = np.array([0, 0.7071, 0, 0.7071])
927+
world_ori = np.array([0, 0.7071, 0, 0.7071]) # Fixed orientation for grasping
679928
robot_pos, robot_ori = self.world_to_robot_frame(world_pos, world_ori)
680929
return self._move_left_arm_to_pose(robot_pos, robot_ori)
681930

682931
def move_to_pick_state():
683932
"""Move to pick position"""
933+
# Re-detect object position for more accurate pick
934+
vision_result = self.get_object_pose_from_vision("cube")
935+
if vision_result is not None:
936+
world_pos, world_ori = vision_result
937+
# Use detected position for more accurate pick
938+
pick_pos = world_pos + np.array([0, 0, 0.03])
939+
else:
940+
# Fallback to stored position
941+
pick_pos = self.cube_position + np.array([0, 0, 0.03])
942+
684943
# Convert world frame pose to robot frame
685-
world_pos = self.cube_position + np.array([0, 0, 0.03])
686-
world_ori = np.array([0, 0.7071, 0, 0.7071])
687-
robot_pos, robot_ori = self.world_to_robot_frame(world_pos, world_ori)
944+
world_ori = np.array([0, 0.7071, 0, 0.7071]) # Fixed orientation for grasping
945+
robot_pos, robot_ori = self.world_to_robot_frame(pick_pos, world_ori)
688946
return self._move_left_arm_to_pose(robot_pos, robot_ori)
689947

690948
def grasp_state():
@@ -710,13 +968,24 @@ def move_to_pre_place_state():
710968
def move_to_place_state():
711969
"""Move to place position"""
712970
if self.state_first_entry:
713-
bin_state = self.simulator.get_object_state("/World/bin")
714-
self.bin_position = bin_state["position"].copy()
971+
# Use vision model to detect bin pose instead of ground truth
972+
vision_result = self.get_object_pose_from_vision("bin")
973+
if vision_result is not None:
974+
world_pos, world_ori = vision_result
975+
self.bin_position = world_pos.copy()
976+
self.bin_orientation = world_ori.copy()
977+
print(f"Vision detected bin at position: {world_pos}")
978+
else:
979+
# Fallback to ground truth if vision fails
980+
bin_state = self.simulator.get_object_state("/World/Bin")
981+
self.bin_position = bin_state["position"].copy()
982+
self.bin_orientation = bin_state["orientation"].copy()
983+
print("Using ground truth fallback for bin position")
715984
self.state_first_entry = False
716985

717986
# Convert world frame pose to robot frame
718987
world_pos = self.bin_position + np.array([0, 0, 0.3])
719-
world_ori = np.array([0, 0.7071, 0, 0.7071])
988+
world_ori = np.array([0, 0.7071, 0, 0.7071]) # Fixed orientation for placing
720989
robot_pos, robot_ori = self.world_to_robot_frame(world_pos, world_ori)
721990
return self._move_left_arm_to_pose(robot_pos, robot_ori)
722991

0 commit comments

Comments
 (0)