Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ python linkage_optimization/get_equivalent_thumb.py -r /path/to/sim -b /path/to/

**Multi-Modal Data Handling**: Vision, proprioception, and force data are processed separately then concatenated for policy conditioning. Missing modalities are handled gracefully.

**Direct Proprioceptive Training**: The system supports direct training on robot proprioceptive data without complex preprocessing. This approach bypasses exoskeleton-to-robot data conversion, avoiding interpolation errors and maintaining data integrity. Robot joint angles, positions, and force readings are fed directly into the network model for training.

**Distributed Training**: Uses Accelerate for multi-GPU training with NCCL backend. Training supports gradient accumulation and EMA model updates.

**Real-Time Control**: Policy evaluation runs in real-time with 30Hz control loop, requiring careful attention to inference latency and action smoothing.
26 changes: 17 additions & 9 deletions dexumi/camera/realsense_camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def get_camera_frame(self):
try:
# Wait for a coherent pair of frames
frames = self.pipeline.wait_for_frames()
receive_time = time.monotonic()
receive_time = time.time() # Changed to wall clock time for consistency with HTTP client

if self.align_to_color and self.enable_depth:
frames = self.align.process(frames)
Expand Down Expand Up @@ -122,19 +122,27 @@ def get_camera_frame(self):

def start_streaming(self):
try:
# Get device info and validate RGB camera
# Get device info and validate color stream support
pipeline_wrapper = rs.pipeline_wrapper(self.pipeline)
pipeline_profile = self.config.resolve(pipeline_wrapper)
device = pipeline_profile.get_device()

found_rgb = False
found_color = False
for sensor in device.sensors:
if sensor.get_info(rs.camera_info.name) == "RGB Camera":
found_rgb = True
break

if not found_rgb:
raise RuntimeError("The RealSense device does not have an RGB camera")
# Check if sensor supports color stream
try:
profiles = sensor.get_stream_profiles()
for profile in profiles:
if profile.stream_type() == rs.stream.color:
found_color = True
break
if found_color:
break
except:
continue

if not found_color:
raise RuntimeError("The RealSense device does not support color stream")

