diff --git a/README.md b/README.md index 32b8bec..b556c03 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,19 @@ A simple interface that prints the teleop responses. You can use it as a referen python3 -m teleop.basic ``` +### xArm + +Interface to teleoperate the [uFactory Lite 6](https://www.ufactory.cc/lite-6-collaborative-robot/) robot. +Minor changes are probably necessary to support other xArm robots. + +```bash +python3 -m teleop.xarm +``` + +Note that the interface is very simple, it doesn't implement any kind of filtering. +Therefore, you probably want to teleoperate it with a device with high frequency. +Smart phones are typically 30fps while VR joysticks 90fps which is much more preferable for teleoperation without filtering. + ### ROS 2 Interface The ROS 2 interface is designed primarily for use with the [cartesian_controllers](https://github.com/fzi-forschungszentrum-informatik/cartesian_controllers) package, but it can also be adapted for [MoveIt Servo](https://moveit.picknik.ai/main/doc/examples/realtime_servo/realtime_servo_tutorial.html) or other packages. diff --git a/examples/webots/controllers/inverse_kinematics/inverse_kinematics.py b/examples/webots/controllers/inverse_kinematics/inverse_kinematics.py index 9b48db1..3ae5086 100644 --- a/examples/webots/controllers/inverse_kinematics/inverse_kinematics.py +++ b/examples/webots/controllers/inverse_kinematics/inverse_kinematics.py @@ -16,16 +16,16 @@ def __init__(self): self.__jacobi = JacobiRobot(file.name, ee_link="wrist_3_link") # Initialize the arm motors and encoders. - timestep = int(self.getBasicTimeStep()) + self.timestep = int(self.getBasicTimeStep()) for joint_name in self.__jacobi.get_joint_names(): motor = self.getDevice(joint_name) motor.setVelocity(1.0) position_sensor = motor.getPositionSensor() - position_sensor.enable(timestep) + position_sensor.enable(self.timestep) self.__motors[joint_name] = motor def move_to_position(self, pose): - self.__jacobi.servo_to_pose(pose) + self.__jacobi.servo_to_pose(pose, dt=self.timestep / 1000.0) for name, motor in self.__motors.items(): motor.setPosition(self.__jacobi.get_joint_position(name)) @@ -40,7 +40,7 @@ def get_current_pose(self): def main(): robot = RobotArm() - teleop = Teleop() + teleop = Teleop(natural_phone_orientation_euler=[0, 0, 0]) target_pose = None def on_teleop_callback(pose, message): diff --git a/examples/webots/worlds/.teleop.wbproj b/examples/webots/worlds/.teleop.wbproj index d8be894..04b3863 100644 --- a/examples/webots/worlds/.teleop.wbproj +++ b/examples/webots/worlds/.teleop.wbproj @@ -1,5 +1,5 @@ -Webots Project File version R2023b -perspectives: 000000ff00000000fd00000003000000000000000000000000fc0100000001fb0000001a0044006f00630075006d0065006e0074006100740069006f006e0000000000ffffffff0000000000000000000000010000012400000391fc0200000001fb0000001400540065007800740045006400690074006f00720000000016000003910000008900ffffff0000000300000d7000000042fc0100000002fb0000000e0043006f006e0073006f006c00650100000000000006900000000000000000fb0000001a0043006f006e0073006f006c00650041006c006c0041006c006c010000000000000d700000006900ffffff00000d70000004f000000001000000020000000100000008fc00000000 +Webots Project File version R2025a +perspectives: 000000ff00000000fd00000003000000000000000000000000fc0100000001fb0000001a0044006f00630075006d0065006e0074006100740069006f006e0000000000ffffffff0000000000000000000000010000012400000391fc0200000001fb0000001400540065007800740045006400690074006f00720000000016000003910000008900ffffff00000003000004f400000039fc0100000002fb0000000e0043006f006e0073006f006c00650100000000000006900000000000000000fb0000001a0043006f006e0073006f006c00650041006c006c0041006c006c0100000000000004f40000006900ffffff000004f40000017800000001000000020000000100000008fc00000000 simulationViewPerspectives: 000000ff000000010000000200000133000006470100000002010000000100 sceneTreePerspectives: 000000ff00000001000000030000001c000000c0000000fa0100000002010000000200 maximizedDockId: -1 diff --git a/examples/webots/worlds/teleop.wbt b/examples/webots/worlds/teleop.wbt index b6a0683..fced0d6 100644 --- a/examples/webots/worlds/teleop.wbt +++ b/examples/webots/worlds/teleop.wbt @@ -1,13 +1,14 @@ -#VRML_SIM R2023b utf8 +#VRML_SIM R2025a utf8 -EXTERNPROTO "https://raw.githubusercontent.com/cyberbotics/webots/R2023b/projects/objects/backgrounds/protos/TexturedBackground.proto" -EXTERNPROTO "https://raw.githubusercontent.com/cyberbotics/webots/R2023b/projects/objects/backgrounds/protos/TexturedBackgroundLight.proto" -EXTERNPROTO "https://raw.githubusercontent.com/cyberbotics/webots/R2023b/projects/objects/floors/protos/Floor.proto" -EXTERNPROTO "https://raw.githubusercontent.com/cyberbotics/webots/R2023b/projects/robots/universal_robots/protos/UR5e.proto" -EXTERNPROTO "https://raw.githubusercontent.com/cyberbotics/webots/R2023b/projects/objects/tables/protos/Table.proto" -EXTERNPROTO "https://raw.githubusercontent.com/cyberbotics/webots/R2023b/projects/appearances/protos/CementTiles.proto" +EXTERNPROTO "https://raw.githubusercontent.com/cyberbotics/webots/R2025a/projects/objects/backgrounds/protos/TexturedBackground.proto" +EXTERNPROTO "https://raw.githubusercontent.com/cyberbotics/webots/R2025a/projects/objects/backgrounds/protos/TexturedBackgroundLight.proto" +EXTERNPROTO "https://raw.githubusercontent.com/cyberbotics/webots/R2025a/projects/objects/floors/protos/Floor.proto" +EXTERNPROTO "https://raw.githubusercontent.com/cyberbotics/webots/R2025a/projects/robots/universal_robots/protos/UR5e.proto" +EXTERNPROTO "https://raw.githubusercontent.com/cyberbotics/webots/R2025a/projects/objects/tables/protos/Table.proto" +EXTERNPROTO "https://raw.githubusercontent.com/cyberbotics/webots/R2025a/projects/appearances/protos/CementTiles.proto" WorldInfo { + basicTimeStep 10 } Viewpoint { orientation 0.12086428145072244 0.9228327289041858 -0.36574797324476627 0.6877052442483729 diff --git a/setup.py b/setup.py index d8b9e52..7d9a3dc 100644 --- a/setup.py +++ b/setup.py @@ -16,25 +16,25 @@ setup( name="teleop", - version="0.0.9", - packages=["teleop", "teleop.basic", "teleop.ros2", "teleop.utils", "teleop.ros2_ik"], + version="0.1.0", + packages=["teleop", "teleop.basic", "teleop.ros2", "teleop.utils", "teleop.ros2_ik", "teleop.xarm"], long_description=long_description, long_description_content_type="text/markdown", description="Turns your phone into a robot arm teleoperation device by leveraging the WebXR API", install_requires=[ - "Flask", - "Flask-SocketIO", + "fastapi", + "uvicorn[standard]", "numpy", "transforms3d", - "werkzeug", "pytest", "requests", + "websocket-client", ], extras_require={ "utils": ["pin"], }, package_data={ - "teleop": ["cert.pem", "key.pem", "index.html"], + "teleop": ["cert.pem", "key.pem", "index.html", "teleop-ui.js"], }, license="Apache 2.0", author="Spes Robotics", diff --git a/teleop/__init__.py b/teleop/__init__.py index d4abbf9..eca0836 100644 --- a/teleop/__init__.py +++ b/teleop/__init__.py @@ -1,15 +1,14 @@ -import ssl import os import math import socket import logging -from werkzeug.serving import ThreadedWSGIServer -from typing import Callable -from flask import Flask, send_from_directory, request -from flask_socketio import SocketIO, emit +from typing import Callable, List +import uvicorn +from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from fastapi.responses import FileResponse import transforms3d as t3d import numpy as np - +import json TF_RUB2FLU = np.array([[0, 0, -1, 0], [-1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) THIS_DIR = os.path.dirname(os.path.realpath(__file__)) @@ -45,25 +44,114 @@ def are_close(a, b=None, lin_tol=1e-9, ang_tol=1e-9): d = np.linalg.inv(a) @ b if not np.allclose(d[:3, 3], np.zeros(3), atol=lin_tol): return False - rpy = t3d.euler.mat2euler(d[:3, :3]) + yaw = math.atan2(d[1, 0], d[0, 0]) + pitch = math.asin(-d[2, 0]) + roll = math.atan2(d[2, 1], d[2, 2]) + rpy = np.array([roll, pitch, yaw]) return np.allclose(rpy, np.zeros(3), atol=ang_tol) +def slerp(q1, q2, t): + """Spherical linear interpolation between two quaternions.""" + q1 = q1 / np.linalg.norm(q1) + q2 = q2 / np.linalg.norm(q2) + + dot = np.dot(q1, q2) + + # If the dot product is negative, use the shortest path + if dot < 0.0: + q2 = -q2 + dot = -dot + + DOT_THRESHOLD = 0.9995 + if dot > DOT_THRESHOLD: + # Linear interpolation fallback for nearly identical quaternions + result = q1 + t * (q2 - q1) + return result / np.linalg.norm(result) + + theta_0 = np.arccos(dot) + theta = theta_0 * t + + q3 = q2 - q1 * dot + q3 = q3 / np.linalg.norm(q3) + + return q1 * np.cos(theta) + q3 * np.sin(theta) + + +def interpolate_transforms(T1, T2, alpha): + """ + Interpolate between two 4x4 transformation matrices using SLERP + linear translation. + + Args: + T1 (np.ndarray): Start transform (4x4) + T2 (np.ndarray): End transform (4x4) + alpha (float): Interpolation factor [0, 1] + + Returns: + np.ndarray: Interpolated transform (4x4) + """ + assert T1.shape == (4, 4) and T2.shape == (4, 4) + assert 0.0 <= alpha <= 1.0 + + # Translation + t1 = T1[:3, 3] + t2 = T2[:3, 3] + t_interp = (1 - alpha) * t1 + alpha * t2 + + # Rotation + R1 = T1[:3, :3] + R2 = T2[:3, :3] + q1 = t3d.quaternions.mat2quat(R1) + q2 = t3d.quaternions.mat2quat(R2) + + # SLERP + q_interp = slerp(q1, q2, alpha) + R_interp = t3d.quaternions.quat2mat(q_interp) + + # Final transform + T_interp = np.eye(4) + T_interp[:3, :3] = R_interp + T_interp[:3, 3] = t_interp + + return T_interp + + +class ConnectionManager: + def __init__(self): + self.active_connections: List[WebSocket] = [] + + async def connect(self, websocket: WebSocket): + await websocket.accept() + self.active_connections.append(websocket) + + def disconnect(self, websocket: WebSocket): + self.active_connections.remove(websocket) + + async def send_personal_message(self, message: str, websocket: WebSocket): + await websocket.send_text(message) + + async def broadcast(self, message: str): + for connection in self.active_connections: + try: + await connection.send_text(message) + except: + # Remove broken connections + self.active_connections.remove(connection) + + class Teleop: """ - Teleop class for controlling a robot remotely. + Teleop class for controlling a robot remotely using FastAPI and WebSockets. Args: host (str, optional): The host IP address. Defaults to "0.0.0.0". port (int, optional): The port number. Defaults to 4443. - ssl_context (ssl.SSLContext, optional): The SSL context for secure communication. Defaults to None. """ def __init__( self, host="0.0.0.0", port=4443, - ssl_context=None, natural_phone_orientation_euler=None, natural_phone_position=None, ): @@ -71,10 +159,8 @@ def __init__( self.__logger.setLevel(logging.INFO) self.__logger.addHandler(logging.StreamHandler()) - self.__server = None self.__host = host self.__port = port - self.__ssl_context = ssl_context self.__relative_pose_init = None self.__absolute_pose_init = None @@ -92,21 +178,12 @@ def __init__( [1, 1, 1], ) - if self.__ssl_context is None: - self.__ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS) - self.__ssl_context.load_cert_chain( - certfile=os.path.join(THIS_DIR, "cert.pem"), - keyfile=os.path.join(THIS_DIR, "key.pem"), - ) + self.__app = FastAPI() + self.__manager = ConnectionManager() - self.__app = Flask(__name__) - self.__app.config['SECRET_KEY'] = 'teleop_secret_key' - self.__socketio = SocketIO(self.__app, cors_allowed_origins="*") - - log = logging.getLogger("werkzeug") - log.setLevel(logging.ERROR) - self.__register_routes() - self.__register_socketio_events() + # Configure logging + logging.getLogger("uvicorn.access").setLevel(logging.WARNING) + self.__setup_routes() def set_pose(self, pose: np.ndarray) -> None: """ @@ -137,6 +214,7 @@ def __update(self, message): move = message["move"] position = message["position"] orientation = message["orientation"] + scale = message.get("scale", 1.0) position = np.array([position["x"], position["y"], position["z"]]) quat = np.array( @@ -163,7 +241,7 @@ def __update(self, message): if not are_close( received_pose, self.__previous_received_pose, - lin_tol=10e-2, + lin_tol=0.05, ang_tol=math.radians(35), ): self.__logger.warning("Pose jump detected, resetting the pose") @@ -178,44 +256,56 @@ def __update(self, message): self.__absolute_pose_init = self.__pose self.__previous_received_pose = None - relative_pose = np.linalg.inv(self.__relative_pose_init) @ received_pose - self.__pose = np.eye(4) - self.__pose[:3, 3] = self.__absolute_pose_init[:3, 3] + relative_pose[:3, 3] - self.__pose[:3, :3] = ( - relative_pose[:3, :3] @ self.__absolute_pose_init[:3, :3] + relative_position = received_pose[:3, 3] - self.__relative_pose_init[:3, 3] + relative_orientation = received_pose[:3, :3] @ np.linalg.inv( + self.__relative_pose_init[:3, :3] ) + self.__pose = np.eye(4) + self.__pose[:3, 3] = self.__absolute_pose_init[:3, 3] + relative_position + self.__pose[:3, :3] = relative_orientation @ self.__absolute_pose_init[:3, :3] + + # Apply scale + if scale != 1.0: + self.__pose = interpolate_transforms( + self.__absolute_pose_init, self.__pose, scale + ) # Notify the subscribers self.__notify_subscribers(self.__pose, message) - def __register_routes(self): - @self.__app.route("/") - def serve_file(filename): - self.__logger.debug(f"Serving the {filename} file") - return send_from_directory(THIS_DIR, filename) - - @self.__app.route("/") - def index(): + def __setup_routes(self): + @self.__app.get("/") + async def index(): self.__logger.debug("Serving the index.html file") - return send_from_directory(THIS_DIR, "index.html") + return FileResponse(os.path.join(THIS_DIR, "index.html")) - def __register_socketio_events(self): - @self.__socketio.on('connect') - def handle_connect(): - self.__logger.info('Client connected') - - @self.__socketio.on('disconnect') - def handle_disconnect(): - self.__logger.info('Client disconnected') - - @self.__socketio.on('pose') - def handle_pose(data): - self.__logger.debug(f"Received pose data: {data}") - self.__update(data) - - @self.__socketio.on('log') - def handle_log(data): - self.__logger.info(f"Received log message: {data}") + @self.__app.get("/{filename:path}") + async def serve_file(filename: str): + self.__logger.debug(f"Serving the {filename} file") + file_path = os.path.join(THIS_DIR, filename) + if os.path.exists(file_path): + return FileResponse(file_path) + return {"error": "File not found"} + + @self.__app.websocket("/ws") + async def websocket_endpoint(websocket: WebSocket): + await self.__manager.connect(websocket) + self.__logger.info("Client connected") + + try: + while True: + data = await websocket.receive_text() + message = json.loads(data) + + if message.get("type") == "pose": + self.__logger.debug(f"Received pose data: {message['data']}") + self.__update(message["data"]) + elif message.get("type") == "log": + self.__logger.info(f"Received log message: {message['data']}") + + except WebSocketDisconnect: + self.__manager.disconnect(websocket) + self.__logger.info("Client disconnected") def run(self) -> None: """ @@ -226,17 +316,21 @@ def run(self) -> None: f"The phone web app should be available at https://{get_local_ip()}:{self.__port}" ) - self.__server = ThreadedWSGIServer( - app=self.__app, + ssl_keyfile = os.path.join(THIS_DIR, "key.pem") + ssl_certfile = os.path.join(THIS_DIR, "cert.pem") + + uvicorn.run( + self.__app, host=self.__host, port=self.__port, - ssl_context=self.__ssl_context, + ssl_keyfile=ssl_keyfile, + ssl_certfile=ssl_certfile, + log_level="warning", ) - self.__server.serve_forever() def stop(self) -> None: """ Stops the teleop server. """ - if self.__server: - self.__server.shutdown() + # FastAPI/uvicorn handles shutdown automatically + pass diff --git a/teleop/index.html b/teleop/index.html index d8effe2..9be2849 100644 --- a/teleop/index.html +++ b/teleop/index.html @@ -5,9 +5,8 @@ Teleop + - - + +
+
+
+ +
+ +
+
+
Local Stats
+
Waiting...
+
+ +
+
Server Status
+
Waiting...
+
+
+ +
+ +
scale 1.0
+
+ +
+
A
+
B
+
+ +
+ + + +
+
+ `; + } + + setupEventListeners() { + const exitButton = this.shadowRoot.getElementById('exitButton'); + const scaleSlider = this.shadowRoot.getElementById('scaleSlider'); + const gripperButton = this.shadowRoot.getElementById('gripperButton'); + const motionButton = this.shadowRoot.getElementById('motionButton'); + const reservedButtonA = this.shadowRoot.getElementById('reservedButtonA'); + const reservedButtonB = this.shadowRoot.getElementById('reservedButtonB'); + + // Exit button + exitButton.addEventListener('click', () => { + this.dispatchEvent(new CustomEvent('exit')); + }); + + // Scale slider + scaleSlider.addEventListener('input', (event) => { + const sliderValues = [0.1, 0.25, 0.5, 1.0]; + this.sliderValue = sliderValues[event.target.value]; + this.shadowRoot.getElementById('scaleValue').textContent = `scale scale ${this.sliderValue.toFixed(2)}`; + + this.dispatchEvent(new CustomEvent('scalechange', { + detail: { scale: this.sliderValue } + })); + }); + + // Gripper button + gripperButton.addEventListener('click', () => { + this.gripperEngaged = !this.gripperEngaged; + gripperButton.textContent = `Gripper ${this.gripperEngaged ? 'Engaged' : 'Disengaged'}`; + if (this.gripperEngaged) { + gripperButton.classList.add('engaged'); + } else { + gripperButton.classList.remove('engaged'); + } + + this.dispatchEvent(new CustomEvent('gripperchange', { + detail: { engaged: this.gripperEngaged } + })); + }); + + // Motion button (touch and mouse events) + const handleMotionStart = (event) => { + event.preventDefault(); + this.motionEnabled = true; + motionButton.classList.add('active'); + motionButton.textContent = 'Moving...'; + + this.dispatchEvent(new CustomEvent('motionchange', { + detail: { enabled: true } + })); + }; + + const handleMotionEnd = () => { + if (!this.motionEnabled) return; + + this.motionEnabled = false; + motionButton.classList.remove('active'); + motionButton.textContent = 'Hold to Move'; + + this.dispatchEvent(new CustomEvent('motionchange', { + detail: { enabled: false } + })); + }; + + const handleReservedButtonStart = (event, buttonName) => { + event.preventDefault(); + const buttonId = event.currentTarget.id; + const button = this.shadowRoot.getElementById(buttonId); + button.classList.add('active'); + + const buttonNameLower = buttonName.toLowerCase(); + + if (buttonName === 'A') + this.reservedButtonAActive = true; + else if (buttonName === 'B') + this.reservedButtonBActive = true; + this.dispatchEvent(new CustomEvent(`reservedbutton${buttonNameLower}change`, { + detail: { active: true } + })); + } + + const handleReservedButtonEnd = (buttonName) => { + if (!this.reservedButtonAActive && buttonName === 'A') return; + if (!this.reservedButtonBActive && buttonName === 'B') return; + + const button = (buttonName === 'A') ? reservedButtonA : reservedButtonB; + button.classList.remove('active'); + + const buttonNameLower = buttonName.toLowerCase(); + + if (buttonName === 'A') + this.reservedButtonAActive = false; + else if (buttonName === 'B') + this.reservedButtonBActive = false; + this.dispatchEvent(new CustomEvent(`reservedbutton${buttonNameLower}change`, { + detail: { active: false } + })); + } + + motionButton.addEventListener('mousedown', handleMotionStart); + motionButton.addEventListener('touchstart', handleMotionStart); + document.addEventListener('mouseup', handleMotionEnd); + document.addEventListener('touchend', handleMotionEnd); + + reservedButtonA.addEventListener('mousedown', e => handleReservedButtonStart(e, 'A')); + reservedButtonA.addEventListener('touchstart', e => handleReservedButtonStart(e, 'A')); + document.addEventListener('mouseup', () => handleReservedButtonEnd('A')); + document.addEventListener('touchend', () => handleReservedButtonEnd('A')); + + reservedButtonB.addEventListener('mousedown', e => handleReservedButtonStart(e, 'B')); + reservedButtonB.addEventListener('touchstart', e => handleReservedButtonStart(e, 'B')); + document.addEventListener('mouseup', () => handleReservedButtonEnd('B')); + document.addEventListener('touchend', () => handleReservedButtonEnd('B')); + } + + // Public methods to update displays + updateLocalStats(stats) { + this.localStats = stats; + const statsContent = this.shadowRoot.getElementById('statsContent'); + + statsContent.textContent = ''; + if (stats.position) { + const position = stats.position; + statsContent.textContent += `Position: X: ${position.x.toFixed(3)}, Y: ${position.y.toFixed(3)}, Z: ${position.z.toFixed(3)} +`; + } + + if (stats.orientation) { + const orientation = stats.orientation; + statsContent.textContent += `Orientation: X: ${orientation.x.toFixed(2)}, Y: ${orientation.y.toFixed(2)}, Z: ${orientation.z.toFixed(2)}, W: ${orientation.w.toFixed(2)} +`; + } + + if (stats.fps) { + const fps = stats.fps.toFixed(2); + statsContent.textContent += `FPS: ${fps} +`; + } + } + + updateServerDiagnostics(data) { + this.serverDiagnostics = data; + const diagnosticsContent = this.shadowRoot.getElementById('diagnosticsContent'); + + if (typeof data === 'object') { + diagnosticsContent.textContent = JSON.stringify(data, null, 2); + } else { + diagnosticsContent.textContent = data; + } + } + + // Getters + getScale() { + return this.sliderValue; + } + + isGripperEngaged() { + return this.gripperEngaged; + } + + isMotionEnabled() { + return this.motionEnabled; + } + + isReservedButtonAActive() { + return this.reservedButtonAActive; + } + + isReservedButtonBActive() { + return this.reservedButtonBActive; + } +} + +// Register the custom element +customElements.define('teleop-ui', TeleopUI); diff --git a/teleop/test-teleop-ui.html b/teleop/test-teleop-ui.html new file mode 100644 index 0000000..07f7130 --- /dev/null +++ b/teleop/test-teleop-ui.html @@ -0,0 +1,56 @@ + + + + + + + Teleop + + + + + + + + + + + diff --git a/teleop/utils/transform_limiter.py b/teleop/utils/transform_limiter.py new file mode 100644 index 0000000..8f5e1b5 --- /dev/null +++ b/teleop/utils/transform_limiter.py @@ -0,0 +1,101 @@ +import numpy as np +from transforms3d.quaternions import mat2quat, quat2mat +from transforms3d.quaternions import axangle2quat, quat2axangle + + +def se3_to_twist(T1, T2): + """Compute linear and angular velocity from T1 to T2.""" + dT = np.linalg.inv(T1) @ T2 + # Linear velocity + v = dT[:3, 3] + + # Angular velocity + R = dT[:3, :3] + q = mat2quat(R) + angle, axis = quat2axangle(q) + w = angle * np.array(axis) + + return v, w + + +def apply_twist(T, v, w, dt=1.0): + """Apply twist (v, w) to transformation T over time dt.""" + # Translation update + t_new = T[:3, 3] + v * dt + + # Rotation update + angle = np.linalg.norm(w) + if angle < 1e-8: + R_update = np.eye(3) + else: + axis = w / angle + q_update = axangle2quat(axis, angle * dt) + R_update = quat2mat(q_update) + + R_new = R_update @ T[:3, :3] + + T_new = np.eye(4) + T_new[:3, :3] = R_new + T_new[:3, 3] = t_new + return T_new + + +def limit_magnitude(vec, max_val): + norm = np.linalg.norm(vec) + if norm > max_val: + return vec * (max_val / norm) + return vec + + +def clamp_twist(v_prev, v_desired, max_vel, max_acc): + # Acceleration + a = v_desired - v_prev + a = limit_magnitude(a, max_acc) + v_new = v_prev + a + return limit_magnitude(v_new, max_vel) + + +def compute_next_transform( + T_tm1, + T_t, + T_tp1_desired, + max_lin_vel, + max_ang_vel, + max_lin_acc, + max_ang_acc, + dt=1.0, +): + # Previous and desired velocities + v_prev, w_prev = se3_to_twist(T_tm1, T_t) + v_desired, w_desired = se3_to_twist(T_t, T_tp1_desired) + + # Clamp to acceleration and velocity limits + v_new = clamp_twist(v_prev, v_desired, max_lin_vel * dt, max_lin_acc * dt) + w_new = clamp_twist(w_prev, w_desired, max_ang_vel * dt, max_ang_acc * dt) + + # Apply twist to current pose + T_tp1 = apply_twist(T_t, v_new, w_new, dt) + + return T_tp1 + + +if __name__ == "__main__": + T_tm1 = np.eye(4) + T_t = np.array( + [[0.998, -0.05, 0, 0.5], [0.05, 0.998, 0, 0.0], [0, 0, 1, 0.0], [0, 0, 0, 1]] + ) + T_tp1_desired = np.array( + [[0.992, -0.12, 0, 1.2], [0.12, 0.992, 0, 0.3], [0, 0, 1, 0.0], [0, 0, 0, 1]] + ) + + T_tp1_clamped = compute_next_transform( + T_tm1, + T_t, + T_tp1_desired, + max_lin_vel=1.0, # m/s + max_ang_vel=np.radians(45), # rad/s + max_lin_acc=0.5, # m/s² + max_ang_acc=np.radians(30), # rad/s² + ) + + print("New transform at t+1 (clamped):\n", T_tp1_clamped) diff --git a/teleop/xarm/__main__.py b/teleop/xarm/__main__.py new file mode 100644 index 0000000..5e5efed --- /dev/null +++ b/teleop/xarm/__main__.py @@ -0,0 +1,114 @@ +import numpy as np +from teleop import Teleop +import time +import transforms3d as t3d + +try: + from xarm.wrapper import XArmAPI +except ImportError: + raise ImportError( + "xarm-python-sdk is not installed. Please install the dependency with `pip install xarm-python-sdk`." + ) + + +class Lite6Gripper: + def __init__(self, arm: XArmAPI): + self._arm = arm + self._prev_gripper_state = None + self._gripper_state = 0.0 + self._gripper_open_time = 0.0 + self._gripper_stopped = False + + def open(self): + self._arm.open_lite6_gripper() + + def close(self): + self._arm.close_lite6_gripper() + + def stop(self): + self._arm.stop_lite6_gripper() + + def set_gripper_state(self, gripper_state: float) -> None: + """Set gripper state and handle opening/closing logic.""" + self._gripper_state = gripper_state + + if self._gripper_state is not None: + if ( + self._prev_gripper_state is None + or self._prev_gripper_state != self._gripper_state + ): + if self._gripper_state < 1.0: + self._gripper_stopped = False + self.close() + else: + self._gripper_open_time = time.time() + self._gripper_stopped = False + self.open() + + if ( + not self._gripper_stopped + and self._gripper_state >= 1.0 + and time.time() - self._gripper_open_time > 1.0 + ): + # If gripper was closed and now is open, stop the gripper + self._gripper_stopped = True + self.stop() + + self._prev_gripper_state = self._gripper_state + + def get_gripper_state(self) -> float: + """Get current gripper state.""" + return self._gripper_state + + def reset_gripper(self) -> None: + """Reset gripper state.""" + self._prev_gripper_state = None + self._gripper_state = 0.0 + self._gripper_open_time = 0.0 + self._gripper_stopped = False + + +def get_pose(arm): + ok, pose = arm.get_position() + if ok != 0: + raise RuntimeError(f"Failed to get arm position: {ok}") + + translation = np.array(pose[:3]) / 1000 + eulers = np.array(pose[3:]) + rotation = t3d.euler.euler2mat(eulers[0], eulers[1], eulers[2], "sxyz") + pose = t3d.affines.compose(translation, rotation, np.ones(3)) + return pose + + +def servo(arm, pose): + x = pose[0, 3] * 1000 + y = pose[1, 3] * 1000 + z = pose[2, 3] * 1000 + roll, pitch, yaw = t3d.euler.mat2euler(pose[:3, :3]) + error = arm.set_servo_cartesian([x, y, z, roll, pitch, yaw], speed=100, mvacc=100) + return error + + +def main(): + arm = XArmAPI("192.168.1.184", is_radian=True) + arm.connect() + arm.motion_enable(enable=True) + arm.set_mode(1) + arm.set_state(state=0) + + gripper = Lite6Gripper(arm) + + def callback(pose, message): + servo(arm, pose) + + gripper_state = 1.0 if message['gripper'] == "close" else 0.0 + gripper.set_gripper_state(gripper_state) + + teleop = Teleop(natural_phone_orientation_euler=[0, 0, 0]) + teleop.set_pose(get_pose(arm)) + teleop.subscribe(callback) + teleop.run() + + +if __name__ == "__main__": + main() diff --git a/tests/test_pose_compounding.py b/tests/test_pose_compounding.py index daf8d30..e73c656 100644 --- a/tests/test_pose_compounding.py +++ b/tests/test_pose_compounding.py @@ -1,8 +1,9 @@ import unittest import threading import time -import socketio +import json import ssl +import websocket from teleop import Teleop @@ -15,7 +16,7 @@ def get_message(): } -BASE_URL = "https://localhost:4443" +BASE_URL = "ws://localhost:4443/ws" class TestPoseCompounding(unittest.TestCase): @@ -25,6 +26,9 @@ def setUpClass(cls): cls.__last_pose = None cls.__last_message = None cls.__callback_event = threading.Event() + cls.__ws = None + cls.__ws_connected = threading.Event() # Use Event for connection tracking + cls.__ws_error = None def callback(pose, message): cls.__last_pose = pose @@ -40,34 +44,64 @@ def callback(pose, message): time.sleep(3) - cls.sio = socketio.Client(ssl_verify=False, logger=False, engineio_logger=False) - - @cls.sio.event - def connect(): - print("Connected to server") - - @cls.sio.event - def disconnect(): - print("Disconnected from server") - - @cls.sio.event - def connect_error(data): - print(f"Connection error: {data}") - - max_retries = 3 - for attempt in range(max_retries): + # WebSocket event handlers + def on_message(ws, message): try: - cls.sio.connect( - BASE_URL, transports=["polling"], wait_timeout=10, retry=True - ) - break - except Exception as e: - print(f"Connection attempt {attempt + 1} failed: {e}") - if attempt == max_retries - 1: - raise - time.sleep(2) - - time.sleep(2) + data = json.loads(message) + print(f"Received from server: {data}") + except json.JSONDecodeError: + print(f"Received non-JSON message: {message}") + + def on_error(ws, error): + print(f"WebSocket error: {error}") + cls.__ws_error = error + cls.__ws_connected.clear() + + def on_close(ws, close_status_code, close_msg): + print(f"WebSocket connection closed: {close_status_code} - {close_msg}") + cls.__ws_connected.clear() + + def on_open(ws): + print("WebSocket connection opened") + cls.__ws_connected.set() # Signal that connection is established + + # Create WebSocket connection + websocket.enableTrace(True) # Enable for debugging + + # Use wss:// for HTTPS or ws:// for HTTP + url = BASE_URL.replace("ws://", "wss://") if "4443" in BASE_URL else BASE_URL + + cls.__ws = websocket.WebSocketApp( + url, + on_open=on_open, + on_message=on_message, + on_error=on_error, + on_close=on_close + ) + + # Start WebSocket in separate thread + cls.__ws_thread = threading.Thread( + target=cls.__ws.run_forever, + kwargs={ + 'sslopt': { + "cert_reqs": ssl.CERT_NONE, + "check_hostname": False + } if url.startswith("wss://") else None + } + ) + cls.__ws_thread.daemon = True + cls.__ws_thread.start() + + # Wait for connection to be established + connection_established = cls.__ws_connected.wait(timeout=10) + if not connection_established: + if cls.__ws_error: + raise Exception(f"WebSocket connection failed: {cls.__ws_error}") + else: + raise Exception("WebSocket connection timeout") + + # Give a small buffer after connection + time.sleep(0.5) def setUp(self): self.__class__.__last_pose = None @@ -77,14 +111,35 @@ def setUp(self): def _wait_for_callback(self, timeout=10.0): return self.__callback_event.wait(timeout=timeout) + def _send_message(self, payload): + """Send message to WebSocket server""" + # Check if connection is still active + if not self.__ws_connected.is_set(): + raise Exception("WebSocket connection not active") + + if not hasattr(self.__ws, 'sock') or self.__ws.sock is None: + raise Exception("WebSocket socket is None") + + message = { + "type": "pose", + "data": payload + } + + try: + self.__ws.send(json.dumps(message)) + except Exception as e: + # Connection might have been closed, try to reconnect + self.__ws_connected.clear() + raise Exception(f"Failed to send message: {e}") + def test_response(self): - if not self.sio.connected: - self.skipTest("Socket.IO client not connected") + if not self.__ws_connected.is_set(): + self.skipTest("WebSocket client not connected") payload = get_message() print(f"Sending payload: {payload}") - self.sio.emit("pose", payload) + self._send_message(payload) if not self._wait_for_callback(timeout=10.0): self.fail("Callback was not triggered within 10 seconds") @@ -94,14 +149,14 @@ def test_response(self): ) def test_single_position_update(self): - if not self.sio.connected: - self.skipTest("Socket.IO client not connected") + if not self.__ws_connected.is_set(): + self.skipTest("WebSocket client not connected") payload = get_message() print(f"Sending first payload: {payload}") payload["move"] = True - self.sio.emit("pose", payload) + self._send_message(payload) if not self._wait_for_callback(timeout=10.0): self.fail("First callback was not triggered within 10 seconds") @@ -119,7 +174,7 @@ def test_single_position_update(self): payload["move"] = True payload["position"]["y"] = 0.05 print(f"Sending second payload: {payload}") - self.sio.emit("pose", payload) + self._send_message(payload) if not self._wait_for_callback(timeout=10.0): self.fail("Second callback was not triggered within 10 seconds") @@ -131,7 +186,7 @@ def test_single_position_update(self): payload["move"] = True payload["position"]["y"] = 0.1 print(f"Sending third payload: {payload}") - self.sio.emit("pose", payload) + self._send_message(payload) if not self._wait_for_callback(timeout=10.0): self.fail("Third callback was not triggered within 10 seconds") @@ -141,10 +196,10 @@ def test_single_position_update(self): @classmethod def tearDownClass(cls): try: - if hasattr(cls, "sio") and cls.sio.connected: - cls.sio.disconnect() + if hasattr(cls, '__ws') and cls.__ws: + cls.__ws.close() except Exception as e: - print(f"Error during disconnect: {e}") + print(f"Error during cleanup: {e}") if __name__ == "__main__":