From d0ee7064a0cef8724bf5465eab1777b545c69651 Mon Sep 17 00:00:00 2001 From: Gray Date: Tue, 2 Sep 2025 14:20:08 +0800 Subject: [PATCH 01/10] [feat] eval, sh --- real_script/eval_policy/eval_xhand_franka.sh | 76 +++++ .../eval_policy/eval_xhand_franka_advanced.sh | 270 ++++++++++++++++++ .../eval_policy/test_xhand_franka_setup.sh | 200 +++++++++++++ 3 files changed, 546 insertions(+) create mode 100755 real_script/eval_policy/eval_xhand_franka.sh create mode 100755 real_script/eval_policy/eval_xhand_franka_advanced.sh create mode 100755 real_script/eval_policy/test_xhand_franka_setup.sh diff --git a/real_script/eval_policy/eval_xhand_franka.sh b/real_script/eval_policy/eval_xhand_franka.sh new file mode 100755 index 0000000..7c3655a --- /dev/null +++ b/real_script/eval_policy/eval_xhand_franka.sh @@ -0,0 +1,76 @@ +#!/bin/bash + +# Evaluation script for XHand with Franka using HTTP control + +# Activate conda environment +source ~/miniconda3/etc/profile.d/conda.sh +conda activate dexumi + +# Path to your trained model +MODEL_PATH="/path/to/your/trained/model" # TODO: Update this path +CHECKPOINT=600 + +# Control parameters +FREQUENCY=10 # Control frequency in Hz +EXEC_HORIZON=8 # Number of action steps to execute before re-predicting + +# Camera configuration +CAMERA_TYPE="realsense" # Options: "realsense" or "oak" + +# Latency parameters (in seconds) +CAMERA_LATENCY=0.185 +HAND_ACTION_LATENCY=0.3 +ROBOT_ACTION_LATENCY=0.170 + +# Video recording path +VIDEO_RECORD_PATH="video_record" + +echo "=========================================" +echo "DexUMI Evaluation with XHand + Franka" +echo "=========================================" +echo "" +echo "Model: $MODEL_PATH" +echo "Checkpoint: $CHECKPOINT" +echo "Camera Type: $CAMERA_TYPE" +echo "Frequency: $FREQUENCY Hz" +echo "Execution Horizon: $EXEC_HORIZON steps" +echo "" +echo "Latency Settings:" +echo " Camera: ${CAMERA_LATENCY}s" +echo " Hand Action: ${HAND_ACTION_LATENCY}s" +echo " Robot Action: ${ROBOT_ACTION_LATENCY}s" +echo "" +echo "Key Features:" +echo "✓ Direct Franka ee_pose (no T_ET transformation)" +echo "✓ Fixed initial positions" +echo "✓ HTTP control interface" +echo "✓ RealSense/OAK camera support" +echo "✓ Multi-step action execution" +echo "" +echo "Make sure the robot server is running:" +echo " python franka_server.py" +echo "" +echo "Press Ctrl+C to abort, or wait 3 seconds to continue..." +echo "=========================================" +echo "" + +# Wait for user to check +sleep 3 + +# Run the evaluation script +python real_script/eval_policy/eval_xhand_franka.py \ + --model_path "$MODEL_PATH" \ + --ckpt $CHECKPOINT \ + --frequency $FREQUENCY \ + --exec_horizon $EXEC_HORIZON \ + --camera_type $CAMERA_TYPE \ + --camera_latency $CAMERA_LATENCY \ + --hand_action_latency $HAND_ACTION_LATENCY \ + --robot_action_latency $ROBOT_ACTION_LATENCY \ + --video_record_path "$VIDEO_RECORD_PATH" + +# Optional: Enable record camera (requires second camera) +# Add --enable_record_camera flag if you have two cameras + +# Optional: Match episode path for comparison +# --match_episode_path "/path/to/reference/episodes" \ No newline at end of file diff --git a/real_script/eval_policy/eval_xhand_franka_advanced.sh b/real_script/eval_policy/eval_xhand_franka_advanced.sh new file mode 100755 index 0000000..b1ca8da --- /dev/null +++ b/real_script/eval_policy/eval_xhand_franka_advanced.sh @@ -0,0 +1,270 @@ +#!/bin/bash + +# Advanced evaluation script for XHand with Franka using HTTP control +# Supports multiple configurations and debug modes + +set -e # Exit on error + +# ============================================ +# Configuration Section +# ============================================ + +# Model configuration +MODEL_BASE_DIR="/home/gray/Project/DexUMI/data/weight" +MODEL_NAME="vision_tactile_propio" # Update this to your model name +MODEL_PATH="${MODEL_BASE_DIR}/${MODEL_NAME}" +CHECKPOINT=600 + +# Control parameters +FREQUENCY=10 # Control frequency in Hz +EXEC_HORIZON=8 # Number of action steps to execute before re-predicting + +# Camera configuration +CAMERA_TYPE="realsense" # Options: "realsense" or "oak" +ENABLE_RECORD_CAMERA=false # Set to true if you have a second camera + +# Latency parameters (in seconds) +CAMERA_LATENCY=0.185 +HAND_ACTION_LATENCY=0.3 +ROBOT_ACTION_LATENCY=0.170 + +# Recording configuration +VIDEO_RECORD_PATH="video_record/$(date +%Y%m%d_%H%M%S)" +MATCH_EPISODE_PATH="" # Optional: path to reference episodes + +# Server configuration +ROBOT_SERVER_URL="http://127.0.0.1:5000" + +# Debug mode +DEBUG_MODE=false + +# ============================================ +# Functions +# ============================================ + +print_banner() { + echo "=========================================" + echo "DexUMI Evaluation with XHand + Franka" + echo "=========================================" +} + +print_config() { + echo "" + echo "Configuration:" + echo " Model: $MODEL_PATH" + echo " Checkpoint: $CHECKPOINT" + echo " Camera Type: $CAMERA_TYPE" + echo " Record Camera: $ENABLE_RECORD_CAMERA" + echo " Frequency: $FREQUENCY Hz" + echo " Execution Horizon: $EXEC_HORIZON steps" + echo "" + echo "Latency Settings:" + echo " Camera: ${CAMERA_LATENCY}s" + echo " Hand Action: ${HAND_ACTION_LATENCY}s" + echo " Robot Action: ${ROBOT_ACTION_LATENCY}s" + echo "" + echo "Recording:" + echo " Video Path: $VIDEO_RECORD_PATH" + if [ -n "$MATCH_EPISODE_PATH" ]; then + echo " Match Episode: $MATCH_EPISODE_PATH" + fi + echo "" +} + +check_prerequisites() { + echo "Checking prerequisites..." + + # Check if model exists + if [ ! -d "$MODEL_PATH" ]; then + echo "❌ Error: Model path does not exist: $MODEL_PATH" + exit 1 + fi + + # Check if checkpoint exists + CKPT_FILE="${MODEL_PATH}/ckpt_${CHECKPOINT}.pt" + if [ ! -f "$CKPT_FILE" ]; then + echo "⚠️ Warning: Checkpoint file not found: $CKPT_FILE" + echo " Available checkpoints:" + ls -la ${MODEL_PATH}/ckpt_*.pt 2>/dev/null || echo " No checkpoints found" + fi + + # Check if robot server is running + echo -n "Checking robot server at $ROBOT_SERVER_URL... " + if curl -s -o /dev/null -w "%{http_code}" "${ROBOT_SERVER_URL}/health" | grep -q "200"; then + echo "✅ Connected" + else + echo "❌ Not responding" + echo "" + echo "Please start the robot server first:" + echo " python franka_server.py" + exit 1 + fi + + # Check camera availability + echo -n "Checking ${CAMERA_TYPE} camera... " + if [ "$CAMERA_TYPE" == "realsense" ]; then + if python3 -c "import pyrealsense2" 2>/dev/null; then + echo "✅ Library available" + else + echo "❌ pyrealsense2 not installed" + exit 1 + fi + else + echo "✅ Using OAK camera" + fi + + echo "" +} + +show_help() { + echo "Usage: $0 [OPTIONS]" + echo "" + echo "Options:" + echo " -m, --model MODEL_NAME Model name (default: $MODEL_NAME)" + echo " -c, --checkpoint CKPT Checkpoint number (default: $CHECKPOINT)" + echo " -f, --frequency HZ Control frequency (default: $FREQUENCY)" + echo " -e, --exec-horizon N Execution horizon (default: $EXEC_HORIZON)" + echo " -t, --camera-type TYPE Camera type: realsense|oak (default: $CAMERA_TYPE)" + echo " -r, --record Enable record camera" + echo " -d, --debug Enable debug mode" + echo " -h, --help Show this help message" + echo "" + echo "Examples:" + echo " $0 # Run with default settings" + echo " $0 -m my_model -c 800 # Use specific model and checkpoint" + echo " $0 -t oak -r # Use OAK camera with recording" + echo " $0 -d # Run in debug mode" + exit 0 +} + +# ============================================ +# Parse command line arguments +# ============================================ + +while [[ $# -gt 0 ]]; do + case $1 in + -m|--model) + MODEL_NAME="$2" + MODEL_PATH="${MODEL_BASE_DIR}/${MODEL_NAME}" + shift 2 + ;; + -c|--checkpoint) + CHECKPOINT="$2" + shift 2 + ;; + -f|--frequency) + FREQUENCY="$2" + shift 2 + ;; + -e|--exec-horizon) + EXEC_HORIZON="$2" + shift 2 + ;; + -t|--camera-type) + CAMERA_TYPE="$2" + shift 2 + ;; + -r|--record) + ENABLE_RECORD_CAMERA=true + shift + ;; + -d|--debug) + DEBUG_MODE=true + shift + ;; + -h|--help) + show_help + ;; + *) + echo "Unknown option: $1" + show_help + ;; + esac +done + +# ============================================ +# Main execution +# ============================================ + +print_banner + +# Activate conda environment +echo "Activating conda environment..." +source ~/miniconda3/etc/profile.d/conda.sh || source ~/anaconda3/etc/profile.d/conda.sh +conda activate dexumi + +print_config + +echo "Key Features:" +echo "✓ Direct Franka ee_pose (no T_ET transformation)" +echo "✓ Fixed initial positions" +echo "✓ HTTP control interface" +echo "✓ ${CAMERA_TYPE^} camera support" +echo "✓ Multi-step action execution" +echo "✓ Relative control semantics" +echo "" + +check_prerequisites + +echo "Starting in 3 seconds... (Press Ctrl+C to abort)" +for i in 3 2 1; do + echo -n "$i... " + sleep 1 +done +echo "" +echo "" + +# Create video recording directory +mkdir -p "$VIDEO_RECORD_PATH" + +# Build command +CMD="python real_script/eval_policy/eval_xhand_franka.py" +CMD="$CMD --model_path \"$MODEL_PATH\"" +CMD="$CMD --ckpt $CHECKPOINT" +CMD="$CMD --frequency $FREQUENCY" +CMD="$CMD --exec_horizon $EXEC_HORIZON" +CMD="$CMD --camera_type $CAMERA_TYPE" +CMD="$CMD --camera_latency $CAMERA_LATENCY" +CMD="$CMD --hand_action_latency $HAND_ACTION_LATENCY" +CMD="$CMD --robot_action_latency $ROBOT_ACTION_LATENCY" +CMD="$CMD --video_record_path \"$VIDEO_RECORD_PATH\"" + +if [ "$ENABLE_RECORD_CAMERA" = true ]; then + CMD="$CMD --enable_record_camera" +fi + +if [ -n "$MATCH_EPISODE_PATH" ]; then + CMD="$CMD --match_episode_path \"$MATCH_EPISODE_PATH\"" +fi + +# Execute +echo "Executing command:" +echo "$CMD" +echo "" +echo "=========================================" +echo "" + +if [ "$DEBUG_MODE" = true ]; then + # Debug mode: run with Python debugger + python -m pdb real_script/eval_policy/eval_xhand_franka.py \ + --model_path "$MODEL_PATH" \ + --ckpt $CHECKPOINT \ + --frequency $FREQUENCY \ + --exec_horizon $EXEC_HORIZON \ + --camera_type $CAMERA_TYPE \ + --camera_latency $CAMERA_LATENCY \ + --hand_action_latency $HAND_ACTION_LATENCY \ + --robot_action_latency $ROBOT_ACTION_LATENCY \ + --video_record_path "$VIDEO_RECORD_PATH" \ + $([ "$ENABLE_RECORD_CAMERA" = true ] && echo "--enable_record_camera") \ + $([ -n "$MATCH_EPISODE_PATH" ] && echo "--match_episode_path \"$MATCH_EPISODE_PATH\"") +else + # Normal execution + eval $CMD +fi + +echo "" +echo "=========================================" +echo "Evaluation completed!" +echo "Video saved to: $VIDEO_RECORD_PATH" +echo "=========================================" \ No newline at end of file diff --git a/real_script/eval_policy/test_xhand_franka_setup.sh b/real_script/eval_policy/test_xhand_franka_setup.sh new file mode 100755 index 0000000..4799651 --- /dev/null +++ b/real_script/eval_policy/test_xhand_franka_setup.sh @@ -0,0 +1,200 @@ +#!/bin/bash + +# Quick test script to verify XHand + Franka setup + +echo "=========================================" +echo "XHand + Franka Setup Test" +echo "=========================================" +echo "" + +# Colors for output +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +# Test results +TESTS_PASSED=0 +TESTS_FAILED=0 + +# Function to print test results +print_test() { + local test_name=$1 + local result=$2 + local message=$3 + + if [ "$result" = "pass" ]; then + echo -e "${GREEN}✅ $test_name${NC}" + ((TESTS_PASSED++)) + elif [ "$result" = "warn" ]; then + echo -e "${YELLOW}⚠️ $test_name${NC}" + [ -n "$message" ] && echo " $message" + else + echo -e "${RED}❌ $test_name${NC}" + [ -n "$message" ] && echo " $message" + ((TESTS_FAILED++)) + fi +} + +echo "1. Checking Python environment..." +if python3 -c "import sys; sys.exit(0 if sys.version_info >= (3,7) else 1)" 2>/dev/null; then + print_test "Python version" "pass" +else + print_test "Python version" "fail" "Python 3.7+ required" +fi + +echo "" +echo "2. Checking required packages..." + +# Check DexUMI package +if python3 -c "import dexumi" 2>/dev/null; then + print_test "DexUMI package" "pass" +else + print_test "DexUMI package" "fail" "DexUMI not installed or not in PYTHONPATH" +fi + +# Check HTTP client +if python3 -c "from dexumi.real_env.common.http_client import HTTPRobotClient" 2>/dev/null; then + print_test "HTTP client module" "pass" +else + print_test "HTTP client module" "fail" "Cannot import HTTPRobotClient" +fi + +# Check camera modules +if python3 -c "import pyrealsense2" 2>/dev/null; then + print_test "RealSense library" "pass" +else + print_test "RealSense library" "warn" "pyrealsense2 not installed (OK if using OAK)" +fi + +if python3 -c "import depthai" 2>/dev/null; then + print_test "OAK library" "pass" +else + print_test "OAK library" "warn" "depthai not installed (OK if using RealSense)" +fi + +# Check other dependencies +if python3 -c "import cv2" 2>/dev/null; then + print_test "OpenCV" "pass" +else + print_test "OpenCV" "fail" "opencv-python not installed" +fi + +if python3 -c "import numpy" 2>/dev/null; then + print_test "NumPy" "pass" +else + print_test "NumPy" "fail" "numpy not installed" +fi + +if python3 -c "import scipy" 2>/dev/null; then + print_test "SciPy" "pass" +else + print_test "SciPy" "fail" "scipy not installed" +fi + +echo "" +echo "3. Checking hardware connections..." + +# Test robot server connection +SERVER_URL="http://127.0.0.1:5000" +echo -n "Testing robot server at $SERVER_URL... " +if curl -s -o /dev/null -w "%{http_code}" "${SERVER_URL}/health" 2>/dev/null | grep -q "200"; then + print_test "Robot server" "pass" +else + print_test "Robot server" "fail" "Server not responding. Run: python franka_server.py" +fi + +# Test camera availability +echo "" +echo "4. Checking camera devices..." + +# Check RealSense cameras +python3 - </dev/null +try: + import pyrealsense2 as rs + ctx = rs.context() + devices = ctx.query_devices() + if len(devices) > 0: + print("REALSENSE_FOUND") + for i, device in enumerate(devices): + serial = device.get_info(rs.camera_info.serial_number) + name = device.get_info(rs.camera_info.name) + print(f" Camera {i}: {name} (Serial: {serial})") + else: + print("REALSENSE_NOT_FOUND") +except: + print("REALSENSE_ERROR") +EOF + +RS_RESULT=$? +if [ $RS_RESULT -eq 0 ]; then + if grep -q "REALSENSE_FOUND" <<< "$(python3 -c 'from dexumi.camera.realsense_camera import get_all_realsense_cameras; cams=get_all_realsense_cameras(); print("REALSENSE_FOUND" if cams else "REALSENSE_NOT_FOUND")' 2>/dev/null)"; then + print_test "RealSense camera detection" "pass" + else + print_test "RealSense camera detection" "warn" "No RealSense cameras found" + fi +fi + +# Check OAK cameras +python3 - </dev/null +try: + from dexumi.camera.oak_camera import get_all_oak_cameras + cameras = get_all_oak_cameras() + if cameras: + print("OAK_FOUND") + for i, cam_id in enumerate(cameras): + print(f" Camera {i}: {cam_id}") + else: + print("OAK_NOT_FOUND") +except: + print("OAK_ERROR") +EOF + +OAK_RESULT=$? +if [ $OAK_RESULT -eq 0 ]; then + if grep -q "OAK_FOUND" <<< "$(python3 -c 'from dexumi.camera.oak_camera import get_all_oak_cameras; cams=get_all_oak_cameras(); print("OAK_FOUND" if cams else "OAK_NOT_FOUND")' 2>/dev/null)"; then + print_test "OAK camera detection" "pass" + else + print_test "OAK camera detection" "warn" "No OAK cameras found" + fi +fi + +echo "" +echo "5. Checking model files..." + +# Default model path +MODEL_BASE="/home/gray/Project/DexUMI/data/weight" +if [ -d "$MODEL_BASE" ]; then + echo "Available models in $MODEL_BASE:" + for model_dir in "$MODEL_BASE"/*; do + if [ -d "$model_dir" ]; then + model_name=$(basename "$model_dir") + ckpt_count=$(ls -1 "$model_dir"/ckpt_*.pt 2>/dev/null | wc -l) + if [ $ckpt_count -gt 0 ]; then + echo -e " ${GREEN}✓${NC} $model_name ($ckpt_count checkpoints)" + else + echo -e " ${YELLOW}⚠${NC} $model_name (no checkpoints)" + fi + fi + done +else + print_test "Model directory" "warn" "$MODEL_BASE not found" +fi + +echo "" +echo "=========================================" +echo "Test Summary:" +echo " Passed: $TESTS_PASSED" +echo " Failed: $TESTS_FAILED" + +if [ $TESTS_FAILED -eq 0 ]; then + echo -e "${GREEN}All critical tests passed!${NC}" + echo "" + echo "You can now run the evaluation script:" + echo " ./eval_xhand_franka.sh" +else + echo -e "${RED}Some tests failed. Please fix the issues above.${NC}" +fi +echo "=========================================" + +exit $TESTS_FAILED \ No newline at end of file From 54af12f2949f16d4109ad0d362afdf844ece7deb Mon Sep 17 00:00:00 2001 From: Gray Date: Tue, 2 Sep 2025 15:42:33 +0800 Subject: [PATCH 02/10] [refactor] simplified eval code --- dexumi/real_env/real_policy.py | 84 ++++++-- real_script/eval_policy/eval_xhand_franka.py | 215 ++++--------------- 2 files changed, 108 insertions(+), 191 deletions(-) diff --git a/dexumi/real_env/real_policy.py b/dexumi/real_env/real_policy.py index c601da6..ca9efda 100644 --- a/dexumi/real_env/real_policy.py +++ b/dexumi/real_env/real_policy.py @@ -14,33 +14,54 @@ class RealPolicy: + """ + 真实环境策略类 - 用于在真实机器人环境中执行动作预测 + 基于扩散模型的机器人策略,可以处理多模态输入(视觉、本体感受、力传感器数据) + """ def __init__( self, - model_path: str, - ckpt: int, + model_path: str, # 模型路径 + ckpt: int, # 检查点编号 ): + """ + 初始化真实环境策略 + + Args: + model_path: 训练好的模型文件路径 + ckpt: 要加载的检查点编号 + """ + # 加载模型配置文件 model_cfg = load_config(model_path) + + # 加载扩散模型和噪声调度器 model, noise_scheduler = load_diffusion_model( model_path, ckpt, use_ema=model_cfg.training.use_ema ) + + # 加载数据统计信息(用于数据归一化和反归一化) stats = read_pickle(os.path.join(model_path, "stats.pickle")) - self.pred_horizon = model_cfg.dataset.pred_horizon - self.action_dim = model_cfg.action_dim - self.obs_horizon = model_cfg.dataset.obs_horizon + + # 设置模型参数 + self.pred_horizon = model_cfg.dataset.pred_horizon # 预测时间范围 + self.action_dim = model_cfg.action_dim # 动作维度 + self.obs_horizon = model_cfg.dataset.obs_horizon # 观测时间范围 + # 将模型设置为评估模式 self.model = model.eval() self.noise_scheduler = noise_scheduler - self.num_inference_steps = model_cfg.num_inference_steps + self.num_inference_steps = model_cfg.num_inference_steps # 推理步数 self.stats = stats - self.camera_resize_shape = model_cfg.dataset.camera_resize_shape + self.camera_resize_shape = model_cfg.dataset.camera_resize_shape # 相机图像调整尺寸 + # 处理手部动作类型(相对位置 vs 绝对位置) if model_cfg.dataset.relative_hand_action: - print("Using relative hand action") + print("Using relative hand action") # 使用相对手部动作 print("hand_action stats", stats["relative_hand_action"]) print( stats["relative_hand_action"]["max"] - stats["relative_hand_action"]["min"] > 5e-2 ) + # 组合相对姿态和相对手部动作的统计信息 self.stats["action"] = { "min": np.concatenate( [ @@ -56,9 +77,10 @@ def __init__( ), } else: - print("Using absolute hand action") + print("Using absolute hand action") # 使用绝对手部动作 print("hand_action stats", stats["hand_action"]) print(stats["hand_action"]["max"] - stats["hand_action"]["min"] > 5e-2) + # 组合相对姿态和绝对手部动作的统计信息 self.stats["action"] = { "min": np.concatenate( [stats["relative_pose"]["min"], stats["hand_action"]["min"]] @@ -71,13 +93,27 @@ def __init__( self.model_cfg = model_cfg def predict_action(self, proprioception, fsr, visual_obs): - # visual_obs: NxHxWxC + """ + 预测机器人动作 + + Args: + proprioception: 本体感受数据(关节位置、速度等) + fsr: 力传感器数据 + visual_obs: 视觉观测数据,形状为 NxHxWxC + + Returns: + action: 预测的动作序列 + """ + # 获取视觉观测的尺寸信息 _, H, W, _ = visual_obs.shape - B = 1 + B = 1 # 批次大小为1(单次预测) + # 处理本体感受数据 if proprioception is not None and "proprioception" in self.stats: + # 归一化本体感受数据 proprioception = normalize_data( proprioception.reshape(1, -1), self.stats["proprioception"] ) # (1,N) + # 转换为PyTorch张量并移到GPU proprioception = ( torch.from_numpy(proprioception).unsqueeze(0).cuda() ) # (B,1,6) @@ -85,35 +121,46 @@ def predict_action(self, proprioception, fsr, visual_obs): print("Warning: proprioception data provided but no stats available, setting to None") proprioception = None + # 处理力传感器数据 if fsr is not None and "fsr" in self.stats: + # 归一化力传感器数据 fsr = normalize_data(fsr.reshape(1, -1), self.stats["fsr"]) fsr = torch.from_numpy(fsr).unsqueeze(0).cuda() # (B,1,2) elif fsr is not None: print("Warning: fsr data provided but no stats available, setting to None") fsr = None + # 处理视觉观测数据 + # 1. 调整图像尺寸并转换颜色空间(BGR到RGB) visual_obs = np.array( [ cv2.cvtColor( cv2.resize( obs, ( - int(W * INPAINT_RESIZE_RATIO), - int(H * INPAINT_RESIZE_RATIO), + int(W * INPAINT_RESIZE_RATIO), # 按比例调整宽度 + int(H * INPAINT_RESIZE_RATIO), # 按比例调整高度 ), ), - cv2.COLOR_BGR2RGB, + cv2.COLOR_BGR2RGB, # BGR转RGB ) for obs in visual_obs ] ) + + # 2. 进一步处理图像(调整大小、中心裁剪等) visual_obs = process_image( visual_obs, optional_transforms=["Resize", "CenterCrop"], resize_shape=self.camera_resize_shape, ) + + # 3. 转换为PyTorch张量并移到GPU visual_obs = visual_obs.unsqueeze(0).cuda() + # 初始化随机轨迹作为扩散模型的起点 trajectory = torch.randn(B, self.pred_horizon, self.action_dim).cuda() + + # 使用扩散模型进行推理,生成动作轨迹 trajectory = self.model.inference( proprioception=proprioception, fsr=fsr, @@ -122,10 +169,17 @@ def predict_action(self, proprioception, fsr, visual_obs): noise_scheduler=self.noise_scheduler, num_inference_steps=self.num_inference_steps, ) + + # 将结果转移到CPU并转换为numpy数组 trajectory = trajectory.detach().to("cpu").numpy() - naction = trajectory[0] + naction = trajectory[0] # 获取第一个批次的结果 + + # 反归一化动作数据,恢复到原始尺度 action_pred = unnormalize_data(naction, stats=self.stats["action"]) + + # 提取有效的动作序列(从观测范围结束到预测范围结束) start = self.obs_horizon - 1 end = start + self.pred_horizon action = action_pred[start:end, :] + return action diff --git a/real_script/eval_policy/eval_xhand_franka.py b/real_script/eval_policy/eval_xhand_franka.py index 351cb45..197a74f 100644 --- a/real_script/eval_policy/eval_xhand_franka.py +++ b/real_script/eval_policy/eval_xhand_franka.py @@ -1,4 +1,3 @@ -import os import time from collections import deque @@ -7,22 +6,16 @@ import numpy as np import scipy.spatial.transform as st -from dexumi.camera.camera import FrameData from dexumi.camera.realsense_camera import RealSenseCamera, get_all_realsense_cameras from dexumi.common.frame_manager import FrameRateContext from dexumi.common.utility.matrix import ( homogeneous_matrix_to_6dof, vec6dof_to_homogeneous_matrix, ) -from dexumi.common.utility.video import ( - extract_frames_videos, -) from dexumi.constants import ( XHAND_HAND_MOTOR_SCALE_FACTOR, ) -from dexumi.data_recording import VideoRecorder from dexumi.data_recording.data_buffer import PoseInterpolator -from dexumi.data_recording.record_manager import RecorderManager # Import HTTP control classes from dexumi.real_env.common.http_client import HTTPRobotClient, HTTPHandClient @@ -78,9 +71,6 @@ def compute_total_force_per_finger(all_fsr_observations): @click.command() @click.option("-f", "--frequency", type=float, default=10, help="Control frequency (Hz)") -@click.option( - "-rc", "--enable_record_camera", is_flag=True, help="Enable record camera" -) @click.option( "-ct", "--camera_type", type=click.Choice(['realsense', 'oak']), default="realsense", help="Camera type to use" @@ -105,23 +95,8 @@ def compute_total_force_per_finger(all_fsr_observations): help="Robot action latency", ) @click.option("-eh", "--exec_horizon", type=int, default=8, help="Execution horizon") -@click.option( - "-vp", - "--video_record_path", - type=str, - default="video_record", - help="Path to save video recordings", -) -@click.option( - "-mep", - "--match_episode_path", - type=str, - default=None, - help="Path to match episode folder", -) def main( frequency, - enable_record_camera, camera_type, model_path, ckpt, @@ -129,8 +104,6 @@ def main( hand_action_latency, robot_action_latency, exec_horizon, - video_record_path, - match_episode_path, ): # Initialize HTTP clients for robot and hand control robot_client = HTTPRobotClient(base_url="http://127.0.0.1:5000") @@ -144,25 +117,13 @@ def main( return # Use the first available camera for observation - # Configure for 240x240 output to match training data obs_camera = RealSenseCamera( camera_name="obs camera", device_id=all_cameras[0], - camera_resolution=(640, 480), # Native resolution - enable_depth=False, # We don't need depth for inference + camera_resolution=(640, 480), + enable_depth=False, fps=30 ) - camera_sources = [obs_camera] - - if enable_record_camera and len(all_cameras) > 1: - record_camera = RealSenseCamera( - camera_name="record camera", - device_id=all_cameras[1], - camera_resolution=(640, 480), - enable_depth=False, - fps=30 - ) - camera_sources.append(record_camera) else: # Fall back to OAK cameras from dexumi.camera.oak_camera import OakCamera, get_all_oak_cameras @@ -172,53 +133,16 @@ def main( return obs_camera = OakCamera("obs camera", device_id=all_cameras[0]) - camera_sources = [obs_camera] - - if enable_record_camera and len(all_cameras) > 1: - record_camera = OakCamera("record camera", device_id=all_cameras[1]) - camera_sources.append(record_camera) - # Start cameras - for camera in camera_sources: - camera.start_streaming() - video_recorder = VideoRecorder( - record_fps=45, - stream_fps=60, - video_record_path=video_record_path, - camera_sources=camera_sources, - frame_data_class=FrameData, - verbose=False, - ) - recorder_manager = RecorderManager( - recorders=[video_recorder], - verbose=False, - ) - recorder_manager.start_streaming() + # Start camera + obs_camera.start_streaming() + dt = 1 / frequency - match_episode_folder = match_episode_path - # Main control loop (without manual control) + # Main control loop while True: - print("Ready!") - - # Handle match episode if provided - if match_episode_folder is not None: - print( - f"Extracting frames from match episode {recorder_manager.episode_id}" - ) - # Extract frames for reference (not used in simplified version) - _ = extract_frames_videos( - os.path.join( - match_episode_folder, - f"episode_{recorder_manager.episode_id}/camera_1.mp4", - ), - BGR2RGB=True, - ) - # match_initial_frame = match_episode[0] # Not used in simplified version - else: - print("No match episode folder provided") - # match_initial_frame = None # Not used in simplified version + print("Ready! Starting 20-second inference session...") # Reset robot to initial position print("Moving robot to initial position...") @@ -263,36 +187,21 @@ def main( model_path=model_path, ckpt=ckpt, ) - - # Start recording - if recorder_manager.reset_episode_recording(): - click.echo("Starting recording...") - recorder_manager.start_recording() # Calculate inference parameters inference_iter_time = exec_horizon * dt inference_fps = 1 / inference_iter_time print("inference_fps", inference_fps) + # Start 20-second inference session + session_start_time = time.time() + session_duration = 20.0 # 20 seconds + # Policy execution loop - while True: + while time.time() - session_start_time < session_duration: with FrameRateContext(frame_rate=inference_fps): - # gather observation - record_frame = recorder_manager.get_latest_frames() - if enable_record_camera: - video_frame = record_frame["record camera"][-1] - viz_frame = video_frame.rgb.copy() - cv2.putText( - viz_frame, - f"Episode: {recorder_manager.episode_id}", - (10, 30), - cv2.FONT_HERSHEY_SIMPLEX, - 1, - (0, 255, 0), - 2, - ) - - obs_frame = record_frame["obs camera"][-1] + # Get observation from camera + obs_frame = obs_camera.get_latest_frame() obs_frame_recieved_time = obs_frame.receive_time obs_frame_rgb = obs_frame.rgb.copy() @@ -306,76 +215,7 @@ def main( cropped = obs_frame_rgb[start_y:start_y+crop_size, start_x:start_x+crop_size] # Resize to 240x240 obs_frame_rgb = cv2.resize(cropped, (240, 240), interpolation=cv2.INTER_AREA) - cv2.putText( - obs_frame_rgb, - f"Episode: {recorder_manager.episode_id}", - (10, 30), - cv2.FONT_HERSHEY_SIMPLEX, - 1, - (0, 255, 0), - 2, - ) - if policy.model_cfg.dataset.enable_fsr: - # Draw FSR values on viz_frame - cv2.putText( - obs_frame_rgb, - f"FSR1: {fsr_value[0]:.0f}", - (10, 60), - cv2.FONT_HERSHEY_SIMPLEX, - 1, - (0, 255, 0), - 2, - ) - cv2.putText( - obs_frame_rgb, - f"FSR2: {fsr_value[1]:.0f}", - (10, 90), - cv2.FONT_HERSHEY_SIMPLEX, - 1, - (0, 255, 0), - 2, - ) - # Draw binary cutoff values - cv2.putText( - obs_frame_rgb, - f"FSR1 Binary: {int(fsr_value[0] > binary_cutoff[0])}", - (10, 120), - cv2.FONT_HERSHEY_SIMPLEX, - 1, - (0, 255, 0), - 2, - ) - cv2.putText( - obs_frame_rgb, - f"FSR2 Binary: {int(fsr_value[1] > binary_cutoff[1])}", - (10, 150), - cv2.FONT_HERSHEY_SIMPLEX, - 1, - (0, 255, 0), - 2, - ) - cv2.putText( - obs_frame_rgb, - f"FSR3 Binary: {int(fsr_value[2] > binary_cutoff[2])}", - (10, 210), - cv2.FONT_HERSHEY_SIMPLEX, - 1, - (0, 255, 0), - 2, - ) - cv2.imshow("obs frame", obs_frame_rgb) - if enable_record_camera: - cv2.imshow("record frame", viz_frame) - key = cv2.waitKey(1) & 0xFF - if key == ord("q"): - if recorder_manager.stop_recording(): - recorder_manager.save_recordings() - cv2.destroyAllWindows() - break - elif key == ord("a"): - if recorder_manager.stop_recording(): - recorder_manager.clear_recording() - break + print(f"Time remaining: {session_duration - (time.time() - session_start_time):.1f}s") if policy.model_cfg.dataset.enable_fsr: print("Using FSR") fsr_raw_obs = dexhand_client.get_tactile(calc=True) @@ -534,7 +374,30 @@ def main( print( f"Scheduled actions: {robot_scheduled} robot waypoints, {hand_scheduled} hand waypoints" ) - virtual_hand_pos = hand_action[exec_horizon + 1] + if len(hand_action) > exec_horizon + 1: + virtual_hand_pos = hand_action[exec_horizon + 1] + else: + virtual_hand_pos = hand_action[-1] + + # Session completed, reset to initial positions + print("20-second session completed. Resetting to initial positions...") + + # Reset robot to initial position + initial_pose_6d = np.zeros(6) + initial_pose_6d[:3] = initial_robot_pose[:3] + initial_pose_6d[3:] = st.Rotation.from_quat(initial_robot_pose[3:]).as_rotvec() + robot_client.schedule_waypoint(initial_pose_6d, time.time()) + + # Reset hand to initial position + for _ in range(3): + dexhand_client.schedule_waypoint( + target_pos=initial_hand_pos, + target_time=time.time() + 0.05, + ) + time.sleep(1) + + print("Reset completed. Ready for next session.") + time.sleep(2) if __name__ == "__main__": From 5ea6158733a4215a9a1b819b3c680fb2bcf26fae Mon Sep 17 00:00:00 2001 From: Gray Date: Wed, 3 Sep 2025 12:04:54 +0800 Subject: [PATCH 03/10] [fix] realsense connect error(rgb camera detect) --- dexumi/camera/realsense_camera.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/dexumi/camera/realsense_camera.py b/dexumi/camera/realsense_camera.py index 5139cac..a451ae9 100644 --- a/dexumi/camera/realsense_camera.py +++ b/dexumi/camera/realsense_camera.py @@ -122,19 +122,27 @@ def get_camera_frame(self): def start_streaming(self): try: - # Get device info and validate RGB camera + # Get device info and validate color stream support pipeline_wrapper = rs.pipeline_wrapper(self.pipeline) pipeline_profile = self.config.resolve(pipeline_wrapper) device = pipeline_profile.get_device() - found_rgb = False + found_color = False for sensor in device.sensors: - if sensor.get_info(rs.camera_info.name) == "RGB Camera": - found_rgb = True - break - - if not found_rgb: - raise RuntimeError("The RealSense device does not have an RGB camera") + # Check if sensor supports color stream + try: + profiles = sensor.get_stream_profiles() + for profile in profiles: + if profile.stream_type() == rs.stream.color: + found_color = True + break + if found_color: + break + except: + continue + + if not found_color: + raise RuntimeError("The RealSense device does not support color stream") # Start streaming self.profile = self.pipeline.start(self.config) From 3aee3a01ea4e6774028628d7a9fcaa77ffa9cceb Mon Sep 17 00:00:00 2001 From: Gray Date: Thu, 4 Sep 2025 14:10:53 +0800 Subject: [PATCH 04/10] [fix] camera image preprocess --- dexumi/real_env/real_policy.py | 46 +++++----- real_script/eval_policy/eval_xhand_franka.py | 97 ++++++++++---------- 2 files changed, 72 insertions(+), 71 deletions(-) diff --git a/dexumi/real_env/real_policy.py b/dexumi/real_env/real_policy.py index ca9efda..57f4669 100644 --- a/dexumi/real_env/real_policy.py +++ b/dexumi/real_env/real_policy.py @@ -5,7 +5,7 @@ import torch from dexumi.common.utility.file import read_pickle from dexumi.common.utility.model import load_config, load_diffusion_model -from dexumi.constants import INPAINT_RESIZE_RATIO +# INPAINT_RESIZE_RATIO is no longer needed since we use center cropping from dexumi.diffusion_policy.dataloader.diffusion_bc_dataset import ( normalize_data, process_image, @@ -104,8 +104,6 @@ def predict_action(self, proprioception, fsr, visual_obs): Returns: action: 预测的动作序列 """ - # 获取视觉观测的尺寸信息 - _, H, W, _ = visual_obs.shape B = 1 # 批次大小为1(单次预测) # 处理本体感受数据 if proprioception is not None and "proprioception" in self.stats: @@ -130,28 +128,32 @@ def predict_action(self, proprioception, fsr, visual_obs): print("Warning: fsr data provided but no stats available, setting to None") fsr = None - # 处理视觉观测数据 - # 1. 调整图像尺寸并转换颜色空间(BGR到RGB) - visual_obs = np.array( - [ - cv2.cvtColor( - cv2.resize( - obs, - ( - int(W * INPAINT_RESIZE_RATIO), # 按比例调整宽度 - int(H * INPAINT_RESIZE_RATIO), # 按比例调整高度 - ), - ), - cv2.COLOR_BGR2RGB, # BGR转RGB - ) - for obs in visual_obs - ] - ) + # 处理视觉观测数据 - 与训练时保持完全一致的处理流程 + # 1. 中心裁剪到正方形(与XhandMultimodalCollection.py完全一致) + processed_obs = [] + for obs in visual_obs: + h, w = obs.shape[:2] # 应该是 (240, 424, 3) + + # 中心裁剪为正方形(与训练时完全相同的逻辑) + crop_size = min(h, w) # min(240, 424) = 240 + start_x = (w - crop_size) // 2 # (424-240)//2 = 92 + start_y = (h - crop_size) // 2 # (240-240)//2 = 0 + cropped = obs[start_y:start_y+crop_size, start_x:start_x+crop_size].copy() + + # 确保正确的240x240尺寸(与训练时一致) + if cropped.shape[:2] != (240, 240): + cropped = cv2.resize(cropped, (240, 240), interpolation=cv2.INTER_AREA) + + # BGR转RGB(与训练时一致) + rgb_obs = cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB) + processed_obs.append(rgb_obs) + + visual_obs = np.array(processed_obs) - # 2. 进一步处理图像(调整大小、中心裁剪等) + # 2. 使用process_image进行标准化处理(需要添加CenterCrop到224x224以匹配ViT模型) visual_obs = process_image( visual_obs, - optional_transforms=["Resize", "CenterCrop"], + optional_transforms=["CenterCrop"], # 匹配训练时的["Resize", "RandomCrop"] resize_shape=self.camera_resize_shape, ) diff --git a/real_script/eval_policy/eval_xhand_franka.py b/real_script/eval_policy/eval_xhand_franka.py index 197a74f..0179bf4 100644 --- a/real_script/eval_policy/eval_xhand_franka.py +++ b/real_script/eval_policy/eval_xhand_franka.py @@ -71,10 +71,6 @@ def compute_total_force_per_finger(all_fsr_observations): @click.command() @click.option("-f", "--frequency", type=float, default=10, help="Control frequency (Hz)") -@click.option( - "-ct", "--camera_type", type=click.Choice(['realsense', 'oak']), default="realsense", - help="Camera type to use" -) @click.option( "-mp", "--model_path", @@ -97,7 +93,6 @@ def compute_total_force_per_finger(all_fsr_observations): @click.option("-eh", "--exec_horizon", type=int, default=8, help="Execution horizon") def main( frequency, - camera_type, model_path, ckpt, camera_latency, @@ -109,30 +104,21 @@ def main( robot_client = HTTPRobotClient(base_url="http://127.0.0.1:5000") dexhand_client = HTTPHandClient(base_url="http://127.0.0.1:5000") - # Initialize cameras based on selected type - if camera_type == "realsense": - all_cameras = get_all_realsense_cameras() - if len(all_cameras) < 1: - print("Warning: No RealSense cameras found. Exiting...") - return - - # Use the first available camera for observation - obs_camera = RealSenseCamera( - camera_name="obs camera", - device_id=all_cameras[0], - camera_resolution=(640, 480), - enable_depth=False, - fps=30 - ) - else: - # Fall back to OAK cameras - from dexumi.camera.oak_camera import OakCamera, get_all_oak_cameras - all_cameras = get_all_oak_cameras() - if len(all_cameras) < 1: - print("Warning: No OAK cameras found. Exiting...") - return - - obs_camera = OakCamera("obs camera", device_id=all_cameras[0]) + # Initialize RealSense cameras + all_cameras = get_all_realsense_cameras() + if len(all_cameras) < 1: + print("Warning: No RealSense cameras found. Exiting...") + return + + # Use the first available camera for observation + # Match training format exactly: 424x240 BGR -> center crop to 240x240 (same as XhandMultimodalCollection.py) + obs_camera = RealSenseCamera( + camera_name="obs camera", + device_id=all_cameras[0], + camera_resolution=(424, 240), # Exact same as training: 424x240 + enable_depth=False, + fps=30 + ) # Start camera obs_camera.start_streaming() @@ -205,16 +191,8 @@ def main( obs_frame_recieved_time = obs_frame.receive_time obs_frame_rgb = obs_frame.rgb.copy() - # Ensure image is 240x240 for model input - if obs_frame_rgb.shape[:2] != (240, 240): - # Center crop to square - h, w = obs_frame_rgb.shape[:2] - crop_size = min(h, w) - start_x = (w - crop_size) // 2 - start_y = (h - crop_size) // 2 - cropped = obs_frame_rgb[start_y:start_y+crop_size, start_x:start_x+crop_size] - # Resize to 240x240 - obs_frame_rgb = cv2.resize(cropped, (240, 240), interpolation=cv2.INTER_AREA) + # Note: real_policy.py will handle all image preprocessing + # The image should be in BGR format to match training data print(f"Time remaining: {session_duration - (time.time() - session_start_time):.1f}s") if policy.model_cfg.dataset.enable_fsr: print("Using FSR") @@ -246,26 +224,40 @@ def main( ) print("camera_total_latency", camera_total_latency) t_actual_inference = t_inference - camera_total_latency - # Prepare image for model (ensure 240x240) - model_input_image = obs_frame.rgb.copy() - if model_input_image.shape[:2] != (240, 240): - h, w = model_input_image.shape[:2] - crop_size = min(h, w) - start_x = (w - crop_size) // 2 - start_y = (h - crop_size) // 2 - cropped = model_input_image[start_y:start_y+crop_size, start_x:start_x+crop_size] - model_input_image = cv2.resize(cropped, (240, 240), interpolation=cv2.INTER_AREA) + + # ============ DEBUG SECTION START ============ + print("\n" + "="*50) + print("DEBUG: Input Information") + print(f"Image shape: {obs_frame_rgb.shape}, dtype: {obs_frame_rgb.dtype}") + print(f"Image range: [{obs_frame_rgb.min():.2f}, {obs_frame_rgb.max():.2f}]") + if policy.model_cfg.dataset.enable_fsr: + print(f"FSR obs shape: {np.array(list(fsr_obs)).shape}") + print(f"FSR obs values: {np.array(list(fsr_obs))[-1]}") # Last FSR reading action = policy.predict_action( None, np.array(list(fsr_obs)).astype(np.float32) if policy.model_cfg.dataset.enable_fsr else None, - model_input_image[None, ...], # Use processed image + obs_frame_rgb[None, ...], # Use original image, let real_policy.py handle preprocessing ) + + print("\nDEBUG: Raw Action Output") + print(f"Action shape: {action.shape}") + print(f"Action min/max: [{action.min():.4f}, {action.max():.4f}]") + print(f"Action mean/std: [mean={action.mean():.4f}, std={action.std():.4f}]") + # convert to abs action relative_pose = action[:, :6] hand_action = action[:, 6:] + + print("\nDEBUG: Action Components") + print(f"Relative pose (first): {relative_pose[0]}") + print(f"Hand action (first): {hand_action[0][:4]}...") # Show first 4 joints + print(f"Relative pose norm: {np.linalg.norm(relative_pose[0][:3]):.4f}") # Position magnitude + print(f"Relative rotation norm: {np.linalg.norm(relative_pose[0][3:]):.4f}") # Rotation magnitude + # ============ DEBUG SECTION END ============ + relative_pose = np.array( [ vec6dof_to_homogeneous_matrix(rp[:3], rp[3:]) @@ -339,6 +331,13 @@ def main( for iter_idx in range(len(relative_pose)): # Direct application: T_BN = T_BE @ relative_pose T_BN[iter_idx] = T_BE @ relative_pose[iter_idx] + + # ============ DEBUG: Target Poses ============ + print("\nDEBUG: Target Transformation") + print(f"Current EE pose: {ee_aligned_pose[:3]}") # Current position + print(f"First target pose: {T_BN[0, :3, -1]}") # First target position + print(f"Position change: {T_BN[0, :3, -1] - ee_aligned_pose[:3]}") # Delta position + # ============ DEBUG END ============ # discard actions which in the past n_action = T_BN.shape[0] t_exec = time.monotonic() From 807f89d417e44b036af22e2ba62329e3df24e2fc Mon Sep 17 00:00:00 2001 From: Hly-123 <452663784@qq.com> Date: Fri, 5 Sep 2025 12:56:37 +0800 Subject: [PATCH 05/10] [fix] dp inference version 0.1, can move --- real_script/eval_policy/eval_xhand_franka.py | 122 +++++++++++++------ real_script/eval_policy/eval_xhand_franka.sh | 19 ++- 2 files changed, 89 insertions(+), 52 deletions(-) diff --git a/real_script/eval_policy/eval_xhand_franka.py b/real_script/eval_policy/eval_xhand_franka.py index 0179bf4..da99407 100644 --- a/real_script/eval_policy/eval_xhand_franka.py +++ b/real_script/eval_policy/eval_xhand_franka.py @@ -52,20 +52,20 @@ def compute_total_force_per_finger(all_fsr_observations): obs_horizon = 1 binary_cutoff = [10, 10, 10] -# Initial hand position (open position) +# Initial hand position (open position) - Updated to match open_gripper command initial_hand_pos = np.array([ - 0.92755819, - 0.52026953, - 0.22831853, - 0.0707963, - 1.1, - 0.15707963, - 0.95, - 0.12217305, - 1.0392188, - 0.03490659, - 1.0078164, - 0.17453293, + 1.516937255859375, + 0.5177657604217529, + 0.04799513891339302, + 0.01787799410521984, + 0.005817593075335026, + 0.034905556589365005, + 0.014543981291353703, + 0.011635186150670052, + 0.002908796537667513, + 0.01599838025867939, + 0.007271990645676851, + 0.024724768474698067, ]) @@ -142,14 +142,10 @@ def main( # Initialize FSR observations fsr_obs = deque(maxlen=obs_horizon) - fsr_raw_obs = dexhand_client.get_tactile(calc=True) - # Reshape fsr_raw_obs to add a batch dimension - fsr_raw_obs = fsr_raw_obs[None, ...] # This adds a dimension at the start - fsr_raw_obs = compute_total_force_per_finger(fsr_raw_obs)[0] - fsr_value = np.array(fsr_raw_obs[:3]) - print("fsr_value", fsr_value) + + # Initialize FSR with proper dimensions for _ in range(obs_horizon): - fsr_obs.append(np.zeros(2)) + fsr_obs.append(np.zeros(3, dtype=np.float32)) # 3 fingers, binary values print( "resetting hand----------------------------------------------------------------------------------" @@ -187,7 +183,7 @@ def main( while time.time() - session_start_time < session_duration: with FrameRateContext(frame_rate=inference_fps): # Get observation from camera - obs_frame = obs_camera.get_latest_frame() + obs_frame = obs_camera.get_camera_frame() obs_frame_recieved_time = obs_frame.receive_time obs_frame_rgb = obs_frame.rgb.copy() @@ -198,21 +194,28 @@ def main( print("Using FSR") fsr_raw_obs = dexhand_client.get_tactile(calc=True) print("raw", fsr_raw_obs) - # Reshape fsr_raw_obs to add a batch dimension - fsr_raw_obs = fsr_raw_obs[ - None, ... - ] # This adds a dimension at the start - fsr_raw_obs = compute_total_force_per_finger(fsr_raw_obs)[0] - fsr_value = np.array(fsr_raw_obs[:3]) + print(f"FSR raw shape: {fsr_raw_obs.shape}") + + # Process FSR data - should be (5, 3) -> compute total force per finger -> take first 3 fingers + if fsr_raw_obs.ndim == 2 and fsr_raw_obs.shape[0] >= 3: + # Compute total force per finger (magnitude of 3D force vector) + fsr_total_forces = np.linalg.norm(fsr_raw_obs, axis=1) # Shape: (5,) + # Take first 3 fingers to match training data + fsr_value = fsr_total_forces[:3] # Shape: (3,) + else: + # Fallback: if shape is unexpected, try to extract 3 values + fsr_value = fsr_raw_obs.flatten()[:3] + + print(f"fsr_value shape: {fsr_value.shape}") print("fsr_value", fsr_value) fsr_value = fsr_value.astype(np.float32) - # Apply binary cutoff - fsr_value_binary = (fsr_value >= binary_cutoff).astype( - np.float32 - ) + + # Apply binary cutoff (same as training) + fsr_value_binary = (fsr_value >= binary_cutoff).astype(np.float32) fsr_obs.append(fsr_value_binary) # inference action t_inference = time.monotonic() + t_inference_wall = time.time() # Wall clock time for interpolation # camera latency + transfer time print( "t_inference|obs_frame_recieved_time", @@ -223,7 +226,7 @@ def main( camera_latency + t_inference - obs_frame_recieved_time ) print("camera_total_latency", camera_total_latency) - t_actual_inference = t_inference - camera_total_latency + t_actual_inference = t_inference_wall - camera_total_latency # ============ DEBUG SECTION START ============ print("\n" + "="*50) @@ -234,8 +237,27 @@ def main( print(f"FSR obs shape: {np.array(list(fsr_obs)).shape}") print(f"FSR obs values: {np.array(list(fsr_obs))[-1]}") # Last FSR reading + # Get robot proprioception data for model input + robot_state = robot_client.get_state() + proprioception = None + + if robot_state and "state" in robot_state: + # Extract joint positions and velocities + joint_q = np.array(robot_state["state"]["ActualQ"]) # 7D joint positions + joint_dq = np.array(robot_state["state"]["ActualQd"]) # 7D joint velocities + # Create 14D proprioception vector: [joint_q, joint_dq] + proprioception = np.concatenate([joint_q, joint_dq]).astype(np.float32) + + print(f"Proprioception shape: {proprioception.shape}") + print(f"Joint positions: {joint_q}") + print(f"Joint velocities: {joint_dq}") + else: + # Fallback to zeros if state unavailable + proprioception = np.zeros(14, dtype=np.float32) + print("Warning: Robot state unavailable, using zero proprioception") + action = policy.predict_action( - None, + proprioception[None, ...] if proprioception is not None else None, # Add batch dimension np.array(list(fsr_obs)).astype(np.float32) if policy.model_cfg.dataset.enable_fsr else None, @@ -310,13 +332,33 @@ def main( robot_timestamp = np.array(robot_timestamp) robot_homogeneous_matrix = np.array(robot_homogeneous_matrix) - # Interpolate to get pose at inference time - robot_pose_interpolator = PoseInterpolator( - timestamps=robot_timestamp, - homogeneous_matrix=robot_homogeneous_matrix, - ) - aligned_pose = robot_pose_interpolator([t_actual_inference])[0] - ee_aligned_pose = homogeneous_matrix_to_6dof(aligned_pose) + # Handle insufficient history for interpolation + if len(robot_frames) < 2: + print(f"Warning: Only {len(robot_frames)} robot states in history, using current state") + # Use current robot state directly + current_state = robot_client.get_state() + tcp_pose = current_state["state"]["ActualTCPPose"] + xyz = tcp_pose[:3] + if len(tcp_pose) == 7: + rotvec = st.Rotation.from_quat(tcp_pose[3:]).as_rotvec() + else: + rotvec = tcp_pose[3:] + aligned_pose_matrix = vec6dof_to_homogeneous_matrix(xyz, rotvec) + aligned_pose = homogeneous_matrix_to_6dof(aligned_pose_matrix) + else: + # Interpolate to get pose at inference time + try: + robot_pose_interpolator = PoseInterpolator( + timestamps=robot_timestamp, + homogeneous_matrix=robot_homogeneous_matrix, + ) + aligned_pose_matrix = robot_pose_interpolator([t_actual_inference])[0] + aligned_pose = homogeneous_matrix_to_6dof(aligned_pose_matrix) + except ValueError as e: + print(f"Interpolation failed: {e}, using most recent state") + # Fallback to most recent state + aligned_pose = homogeneous_matrix_to_6dof(robot_homogeneous_matrix[-1]) + ee_aligned_pose = aligned_pose # Build current end-effector transformation matrix T_BE T_BE = np.eye(4) diff --git a/real_script/eval_policy/eval_xhand_franka.sh b/real_script/eval_policy/eval_xhand_franka.sh index 7c3655a..bf13a5d 100755 --- a/real_script/eval_policy/eval_xhand_franka.sh +++ b/real_script/eval_policy/eval_xhand_franka.sh @@ -3,17 +3,20 @@ # Evaluation script for XHand with Franka using HTTP control # Activate conda environment -source ~/miniconda3/etc/profile.d/conda.sh +source ~/anaconda3/etc/profile.d/conda.sh conda activate dexumi # Path to your trained model -MODEL_PATH="/path/to/your/trained/model" # TODO: Update this path +MODEL_PATH="/home/ubuntu/hgw/IL/DexUMI/data/weight/vision_tactile_propio" # TODO: Update this path CHECKPOINT=600 # Control parameters -FREQUENCY=10 # Control frequency in Hz +FREQUENCY=15 # Control frequency in Hz EXEC_HORIZON=8 # Number of action steps to execute before re-predicting +# Visualization settings +ENABLE_VISUALIZATION=false # Set to true to enable real-time camera visualization + # Camera configuration CAMERA_TYPE="realsense" # Options: "realsense" or "oak" @@ -63,14 +66,6 @@ python real_script/eval_policy/eval_xhand_franka.py \ --ckpt $CHECKPOINT \ --frequency $FREQUENCY \ --exec_horizon $EXEC_HORIZON \ - --camera_type $CAMERA_TYPE \ --camera_latency $CAMERA_LATENCY \ --hand_action_latency $HAND_ACTION_LATENCY \ - --robot_action_latency $ROBOT_ACTION_LATENCY \ - --video_record_path "$VIDEO_RECORD_PATH" - -# Optional: Enable record camera (requires second camera) -# Add --enable_record_camera flag if you have two cameras - -# Optional: Match episode path for comparison -# --match_episode_path "/path/to/reference/episodes" \ No newline at end of file + --robot_action_latency $ROBOT_ACTION_LATENCY \ No newline at end of file From be2f58f039c5ba4909c900e4d04721d81408336e Mon Sep 17 00:00:00 2001 From: Gray Date: Fri, 5 Sep 2025 19:45:01 +0800 Subject: [PATCH 06/10] [fix] timestamp fix --- dexumi/camera/realsense_camera.py | 2 +- real_script/eval_policy/eval_xhand_franka.py | 41 +++++++++++++++----- 2 files changed, 32 insertions(+), 11 deletions(-) diff --git a/dexumi/camera/realsense_camera.py b/dexumi/camera/realsense_camera.py index a451ae9..40c6e16 100644 --- a/dexumi/camera/realsense_camera.py +++ b/dexumi/camera/realsense_camera.py @@ -83,7 +83,7 @@ def get_camera_frame(self): try: # Wait for a coherent pair of frames frames = self.pipeline.wait_for_frames() - receive_time = time.monotonic() + receive_time = time.time() # Changed to wall clock time for consistency with HTTP client if self.align_to_color and self.enable_depth: frames = self.align.process(frames) diff --git a/real_script/eval_policy/eval_xhand_franka.py b/real_script/eval_policy/eval_xhand_franka.py index da99407..03d751b 100644 --- a/real_script/eval_policy/eval_xhand_franka.py +++ b/real_script/eval_policy/eval_xhand_franka.py @@ -213,10 +213,9 @@ def main( # Apply binary cutoff (same as training) fsr_value_binary = (fsr_value >= binary_cutoff).astype(np.float32) fsr_obs.append(fsr_value_binary) - # inference action - t_inference = time.monotonic() - t_inference_wall = time.time() # Wall clock time for interpolation - # camera latency + transfer time + # inference action - use wall clock time consistently + t_inference = time.time() # Use wall clock time consistently with camera and HTTP client + # camera latency + transfer time (both now in wall clock time) print( "t_inference|obs_frame_recieved_time", t_inference, @@ -226,7 +225,15 @@ def main( camera_latency + t_inference - obs_frame_recieved_time ) print("camera_total_latency", camera_total_latency) - t_actual_inference = t_inference_wall - camera_total_latency + t_actual_inference = t_inference - camera_total_latency + + # ============ TIMESTAMP VERIFICATION ============ + print(f"\n🕒 TIMESTAMP DEBUG:") + print(f" Camera timestamp: {obs_frame_recieved_time:.3f} (wall time)") + print(f" Inference timestamp: {t_inference:.3f} (wall time)") + print(f" Camera to inference delay: {t_inference - obs_frame_recieved_time:.3f}s (should be < 0.1s)") + print(f" Actual inference timestamp: {t_actual_inference:.3f} (wall time)") + # ============ TIMESTAMP VERIFICATION END ============ # ============ DEBUG SECTION START ============ print("\n" + "="*50) @@ -332,6 +339,14 @@ def main( robot_timestamp = np.array(robot_timestamp) robot_homogeneous_matrix = np.array(robot_homogeneous_matrix) + # ============ ROBOT TIMESTAMP VERIFICATION ============ + if len(robot_timestamp) > 0: + print(f"🤖 ROBOT TIMESTAMP DEBUG:") + print(f" Latest robot timestamp: {robot_timestamp[-1]:.3f} (wall time)") + print(f" Robot timestamps range: {len(robot_timestamp)} samples") + print(f" Time gap robot->inference: {t_actual_inference - robot_timestamp[-1]:.3f}s") + # ============ ROBOT TIMESTAMP VERIFICATION END ============ + # Handle insufficient history for interpolation if len(robot_frames) < 2: print(f"Warning: Only {len(robot_frames)} robot states in history, using current state") @@ -382,15 +397,22 @@ def main( # ============ DEBUG END ============ # discard actions which in the past n_action = T_BN.shape[0] - t_exec = time.monotonic() + t_exec = time.time() # Use wall clock time consistently robot_scheduled = 0 hand_scheduled = 0 # Process robot waypoints robot_times = t_actual_inference + np.arange(n_action) * dt valid_robot_idx = robot_times >= t_exec + robot_action_latency + dt - # convert to global time - robot_times = robot_times - time.monotonic() + time.time() + # robot_times are already in wall clock time, no conversion needed + + # ============ SCHEDULING TIME VERIFICATION ============ + print(f"📅 SCHEDULING DEBUG:") + print(f" Current execution time: {t_exec:.3f}") + print(f" First robot waypoint time: {robot_times[0]:.3f}") + print(f" Time until first execution: {robot_times[0] - t_exec:.3f}s") + print(f" Valid robot actions: {np.sum(valid_robot_idx)}/{n_action}") + # ============ SCHEDULING TIME VERIFICATION END ============ for k in np.where(valid_robot_idx)[0]: target_pose = np.zeros(6) target_pose[:3] = T_BN[k, :3, -1] @@ -403,8 +425,7 @@ def main( # Process hand waypoints hand_times = t_actual_inference + np.arange(n_action) * dt valid_hand_idx = hand_times >= t_exec + hand_action_latency + dt - # convert to global time - hand_times = hand_times - time.monotonic() + time.time() + # hand_times are already in wall clock time, no conversion needed for k in np.where(valid_hand_idx)[0]: target_hand_action = hand_action[k] dexhand_client.schedule_waypoint( From a8882207a205ef98d65305b27304bb12f71ee687 Mon Sep 17 00:00:00 2001 From: Gray Date: Tue, 9 Sep 2025 14:55:32 +0800 Subject: [PATCH 07/10] [doc] update CLAUDE.md --- CLAUDE.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CLAUDE.md b/CLAUDE.md index 627ae4d..fe52c2d 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -127,6 +127,8 @@ python linkage_optimization/get_equivalent_thumb.py -r /path/to/sim -b /path/to/ **Multi-Modal Data Handling**: Vision, proprioception, and force data are processed separately then concatenated for policy conditioning. Missing modalities are handled gracefully. +**Direct Proprioceptive Training**: The system supports direct training on robot proprioceptive data without complex preprocessing. This approach bypasses exoskeleton-to-robot data conversion, avoiding interpolation errors and maintaining data integrity. Robot joint angles, positions, and force readings are fed directly into the network model for training. + **Distributed Training**: Uses Accelerate for multi-GPU training with NCCL backend. Training supports gradient accumulation and EMA model updates. **Real-Time Control**: Policy evaluation runs in real-time with 30Hz control loop, requiring careful attention to inference latency and action smoothing. \ No newline at end of file From 4dae422c13cd31c63496c7ab8a966f43065626dd Mon Sep 17 00:00:00 2001 From: Gray Date: Tue, 9 Sep 2025 15:24:57 +0800 Subject: [PATCH 08/10] [feat] enhance data collection with quality tools MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix thread-safe camera collection with per-camera locks - Add pickle-to-zarr conversion for training compatibility - Add comprehensive data quality check system - Include interactive visualizer and export tools 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- .../XhandMultimodalCollection.py | 329 +++++--- .../data_collection/convert_pickle_to_zarr.py | 568 +++++++++++++ .../comprehensive_data_quality_check.py | 767 ++++++++++++++++++ .../data_check/doc/CONVERT_USAGE.md | 101 +++ .../doc/DATA_QUALITY_CHECK_USAGE.md | 236 ++++++ .../data_check/export_episode_sequence.py | 220 +++++ .../data_check/interactive_zarr_visualizer.py | 317 ++++++++ .../data_check/visualize_rgb.py | 149 ++++ 8 files changed, 2559 insertions(+), 128 deletions(-) create mode 100644 real_script/data_collection/convert_pickle_to_zarr.py create mode 100644 real_script/data_collection/data_check/comprehensive_data_quality_check.py create mode 100644 real_script/data_collection/data_check/doc/CONVERT_USAGE.md create mode 100644 real_script/data_collection/data_check/doc/DATA_QUALITY_CHECK_USAGE.md create mode 100644 real_script/data_collection/data_check/export_episode_sequence.py create mode 100644 real_script/data_collection/data_check/interactive_zarr_visualizer.py create mode 100644 real_script/data_collection/data_check/visualize_rgb.py diff --git a/real_script/data_collection/XhandMultimodalCollection.py b/real_script/data_collection/XhandMultimodalCollection.py index 4904b01..79a3a58 100644 --- a/real_script/data_collection/XhandMultimodalCollection.py +++ b/real_script/data_collection/XhandMultimodalCollection.py @@ -43,20 +43,24 @@ class CollectionState(Enum): class CameraCollector: """独立相机采集器 - 高帧率采集,记录硬件时间戳""" - def __init__(self): - self.camera_data = [] # 存储采集的相机帧 + def __init__(self, max_frames_per_camera: int = 10000): + self.camera_data = {} # 改为每个相机独立的数据列表 {camera_id: [frames]} self._state = CollectionState.STOPPED self._state_lock = threading.Lock() - self.lock = threading.Lock() + self.locks = {} # {camera_id: threading.Lock()} self._last_frame_hash = {} # 用于检测重复帧 + self._hash_lock = threading.Lock() # 保护 _last_frame_hash 的锁 + self.max_frames_per_camera = max_frames_per_camera # 每个相机最大帧数限制 def start_collecting(self): """开始新的采集会话(清空数据)""" with self._state_lock: self._state = CollectionState.RUNNING - with self.lock: - self.camera_data = [] - self._last_frame_hash = {} + # 清空所有相机的数据 + for cam_id in self.camera_data: + with self.locks[cam_id]: + self.camera_data[cam_id] = [] + self._last_frame_hash = {} def resume_collecting(self, pipelines=None): """恢复采集(保持现有数据)""" @@ -115,41 +119,75 @@ def _validate_frame(self, frame_data: Dict, cam_id: int) -> bool: # 检查图像内容是否变化(可选 - 用于调试) current_hash = hash(rgb[::8, ::8].tobytes()) # 使用降采样避免计算开销 - if cam_id in self._last_frame_hash: - if current_hash == self._last_frame_hash[cam_id]: - logger.debug(f"警告: 相机{cam_id}图像内容可能重复") - # 注意:这里不返回False,因为某些场景下图像确实可能相同 - self._last_frame_hash[cam_id] = current_hash + + # 使用锁保护 _last_frame_hash 的读写 + with self._hash_lock: + if cam_id in self._last_frame_hash: + if current_hash == self._last_frame_hash[cam_id]: + logger.debug(f"警告: 相机{cam_id}图像内容可能重复") + # 注意:这里不返回False,因为某些场景下图像确实可能相同 + self._last_frame_hash[cam_id] = current_hash return True def get_current_frame(self) -> Optional[Dict]: """获取最新的相机帧(用于实时显示)""" - with self.lock: - if self.camera_data: - return self.camera_data[-1] - return None + # 使用固定顺序加锁避免死锁 + latest_frame = {} + sorted_cam_ids = sorted(self.camera_data.keys()) # 固定加锁顺序 + + for cam_id in sorted_cam_ids: + with self.locks[cam_id]: + if self.camera_data[cam_id]: + # 快速复制引用,减少锁持有时间 + latest_frame.update(self.camera_data[cam_id][-1]) + + return latest_frame if latest_frame else None def get_collected_data(self) -> List[Dict]: """获取已采集的所有相机数据""" - with self.lock: - return self.camera_data.copy() + # 合并所有相机的数据 + merged_data = [] + + # 找到最短的相机数据长度(用于同步) + min_length = float('inf') + for cam_id in self.camera_data: + with self.locks[cam_id]: + length = len(self.camera_data[cam_id]) + if length > 0: + min_length = min(min_length, length) + + if min_length == float('inf'): + return [] + + # 按时间戳顺序合并数据 + for i in range(min_length): + frame_data = {} + for cam_id in self.camera_data: + with self.locks[cam_id]: + if i < len(self.camera_data[cam_id]): + frame_data.update(self.camera_data[cam_id][i]) + merged_data.append(frame_data) + + return merged_data def clear_data(self): """清空采集数据""" - with self.lock: - self.camera_data = [] + for cam_id in self.camera_data: + with self.locks[cam_id]: + self.camera_data[cam_id] = [] class RobotDataCollector: """独立机器人数据采集器 - 固定频率采集HTTP数据""" - def __init__(self, url: str = "http://127.0.0.1:5000/"): + def __init__(self, url: str = "http://127.0.0.1:5000/", max_frames: int = 10000): self.url = url self.robot_data = [] # 存储采集的机器人数据 self._state = CollectionState.STOPPED self._state_lock = threading.Lock() self.lock = threading.Lock() + self.max_frames = max_frames # 最大帧数限制 def start_collecting(self): """开始新的采集会话(清空数据)""" @@ -288,7 +326,7 @@ def stop(self): class XhandMultimodalDataCollector: """XHand + Franka多模态数据采集接口""" - def __init__(self, num_cameras: int = 1, xhand_port: str = "/dev/ttyUSB0"): + def __init__(self, num_cameras: int = 2, xhand_port: str = "/dev/ttyUSB0"): self.num_cameras = num_cameras self.xhand_port = xhand_port self.url = "http://127.0.0.1:5000/" @@ -301,9 +339,9 @@ def __init__(self, num_cameras: int = 1, xhand_port: str = "/dev/ttyUSB0"): self.camera_collector = CameraCollector() self.robot_collector = RobotDataCollector(self.url) - # 相机线程 + # 相机线程管理 self.camera_running = False - self.camera_thread = None + self.camera_threads = [] # 改为线程列表,每个相机一个线程 self.robot_thread = None # 性能统计 @@ -317,9 +355,16 @@ def __init__(self, num_cameras: int = 1, xhand_port: str = "/dev/ttyUSB0"): # 初始化系统 self._init_cameras() + # 为每个相机初始化独立的锁和数据存储 + for cam_info in self.pipelines: + cam_id = cam_info['camera_id'] + self.camera_collector.locks[cam_id] = threading.Lock() + self.camera_collector.camera_data[cam_id] = [] + print("✓ XHand + Franka多模态数据采集系统初始化完成") print("✓ 相机分辨率: 240x240 (匹配DexUMI训练需求)") print("✓ 独立采集模式已启用") + print(f"✓ 初始化了 {len(self.pipelines)} 个相机的独立数据存储") # 初始化重置 print("\n执行初始化重置...") @@ -327,8 +372,9 @@ def __init__(self, num_cameras: int = 1, xhand_port: str = "/dev/ttyUSB0"): def _init_cameras(self): """初始化相机 - 240x240分辨率""" - top_serial = "244622072813" - wrist_serial = "230322271519" + # 更新为当前实际连接的相机序列号 + top_serial = "218622274962" # 更新序列号 , 218622274962 + wrist_serial = "218622270499" # 更新序列号 , 218622270499 try: ctx = rs.context() @@ -387,10 +433,18 @@ def _init_cameras(self): def start_collection(self): """启动独立的数据采集线程""" - # 启动相机采集线程 + # 启动每个相机的独立采集线程 self.camera_running = True - self.camera_thread = threading.Thread(target=self._camera_thread, daemon=True) - self.camera_thread.start() + self.camera_threads = [] + + for cam_info in self.pipelines: + camera_thread = threading.Thread( + target=self._single_camera_thread, + args=(cam_info,), + daemon=True + ) + camera_thread.start() + self.camera_threads.append(camera_thread) # 启动机器人数据采集线程 self.robot_thread = threading.Thread(target=self._robot_thread, daemon=True) @@ -400,10 +454,10 @@ def start_collection(self): self.camera_collector.start_collecting() self.robot_collector.start_collecting() - print("✓ 独立采集线程已启动") + print(f"✓ 启动了 {len(self.camera_threads)} 个相机独立采集线程") def stop_collection(self): - """停止数据采集""" + """停止数据采集并验证线程停止""" # 停止采集 self.camera_collector.stop_collecting() self.robot_collector.stop_collecting() @@ -411,110 +465,106 @@ def stop_collection(self): # 停止线程 self.camera_running = False - if self.camera_thread: - self.camera_thread.join(timeout=2.0) + # 等待所有相机线程结束,并验证 + for i, thread in enumerate(self.camera_threads): + thread.join(timeout=5.0) + if thread.is_alive(): + logger.warning(f"警告: 相机线程 {i} 未能在5秒内停止") + # 这里可以考虑强制终止,但Python线程不支持强制终止 + # 最好的做法是确保线程内部有正确的退出逻辑 + if self.robot_thread: - self.robot_thread.join(timeout=2.0) + self.robot_thread.join(timeout=5.0) + if self.robot_thread.is_alive(): + logger.warning("警告: 机器人线程未能在5秒内停止") + + # 清理线程列表 + self.camera_threads = [] + self.robot_thread = None print("✓ 数据采集已停止") - def _camera_thread(self): - """相机采集线程 - 最大帧率采集""" - print("相机采集线程启动") + def _single_camera_thread(self, cam_info): + """单个相机的独立采集线程""" + camera_id = cam_info['camera_id'] + print(f"相机 {camera_id} 采集线程启动 (序列号: {cam_info['serial']})") frame_counter = 0 while self.camera_running: if not self.camera_collector.is_running(): - # 不采集时持续消费帧,防止缓冲区积累旧数据 - for cam_info in self.pipelines: - try: - # 使用try_wait_for_frames短时间等待并丢弃帧 - frames = cam_info['pipeline'].try_wait_for_frames(timeout_ms=10) - except: - pass + # 不采集时持续消费帧,防止缓冲区积累 + try: + frames = cam_info['pipeline'].try_wait_for_frames(timeout_ms=10) + except: + pass time.sleep(0.01) continue try: capture_start = time.time() - frame_data = {} - # 获取所有相机帧 - for cam_info in self.pipelines: - try: - # 等待新帧,设置合理超时时间 - frames = cam_info['pipeline'].wait_for_frames(timeout_ms=50) - if not frames: - continue - - color_frame = frames.get_color_frame() - - if color_frame: - # 记录硬件时间戳和系统时间戳 - hardware_timestamp = color_frame.get_timestamp() # 硬件时间戳 - system_timestamp = time.time() # 系统时间戳 - - # 增加帧计数器用于调试 - frame_counter += 1 - - # 获取原始图像并处理为240x240 - # 安全的内存拷贝 - raw_img = np.asanyarray(color_frame.get_data()).copy() # (240, 424, 3) - - # 中心裁剪为240x240 - h, w = raw_img.shape[:2] - crop_size = min(h, w) # 240 - start_x = (w - crop_size) // 2 # (424-240)//2 = 92 - start_y = (h - crop_size) // 2 # 0 - - # 使用copy()确保是独立的数组,不是视图 - cropped = raw_img[start_y:start_y+crop_size, start_x:start_x+crop_size].copy() - - # 确保正确的240x240尺寸 - if cropped.shape[:2] != (240, 240): - cropped = cv2.resize(cropped, (240, 240), interpolation=cv2.INTER_AREA) - - frame_data[f'camera_{cam_info["camera_id"]}'] = { - 'rgb': cropped, - 'hardware_timestamp': hardware_timestamp, - 'system_timestamp': system_timestamp, - 'capture_time': capture_start - } - except Exception as e: - logger.debug(f"相机 {cam_info['camera_id']} 获取帧失败: {e}") - - # 验证和存储相机数据 - if frame_data: - # 验证帧数据有效性 - valid_frames = {} - for cam_key, cam_data in frame_data.items(): - if cam_key.startswith('camera_'): - cam_id = int(cam_key.split('_')[1]) - if self.camera_collector._validate_frame(cam_data, cam_id): - valid_frames[cam_key] = cam_data - else: - logger.debug(f"相机 {cam_id} 帧数据验证失败") + # 获取相机帧 + frames = cam_info['pipeline'].wait_for_frames(timeout_ms=50) + if not frames: + continue - # 只存储验证成功的帧 - if valid_frames: - with self.camera_collector.lock: - self.camera_collector.camera_data.append(valid_frames) - - # 更新FPS统计 - now = time.time() - if self.perf_stats['last_camera_time'] > 0: - fps = 1.0 / (now - self.perf_stats['last_camera_time']) - self.perf_stats['camera_fps'].append(fps) - self.perf_stats['last_camera_time'] = now + color_frame = frames.get_color_frame() + if color_frame: + # 记录时间戳 + hardware_timestamp = color_frame.get_timestamp() + system_timestamp = time.time() + frame_counter += 1 + + # 处理图像 + raw_img = np.asanyarray(color_frame.get_data()).copy() + + # 中心裁剪为240x240 + h, w = raw_img.shape[:2] + crop_size = min(h, w) + start_x = (w - crop_size) // 2 + start_y = (h - crop_size) // 2 + cropped = raw_img[start_y:start_y+crop_size, start_x:start_x+crop_size].copy() + + if cropped.shape[:2] != (240, 240): + cropped = cv2.resize(cropped, (240, 240), interpolation=cv2.INTER_AREA) + + frame_data = { + f'camera_{camera_id}': { + 'rgb': cropped, + 'hardware_timestamp': hardware_timestamp, + 'system_timestamp': system_timestamp, + 'capture_time': capture_start + } + } + + # 线程安全地存储数据,实施有界缓冲 + with self.camera_collector.locks[camera_id]: + cam_data = self.camera_collector.camera_data[camera_id] + cam_data.append(frame_data) + + # 如果超过最大帧数限制,移除最旧的帧(FIFO) + if len(cam_data) > self.camera_collector.max_frames_per_camera: + cam_data.pop(0) # 移除最旧的帧 + logger.debug(f"相机{camera_id}达到最大帧数限制,移除最旧帧") + + # 更新FPS统计 - 为每个相机独立统计 + now = time.time() + if hasattr(self, f'_last_camera_{camera_id}_time'): + last_time = getattr(self, f'_last_camera_{camera_id}_time') + if last_time > 0: + fps = 1.0 / (now - last_time) + self.perf_stats['camera_fps'].append(fps) + setattr(self, f'_last_camera_{camera_id}_time', now) + except Exception as e: - logger.debug(f"相机采集错误: {e}") + logger.debug(f"相机 {camera_id} 采集错误: {e}") - # 控制轮询频率 - time.sleep(0.005) # 5ms轮询间隔,在性能和稳定性间平衡 + # 最小延迟,让CPU有机会调度其他线程 + time.sleep(0.001) # 1ms - print("相机采集线程停止") - + print(f"相机 {camera_id} 采集线程停止") + def _robot_thread(self): """机器人数据采集线程 - 固定20Hz采集""" print("机器人采集线程启动") @@ -536,8 +586,6 @@ def _robot_thread(self): tactile_response = requests.post(self.url + "get_handtactile", timeout=1.0) tactile_data = tactile_response.json() tactile_end = time.time() - - # ly TODO: send object pose here from Foundationpose # 计算网络延迟 robot_delay = request_end - request_start @@ -569,9 +617,14 @@ def _robot_thread(self): 'network_delay_tactile': tactile_delay } - # 存储机器人数据 + # 存储机器人数据,实施有界缓冲 with self.robot_collector.lock: self.robot_collector.robot_data.append(state_data) + + # 如果超过最大帧数限制,移除最旧的帧 + if len(self.robot_collector.robot_data) > self.robot_collector.max_frames: + self.robot_collector.robot_data.pop(0) + logger.debug("机器人数据达到最大帧数限制,移除最旧帧") # 更新FPS统计 now = time.time() @@ -664,17 +717,37 @@ def get_performance_stats(self) -> Dict: 'robot_frames': len(self.robot_collector.get_collected_data()) } - def __del__(self): - """清理资源""" + def cleanup(self): + """显式清理资源 - 应该在程序退出前调用""" try: + # 停止所有采集 self.stop_collection() + + # 关闭所有相机pipeline for cam_info in self.pipelines: try: cam_info['pipeline'].stop() - except: - pass + logger.info(f"相机 {cam_info['camera_id']} pipeline已关闭") + except Exception as e: + logger.error(f"关闭相机 {cam_info['camera_id']} 失败: {e}") + + # 清空pipeline列表 + self.pipelines = [] + + logger.info("✓ 所有资源已清理") + except Exception as e: logger.error(f"清理资源时出错: {e}") + + def __del__(self): + """析构函数 - 仅作为备份,不应依赖此方法""" + # 尝试清理,但不应该依赖析构函数 + try: + if hasattr(self, 'pipelines') and self.pipelines: + logger.warning("警告: 依赖__del__进行资源清理,应显式调用cleanup()") + self.cleanup() + except: + pass # 在析构时忽略所有错误 def save_episode_offline_aligned(episode_path: str, camera_data: List[Dict], robot_data: List[Dict]) -> bool: @@ -805,8 +878,8 @@ def display_status_offline_aligned(display_data: Dict, episode_num: int, perf_st def main(): parser = argparse.ArgumentParser(description="XHand + Franka多模态数据采集") - parser.add_argument('--num_cameras', type=int, default=1, choices=[1, 2], - help='使用的相机数量 (默认: 1)') + parser.add_argument('--num_cameras', type=int, default=2, choices=[1, 2], + help='使用的相机数量 (默认: 2)') parser.add_argument('--data_dir', type=str, default='XhandData_Multimodal', help='数据保存目录 (默认: XhandData_Multimodal)') parser.add_argument('--episode_start', type=int, default=None, @@ -964,10 +1037,10 @@ def main(): finally: collector.stop() try: - data_collector.stop_collection() - except: - pass - del data_collector + # 使用显式的cleanup方法而不是依赖__del__ + data_collector.cleanup() + except Exception as e: + logger.error(f"清理数据采集器失败: {e}") print("\n多模态采集系统已关闭") diff --git a/real_script/data_collection/convert_pickle_to_zarr.py b/real_script/data_collection/convert_pickle_to_zarr.py new file mode 100644 index 0000000..4de9468 --- /dev/null +++ b/real_script/data_collection/convert_pickle_to_zarr.py @@ -0,0 +1,568 @@ +""" +Convert pickle format data to zarr format for DexUMI training + +This script converts collected data from pickle format to zarr format that is compatible +with DexUMI training pipeline. + +Input structure (pickle): + collected_data/ + └── episode_0/ + ├── camera_0/ + │ ├── rgb.pkl # [T, H, W, 3] + │ └── receive_time.pkl # [T] + ├── pose.pkl # [T, 7] (xyz + quaternion) + ├── hand_action.pkl # [T, 12] + ├── proprioception.pkl # [T, 14] + ├── fsr.pkl # [T, 5, 3] + └── timestamps.pkl # dict with various timestamps + +Output structure (zarr): + dataset.zarr/ + └── episode_0/ + ├── pose # [T, 6] (xyz + euler angles) + ├── hand_action # [T, 12] + ├── proprioception # [T, 14] + ├── fsr # [T, 3] (averaged across fingers) + ├── camera_0/ + │ └── rgb # [T, H, W, 3] + └── camera_1/ # (if available) + └── rgb # [T, H, W, 3] +""" + +import os +import pickle +import numpy as np +import zarr +from typing import Dict, List, Optional, Tuple +import argparse +from pathlib import Path +from tqdm import tqdm +from scipy.spatial.transform import Rotation + + +def load_timestamps(timestamp_path: Path) -> np.ndarray: + """ + Load timestamps from pickle file + + Args: + timestamp_path: Path to timestamp pickle file + + Returns: + timestamps: Array of timestamps in seconds + """ + with open(timestamp_path, "rb") as f: + timestamps = pickle.load(f) + + # Handle different timestamp formats + if isinstance(timestamps, dict): + # For main timestamps.pkl files + if 'main_timestamps' in timestamps: + return timestamps['main_timestamps'] + elif 'robot_state_timestamps' in timestamps: + return timestamps['robot_state_timestamps'] + else: + # Return first available timestamp array + for value in timestamps.values(): + if isinstance(value, np.ndarray): + return value + elif isinstance(timestamps, np.ndarray): + # For direct timestamp arrays (like receive_time.pkl) + # Convert from milliseconds to seconds if needed + if timestamps.max() > 1e10: # Likely milliseconds + return timestamps / 1000.0 + return timestamps + + raise ValueError(f"Unsupported timestamp format in {timestamp_path}") + + +def find_nearest_timestamps(target_timestamps: np.ndarray, source_timestamps: np.ndarray) -> np.ndarray: + """ + Find nearest timestamp indices using binary search for efficiency + + Args: + target_timestamps: Target timestamps to match to (e.g., robot timestamps) + source_timestamps: Source timestamps to match from (e.g., camera timestamps) + + Returns: + matched_indices: Indices in source_timestamps that best match target_timestamps + """ + matched_indices = np.zeros(len(target_timestamps), dtype=int) + + for i, target_ts in enumerate(target_timestamps): + # Find closest timestamp using binary search + idx = np.searchsorted(source_timestamps, target_ts) + + # Handle boundary cases + if idx == 0: + matched_indices[i] = 0 + elif idx == len(source_timestamps): + matched_indices[i] = len(source_timestamps) - 1 + else: + # Choose the closer one + if abs(source_timestamps[idx-1] - target_ts) <= abs(source_timestamps[idx] - target_ts): + matched_indices[i] = idx - 1 + else: + matched_indices[i] = idx + + return matched_indices + + +def align_multimodal_episode(episode_path: Path, camera_ids: Optional[List[int]] = None) -> Dict: + """ + Align multimodal data for a single episode using timestamp matching + + Args: + episode_path: Path to episode directory + camera_ids: List of camera IDs to include (None = all cameras) + + Returns: + Dict containing aligned episode data + """ + # Load robot timestamps (low frequency reference) + robot_timestamps = load_timestamps(episode_path / "timestamps.pkl") + + # Load robot state data + with open(episode_path / "pose.pkl", "rb") as f: + pose_data = pickle.load(f) + positions = pose_data[:, :3] + quaternions = pose_data[:, 3:7] + euler_angles = np.array([quaternion_to_euler(q) for q in quaternions]) + aligned_pose = np.concatenate([positions, euler_angles], axis=-1).astype(np.float32) + + with open(episode_path / "hand_action.pkl", "rb") as f: + aligned_hand_action = pickle.load(f).astype(np.float32) + + with open(episode_path / "proprioception.pkl", "rb") as f: + aligned_proprioception = pickle.load(f).astype(np.float32) + + with open(episode_path / "fsr.pkl", "rb") as f: + fsr_data = pickle.load(f) + aligned_fsr = np.mean(fsr_data, axis=1).astype(np.float32) + + # Verify robot data consistency + robot_length = len(aligned_pose) + for data, name in [(aligned_hand_action, "hand_action"), (aligned_proprioception, "proprioception"), (aligned_fsr, "fsr")]: + if len(data) != robot_length: + print(f"Warning: {name} length ({len(data)}) doesn't match pose length ({robot_length})") + + # Align camera data + aligned_cameras = {} + for cam_dir in episode_path.glob("camera_*"): + if cam_dir.is_dir(): + cam_id = int(cam_dir.name.split("_")[1]) + + # Filter by camera_ids if specified + if camera_ids is not None and cam_id not in camera_ids: + print(f" Skipping camera {cam_id} (not in selected camera IDs: {camera_ids})") + continue + + # Load camera timestamps and data + camera_timestamps = load_timestamps(cam_dir / "receive_time.pkl") + + with open(cam_dir / "rgb.pkl", "rb") as f: + rgb_data = pickle.load(f) + if rgb_data.dtype != np.uint8: + rgb_data = (rgb_data * 255).astype(np.uint8) + + # Find matching camera frames for each robot timestamp + matched_indices = find_nearest_timestamps(robot_timestamps, camera_timestamps) + + # Align camera data to robot frequency + aligned_rgb = rgb_data[matched_indices] + aligned_cameras[f"camera_{cam_id}"] = aligned_rgb + + # Print alignment info + max_time_diff = np.max(np.abs(camera_timestamps[matched_indices] - robot_timestamps)) + print(f" Camera {cam_id}: {len(rgb_data)} -> {len(aligned_rgb)} frames, max time diff: {max_time_diff:.3f}s") + + return { + "pose": aligned_pose, + "hand_action": aligned_hand_action, + "proprioception": aligned_proprioception, + "fsr": aligned_fsr, + "cameras": aligned_cameras + } + + +def quaternion_to_euler(quat: np.ndarray) -> np.ndarray: + """ + Convert quaternion (w, x, y, z) to Euler angles (roll, pitch, yaw) + + Args: + quat: Array of shape (..., 4) with quaternion in (w, x, y, z) format + + Returns: + euler: Array of shape (..., 3) with Euler angles in radians + """ + # Ensure input is numpy array + quat = np.asarray(quat) + + # Handle both single quaternion and batch + original_shape = quat.shape + if quat.ndim == 1: + quat = quat.reshape(1, -1) + + # Convert from (w, x, y, z) to (x, y, z, w) for scipy + quat_scipy = np.concatenate([quat[..., 1:], quat[..., :1]], axis=-1) + + # Create rotation object and get euler angles + r = Rotation.from_quat(quat_scipy) + euler = r.as_euler('xyz', degrees=False) + + # Restore original shape + if len(original_shape) == 1: + euler = euler.squeeze(0) + + return euler + + +def load_pickle_episode(episode_path: Path, multimodal_format: bool = False, camera_ids: Optional[List[int]] = None) -> Dict: + """ + Load a single episode from pickle files + + Args: + episode_path: Path to episode directory containing pickle files + multimodal_format: Whether to use XhandData_Multimodal format with timestamp alignment + camera_ids: List of camera IDs to include (None = all cameras) + + Returns: + Dict containing all episode data + """ + if multimodal_format: + return align_multimodal_episode(episode_path, camera_ids) + + # Original format loading + data = {} + + # Load core data files + with open(episode_path / "pose.pkl", "rb") as f: + pose_data = pickle.load(f) + # Convert quaternion (xyz + quat_wxyz) to 6DoF (xyz + euler_xyz) + positions = pose_data[:, :3] # xyz positions + quaternions = pose_data[:, 3:7] # quaternion (w, x, y, z) + euler_angles = np.array([quaternion_to_euler(q) for q in quaternions]) + data["pose"] = np.concatenate([positions, euler_angles], axis=-1).astype(np.float32) + + with open(episode_path / "hand_action.pkl", "rb") as f: + data["hand_action"] = pickle.load(f).astype(np.float32) + + with open(episode_path / "proprioception.pkl", "rb") as f: + data["proprioception"] = pickle.load(f).astype(np.float32) + + with open(episode_path / "fsr.pkl", "rb") as f: + fsr_data = pickle.load(f) # Shape: [T, 5, 3] + # Average across fingers to get [T, 3] + data["fsr"] = np.mean(fsr_data, axis=1).astype(np.float32) + + # Load camera data + cameras = {} + for cam_dir in episode_path.glob("camera_*"): + if cam_dir.is_dir(): + cam_id = int(cam_dir.name.split("_")[1]) + + # Filter by camera_ids if specified + if camera_ids is not None and cam_id not in camera_ids: + print(f" Skipping camera {cam_id} (not in selected camera IDs: {camera_ids})") + continue + + rgb_path = cam_dir / "rgb.pkl" + if rgb_path.exists(): + with open(rgb_path, "rb") as f: + rgb_data = pickle.load(f) + # Ensure RGB data is uint8 and has correct shape + if rgb_data.dtype != np.uint8: + rgb_data = (rgb_data * 255).astype(np.uint8) + cameras[f"camera_{cam_id}"] = rgb_data + + data["cameras"] = cameras + + # Verify data consistency + episode_length = len(data["pose"]) + for key in ["hand_action", "proprioception", "fsr"]: + if len(data[key]) != episode_length: + print(f"Warning: {key} length ({len(data[key])}) doesn't match pose length ({episode_length})") + + return data + + +def detect_data_format(input_path: Path) -> bool: + """ + Detect if the data is in XhandData_Multimodal format + + Args: + input_path: Path to input directory + + Returns: + True if multimodal format, False if original format + """ + # Check first episode for multimodal format indicators + episode_dirs = sorted([d for d in input_path.glob("episode_*") if d.is_dir()]) + if not episode_dirs: + return False + + first_episode = episode_dirs[0] + + # Check for multimodal format indicators + has_timestamps = (first_episode / "timestamps.pkl").exists() + has_camera_timestamps = any((cam_dir / "receive_time.pkl").exists() + for cam_dir in first_episode.glob("camera_*") + if cam_dir.is_dir()) + + return has_timestamps and has_camera_timestamps + + +def create_zarr_dataset( + input_dir: str, + output_path: str, + episode_ids: Optional[List[int]] = None, + compression: str = "blosc", + overwrite: bool = False, + multimodal_format: Optional[bool] = None, + camera_ids: Optional[List[int]] = None +) -> None: + """ + Convert pickle episodes to zarr format + + Args: + input_dir: Directory containing pickle episodes + output_path: Path to output zarr file + episode_ids: List of episode IDs to convert (None = all) + compression: Compression algorithm for zarr + overwrite: Whether to overwrite existing zarr file + multimodal_format: Whether to use XhandData_Multimodal format (None = auto-detect) + camera_ids: List of camera IDs to include (None = all cameras) + """ + input_path = Path(input_dir) + + # Auto-detect data format if not specified + if multimodal_format is None: + multimodal_format = detect_data_format(input_path) + format_type = "XhandData_Multimodal" if multimodal_format else "Original" + print(f"Auto-detected data format: {format_type}") + + # Find all episodes + if episode_ids is None: + episode_dirs = sorted([d for d in input_path.glob("episode_*") if d.is_dir()]) + episode_ids = [int(d.name.split("_")[1]) for d in episode_dirs] + else: + episode_dirs = [input_path / f"episode_{i}" for i in episode_ids] + + if not episode_dirs: + print(f"No episodes found in {input_dir}") + return + + print(f"Found {len(episode_dirs)} episodes to convert") + if multimodal_format: + print("Using multimodal format with timestamp alignment") + if camera_ids is not None: + print(f"Filtering cameras: only including camera IDs {camera_ids}") + + # Create or open zarr file + if overwrite and os.path.exists(output_path): + import shutil + shutil.rmtree(output_path) + + store = zarr.DirectoryStore(output_path) + root = zarr.group(store=store, overwrite=overwrite) + + # Process each episode + for episode_dir in tqdm(episode_dirs, desc="Converting episodes"): + episode_name = episode_dir.name + print(f"\nProcessing {episode_name}...") + + try: + # Load episode data + episode_data = load_pickle_episode(episode_dir, multimodal_format=multimodal_format, camera_ids=camera_ids) + + # Create episode group in zarr + episode_group = root.create_group(episode_name, overwrite=True) + + # Save core data + # Pose: [T, 6] (xyz + euler angles) + episode_group.create_dataset( + "pose", + data=episode_data["pose"], + chunks=(100, 6), + dtype=np.float32, + compressor=zarr.Blosc(cname=compression, clevel=5, shuffle=1) + ) + + # Hand action: [T, 12] + episode_group.create_dataset( + "hand_action", + data=episode_data["hand_action"], + chunks=(100, 12), + dtype=np.float32, + compressor=zarr.Blosc(cname=compression, clevel=5, shuffle=1) + ) + + # Proprioception: [T, 14] + episode_group.create_dataset( + "proprioception", + data=episode_data["proprioception"], + chunks=(100, 14), + dtype=np.float32, + compressor=zarr.Blosc(cname=compression, clevel=5, shuffle=1) + ) + + # FSR: [T, 3] + episode_group.create_dataset( + "fsr", + data=episode_data["fsr"], + chunks=(100, 3), + dtype=np.float32, + compressor=zarr.Blosc(cname=compression, clevel=5, shuffle=1) + ) + + # Save camera data + for cam_name, rgb_data in episode_data["cameras"].items(): + cam_group = episode_group.create_group(cam_name) + + # RGB data: [T, H, W, C] + cam_group.create_dataset( + "rgb", + data=rgb_data, + chunks=(10, rgb_data.shape[1], rgb_data.shape[2], 3), + dtype=np.uint8, + compressor=zarr.Blosc(cname=compression, clevel=5, shuffle=1) + ) + + # Print episode info + print(f" - Pose shape: {episode_data['pose'].shape}") + print(f" - Hand action shape: {episode_data['hand_action'].shape}") + print(f" - Proprioception shape: {episode_data['proprioception'].shape}") + print(f" - FSR shape: {episode_data['fsr'].shape}") + for cam_name, rgb_data in episode_data["cameras"].items(): + print(f" - {cam_name} RGB shape: {rgb_data.shape}") + + except Exception as e: + print(f"Error processing {episode_name}: {e}") + continue + + print(f"\n✓ Conversion complete! Zarr dataset saved to: {output_path}") + + # Print dataset summary + print("\nDataset summary:") + total_frames = 0 + for episode_name in root.group_keys(): + episode = root[episode_name] + frames = len(episode["pose"]) + total_frames += frames + print(f" - {episode_name}: {frames} frames") + print(f" Total: {len(list(root.group_keys()))} episodes, {total_frames} frames") + + +def verify_zarr_dataset(zarr_path: str, num_samples: int = 3) -> None: + """ + Verify the converted zarr dataset + + Args: + zarr_path: Path to zarr dataset + num_samples: Number of sample frames to check + """ + print(f"\nVerifying zarr dataset: {zarr_path}") + + root = zarr.open(zarr_path, mode='r') + episodes = list(root.group_keys()) + + print(f"Found {len(episodes)} episodes") + + for episode_name in episodes[:num_samples]: + print(f"\n{episode_name}:") + episode = root[episode_name] + + # Check all expected keys + expected_keys = ["pose", "hand_action", "proprioception", "fsr"] + for key in expected_keys: + if key in episode: + data = episode[key] + print(f" - {key}: shape={data.shape}, dtype={data.dtype}") + # Print sample values + if len(data) > 0: + print(f" Sample: {data[0][:5]}...") + else: + print(f" - {key}: MISSING") + + # Check cameras + for key in episode.group_keys(): + if key.startswith("camera_"): + cam_group = episode[key] + if "rgb" in cam_group: + rgb_data = cam_group["rgb"] + print(f" - {key}/rgb: shape={rgb_data.shape}, dtype={rgb_data.dtype}") + # Check value range + if len(rgb_data) > 0: + print(f" Value range: [{rgb_data[0].min()}, {rgb_data[0].max()}]") + + +def main(): + parser = argparse.ArgumentParser(description="Convert pickle data to zarr format for DexUMI") + parser.add_argument( + "--input_dir", + type=str, + default="collected_data", + help="Input directory containing pickle episodes" + ) + parser.add_argument( + "--output_path", + type=str, + default="dataset.zarr", + help="Output zarr file path" + ) + parser.add_argument( + "--episodes", + type=int, + nargs="+", + default=None, + help="Specific episode IDs to convert (default: all)" + ) + parser.add_argument( + "--compression", + type=str, + default="blosclz", + choices=["blosclz", "zstd", "lz4"], + help="Compression algorithm" + ) + parser.add_argument( + "--overwrite", + action="store_true", + help="Overwrite existing zarr file" + ) + parser.add_argument( + "--verify", + action="store_true", + help="Verify the converted dataset" + ) + parser.add_argument( + "--multimodal_format", + action="store_true", + help="Force XhandData_Multimodal format with timestamp alignment (default: auto-detect)" + ) + parser.add_argument( + "--camera_ids", + type=int, + nargs="+", + default=None, + help="Specific camera IDs to include (e.g., --camera_ids 0 1). Default: all cameras" + ) + + args = parser.parse_args() + + # Convert data + create_zarr_dataset( + input_dir=args.input_dir, + output_path=args.output_path, + episode_ids=args.episodes, + compression=args.compression, + overwrite=args.overwrite, + multimodal_format=args.multimodal_format if args.multimodal_format else None, + camera_ids=args.camera_ids + ) + + # Verify if requested + if args.verify: + verify_zarr_dataset(args.output_path) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/real_script/data_collection/data_check/comprehensive_data_quality_check.py b/real_script/data_collection/data_check/comprehensive_data_quality_check.py new file mode 100644 index 0000000..daffdc5 --- /dev/null +++ b/real_script/data_collection/data_check/comprehensive_data_quality_check.py @@ -0,0 +1,767 @@ +#!/usr/bin/env python3 +""" +综合数据质量检查脚本 - 检查XHand多模态数据的质量问题 + +功能: +1. 轨迹长度分布检查 - 识别异常短或长的轨迹 +2. 数据分布分析 - TCP位置、关节角度、触觉数据的统计分析 +3. 异常轨迹检测 - 静止轨迹、异常跳跃、图像质量问题 +4. 数据完整性验证 - 文件存在性、帧数匹配 +5. 可视化报告 - 生成统计图表和异常列表 + +作者: Claude +日期: 2024-09-09 +""" + +import pickle +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +from pathlib import Path +import os +import sys +import argparse +from typing import Dict, List, Tuple, Optional +import warnings +from collections import defaultdict +# import cv2 # 暂时不需要cv2 + +# 设置字体 - 如果没有中文字体就使用默认字体 +try: + plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans', 'Arial'] + plt.rcParams['axes.unicode_minus'] = False +except: + pass + +warnings.filterwarnings('ignore') + +class DataQualityChecker: + """数据质量检查器""" + + def __init__(self, data_dir: str, min_trajectory_length: int = 50, + max_trajectory_length: int = 2000, workspace_bounds: Dict = None): + """ + 初始化数据质量检查器 + + Args: + data_dir: 数据目录路径 + min_trajectory_length: 最小轨迹长度阈值 + max_trajectory_length: 最大轨迹长度阈值 + workspace_bounds: 工作空间边界 {'x': [min, max], 'y': [min, max], 'z': [min, max]} + """ + self.data_dir = Path(data_dir) + self.min_traj_len = min_trajectory_length + self.max_traj_len = max_trajectory_length + + # 默认工作空间边界 (基于Franka机械臂的典型工作空间) + self.workspace_bounds = workspace_bounds or { + 'x': [0.2, 0.8], # 机械臂前方0.2-0.8m + 'y': [-0.4, 0.4], # 左右±0.4m + 'z': [0.0, 0.6] # 高度0-0.6m + } + + # 存储检查结果 + self.results = { + 'episodes': [], + 'length_issues': [], + 'distribution_issues': [], + 'anomaly_issues': [], + 'completeness_issues': [], + 'statistics': {} + } + + def check_all_episodes(self) -> Dict: + """检查所有episode的数据质量""" + print(f"🔍 开始检查数据目录: {self.data_dir}") + print("=" * 80) + + # 获取所有episode目录 + episode_dirs = sorted([d for d in self.data_dir.glob("episode_*") if d.is_dir()]) + + if not episode_dirs: + print("❌ 未找到任何episode目录!") + return self.results + + print(f"📊 找到 {len(episode_dirs)} 个episodes,开始质量检查...") + + # 检查每个episode + for i, episode_dir in enumerate(episode_dirs): + print(f"\n[{i+1}/{len(episode_dirs)}] 检查 {episode_dir.name}...") + episode_result = self._check_single_episode(episode_dir) + self.results['episodes'].append(episode_result) + + # 计算整体统计 + self._compute_overall_statistics() + + # 生成报告 + self._generate_report() + + return self.results + + def _check_single_episode(self, episode_dir: Path) -> Dict: + """检查单个episode的数据质量""" + result = { + 'name': episode_dir.name, + 'path': str(episode_dir), + 'length_check': {}, + 'distribution_check': {}, + 'anomaly_check': {}, + 'completeness_check': {}, + 'overall_quality': 'unknown' + } + + try: + # 1. 完整性检查 + completeness = self._check_completeness(episode_dir) + result['completeness_check'] = completeness + + if not completeness['all_files_exist']: + result['overall_quality'] = 'bad' + self.results['completeness_issues'].append(result) + return result + + # 2. 长度检查 + length_check = self._check_trajectory_length(episode_dir) + result['length_check'] = length_check + + # 3. 分布检查 + distribution_check = self._check_data_distribution(episode_dir) + result['distribution_check'] = distribution_check + + # 4. 异常检测 + anomaly_check = self._check_anomalies(episode_dir) + result['anomaly_check'] = anomaly_check + + # 5. 综合质量评估 + result['overall_quality'] = self._assess_overall_quality( + length_check, distribution_check, anomaly_check + ) + + # 记录问题episode + if result['overall_quality'] == 'bad': + if length_check.get('is_too_short') or length_check.get('is_too_long'): + self.results['length_issues'].append(result) + if distribution_check.get('has_issues'): + self.results['distribution_issues'].append(result) + if anomaly_check.get('has_anomalies'): + self.results['anomaly_issues'].append(result) + + except Exception as e: + print(f" ❌ 检查时出错: {e}") + result['error'] = str(e) + result['overall_quality'] = 'error' + + return result + + def _check_completeness(self, episode_dir: Path) -> Dict: + """检查数据完整性""" + required_files = [ + 'pose.pkl', 'hand_action.pkl', 'proprioception.pkl', 'fsr.pkl', 'timestamps.pkl' + ] + + missing_files = [] + corrupted_files = [] + file_sizes = {} + + # 检查核心数据文件 + for filename in required_files: + filepath = episode_dir / filename + if not filepath.exists(): + missing_files.append(filename) + else: + try: + with open(filepath, 'rb') as f: + data = pickle.load(f) + file_sizes[filename] = len(data) if hasattr(data, '__len__') else 'N/A' + except Exception as e: + corrupted_files.append(f"{filename}: {str(e)}") + + # 检查相机数据 + camera_dirs = list(episode_dir.glob("camera_*")) + camera_status = {} + + for cam_dir in camera_dirs: + cam_name = cam_dir.name + rgb_file = cam_dir / "rgb.pkl" + + if rgb_file.exists(): + try: + with open(rgb_file, 'rb') as f: + rgb_data = pickle.load(f) + camera_status[cam_name] = { + 'frames': len(rgb_data), + 'shape': rgb_data[0].shape if len(rgb_data) > 0 else None, + 'size_mb': rgb_data.nbytes / (1024 * 1024) if hasattr(rgb_data, 'nbytes') else 'N/A' + } + except Exception as e: + camera_status[cam_name] = {'error': str(e)} + else: + camera_status[cam_name] = {'error': 'rgb.pkl not found'} + + all_files_exist = len(missing_files) == 0 and len(corrupted_files) == 0 + + return { + 'all_files_exist': all_files_exist, + 'missing_files': missing_files, + 'corrupted_files': corrupted_files, + 'file_sizes': file_sizes, + 'camera_status': camera_status, + 'camera_count': len(camera_dirs) + } + + def _check_trajectory_length(self, episode_dir: Path) -> Dict: + """检查轨迹长度""" + try: + with open(episode_dir / 'pose.pkl', 'rb') as f: + pose_data = pickle.load(f) + + length = len(pose_data) + is_too_short = length < self.min_traj_len + is_too_long = length > self.max_traj_len + + return { + 'length': length, + 'is_too_short': is_too_short, + 'is_too_long': is_too_long, + 'is_normal': not (is_too_short or is_too_long) + } + except Exception as e: + return {'error': str(e)} + + def _check_data_distribution(self, episode_dir: Path) -> Dict: + """检查数据分布是否合理""" + issues = [] + statistics = {} + + try: + # 检查TCP位置 + with open(episode_dir / 'pose.pkl', 'rb') as f: + pose_data = pickle.load(f) + + positions = pose_data[:, :3] # x, y, z + statistics['tcp_position'] = { + 'min': positions.min(axis=0), + 'max': positions.max(axis=0), + 'mean': positions.mean(axis=0), + 'std': positions.std(axis=0) + } + + # 检查是否超出工作空间 + for i, axis in enumerate(['x', 'y', 'z']): + min_val, max_val = positions[:, i].min(), positions[:, i].max() + workspace_min, workspace_max = self.workspace_bounds[axis] + + if min_val < workspace_min or max_val > workspace_max: + issues.append(f"TCP {axis}轴超出工作空间: [{min_val:.3f}, {max_val:.3f}] vs [{workspace_min}, {workspace_max}]") + + # 检查关节角度 + with open(episode_dir / 'proprioception.pkl', 'rb') as f: + proprioception_data = pickle.load(f) + + joint_positions = proprioception_data[:, :7] # 前7个是关节位置 + joint_velocities = proprioception_data[:, 7:14] # 后7个是关节速度 + + statistics['joint_positions'] = { + 'min': joint_positions.min(axis=0), + 'max': joint_positions.max(axis=0), + 'range': joint_positions.max(axis=0) - joint_positions.min(axis=0) + } + + # 检查关节限位 (Franka的典型关节限位) + joint_limits = [ + (-2.9, 2.9), (-1.8, 1.8), (-2.9, 2.9), (-3.1, 0.0), + (-2.9, 2.9), (-0.0, 3.8), (-2.9, 2.9) + ] + + for i, (min_limit, max_limit) in enumerate(joint_limits): + joint_range = joint_positions[:, i] + if joint_range.min() < min_limit or joint_range.max() > max_limit: + issues.append(f"关节{i+1}超出限位: [{joint_range.min():.3f}, {joint_range.max():.3f}] vs [{min_limit}, {max_limit}]") + + # 检查触觉数据 + with open(episode_dir / 'fsr.pkl', 'rb') as f: + fsr_data = pickle.load(f) + + # 处理不同的FSR数据形状 + if fsr_data.ndim == 3: # (frames, sensors, values) + # 展平最后两个维度进行统计 + fsr_flat = fsr_data.reshape(fsr_data.shape[0], -1) + else: # (frames, values) + fsr_flat = fsr_data + + statistics['fsr'] = { + 'shape': fsr_data.shape, + 'min': fsr_flat.min(axis=0), + 'max': fsr_flat.max(axis=0), + 'mean': fsr_flat.mean(axis=0) + } + + # 检查FSR数据合理性 (允许小幅负值,可能是传感器偏移) + negative_ratio = (fsr_flat < 0).sum() / fsr_flat.size + extreme_negative = (fsr_flat < -10).any() # 检查是否有极端负值 + + if negative_ratio > 0.5: # 超过50%是负值才报告问题 + issues.append(f"FSR触觉数据负值比例过高: {negative_ratio*100:.1f}%") + elif extreme_negative: + issues.append("FSR触觉数据存在极端负值 (<-10)") + + return { + 'has_issues': len(issues) > 0, + 'issues': issues, + 'statistics': statistics + } + + except Exception as e: + return {'error': str(e)} + + def _check_anomalies(self, episode_dir: Path) -> Dict: + """检查异常模式""" + anomalies = [] + + try: + # 检查静止轨迹 + with open(episode_dir / 'pose.pkl', 'rb') as f: + pose_data = pickle.load(f) + + positions = pose_data[:, :3] + position_changes = np.diff(positions, axis=0) + movement_magnitude = np.linalg.norm(position_changes, axis=1) + + # 如果90%以上的时间移动幅度小于1mm,认为是静止轨迹 + static_threshold = 0.001 # 1mm + static_ratio = np.sum(movement_magnitude < static_threshold) / len(movement_magnitude) + + if static_ratio > 0.9: + anomalies.append(f"疑似静止轨迹: {static_ratio*100:.1f}%的时间移动<1mm") + + # 检查异常跳跃 + max_movement = movement_magnitude.max() + mean_movement = movement_magnitude.mean() + + if max_movement > mean_movement * 10: # 如果最大移动超过平均移动的10倍 + anomalies.append(f"检测到异常跳跃: 最大移动{max_movement*1000:.1f}mm, 平均{mean_movement*1000:.1f}mm") + + # 检查图像质量(如果有相机数据) + camera_dirs = list(episode_dir.glob("camera_*")) + for cam_dir in camera_dirs: + rgb_file = cam_dir / "rgb.pkl" + if rgb_file.exists(): + try: + with open(rgb_file, 'rb') as f: + rgb_data = pickle.load(f) + + if len(rgb_data) > 0: + # 检查前几帧图像 + for i in range(min(5, len(rgb_data))): + img = rgb_data[i] + + # 检查全黑图像 + if img.max() < 10: + anomalies.append(f"{cam_dir.name}: 检测到全黑图像 (帧{i})") + break + + # 检查全白图像 + if img.min() > 245: + anomalies.append(f"{cam_dir.name}: 检测到全白图像 (帧{i})") + break + + # 检查图像标准差过低 (可能表示图像质量问题) + if img.std() < 5: + anomalies.append(f"{cam_dir.name}: 图像对比度过低 (帧{i}, std={img.std():.1f})") + break + + except Exception as e: + anomalies.append(f"{cam_dir.name}: 图像数据读取错误 - {str(e)}") + + return { + 'has_anomalies': len(anomalies) > 0, + 'anomalies': anomalies, + 'movement_stats': { + 'max_movement_mm': max_movement * 1000, + 'mean_movement_mm': mean_movement * 1000, + 'static_ratio': static_ratio + } + } + + except Exception as e: + return {'error': str(e)} + + def _assess_overall_quality(self, length_check: Dict, distribution_check: Dict, anomaly_check: Dict) -> str: + """评估整体数据质量""" + issues = 0 + + # 长度问题 + if length_check.get('is_too_short') or length_check.get('is_too_long'): + issues += 2 # 长度问题权重较高 + + # 分布问题 + if distribution_check.get('has_issues'): + issues += 1 + + # 异常检测 + if anomaly_check.get('has_anomalies'): + anomaly_count = len(anomaly_check.get('anomalies', [])) + if anomaly_count >= 3: + issues += 2 + elif anomaly_count >= 1: + issues += 1 + + # 质量评级 + if issues == 0: + return 'good' + elif issues <= 2: + return 'warning' + else: + return 'bad' + + def _compute_overall_statistics(self): + """计算整体统计信息""" + if not self.results['episodes']: + return + + # 统计质量分布 + quality_counts = defaultdict(int) + lengths = [] + + for episode in self.results['episodes']: + quality = episode.get('overall_quality', 'unknown') + quality_counts[quality] += 1 + + length_info = episode.get('length_check', {}) + if 'length' in length_info: + lengths.append(length_info['length']) + + # 长度统计 + if lengths: + lengths = np.array(lengths) + length_stats = { + 'count': len(lengths), + 'min': int(lengths.min()), + 'max': int(lengths.max()), + 'mean': float(lengths.mean()), + 'median': float(np.median(lengths)), + 'std': float(lengths.std()), + 'q25': float(np.percentile(lengths, 25)), + 'q75': float(np.percentile(lengths, 75)) + } + else: + length_stats = {} + + self.results['statistics'] = { + 'total_episodes': len(self.results['episodes']), + 'quality_distribution': dict(quality_counts), + 'length_statistics': length_stats, + 'issue_summary': { + 'length_issues': len(self.results['length_issues']), + 'distribution_issues': len(self.results['distribution_issues']), + 'anomaly_issues': len(self.results['anomaly_issues']), + 'completeness_issues': len(self.results['completeness_issues']) + } + } + + def _generate_report(self): + """生成检查报告""" + stats = self.results['statistics'] + + print("\n" + "=" * 80) + print("📊 数据质量检查报告") + print("=" * 80) + + # 总体统计 + print(f"\n📈 总体统计:") + print(f" 总episode数: {stats['total_episodes']}") + + if stats['length_statistics']: + ls = stats['length_statistics'] + print(f" 轨迹长度: 平均 {ls['mean']:.1f} 帧 (范围: {ls['min']}-{ls['max']})") + print(f" 长度分布: Q25={ls['q25']:.0f}, 中位数={ls['median']:.0f}, Q75={ls['q75']:.0f}") + + # 质量分布 + print(f"\n🎯 质量分布:") + quality_dist = stats['quality_distribution'] + for quality, count in quality_dist.items(): + percentage = (count / stats['total_episodes']) * 100 + emoji = {'good': '✅', 'warning': '⚠️', 'bad': '❌', 'error': '💥', 'unknown': '❓'}.get(quality, '❓') + print(f" {emoji} {quality.capitalize()}: {count} episodes ({percentage:.1f}%)") + + # 问题总结 + print(f"\n⚠️ 问题总结:") + issue_summary = stats['issue_summary'] + for issue_type, count in issue_summary.items(): + if count > 0: + print(f" - {issue_type.replace('_', ' ').title()}: {count} episodes") + + # 详细问题列表 + if self.results['length_issues']: + print(f"\n📏 长度异常的episodes:") + for episode in self.results['length_issues']: + length_info = episode.get('length_check', {}) + length = length_info.get('length', 'N/A') + if length_info.get('is_too_short'): + print(f" ❌ {episode['name']}: 过短 ({length} < {self.min_traj_len})") + elif length_info.get('is_too_long'): + print(f" ❌ {episode['name']}: 过长 ({length} > {self.max_traj_len})") + + if self.results['anomaly_issues']: + print(f"\n🚨 异常检测结果:") + for episode in self.results['anomaly_issues']: + anomaly_info = episode.get('anomaly_check', {}) + if 'anomalies' in anomaly_info: + print(f" ⚠️ {episode['name']}:") + for anomaly in anomaly_info['anomalies']: + print(f" - {anomaly}") + + if self.results['completeness_issues']: + print(f"\n📋 完整性问题:") + for episode in self.results['completeness_issues']: + comp_info = episode.get('completeness_check', {}) + print(f" ❌ {episode['name']}:") + if comp_info.get('missing_files'): + print(f" - 缺失文件: {', '.join(comp_info['missing_files'])}") + if comp_info.get('corrupted_files'): + print(f" - 损坏文件: {', '.join(comp_info['corrupted_files'])}") + + def visualize_statistics(self, save_plots: bool = True): + """生成可视化统计图表""" + if not self.results['episodes']: + print("❌ 没有数据可以可视化") + return + + # 设置图表样式 + plt.style.use('default') + fig, axes = plt.subplots(2, 2, figsize=(15, 12)) + fig.suptitle('数据质量统计报告', fontsize=16, fontweight='bold') + + # 1. 轨迹长度分布 + lengths = [] + quality_labels = [] + + for episode in self.results['episodes']: + length_info = episode.get('length_check', {}) + if 'length' in length_info: + lengths.append(length_info['length']) + quality_labels.append(episode.get('overall_quality', 'unknown')) + + if lengths: + ax1 = axes[0, 0] + ax1.hist(lengths, bins=20, alpha=0.7, color='skyblue', edgecolor='black') + ax1.axvline(self.min_traj_len, color='red', linestyle='--', label=f'最小长度阈值 ({self.min_traj_len})') + ax1.axvline(self.max_traj_len, color='red', linestyle='--', label=f'最大长度阈值 ({self.max_traj_len})') + ax1.set_xlabel('轨迹长度 (帧数)') + ax1.set_ylabel('Episode数量') + ax1.set_title('轨迹长度分布') + ax1.legend() + ax1.grid(True, alpha=0.3) + + # 2. 质量分布饼图 + ax2 = axes[0, 1] + quality_counts = defaultdict(int) + for episode in self.results['episodes']: + quality = episode.get('overall_quality', 'unknown') + quality_counts[quality] += 1 + + if quality_counts: + colors = {'good': 'lightgreen', 'warning': 'orange', 'bad': 'lightcoral', 'error': 'red', 'unknown': 'gray'} + quality_names = list(quality_counts.keys()) + quality_values = list(quality_counts.values()) + quality_colors = [colors.get(q, 'gray') for q in quality_names] + + ax2.pie(quality_values, labels=quality_names, colors=quality_colors, autopct='%1.1f%%') + ax2.set_title('数据质量分布') + + # 3. 问题类型统计 + ax3 = axes[1, 0] + issue_types = ['长度问题', '分布问题', '异常检测', '完整性问题'] + issue_counts = [ + len(self.results['length_issues']), + len(self.results['distribution_issues']), + len(self.results['anomaly_issues']), + len(self.results['completeness_issues']) + ] + + bars = ax3.bar(issue_types, issue_counts, color=['red', 'orange', 'yellow', 'purple'], alpha=0.7) + ax3.set_ylabel('Episode数量') + ax3.set_title('问题类型统计') + ax3.set_xticklabels(issue_types, rotation=45) + + # 在柱状图上添加数值 + for bar, count in zip(bars, issue_counts): + if count > 0: + ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1, + str(count), ha='center', va='bottom') + + # 4. 轨迹长度箱线图(按质量分类) + ax4 = axes[1, 1] + if lengths and quality_labels: + # 按质量分组长度数据 + quality_lengths = defaultdict(list) + for length, quality in zip(lengths, quality_labels): + quality_lengths[quality].append(length) + + if quality_lengths: + qualities = list(quality_lengths.keys()) + length_groups = [quality_lengths[q] for q in qualities] + + box_plot = ax4.boxplot(length_groups, labels=qualities, patch_artist=True) + + # 设置颜色 + quality_colors_box = {'good': 'lightgreen', 'warning': 'orange', 'bad': 'lightcoral', 'error': 'red', 'unknown': 'gray'} + for patch, quality in zip(box_plot['boxes'], qualities): + patch.set_facecolor(quality_colors_box.get(quality, 'gray')) + + ax4.set_ylabel('轨迹长度 (帧数)') + ax4.set_title('不同质量级别的轨迹长度分布') + ax4.grid(True, alpha=0.3) + + plt.tight_layout() + + if save_plots: + plot_path = self.data_dir / 'data_quality_report.png' + plt.savefig(plot_path, dpi=300, bbox_inches='tight') + print(f"📊 统计图表已保存到: {plot_path}") + + plt.show() + + def export_detailed_report(self, output_file: str = None): + """导出详细的检查报告到文件""" + if output_file is None: + output_file = self.data_dir / 'data_quality_detailed_report.txt' + + with open(output_file, 'w', encoding='utf-8') as f: + f.write("XHand多模态数据质量检查详细报告\n") + f.write("=" * 80 + "\n\n") + + # 检查参数 + f.write("检查参数:\n") + f.write(f" 数据目录: {self.data_dir}\n") + f.write(f" 最小轨迹长度: {self.min_traj_len}\n") + f.write(f" 最大轨迹长度: {self.max_traj_len}\n") + f.write(f" 工作空间边界: {self.workspace_bounds}\n\n") + + # 总体统计 + stats = self.results['statistics'] + f.write("总体统计:\n") + f.write(f" 总episode数: {stats['total_episodes']}\n") + + if stats['length_statistics']: + ls = stats['length_statistics'] + f.write(f" 轨迹长度统计:\n") + f.write(f" - 最小: {ls['min']} 帧\n") + f.write(f" - 最大: {ls['max']} 帧\n") + f.write(f" - 平均: {ls['mean']:.1f} 帧\n") + f.write(f" - 中位数: {ls['median']:.1f} 帧\n") + f.write(f" - 标准差: {ls['std']:.1f}\n\n") + + # 每个episode的详细信息 + f.write("详细检查结果:\n") + f.write("-" * 80 + "\n") + + for episode in self.results['episodes']: + f.write(f"\nEpisode: {episode['name']}\n") + f.write(f"质量评级: {episode['overall_quality']}\n") + + # 长度检查 + length_check = episode.get('length_check', {}) + if 'length' in length_check: + f.write(f"轨迹长度: {length_check['length']} 帧\n") + if length_check.get('is_too_short'): + f.write(" ⚠️ 轨迹过短\n") + elif length_check.get('is_too_long'): + f.write(" ⚠️ 轨迹过长\n") + + # 异常检测 + anomaly_check = episode.get('anomaly_check', {}) + if anomaly_check.get('has_anomalies'): + f.write("异常检测:\n") + for anomaly in anomaly_check.get('anomalies', []): + f.write(f" - {anomaly}\n") + + # 分布检查 + distribution_check = episode.get('distribution_check', {}) + if distribution_check.get('has_issues'): + f.write("分布问题:\n") + for issue in distribution_check.get('issues', []): + f.write(f" - {issue}\n") + + # 完整性检查 + completeness_check = episode.get('completeness_check', {}) + if not completeness_check.get('all_files_exist', True): + f.write("完整性问题:\n") + if completeness_check.get('missing_files'): + f.write(f" 缺失文件: {', '.join(completeness_check['missing_files'])}\n") + if completeness_check.get('corrupted_files'): + f.write(f" 损坏文件: {', '.join(completeness_check['corrupted_files'])}\n") + + f.write("-" * 40 + "\n") + + print(f"📝 详细报告已导出到: {output_file}") + + +def main(): + """主函数""" + parser = argparse.ArgumentParser(description="XHand多模态数据质量检查工具") + parser.add_argument('data_dir', type=str, help='数据目录路径') + parser.add_argument('--min-length', type=int, default=50, help='最小轨迹长度阈值 (默认: 50)') + parser.add_argument('--max-length', type=int, default=2000, help='最大轨迹长度阈值 (默认: 2000)') + parser.add_argument('--no-plots', action='store_true', help='不生成可视化图表') + parser.add_argument('--no-export', action='store_true', help='不导出详细报告') + parser.add_argument('--workspace-x', nargs=2, type=float, default=[0.2, 0.8], + help='X轴工作空间边界 (默认: 0.2 0.8)') + parser.add_argument('--workspace-y', nargs=2, type=float, default=[-0.4, 0.4], + help='Y轴工作空间边界 (默认: -0.4 0.4)') + parser.add_argument('--workspace-z', nargs=2, type=float, default=[0.0, 0.6], + help='Z轴工作空间边界 (默认: 0.0 0.6)') + + args = parser.parse_args() + + # 检查数据目录是否存在 + if not os.path.exists(args.data_dir): + print(f"❌ 数据目录不存在: {args.data_dir}") + sys.exit(1) + + # 设置工作空间边界 + workspace_bounds = { + 'x': args.workspace_x, + 'y': args.workspace_y, + 'z': args.workspace_z + } + + # 创建检查器并运行检查 + checker = DataQualityChecker( + data_dir=args.data_dir, + min_trajectory_length=args.min_length, + max_trajectory_length=args.max_length, + workspace_bounds=workspace_bounds + ) + + # 执行检查 + results = checker.check_all_episodes() + + # 生成可视化报告 + if not args.no_plots: + try: + checker.visualize_statistics(save_plots=True) + except Exception as e: + print(f"⚠️ 生成可视化图表时出错: {e}") + + # 导出详细报告 + if not args.no_export: + try: + checker.export_detailed_report() + except Exception as e: + print(f"⚠️ 导出详细报告时出错: {e}") + + # 返回状态码 + stats = results['statistics'] + total_issues = sum(stats['issue_summary'].values()) + + if total_issues == 0: + print("\n✅ 所有数据质量检查通过!") + sys.exit(0) + else: + print(f"\n⚠️ 发现 {total_issues} 个问题,请查看详细报告") + sys.exit(1) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/real_script/data_collection/data_check/doc/CONVERT_USAGE.md b/real_script/data_collection/data_check/doc/CONVERT_USAGE.md new file mode 100644 index 0000000..d10b05d --- /dev/null +++ b/real_script/data_collection/data_check/doc/CONVERT_USAGE.md @@ -0,0 +1,101 @@ +# Pickle to Zarr Data Conversion Script + +## Overview +This script converts collected data from pickle format to zarr format that is compatible with DexUMI training pipeline. + +## Usage + +### Basic conversion +```bash +python convert_pickle_to_zarr.py --input_dir collected_data --output_path dataset.zarr +``` + +### With verification +```bash +python convert_pickle_to_zarr.py --input_dir collected_data --output_path dataset.zarr --verify +``` + +### Convert specific episodes +```bash +python convert_pickle_to_zarr.py --input_dir collected_data --output_path dataset.zarr --episodes 0 1 2 +``` + +### Overwrite existing dataset +```bash +python convert_pickle_to_zarr.py --input_dir collected_data --output_path dataset.zarr --overwrite +``` + +### Use different compression +```bash +python convert_pickle_to_zarr.py --input_dir collected_data --output_path dataset.zarr --compression zstd +``` + +## Arguments + +- `--input_dir`: Directory containing pickle episodes (default: `collected_data`) +- `--output_path`: Output zarr file path (default: `dataset.zarr`) +- `--episodes`: Specific episode IDs to convert (optional, default: all) +- `--compression`: Compression algorithm (choices: blosclz, zstd, lz4, default: blosclz) +- `--overwrite`: Overwrite existing zarr file +- `--verify`: Verify the converted dataset after conversion + +## Data Format Conversion + +### Input (Pickle) +``` +collected_data/ +└── episode_0/ + ├── camera_0/ + │ ├── rgb.pkl # [T, H, W, 3] + │ └── receive_time.pkl # [T] + ├── pose.pkl # [T, 7] (xyz + quaternion wxyz) + ├── hand_action.pkl # [T, 12] + ├── proprioception.pkl # [T, 14] + ├── fsr.pkl # [T, 5, 3] + └── timestamps.pkl # dict with various timestamps +``` + +### Output (Zarr) +``` +dataset.zarr/ +└── episode_0/ + ├── pose # [T, 6] (xyz + euler angles xyz) + ├── hand_action # [T, 12] + ├── proprioception # [T, 14] + ├── fsr # [T, 3] (averaged across fingers) + └── camera_0/ + └── rgb # [T, H, W, 3] +``` + +## Key Transformations + +1. **Pose**: Quaternion (w,x,y,z) → Euler angles (roll, pitch, yaw) +2. **FSR**: Averaged across 5 fingers: [T, 5, 3] → [T, 3] +3. **Camera**: RGB images resized to 256x256 if needed +4. **Data types**: + - Pose, hand_action, proprioception, fsr: float32 + - RGB images: uint8 + +## Integration with DexUMI + +The converted zarr dataset can be directly used with DexUMI training: + +```python +from dexumi.diffusion_policy.dataloader.replay_buffer import DexUMIReplayBuffer + +# Load the converted dataset +replay_buffer = DexUMIReplayBuffer( + data_path=["dataset.zarr"], + load_camera_ids=[0], # Load camera_0 + camera_resize_shape=[256, 256], + enable_fsr=True, + fsr_binary_cutoff=[0.5, 0.5, 0.5] # Optional FSR thresholding +) +``` + +## Troubleshooting + +1. **Missing data**: Check that all pickle files exist in the episode directories +2. **Dimension mismatch**: Verify that all frames in an episode have consistent dimensions +3. **Memory issues**: For large datasets, process episodes in batches +4. **Compression errors**: Make sure to use supported compressors (blosclz, zstd, lz4) \ No newline at end of file diff --git a/real_script/data_collection/data_check/doc/DATA_QUALITY_CHECK_USAGE.md b/real_script/data_collection/data_check/doc/DATA_QUALITY_CHECK_USAGE.md new file mode 100644 index 0000000..dafe26e --- /dev/null +++ b/real_script/data_collection/data_check/doc/DATA_QUALITY_CHECK_USAGE.md @@ -0,0 +1,236 @@ +# XHand多模态数据质量检查工具使用说明 + +## 概述 + +`comprehensive_data_quality_check.py` 是一个全面的数据质量检查工具,用于验证XHand多模态数据采集的质量。它可以检测多种数据问题,包括轨迹长度异常、数据分布问题、异常模式和数据完整性问题。 + +## 主要功能 + +### 1. 轨迹长度检查 +- 检测过短的轨迹(可能是意外保存的数据) +- 检测过长的轨迹(可能是忘记停止采集) +- 统计轨迹长度分布 + +### 2. 数据分布分析 +- TCP位置是否在合理的工作空间内 +- 关节角度是否超出机械臂限位 +- 触觉数据是否合理(非负值) + +### 3. 异常检测 +- 静止轨迹检测(机械臂基本不动) +- 异常跳跃检测(位置突然大幅变化) +- 图像质量问题(全黑、全白、对比度过低) + +### 4. 数据完整性验证 +- 必要文件是否存在 +- pickle文件是否损坏 +- 不同数据源的帧数是否匹配 + +### 5. 可视化报告 +- 轨迹长度分布直方图 +- 数据质量分布饼图 +- 问题类型统计 +- 分质量级别的箱线图 + +## 安装依赖 + +```bash +pip install numpy matplotlib seaborn opencv-python +``` + +## 基本用法 + +### 快速检查 +```bash +# 检查数据目录(使用默认参数) +python comprehensive_data_quality_check.py /path/to/your/XhandData_Multimodal + +# 示例:检查当前目录下的数据 +python comprehensive_data_quality_check.py ./XhandData_Multimodal +``` + +### 自定义参数 +```bash +# 设置自定义的轨迹长度阈值 +python comprehensive_data_quality_check.py ./XhandData_Multimodal \ + --min-length 30 \ + --max-length 1500 + +# 设置自定义工作空间边界 +python comprehensive_data_quality_check.py ./XhandData_Multimodal \ + --workspace-x 0.3 0.7 \ + --workspace-y -0.3 0.3 \ + --workspace-z 0.1 0.5 +``` + +### 禁用某些功能 +```bash +# 不生成可视化图表 +python comprehensive_data_quality_check.py ./XhandData_Multimodal --no-plots + +# 不导出详细报告文件 +python comprehensive_data_quality_check.py ./XhandData_Multimodal --no-export + +# 只进行基本检查 +python comprehensive_data_quality_check.py ./XhandData_Multimodal --no-plots --no-export +``` + +## 输出文件 + +运行检查后,工具会在数据目录下生成以下文件: + +1. **data_quality_report.png** - 可视化统计图表 +2. **data_quality_detailed_report.txt** - 详细的文本报告 + +## 报告解读 + +### 质量等级 +- **✅ Good**: 数据质量良好,无明显问题 +- **⚠️ Warning**: 有轻微问题,但可以使用 +- **❌ Bad**: 有严重问题,建议重新采集 +- **💥 Error**: 数据读取出错 +- **❓ Unknown**: 无法确定质量 + +### 常见问题类型 + +#### 长度问题 +- **过短轨迹**: 通常<50帧,可能是意外保存 +- **过长轨迹**: 通常>2000帧,可能忘记停止采集 + +#### 分布问题 +- **超出工作空间**: TCP位置超出机械臂安全工作范围 +- **关节限位违反**: 关节角度超出硬件限制 +- **触觉数据异常**: FSR传感器数据异常(如负值) + +#### 异常检测 +- **静止轨迹**: 90%以上时间移动<1mm +- **异常跳跃**: 位置突变超过平均移动的10倍 +- **图像质量**: 全黑、全白或对比度过低的图像 + +#### 完整性问题 +- **缺失文件**: 缺少必要的数据文件 +- **损坏文件**: pickle文件无法正常读取 +- **帧数不匹配**: 不同数据源的帧数不一致 + +## 使用示例 + +### 示例1:检查新采集的数据 +```bash +# 检查刚采集的数据,使用较宽松的参数 +python comprehensive_data_quality_check.py ./XhandData_Multimodal \ + --min-length 20 \ + --max-length 3000 +``` + +### 示例2:严格质量控制 +```bash +# 用于训练前的严格质量检查 +python comprehensive_data_quality_check.py ./XhandData_Multimodal \ + --min-length 100 \ + --max-length 1000 \ + --workspace-x 0.4 0.7 \ + --workspace-y -0.2 0.2 +``` + +### 示例3:快速批量检查 +```bash +# 快速检查多个数据集(脚本化使用) +for dataset in dataset1 dataset2 dataset3; do + echo "检查 $dataset..." + python comprehensive_data_quality_check.py ./data/$dataset --no-plots --no-export +done +``` + +## 质量控制建议 + +### 数据采集时 +1. 确保轨迹长度适中(50-1000帧为佳) +2. 保持机械臂在安全工作空间内操作 +3. 避免过度静止或剧烈运动 +4. 检查相机视野和光照条件 + +### 数据后处理 +1. 删除质量为"Bad"的episodes +2. 对"Warning"级别的数据进行人工审查 +3. 确保训练集中的数据分布均匀 +4. 保留原始数据的备份 + +### 训练前验证 +1. 运行完整的质量检查 +2. 确保至少80%的数据质量为"Good" +3. 检查数据集的整体分布 +4. 验证与训练需求的匹配性 + +## 参数说明 + +| 参数 | 默认值 | 说明 | +|------|--------|------| +| `data_dir` | - | 数据目录路径(必需) | +| `--min-length` | 50 | 最小轨迹长度阈值 | +| `--max-length` | 2000 | 最大轨迹长度阈值 | +| `--workspace-x` | [0.2, 0.8] | X轴工作空间边界 | +| `--workspace-y` | [-0.4, 0.4] | Y轴工作空间边界 | +| `--workspace-z` | [0.0, 0.6] | Z轴工作空间边界 | +| `--no-plots` | False | 禁用可视化图表生成 | +| `--no-export` | False | 禁用详细报告导出 | + +## 返回状态码 + +- **0**: 所有数据质量检查通过 +- **1**: 发现质量问题,需要人工审查 + +## 故障排除 + +### 常见错误 + +1. **ImportError**: 缺少依赖库 + ```bash + pip install numpy matplotlib seaborn opencv-python + ``` + +2. **FileNotFoundError**: 数据目录不存在 + - 检查路径是否正确 + - 确保目录包含episode_*子目录 + +3. **PermissionError**: 权限不足 + - 检查文件和目录的读写权限 + - 使用sudo或更改文件所有权 + +4. **MemoryError**: 内存不足 + - 处理大数据集时可能发生 + - 考虑分批处理或增加系统内存 + +### 性能优化 + +1. **大数据集处理**: + - 使用`--no-plots`跳过可视化 + - 分批处理多个子目录 + +2. **网络存储**: + - 将数据复制到本地SSD + - 使用快速网络连接 + +## 与其他工具的集成 + +### 与现有检查脚本的关系 +- `check_data_consistency.py`: 专注于帧数一致性检查 +- `validate_dexumi_data.py`: 专注于zarr格式验证 +- `comprehensive_data_quality_check.py`: 提供全面的质量分析 + +### 工作流程建议 +1. 数据采集后立即运行基本检查 +2. 使用comprehensive工具进行深度分析 +3. 转换为zarr格式前进行最终验证 +4. 训练前再次确认数据质量 + +## 技术支持 + +如有问题或建议,请: +1. 检查错误消息和日志 +2. 确认数据格式与XhandMultimodalCollection.py兼容 +3. 查看生成的详细报告文件 +4. 参考代码注释了解检查逻辑 + +--- + +*最后更新: 2024-09-09* \ No newline at end of file diff --git a/real_script/data_collection/data_check/export_episode_sequence.py b/real_script/data_collection/data_check/export_episode_sequence.py new file mode 100644 index 0000000..64b79a1 --- /dev/null +++ b/real_script/data_collection/data_check/export_episode_sequence.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python3 +""" +Export all images from a specific episode in sequential order for inspection. +This script saves every frame from an episode as individual image files. +""" + +import os +import pickle +import numpy as np +from PIL import Image +from pathlib import Path +import argparse +from tqdm import tqdm + + +def load_pickle_data(file_path): + """Load data from pickle file.""" + with open(file_path, 'rb') as f: + return pickle.load(f) + + +def export_episode_sequence(episode_dir, output_dir, image_format='jpg', quality=95): + """ + Export all images from an episode as individual files. + + Args: + episode_dir: Path to episode directory + output_dir: Directory to save individual image files + image_format: Output format ('jpg', 'png') + quality: JPEG quality (1-100, only for jpg) + """ + episode_path = Path(episode_dir) + output_path = Path(output_dir) + + # Load RGB images + rgb_file = episode_path / "camera_0" / "rgb.pkl" + if not rgb_file.exists(): + print(f"Error: RGB file not found at {rgb_file}") + return False + + # Load timestamps for reference + timestamps_file = episode_path / "timestamps.pkl" + timestamps_data = None + if timestamps_file.exists(): + timestamps_data = load_pickle_data(timestamps_file) + + print(f"Loading images from {rgb_file}...") + images = load_pickle_data(rgb_file) + + if images is None or len(images) == 0: + print("No images found in the episode") + return False + + # Create output directory + episode_name = episode_path.name + episode_output_dir = output_path / episode_name + episode_output_dir.mkdir(parents=True, exist_ok=True) + + print(f"Exporting {len(images)} images to {episode_output_dir}") + print(f"Image format: {image_format.upper()}, Quality: {quality if image_format.lower() == 'jpg' else 'N/A'}") + + # Export each image + for i, img in enumerate(tqdm(images, desc="Exporting frames")): + try: + # Handle different image formats + if isinstance(img, np.ndarray): + # Ensure uint8 format + if img.dtype != np.uint8: + # Normalize to uint8 if needed + img = ((img - img.min()) / (img.max() - img.min()) * 255).astype(np.uint8) + + # Convert BGR to RGB for correct color display + # RealSense camera outputs BGR format, but PIL expects RGB + if len(img.shape) == 3 and img.shape[2] == 3: + img = img[..., ::-1] # BGR to RGB conversion + + # Convert numpy array to PIL Image + pil_img = Image.fromarray(img) + + # Generate filename with frame number (zero-padded) + filename = f"frame_{i:06d}.{image_format.lower()}" + filepath = episode_output_dir / filename + + # Save image + if image_format.lower() == 'jpg': + pil_img.save(filepath, 'JPEG', quality=quality) + elif image_format.lower() == 'png': + pil_img.save(filepath, 'PNG') + else: + print(f"Unsupported format: {image_format}") + return False + + else: + print(f"Unexpected image format at index {i}: {type(img)}") + continue + + except Exception as e: + print(f"Error processing frame {i}: {e}") + continue + + # Save metadata file + metadata_file = episode_output_dir / "metadata.txt" + with open(metadata_file, 'w') as f: + f.write(f"Episode: {episode_name}\n") + f.write(f"Total frames: {len(images)}\n") + f.write(f"Image format: {image_format.upper()}\n") + f.write(f"Image shape: {images[0].shape if len(images) > 0 else 'Unknown'}\n") + f.write(f"Export format: BGR->RGB converted\n") + + if timestamps_data and 'main_timestamps' in timestamps_data: + main_ts = timestamps_data['main_timestamps'] + if len(main_ts) > 1: + duration = main_ts[-1] - main_ts[0] + fps = len(main_ts) / duration + f.write(f"Duration: {duration:.2f} seconds\n") + f.write(f"Average FPS: {fps:.2f}\n") + + print(f"✓ Successfully exported {len(images)} frames to {episode_output_dir}") + print(f"✓ Metadata saved to {metadata_file}") + + return True + + +def create_video_from_images(image_dir, output_video_path, fps=20): + """ + Create a video from exported images using ffmpeg. + + Args: + image_dir: Directory containing sequential images + output_video_path: Path for output video file + fps: Frames per second for output video + """ + try: + import subprocess + + # Check if ffmpeg is available + subprocess.run(['ffmpeg', '-version'], capture_output=True, check=True) + + # Create video from images + cmd = [ + 'ffmpeg', '-y', # -y to overwrite output file + '-framerate', str(fps), + '-i', str(image_dir / 'frame_%06d.jpg'), + '-c:v', 'libx264', + '-pix_fmt', 'yuv420p', + str(output_video_path) + ] + + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode == 0: + print(f"✓ Video created: {output_video_path}") + return True + else: + print(f"Error creating video: {result.stderr}") + return False + + except (ImportError, subprocess.CalledProcessError, FileNotFoundError): + print("ffmpeg not available. Skipping video creation.") + print("To create video manually, run:") + print(f"ffmpeg -framerate {fps} -i {image_dir}/frame_%06d.jpg -c:v libx264 -pix_fmt yuv420p {output_video_path}") + return False + + +def main(): + parser = argparse.ArgumentParser(description='Export all images from an episode sequence') + parser.add_argument('--data_dir', type=str, default='../../../collected_data_optimized', + help='Path to collected_data_optimized directory') + parser.add_argument('--episode', type=str, required=True, + help='Episode to export (e.g., episode_0)') + parser.add_argument('--output_dir', type=str, default='./exported_sequences', + help='Directory to save exported images') + parser.add_argument('--format', type=str, choices=['jpg', 'png'], default='jpg', + help='Output image format') + parser.add_argument('--quality', type=int, default=95, + help='JPEG quality (1-100, only for jpg format)') + parser.add_argument('--create_video', action='store_true', + help='Also create a video file from the images') + parser.add_argument('--fps', type=int, default=20, + help='FPS for video creation') + + args = parser.parse_args() + + # Setup paths + data_dir = Path(args.data_dir) + episode_dir = data_dir / args.episode + output_dir = Path(args.output_dir) + + if not episode_dir.exists(): + print(f"Error: Episode directory {episode_dir} does not exist") + return + + # Export images + success = export_episode_sequence( + episode_dir=episode_dir, + output_dir=output_dir, + image_format=args.format, + quality=args.quality + ) + + if not success: + print("Failed to export images") + return + + # Create video if requested + if args.create_video and args.format.lower() == 'jpg': + episode_output_dir = output_dir / args.episode + video_path = episode_output_dir / f"{args.episode}.mp4" + create_video_from_images(episode_output_dir, video_path, args.fps) + elif args.create_video: + print("Video creation only supported for jpg format") + + print(f"\n✓ Export complete!") + print(f"Images saved to: {output_dir / args.episode}") + if args.create_video and args.format.lower() == 'jpg': + print(f"Video saved to: {output_dir / args.episode / f'{args.episode}.mp4'}") + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/real_script/data_collection/data_check/interactive_zarr_visualizer.py b/real_script/data_collection/data_check/interactive_zarr_visualizer.py new file mode 100644 index 0000000..cd4b185 --- /dev/null +++ b/real_script/data_collection/data_check/interactive_zarr_visualizer.py @@ -0,0 +1,317 @@ +#!/usr/bin/env python3 +""" +Interactive 3D Pose + RGB Video Synchronized Player +Visualizes zarr dataset with synchronized 3D pose trajectories and RGB video playback. + +python interactive_zarr_visualizer.py -z data/xhand_dataset_aligned.zarr -e 35 + +""" + +import zarr +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.widgets import Slider, Button +import matplotlib.animation as animation +from mpl_toolkits.mplot3d import Axes3D +import cv2 +import argparse +import os +from pathlib import Path + + +class InteractiveZarrVisualizer: + def __init__(self, zarr_path, episode_id=0): + """Initialize the visualizer with zarr dataset.""" + self.zarr_path = zarr_path + self.episode_id = episode_id + self.current_frame = 0 + + # Load data + self.load_data() + + # Setup GUI + self.setup_figure() + self.setup_widgets() + + # Animation control + self.is_playing = False + self.animation = None + + def load_data(self): + """Load data from zarr file.""" + print(f"Loading episode {self.episode_id} from {self.zarr_path}") + + try: + self.root = zarr.open(self.zarr_path, mode='r') + episode_key = f'episode_{self.episode_id}' + + if episode_key not in self.root: + available_episodes = [k for k in self.root.keys() if k.startswith('episode_')] + raise ValueError(f"Episode {self.episode_id} not found. Available episodes: {available_episodes[:10]}...") + + self.episode = self.root[episode_key] + + # Load RGB data + self.rgb_data = self.episode['camera_1']['rgb'][:] + + # Load pose data (6D: likely position + orientation) + self.pose_data = self.episode['pose'][:] + + # Additional data for context + self.fsr_data = self.episode['fsr'][:] + self.hand_action = self.episode['hand_action'][:] + + self.total_frames = len(self.rgb_data) + + print(f"Loaded data:") + print(f" RGB shape: {self.rgb_data.shape}") + print(f" Pose shape: {self.pose_data.shape}") + print(f" Total frames: {self.total_frames}") + print(f" FSR (Tactile) shape: {self.fsr_data.shape}") + print(f" Hand Action shape: {self.hand_action.shape}") + + # Extract position (first 3 dims) and orientation (last 3 dims) from pose + self.positions = self.pose_data[:, :3] # XYZ position + self.orientations = self.pose_data[:, 3:6] # Orientation (euler angles or similar) + + except Exception as e: + print(f"Error loading data: {e}") + raise + + def setup_figure(self): + """Setup the main figure with subplots.""" + self.fig = plt.figure(figsize=(16, 10)) + self.fig.suptitle(f'Interactive Zarr Visualizer - Episode {self.episode_id}', fontsize=16) + + # Create subplot layout: 2x2 grid with custom sizing + gs = self.fig.add_gridspec(3, 2, height_ratios=[2, 2, 0.3], hspace=0.3, wspace=0.3) + + # RGB video display (top left) + self.ax_rgb = self.fig.add_subplot(gs[0, 0]) + self.ax_rgb.set_title('RGB Video') + self.ax_rgb.axis('off') + + # 3D pose trajectory (top right) + self.ax_3d = self.fig.add_subplot(gs[0, 1], projection='3d') + self.ax_3d.set_title('3D Pose Trajectory') + + # Additional data plots (bottom row) + self.ax_data1 = self.fig.add_subplot(gs[1, 0]) + self.ax_data1.set_title('FSR (Tactile) Data') + + self.ax_data2 = self.fig.add_subplot(gs[1, 1]) + self.ax_data2.set_title('Hand Joint Angles (First 6)') + + # Control panel (bottom) + self.ax_controls = self.fig.add_subplot(gs[2, :]) + self.ax_controls.axis('off') + + def setup_widgets(self): + """Setup interactive widgets.""" + # Time slider + slider_ax = plt.axes([0.1, 0.05, 0.6, 0.03]) + self.time_slider = Slider( + slider_ax, 'Frame', 0, self.total_frames - 1, + valinit=0, valfmt='%d', valstep=1 + ) + self.time_slider.on_changed(self.update_frame) + + # Play/Pause button + play_ax = plt.axes([0.75, 0.05, 0.08, 0.04]) + self.play_button = Button(play_ax, 'Play') + self.play_button.on_clicked(self.toggle_play) + + # Reset button + reset_ax = plt.axes([0.85, 0.05, 0.08, 0.04]) + self.reset_button = Button(reset_ax, 'Reset') + self.reset_button.on_clicked(self.reset_view) + + def update_frame(self, frame_idx=None): + """Update display for given frame.""" + if frame_idx is None: + frame_idx = int(self.time_slider.val) + else: + frame_idx = int(frame_idx) + + self.current_frame = max(0, min(frame_idx, self.total_frames - 1)) + + # Update RGB image (convert BGR to RGB if needed) + self.ax_rgb.clear() + rgb_img = self.rgb_data[self.current_frame] + # Convert BGR to RGB for proper color display + rgb_img_corrected = cv2.cvtColor(rgb_img, cv2.COLOR_BGR2RGB) + self.ax_rgb.imshow(rgb_img_corrected) + self.ax_rgb.set_title(f'RGB Frame {self.current_frame}') + self.ax_rgb.axis('off') + + # Update 3D pose trajectory + self.update_3d_pose() + + # Update additional data plots + self.update_data_plots() + + # Only update slider if value is different to avoid recursion + if frame_idx is not None and int(self.time_slider.val) != self.current_frame: + self.time_slider.set_val(self.current_frame) + + self.fig.canvas.draw_idle() + + def update_3d_pose(self): + """Update 3D pose visualization.""" + self.ax_3d.clear() + + # Plot trajectory up to current frame + if self.current_frame > 0: + trajectory = self.positions[:self.current_frame + 1] + self.ax_3d.plot(trajectory[:, 0], trajectory[:, 1], trajectory[:, 2], + 'b-', alpha=0.6, linewidth=2, label='Trajectory') + + # Plot current position as large point + current_pos = self.positions[self.current_frame] + self.ax_3d.scatter(current_pos[0], current_pos[1], current_pos[2], + c='red', s=100, label='Current Position') + + # Plot orientation as arrow (simplified) + current_orient = self.orientations[self.current_frame] + arrow_length = 0.05 + self.ax_3d.quiver(current_pos[0], current_pos[1], current_pos[2], + arrow_length * np.cos(current_orient[0]), + arrow_length * np.sin(current_orient[1]), + arrow_length * current_orient[2], + color='green', arrow_length_ratio=0.3) + + # Set axis labels and limits + self.ax_3d.set_xlabel('X') + self.ax_3d.set_ylabel('Y') + self.ax_3d.set_zlabel('Z') + + # Auto-scale axes based on trajectory + margin = 0.1 + x_range = [self.positions[:, 0].min() - margin, self.positions[:, 0].max() + margin] + y_range = [self.positions[:, 1].min() - margin, self.positions[:, 1].max() + margin] + z_range = [self.positions[:, 2].min() - margin, self.positions[:, 2].max() + margin] + + self.ax_3d.set_xlim(x_range) + self.ax_3d.set_ylim(y_range) + self.ax_3d.set_zlim(z_range) + + self.ax_3d.legend() + self.ax_3d.set_title(f'3D Pose - Frame {self.current_frame}') + + def update_data_plots(self): + """Update additional data visualization.""" + # FSR (Tactile) data only + self.ax_data1.clear() + + # Plot FSR data with different colors + colors = ['blue', 'green', 'orange', 'purple', 'brown', 'pink'] + for i in range(self.fsr_data.shape[1]): + color = colors[i % len(colors)] + self.ax_data1.plot(self.fsr_data[:, i], + label=f'FSR Sensor {i}', + color=color, + linewidth=2, + alpha=0.8) + + # Highlight current frame + self.ax_data1.axvline(x=self.current_frame, color='red', linestyle='-', alpha=0.8, linewidth=2) + self.ax_data1.set_xlabel('Frame') + self.ax_data1.set_ylabel('FSR Value') + # Place legend inside the plot area to avoid overlap + self.ax_data1.legend(loc='upper right', fontsize=9) + self.ax_data1.grid(True, alpha=0.3) + self.ax_data1.set_title('FSR (Tactile) Data') + + # Hand joint angles - only first 6 + self.ax_data2.clear() + joint_names = ['Joint 0', 'Joint 1', 'Joint 2', 'Joint 3', 'Joint 4', 'Joint 5'] + colors2 = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b'] + + # Only plot first 6 joint angles + n_joints = min(6, self.hand_action.shape[1]) + for i in range(n_joints): + self.ax_data2.plot(self.hand_action[:, i], + label=joint_names[i], + color=colors2[i], + linewidth=2, + alpha=0.8) + + self.ax_data2.axvline(x=self.current_frame, color='red', linestyle='-', alpha=0.8, linewidth=2) + self.ax_data2.set_xlabel('Frame') + self.ax_data2.set_ylabel('Joint Angle') + # Place legend inside the plot area to avoid overlap + self.ax_data2.legend(loc='upper left', fontsize=9) + self.ax_data2.grid(True, alpha=0.3) + self.ax_data2.set_title('Hand Joint Angles (First 6)') + + def toggle_play(self, event=None): + """Toggle play/pause animation.""" + if self.is_playing: + self.stop_animation() + else: + self.start_animation() + + def start_animation(self): + """Start automatic playback.""" + self.is_playing = True + self.play_button.label.set_text('Pause') + + def animate(frame): + if self.is_playing and self.current_frame < self.total_frames - 1: + self.update_frame(self.current_frame + 1) + return [] + else: + self.stop_animation() + return [] + + self.animation = animation.FuncAnimation( + self.fig, animate, interval=100, blit=False, repeat=False + ) + + def stop_animation(self): + """Stop automatic playback.""" + self.is_playing = False + self.play_button.label.set_text('Play') + if self.animation: + self.animation.event_source.stop() + + def reset_view(self, event=None): + """Reset to first frame.""" + self.stop_animation() + self.update_frame(0) + + def show(self): + """Display the interactive visualizer.""" + # Initialize with first frame + self.update_frame(0) + plt.tight_layout() + plt.show() + + +def main(): + parser = argparse.ArgumentParser(description='Interactive Zarr Dataset Visualizer') + parser.add_argument('--zarr_path', '-z', type=str, + default='data/xhand_dataset_aligned.zarr', + help='Path to zarr dataset') + parser.add_argument('--episode', '-e', type=int, default=0, + help='Episode ID to visualize') + + args = parser.parse_args() + + # Check if file exists + if not os.path.exists(args.zarr_path): + print(f"Error: Zarr file not found at {args.zarr_path}") + return + + try: + visualizer = InteractiveZarrVisualizer(args.zarr_path, args.episode) + visualizer.show() + + except Exception as e: + print(f"Error creating visualizer: {e}") + import traceback + traceback.print_exc() + + +if __name__ == "__main__": + main() diff --git a/real_script/data_collection/data_check/visualize_rgb.py b/real_script/data_collection/data_check/visualize_rgb.py new file mode 100644 index 0000000..9fce500 --- /dev/null +++ b/real_script/data_collection/data_check/visualize_rgb.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 +""" +Script to visualize rgb.pkl files from camera directories and save as videos. +Each episode contains two cameras (camera_0 and camera_1), and this script +processes all episodes in the current directory. + +usage: +in episode dir: +python visualize_rgb.py --episode episode_0 + +if you want to process all episodes, run: +python visualize_rgb.py + +""" + +import os +import pickle +import numpy as np +import cv2 +from pathlib import Path +import argparsea + +def load_rgb_data(pkl_path): + """Load RGB data from pickle file.""" + try: + with open(pkl_path, 'rb') as f: + data = pickle.load(f) + return data + except Exception as e: + print(f"Error loading {pkl_path}: {e}") + return None + +def create_video_from_rgb(rgb_data, output_path, fps=30): + """Create video from RGB numpy array.""" + if rgb_data is None: + return False + + # Get dimensions + num_frames, height, width, channels = rgb_data.shape + + # Define codec and create VideoWriter + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) + + print(f"Creating video with {num_frames} frames at {fps} FPS...") + + try: + for i in range(num_frames): + frame = rgb_data[i] + # Data is already in correct format, no color conversion needed + out.write(frame) + + # Progress indicator + if (i + 1) % 50 == 0 or i == num_frames - 1: + print(f" Progress: {i + 1}/{num_frames} frames") + + except Exception as e: + print(f"Error creating video: {e}") + return False + finally: + out.release() + + return True + +def process_episode(episode_path, fps=30): + """Process all camera directories in an episode.""" + episode_path = Path(episode_path) + + if not episode_path.exists(): + print(f"Episode path does not exist: {episode_path}") + return + + # Find all camera directories + camera_dirs = [d for d in episode_path.iterdir() + if d.is_dir() and d.name.startswith('camera_')] + + if not camera_dirs: + print(f"No camera directories found in {episode_path}") + return + + print(f"\nProcessing episode: {episode_path.name}") + print(f"Found {len(camera_dirs)} camera directories") + + for camera_dir in sorted(camera_dirs): + rgb_pkl_path = camera_dir / 'rgb.pkl' + + if not rgb_pkl_path.exists(): + print(f" rgb.pkl not found in {camera_dir}") + continue + + print(f" Processing {camera_dir.name}...") + + # Load RGB data + rgb_data = load_rgb_data(rgb_pkl_path) + if rgb_data is None: + continue + + print(f" Loaded RGB data: {rgb_data.shape} (frames, height, width, channels)") + + # Create output video path + video_filename = f"{camera_dir.name}_rgb_video.mp4" + output_video_path = camera_dir / video_filename + + # Create video + success = create_video_from_rgb(rgb_data, str(output_video_path), fps) + + if success: + print(f" ✓ Video saved: {output_video_path}") + else: + print(f" ✗ Failed to create video for {camera_dir}") + +def main(): + """Main function to process all episodes in current directory.""" + parser = argparse.ArgumentParser(description='Visualize RGB pickle files as videos') + parser.add_argument('--fps', type=int, default=30, + help='Frames per second for output videos (default: 30)') + parser.add_argument('--episode', type=str, + help='Process specific episode (e.g., episode_0). If not specified, processes all episodes.') + + args = parser.parse_args() + + current_dir = Path('.') + + if args.episode: + # Process specific episode + episode_path = current_dir / args.episode + if episode_path.exists(): + process_episode(episode_path, args.fps) + else: + print(f"Episode directory not found: {args.episode}") + else: + # Find all episode directories + episode_dirs = [d for d in current_dir.iterdir() + if d.is_dir() and d.name.startswith('episode_')] + + if not episode_dirs: + print("No episode directories found in current directory") + return + + print(f"Found {len(episode_dirs)} episode directories") + + # Process each episode + for episode_dir in sorted(episode_dirs): + process_episode(episode_dir, args.fps) + + print("\n🎬 Video visualization complete!") + +if __name__ == "__main__": + main() \ No newline at end of file From 3afae90adf86ccbef3b9eb782cc91277c6e7a46f Mon Sep 17 00:00:00 2001 From: Gray Date: Fri, 12 Sep 2025 13:27:30 +0800 Subject: [PATCH 09/10] [feat] add video record when in inference --- real_script/eval_policy/eval_xhand_franka.py | 103 ++++++++++++++++++- 1 file changed, 100 insertions(+), 3 deletions(-) diff --git a/real_script/eval_policy/eval_xhand_franka.py b/real_script/eval_policy/eval_xhand_franka.py index 03d751b..184502a 100644 --- a/real_script/eval_policy/eval_xhand_franka.py +++ b/real_script/eval_policy/eval_xhand_franka.py @@ -1,5 +1,7 @@ +import os import time from collections import deque +from datetime import datetime import click import cv2 @@ -41,6 +43,61 @@ def compute_total_force_per_finger(all_fsr_observations): return total_force +def save_video_offline(frames, timestamps, save_dir, session_start_time): + """ + Save collected frames as video file offline after inference session. + + Parameters: + frames (list): List of RGB frames (numpy arrays) + timestamps (list): List of timestamps for each frame + save_dir (str): Directory to save the video file + session_start_time (float): Start time of the session for filename generation + """ + if not frames: + print("No frames to save") + return + + # Create save directory if not exists + os.makedirs(save_dir, exist_ok=True) + + # Generate unique filename based on session start time + timestamp_str = datetime.fromtimestamp(session_start_time).strftime("%Y%m%d_%H%M%S") + filename = f"inference_session_{timestamp_str}.mp4" + filepath = os.path.join(save_dir, filename) + + # Video parameters + height, width, channels = frames[0].shape + + # Calculate FPS from actual timestamps + if len(timestamps) > 1: + duration = timestamps[-1] - timestamps[0] + fps = (len(frames) - 1) / duration if duration > 0 else 30.0 + else: + fps = 30.0 + + # Ensure reasonable FPS range + fps = max(10.0, min(fps, 60.0)) + + print(f"Saving video with {len(frames)} frames at {fps:.1f} FPS to: {filepath}") + + # Initialize VideoWriter + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + out = cv2.VideoWriter(filepath, fourcc, fps, (width, height)) + + if not out.isOpened(): + print(f"Error: Could not open VideoWriter for {filepath}") + return + + # Write frames + for frame in frames: + # obs_frame.rgb is actually BGR format (despite the name) + # No color conversion needed - save directly + out.write(frame) + + out.release() + print(f"Video saved successfully: {filepath}") + + # Fixed initial positions (consistent with data collection) # Initial robot pose from XhandMultimodalCollection.py (7D: xyz + quaternion) initial_robot_pose = np.array([ @@ -54,7 +111,8 @@ def compute_total_force_per_finger(all_fsr_observations): # Initial hand position (open position) - Updated to match open_gripper command initial_hand_pos = np.array([ - 1.516937255859375, + # 1.516937255859375, + 0.156546025276184082, 0.5177657604217529, 0.04799513891339302, 0.01787799410521984, @@ -179,11 +237,35 @@ def main( session_start_time = time.time() session_duration = 20.0 # 20 seconds - # Policy execution loop + # Initialize video recording storage + video_frames = [] + video_timestamps = [] + + # Video recording at full camera FPS (30fps) - independent of inference + import threading + video_recording_active = threading.Event() + video_recording_active.set() + + def video_recording_thread(): + while video_recording_active.is_set() and (time.time() - session_start_time < session_duration): + frame = obs_camera.get_camera_frame() + if frame is not None: + video_frames.append(frame.rgb.copy()) + video_timestamps.append(time.time()) + time.sleep(1/30) # 30 FPS for video recording + + # Start video recording in background + video_thread = threading.Thread(target=video_recording_thread) + video_thread.daemon = True + video_thread.start() + + # Policy execution loop (runs at inference_fps) while time.time() - session_start_time < session_duration: with FrameRateContext(frame_rate=inference_fps): - # Get observation from camera + # Get observation from camera for inference obs_frame = obs_camera.get_camera_frame() + if obs_frame is None: + continue obs_frame_recieved_time = obs_frame.receive_time obs_frame_rgb = obs_frame.rgb.copy() @@ -319,6 +401,13 @@ def main( hand_action += offset + # Fix the first joint to initial position + hand_action[:, 0] = 1.516937255859375 + print(f"Fixed first joint to: {1.516937255859375}") + + # hand_action[:, 0] = 0.156546025276184082 + # print(f"Fixed first joint to: {0.156546025276184082}") + # get the robot pose when images were captured robot_frames = robot_client.get_state_history() robot_timestamp = [] @@ -444,6 +533,14 @@ def main( # Session completed, reset to initial positions print("20-second session completed. Resetting to initial positions...") + # Stop video recording thread + video_recording_active.clear() + video_thread.join(timeout=2) # Wait up to 2 seconds for thread to finish + + # Save video offline (does not affect real-time performance) + video_save_dir = "/home/ubuntu/hgw/IL/DexUMI/data/video" + save_video_offline(video_frames, video_timestamps, video_save_dir, session_start_time) + # Reset robot to initial position initial_pose_6d = np.zeros(6) initial_pose_6d[:3] = initial_robot_pose[:3] From 39b0d9fbf0afb87d29e316513ff338f8090101ea Mon Sep 17 00:00:00 2001 From: Hly-123 <452663784@qq.com> Date: Fri, 12 Sep 2025 14:51:47 +0800 Subject: [PATCH 10/10] improve XHand Franka evaluation with configurable session duration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add session_duration parameter to control inference duration - Add model configuration display with input modality info - Improve coordinate reference frame handling in inference - Add comprehensive debug output for action verification - Update shell script with better user experience and countdown - Fix proprioception handling based on model config 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- real_script/eval_policy/eval_xhand_franka.py | 120 +++++++++++++------ real_script/eval_policy/eval_xhand_franka.sh | 25 ++-- 2 files changed, 102 insertions(+), 43 deletions(-) diff --git a/real_script/eval_policy/eval_xhand_franka.py b/real_script/eval_policy/eval_xhand_franka.py index 184502a..7df535e 100644 --- a/real_script/eval_policy/eval_xhand_franka.py +++ b/real_script/eval_policy/eval_xhand_franka.py @@ -149,6 +149,7 @@ def save_video_offline(frames, timestamps, save_dir, session_start_time): help="Robot action latency", ) @click.option("-eh", "--exec_horizon", type=int, default=8, help="Execution horizon") +@click.option("-sd", "--session_duration", type=float, default=20.0, help="Session duration in seconds") def main( frequency, model_path, @@ -157,6 +158,7 @@ def main( hand_action_latency, robot_action_latency, exec_horizon, + session_duration, ): # Initialize HTTP clients for robot and hand control robot_client = HTTPRobotClient(base_url="http://127.0.0.1:5000") @@ -186,7 +188,7 @@ def main( # Main control loop while True: - print("Ready! Starting 20-second inference session...") + print(f"Ready! Starting {session_duration}-second inference session...") # Reset robot to initial position print("Moving robot to initial position...") @@ -227,15 +229,27 @@ def main( model_path=model_path, ckpt=ckpt, ) + + # Check model configuration for input modalities + print("\n" + "="*50) + print("MODEL CONFIGURATION:") + print(f" Model path: {model_path}") + print(f" Checkpoint: {ckpt}") + skip_proprioception = getattr(policy.model_cfg.dataset, 'skip_proprioception', False) + enable_fsr = getattr(policy.model_cfg.dataset, 'enable_fsr', False) + print(f" Skip proprioception: {skip_proprioception}") + print(f" Enable FSR: {enable_fsr}") + print(f" Global cond dim: {policy.model_cfg.model.diffusion_policy_head.global_cond_dim}") + print("="*50 + "\n") # Calculate inference parameters inference_iter_time = exec_horizon * dt inference_fps = 1 / inference_iter_time print("inference_fps", inference_fps) - # Start 20-second inference session + # Start inference session session_start_time = time.time() - session_duration = 20.0 # 20 seconds + # session_duration is now passed as parameter # Initialize video recording storage video_frames = [] @@ -259,6 +273,10 @@ def video_recording_thread(): video_thread.daemon = True video_thread.start() + # Initialize inference reference tracking + inference_start_pose = None + inference_start_T = None + # Policy execution loop (runs at inference_fps) while time.time() - session_start_time < session_duration: with FrameRateContext(frame_rate=inference_fps): @@ -272,7 +290,8 @@ def video_recording_thread(): # Note: real_policy.py will handle all image preprocessing # The image should be in BGR format to match training data print(f"Time remaining: {session_duration - (time.time() - session_start_time):.1f}s") - if policy.model_cfg.dataset.enable_fsr: + enable_fsr = getattr(policy.model_cfg.dataset, 'enable_fsr', False) + if enable_fsr: print("Using FSR") fsr_raw_obs = dexhand_client.get_tactile(calc=True) print("raw", fsr_raw_obs) @@ -326,29 +345,33 @@ def video_recording_thread(): print(f"FSR obs shape: {np.array(list(fsr_obs)).shape}") print(f"FSR obs values: {np.array(list(fsr_obs))[-1]}") # Last FSR reading - # Get robot proprioception data for model input - robot_state = robot_client.get_state() + # Get robot proprioception data for model input (only if model uses it) proprioception = None + skip_proprioception = getattr(policy.model_cfg.dataset, 'skip_proprioception', False) - if robot_state and "state" in robot_state: - # Extract joint positions and velocities - joint_q = np.array(robot_state["state"]["ActualQ"]) # 7D joint positions - joint_dq = np.array(robot_state["state"]["ActualQd"]) # 7D joint velocities - # Create 14D proprioception vector: [joint_q, joint_dq] - proprioception = np.concatenate([joint_q, joint_dq]).astype(np.float32) - - print(f"Proprioception shape: {proprioception.shape}") - print(f"Joint positions: {joint_q}") - print(f"Joint velocities: {joint_dq}") + if not skip_proprioception: + robot_state = robot_client.get_state() + if robot_state and "state" in robot_state: + # Extract joint positions and velocities + joint_q = np.array(robot_state["state"]["ActualQ"]) # 7D joint positions + joint_dq = np.array(robot_state["state"]["ActualQd"]) # 7D joint velocities + # Create 14D proprioception vector: [joint_q, joint_dq] + proprioception = np.concatenate([joint_q, joint_dq]).astype(np.float32) + + print(f"Proprioception shape: {proprioception.shape}") + print(f"Joint positions: {joint_q}") + print(f"Joint velocities: {joint_dq}") + else: + # Fallback to zeros if state unavailable + proprioception = np.zeros(14, dtype=np.float32) + print("Warning: Robot state unavailable, using zero proprioception") else: - # Fallback to zeros if state unavailable - proprioception = np.zeros(14, dtype=np.float32) - print("Warning: Robot state unavailable, using zero proprioception") + print("Skipping proprioception (vision-only model)") action = policy.predict_action( proprioception[None, ...] if proprioception is not None else None, # Add batch dimension np.array(list(fsr_obs)).astype(np.float32) - if policy.model_cfg.dataset.enable_fsr + if enable_fsr else None, obs_frame_rgb[None, ...], # Use original image, let real_policy.py handle preprocessing ) @@ -357,6 +380,13 @@ def video_recording_thread(): print(f"Action shape: {action.shape}") print(f"Action min/max: [{action.min():.4f}, {action.max():.4f}]") print(f"Action mean/std: [mean={action.mean():.4f}, std={action.std():.4f}]") + print(f"Action first row: {action[0]}") + + # Check if actions are near zero (potential inference issue) + action_magnitude = np.linalg.norm(action, axis=1).mean() + print(f"Average action magnitude: {action_magnitude:.6f}") + if action_magnitude < 0.001: + print("⚠️ WARNING: Actions are near zero! Model may not be generating proper outputs.") # convert to abs action relative_pose = action[:, :6] @@ -401,12 +431,9 @@ def video_recording_thread(): hand_action += offset - # Fix the first joint to initial position - hand_action[:, 0] = 1.516937255859375 - print(f"Fixed first joint to: {1.516937255859375}") - - # hand_action[:, 0] = 0.156546025276184082 - # print(f"Fixed first joint to: {0.156546025276184082}") + # Fix the first joint to specified position during inference + # hand_action[:, 0] = 0.5154829025268555 + # print(f"Fixed first joint to: {0.5154829025268555}") # get the robot pose when images were captured robot_frames = robot_client.get_state_history() @@ -471,18 +498,28 @@ def video_recording_thread(): ).as_matrix() T_BE[:3, -1] = ee_aligned_pose[:3] - # Calculate target poses - SIMPLIFIED without T_ET - # Since we're directly in end-effector frame, just apply relative transform + # Set inference start reference frame (only on first iteration) + if inference_start_pose is None: + inference_start_pose = ee_aligned_pose.copy() + inference_start_T = T_BE.copy() + print(f"🎯 INFERENCE START POSE: {inference_start_pose[:3]} (position)") + print(f"🎯 INFERENCE START ROTVEC: {inference_start_pose[3:]} (rotation)") + + # Calculate target poses - FIXED coordinate reference + # Model predicts relative to SEQUENCE START, not current frame T_BN = np.zeros_like(relative_pose) for iter_idx in range(len(relative_pose)): - # Direct application: T_BN = T_BE @ relative_pose - T_BN[iter_idx] = T_BE @ relative_pose[iter_idx] + # CORRECT: Apply relative transform to inference start position + T_BN[iter_idx] = inference_start_T @ relative_pose[iter_idx] # ============ DEBUG: Target Poses ============ print("\nDEBUG: Target Transformation") - print(f"Current EE pose: {ee_aligned_pose[:3]}") # Current position - print(f"First target pose: {T_BN[0, :3, -1]}") # First target position - print(f"Position change: {T_BN[0, :3, -1] - ee_aligned_pose[:3]}") # Delta position + print(f"🔄 Current EE pose: {ee_aligned_pose[:3]}") # Current position + print(f"📍 Inference start pose: {inference_start_pose[:3]}") # Reference position + print(f"🎯 First target pose: {T_BN[0, :3, -1]}") # First target position + print(f"📊 Target relative to start: {T_BN[0, :3, -1] - inference_start_pose[:3]}") # Relative to start + print(f"📈 Target relative to current: {T_BN[0, :3, -1] - ee_aligned_pose[:3]}") # Relative to current + print(f"📏 Target trajectory span: {np.linalg.norm(T_BN[-1, :3, -1] - T_BN[0, :3, -1]):.4f}m") # Trajectory span # ============ DEBUG END ============ # discard actions which in the past n_action = T_BN.shape[0] @@ -523,15 +560,25 @@ def video_recording_thread(): hand_scheduled += 1 print( - f"Scheduled actions: {robot_scheduled} robot waypoints, {hand_scheduled} hand waypoints" + f"✅ Scheduled actions: {robot_scheduled} robot waypoints, {hand_scheduled} hand waypoints" ) + + # ============ EXECUTION VERIFICATION ============ + if robot_scheduled > 0: + scheduled_targets = [T_BN[k, :3, -1] for k in np.where(valid_robot_idx)[0]] + if scheduled_targets: + print(f"🎯 First scheduled target: {scheduled_targets[0]}") + print(f"📏 Distance to first target: {np.linalg.norm(scheduled_targets[0] - ee_aligned_pose[:3]):.4f}m") + else: + print(f"⚠️ WARNING: No robot actions scheduled! Check timing parameters.") + # ============ EXECUTION VERIFICATION END ============ if len(hand_action) > exec_horizon + 1: virtual_hand_pos = hand_action[exec_horizon + 1] else: virtual_hand_pos = hand_action[-1] # Session completed, reset to initial positions - print("20-second session completed. Resetting to initial positions...") + print(f"{session_duration}-second session completed. Resetting to initial positions...") # Stop video recording thread video_recording_active.clear() @@ -539,6 +586,9 @@ def video_recording_thread(): # Save video offline (does not affect real-time performance) video_save_dir = "/home/ubuntu/hgw/IL/DexUMI/data/video" + print(f"\nVideo Recording Summary:") + print(f" Frames captured: {len(video_frames)}") + print(f" Duration: {video_timestamps[-1] - video_timestamps[0]:.1f}s" if video_timestamps else "0s") save_video_offline(video_frames, video_timestamps, video_save_dir, session_start_time) # Reset robot to initial position diff --git a/real_script/eval_policy/eval_xhand_franka.sh b/real_script/eval_policy/eval_xhand_franka.sh index bf13a5d..34f4a48 100755 --- a/real_script/eval_policy/eval_xhand_franka.sh +++ b/real_script/eval_policy/eval_xhand_franka.sh @@ -6,13 +6,17 @@ source ~/anaconda3/etc/profile.d/conda.sh conda activate dexumi +# Set Python path for dexumi module +export PYTHONPATH="/home/ubuntu/hgw/IL/DexUMI:$PYTHONPATH" + # Path to your trained model -MODEL_PATH="/home/ubuntu/hgw/IL/DexUMI/data/weight/vision_tactile_propio" # TODO: Update this path -CHECKPOINT=600 +MODEL_PATH="/home/ubuntu/hgw/IL/DexUMI/data/weight/vision_only_0909" +CHECKPOINT=100 # Control parameters -FREQUENCY=15 # Control frequency in Hz +FREQUENCY=20 # Control frequency in Hz EXEC_HORIZON=8 # Number of action steps to execute before re-predicting +SESSION_DURATION=120.0 # Session duration in seconds # Visualization settings ENABLE_VISUALIZATION=false # Set to true to enable real-time camera visualization @@ -37,6 +41,7 @@ echo "Checkpoint: $CHECKPOINT" echo "Camera Type: $CAMERA_TYPE" echo "Frequency: $FREQUENCY Hz" echo "Execution Horizon: $EXEC_HORIZON steps" +echo "Session Duration: $SESSION_DURATION seconds" echo "" echo "Latency Settings:" echo " Camera: ${CAMERA_LATENCY}s" @@ -50,15 +55,18 @@ echo "✓ HTTP control interface" echo "✓ RealSense/OAK camera support" echo "✓ Multi-step action execution" echo "" -echo "Make sure the robot server is running:" -echo " python franka_server.py" +echo "Note: Robot server will be checked during runtime" echo "" -echo "Press Ctrl+C to abort, or wait 3 seconds to continue..." +echo "Starting in 3 seconds... (Press Ctrl+C to abort)" echo "=========================================" echo "" -# Wait for user to check -sleep 3 +# Countdown +for i in 3 2 1; do + echo -n "$i... " + sleep 1 +done +echo "Starting!" # Run the evaluation script python real_script/eval_policy/eval_xhand_franka.py \ @@ -66,6 +74,7 @@ python real_script/eval_policy/eval_xhand_franka.py \ --ckpt $CHECKPOINT \ --frequency $FREQUENCY \ --exec_horizon $EXEC_HORIZON \ + --session_duration $SESSION_DURATION \ --camera_latency $CAMERA_LATENCY \ --hand_action_latency $HAND_ACTION_LATENCY \ --robot_action_latency $ROBOT_ACTION_LATENCY \ No newline at end of file