# Start streaming
self.profile = self.pipeline.start(self.config)
Expand Down
116 changes: 86 additions & 30 deletions dexumi/real_env/real_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from dexumi.common.utility.file import read_pickle
from dexumi.common.utility.model import load_config, load_diffusion_model
from dexumi.constants import INPAINT_RESIZE_RATIO
# INPAINT_RESIZE_RATIO is no longer needed since we use center cropping
from dexumi.diffusion_policy.dataloader.diffusion_bc_dataset import (
normalize_data,
process_image,
Expand All @@ -14,33 +14,54 @@


class RealPolicy:
"""
真实环境策略类 - 用于在真实机器人环境中执行动作预测
基于扩散模型的机器人策略,可以处理多模态输入(视觉、本体感受、力传感器数据)
"""
def __init__(
self,
model_path: str,
ckpt: int,
model_path: str, # 模型路径
ckpt: int, # 检查点编号
):
"""
初始化真实环境策略

Args:
model_path: 训练好的模型文件路径
ckpt: 要加载的检查点编号
"""
# 加载模型配置文件
model_cfg = load_config(model_path)

# 加载扩散模型和噪声调度器
model, noise_scheduler = load_diffusion_model(
model_path, ckpt, use_ema=model_cfg.training.use_ema
)

# 加载数据统计信息(用于数据归一化和反归一化)
stats = read_pickle(os.path.join(model_path, "stats.pickle"))
self.pred_horizon = model_cfg.dataset.pred_horizon
self.action_dim = model_cfg.action_dim
self.obs_horizon = model_cfg.dataset.obs_horizon

# 设置模型参数
self.pred_horizon = model_cfg.dataset.pred_horizon # 预测时间范围
self.action_dim = model_cfg.action_dim # 动作维度
self.obs_horizon = model_cfg.dataset.obs_horizon # 观测时间范围

# 将模型设置为评估模式
self.model = model.eval()
self.noise_scheduler = noise_scheduler
self.num_inference_steps = model_cfg.num_inference_steps
self.num_inference_steps = model_cfg.num_inference_steps # 推理步数
self.stats = stats
self.camera_resize_shape = model_cfg.dataset.camera_resize_shape
self.camera_resize_shape = model_cfg.dataset.camera_resize_shape # 相机图像调整尺寸
# 处理手部动作类型(相对位置 vs 绝对位置)
if model_cfg.dataset.relative_hand_action:
print("Using relative hand action")
print("Using relative hand action") # 使用相对手部动作
print("hand_action stats", stats["relative_hand_action"])
print(
stats["relative_hand_action"]["max"]
- stats["relative_hand_action"]["min"]
> 5e-2
)
# 组合相对姿态和相对手部动作的统计信息
self.stats["action"] = {
"min": np.concatenate(
[
Expand All @@ -56,9 +77,10 @@ def __init__(
),
}
else:
print("Using absolute hand action")
print("Using absolute hand action") # 使用绝对手部动作
print("hand_action stats", stats["hand_action"])
print(stats["hand_action"]["max"] - stats["hand_action"]["min"] > 5e-2)
# 组合相对姿态和绝对手部动作的统计信息
self.stats["action"] = {
"min": np.concatenate(
[stats["relative_pose"]["min"], stats["hand_action"]["min"]]
Expand All @@ -71,49 +93,76 @@ def __init__(
self.model_cfg = model_cfg

def predict_action(self, proprioception, fsr, visual_obs):
# visual_obs: NxHxWxC
_, H, W, _ = visual_obs.shape
B = 1
"""
预测机器人动作

Args:
proprioception: 本体感受数据(关节位置、速度等)
fsr: 力传感器数据
visual_obs: 视觉观测数据,形状为 NxHxWxC

Returns:
action: 预测的动作序列
"""
B = 1 # 批次大小为1(单次预测)
# 处理本体感受数据
if proprioception is not None and "proprioception" in self.stats:
# 归一化本体感受数据
proprioception = normalize_data(
proprioception.reshape(1, -1), self.stats["proprioception"]
) # (1,N)
# 转换为PyTorch张量并移到GPU
proprioception = (
torch.from_numpy(proprioception).unsqueeze(0).cuda()
) # (B,1,6)
elif proprioception is not None:
print("Warning: proprioception data provided but no stats available, setting to None")
proprioception = None

# 处理力传感器数据
if fsr is not None and "fsr" in self.stats:
# 归一化力传感器数据
fsr = normalize_data(fsr.reshape(1, -1), self.stats["fsr"])
fsr = torch.from_numpy(fsr).unsqueeze(0).cuda() # (B,1,2)
elif fsr is not None:
print("Warning: fsr data provided but no stats available, setting to None")
fsr = None

visual_obs = np.array(
[
cv2.cvtColor(
cv2.resize(
obs,
(
int(W * INPAINT_RESIZE_RATIO),
int(H * INPAINT_RESIZE_RATIO),
),
),
cv2.COLOR_BGR2RGB,
)
for obs in visual_obs
]
)
# 处理视觉观测数据 - 与训练时保持完全一致的处理流程
# 1. 中心裁剪到正方形(与XhandMultimodalCollection.py完全一致)
processed_obs = []
for obs in visual_obs:
h, w = obs.shape[:2] # 应该是 (240, 424, 3)

# 中心裁剪为正方形(与训练时完全相同的逻辑)
crop_size = min(h, w) # min(240, 424) = 240
start_x = (w - crop_size) // 2 # (424-240)//2 = 92
start_y = (h - crop_size) // 2 # (240-240)//2 = 0
cropped = obs[start_y:start_y+crop_size, start_x:start_x+crop_size].copy()

# 确保正确的240x240尺寸(与训练时一致)
if cropped.shape[:2] != (240, 240):
cropped = cv2.resize(cropped, (240, 240), interpolation=cv2.INTER_AREA)

# BGR转RGB(与训练时一致)
rgb_obs = cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB)
processed_obs.append(rgb_obs)

visual_obs = np.array(processed_obs)

# 2. 使用process_image进行标准化处理(需要添加CenterCrop到224x224以匹配ViT模型)
visual_obs = process_image(
visual_obs,
optional_transforms=["Resize", "CenterCrop"],
optional_transforms=["CenterCrop"], # 匹配训练时的["Resize", "RandomCrop"]
resize_shape=self.camera_resize_shape,
)

# 3. 转换为PyTorch张量并移到GPU
visual_obs = visual_obs.unsqueeze(0).cuda()
# 初始化随机轨迹作为扩散模型的起点
trajectory = torch.randn(B, self.pred_horizon, self.action_dim).cuda()

# 使用扩散模型进行推理,生成动作轨迹
trajectory = self.model.inference(
proprioception=proprioception,
fsr=fsr,
Expand All @@ -122,10 +171,17 @@ def predict_action(self, proprioception, fsr, visual_obs):
noise_scheduler=self.noise_scheduler,
num_inference_steps=self.num_inference_steps,
)

# 将结果转移到CPU并转换为numpy数组
trajectory = trajectory.detach().to("cpu").numpy()
naction = trajectory[0]
naction = trajectory[0] # 获取第一个批次的结果

# 反归一化动作数据,恢复到原始尺度
action_pred = unnormalize_data(naction, stats=self.stats["action"])

# 提取有效的动作序列(从观测范围结束到预测范围结束)
start = self.obs_horizon - 1
end = start + self.pred_horizon
action = action_pred[start:end, :]

return action
Loading