diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml deleted file mode 100644 index 00a16eb..0000000 --- a/.github/workflows/docs.yml +++ /dev/null @@ -1,42 +0,0 @@ -name: docs - -on: - push: - branches: [main] - -permissions: - contents: write - -jobs: - docs: - runs-on: ubuntu-latest - steps: - # Check out source. - - uses: actions/checkout@v2 - with: - fetch-depth: 0 # This ensures the entire history is fetched so we can switch branches - - - name: Set up Python - uses: actions/setup-python@v1 - with: - python-version: "3.12" - - - name: Set up dependencies - run: | - sudo apt update - sudo apt install -y libsuitesparse-dev - pip install uv - uv pip install --system -e ".[dev,examples]" - uv pip install --system -r docs/requirements.txt - - # Build documentation. - - name: Building documentation - run: | - sphinx-build docs/source docs/build -b dirhtml - - # Deploy to version-dependent subdirectory. - - name: Deploy to GitHub Pages - uses: peaceiris/actions-gh-pages@v4 - with: - github_token: ${{ secrets.GITHUB_TOKEN }} - publish_dir: ./docs/build diff --git a/.github/workflows/formatting.yaml b/.github/workflows/formatting.yml similarity index 100% rename from .github/workflows/formatting.yaml rename to .github/workflows/formatting.yml diff --git a/.gitignore b/.gitignore index 6ae5297..54063c7 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ htmlcov .envrc .vite build +tmp/ \ No newline at end of file diff --git a/README.md b/README.md index 88864e5..e2789c2 100644 --- a/README.md +++ b/README.md @@ -1,32 +1,13 @@ -# `PyRoNot`: A Better Python Robot Kinematics Library +# `PyRoNot`: A Python Library for Robot KinematicsUsing Spherical Approximations -## Citation +[![Format Check](https://github.com/CoMMALab/pyronot/actions/workflows/formatting.yml/badge.svg)](https://github.com/CoMMALab/pyronot/actions/workflows/formatting.yml) +[![Pyright](https://github.com/CoMMALab/pyronot/actions/workflows/pyright.yml/badge.svg)](https://github.com/CoMMALab/pyronot/actions/workflows/pyright.yml) +[![Pytest](https://github.com/CoMMALab/pyronot/actions/workflows/pytest.yml/badge.svg)](https://github.com/CoMMALab/pyronot/actions/workflows/pytest.yml) +[![PyPI - Version](https://img.shields.io/pypi/v/pyronot)](https://pypi.org/project/pyronot/) -This repository is based on pyroki. - -
- Chung Min Kim*, Brent Yi*, Hongsuk Choi, Yi Ma, Ken Goldberg, Angjoo Kanazawa. - PyRoki: A Modular Toolkit for Robot Kinematic Optimization - arXiV, 2025. -
- -\*Equal Contribution, UC Berkeley. - -Please cite PyRoki if you find this work useful for your research: - -``` -@inproceedings{kim2025pyroki, - title={PyRoki: A Modular Toolkit for Robot Kinematic Optimization}, - author={Kim*, Chung Min and Yi*, Brent and Choi, Hongsuk and Ma, Yi and Goldberg, Ken and Kanazawa, Angjoo}, - booktitle={2025 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)}, - year={2025}, - url={https://arxiv.org/abs/2505.03728}, -} -``` - -Thanks! +This repository is based on [pyroki](https://github.com/chungmin99/pyroki). ## Installation ``` diff --git a/benchmark/robot_coll_benchmark.py b/benchmark/robot_coll_benchmark.py new file mode 100644 index 0000000..d3a7539 --- /dev/null +++ b/benchmark/robot_coll_benchmark.py @@ -0,0 +1,195 @@ +import jax +jax.config.update('jax_platform_name', 'cpu') +print(f"================================================") +print(f"Jax platform: {jax.lib.xla_bridge.get_backend().platform}") +print(f"================================================") +import numpy as np +import pyronot as prn +import pyroki as pk + +from pyroki.collision import Sphere, RobotCollision +from pyronot.collision import Sphere as SphereNot, RobotCollisionSpherized as RobotCollisionSpherizedNot +import yourdfpy +import pinocchio as pin +import hppfcl +import time + +np.random.seed(41) + +NUM_SAMPLES = 1000 + +SPHERE_CENTERS = [ + [0.55, 0, 0.25], + [0.35, 0.35, 0.25], + [0, 0.55, 0.25], + [-0.55, 0, 0.25], + [-0.35, -0.35, 0.25], + [0, -0.55, 0.25], + [0.35, -0.35, 0.25], + [0.35, 0.35, 0.8], + [0, 0.55, 0.8], + [-0.35, 0.35, 0.8], + [-0.55, 0, 0.8], + [-0.35, -0.35, 0.8], + [0, -0.55, 0.8], + [0.35, -0.35, 0.8], + ] + +SPHERE_R = [0.2] * len(SPHERE_CENTERS) + +urdf_path = "resources/ur5/ur5_spherized.urdf" +mesh_dir = "resources/ur5/meshes" + + + +# ============ Pinocchio Ground Truth Setup ============ +def setup_pinocchio_collision(urdf_path, mesh_dir, sphere_centers, sphere_radii): + """Setup Pinocchio collision model with environment spheres.""" + pin_model = pin.buildModelFromUrdf(urdf_path) + pin_geom_model = pin.buildGeomFromUrdf(pin_model, urdf_path, pin.COLLISION, mesh_dir) + + num_robot_geoms = len(pin_geom_model.geometryObjects) + + # Add obstacle spheres to the geometry model + for i, (center, radius) in enumerate(zip(sphere_centers, sphere_radii)): + sphere_shape = hppfcl.Sphere(radius) + placement = pin.SE3(np.eye(3), np.array(center)) + geom_obj = pin.GeometryObject( + f"obstacle_sphere_{i}", + 0, # parent frame (universe) + 0, # parent joint (universe) + sphere_shape, + placement, + ) + pin_geom_model.addGeometryObject(geom_obj) + + # Add collision pairs: robot geometries vs obstacle spheres + num_obstacles = len(sphere_centers) + for robot_geom_id in range(num_robot_geoms): + for obs_idx in range(num_obstacles): + obs_geom_id = num_robot_geoms + obs_idx + pin_geom_model.addCollisionPair( + pin.CollisionPair(robot_geom_id, obs_geom_id) + ) + + pin_data = pin_model.createData() + pin_geom_data = pin.GeometryData(pin_geom_model) + + return pin_model, pin_data, pin_geom_model, pin_geom_data + + +def check_collision_pinocchio(pin_model, pin_data, pin_geom_model, pin_geom_data, q): + """Check if configuration q is in collision using Pinocchio (ground truth).""" + pin.updateGeometryPlacements(pin_model, pin_data, pin_geom_model, pin_geom_data, q) + return pin.computeCollisions(pin_geom_model, pin_geom_data, True) + + +# Setup Pinocchio collision checker +pin_model, pin_data, pin_geom_model, pin_geom_data = setup_pinocchio_collision( + urdf_path, mesh_dir, SPHERE_CENTERS, SPHERE_R +) + +world_coll_not = SphereNot.from_center_and_radius(SPHERE_CENTERS, SPHERE_R) +world_coll = Sphere.from_center_and_radius(SPHERE_CENTERS, SPHERE_R) + +urdf = yourdfpy.URDF.load(urdf_path, mesh_dir=mesh_dir) +robot_not = prn.Robot.from_urdf(urdf) +robot = pk.Robot.from_urdf(urdf) +robot_coll = RobotCollision.from_urdf(urdf) +robot_coll_not = RobotCollisionSpherizedNot.from_urdf(urdf) + +def generate_dataset(num_samples): + q_batch = [] + robot_lower_limits = robot.joints.lower_limits + robot_upper_limits = robot.joints.upper_limits + for _ in range(num_samples): + q = np.random.uniform(robot_lower_limits, robot_upper_limits) + q_batch.append(q) + return np.array(q_batch) + +q_batch = generate_dataset(NUM_SAMPLES) +print(f"Generated {q_batch.shape[0]} samples") + +# ============ Generate Ground Truth with Pinocchio ============ +print("\n=== Generating Pinocchio Ground Truth ===") +start_time = time.time() +ground_truth = [] +for q in q_batch: + collision = check_collision_pinocchio(pin_model, pin_data, pin_geom_model, pin_geom_data, q) + ground_truth.append(collision) +ground_truth = np.array(ground_truth) +end_time = time.time() +time_taken_ms = (end_time - start_time) * 1000 +print(f"Time taken for Pinocchio ground truth (ms): {time_taken_ms:.2f}") +print(f"Time per collision check (ms): {time_taken_ms/NUM_SAMPLES:.4f}") +print(f"Collision rate: {ground_truth.sum()}/{NUM_SAMPLES} ({100*ground_truth.mean():.1f}%)") + +# ============ Benchmark pyronot collision methods ============ +print("\n=== Benchmarking pyronot Collision Methods ===") +# Warmup for JIT +q = q_batch[0] +robot_coll.at_config(robot, q) +jax.block_until_ready(robot_coll.compute_world_collision_distance(robot, q, world_coll)) +print(f"Type of robot_coll.compute_world_collision_distance: {type(robot_coll.compute_world_collision_distance)}") +try: + print(f"Capsule cache: {robot_coll.compute_world_collision_distance._cache_size()}") +except AttributeError: + print("Capsule method is NOT JIT compiled") + +robot_coll_not.at_config(robot_not, q) +jax.block_until_ready(robot_coll_not.compute_world_collision_distance(robot_not, q, world_coll_not)) +print(f"Type of robot_coll_not.compute_world_collision_distance: {type(robot_coll_not.compute_world_collision_distance)}") +try: + print(f"Sphere cache: {robot_coll_not.compute_world_collision_distance._cache_size()}") +except AttributeError: + print("Sphere method is NOT JIT compiled") +# End of warmup + +with jax.profiler.trace("./tmp/sphere_trace"): + q = q_batch[0] + result = robot_coll_not.compute_world_collision_distance(robot_not, q, world_coll_not) + jax.block_until_ready(result) + +print(f"=== Result ===") +print(f"Time taken for pinocchio for a signle collision check (ms): \t{time_taken_ms/NUM_SAMPLES:.4f}") +start_time = time.time() +for q in q_batch: + # robot_coll_not.at_config(robot_not, q) + robot_coll_not.compute_world_collision_distance(robot_not, q, world_coll_not) +end_time = time.time() +time_taken_ms = (end_time - start_time) * 1000 +print(f"Time taken for sphere for single collision check (ms): \t\t{time_taken_ms/NUM_SAMPLES:.4f}") + +start_time = time.time() +for q in q_batch: + # robot_coll.at_config(robot, q) + robot_coll.compute_world_collision_distance(robot, q, world_coll) +end_time = time.time() +time_taken_ms = (end_time - start_time) * 1000 +print(f"Time taken for capsule for single collision check (ms): \t{time_taken_ms/NUM_SAMPLES:.4f}") + +# ============ Compare with Ground Truth ============ +print("\n=== Accuracy Comparison with Ground Truth ===") + +# Check sphere model accuracy +sphere_predictions = [] +for q in q_batch: + dist = robot_coll_not.compute_world_collision_distance(robot_not, q, world_coll_not) + # Collision if any distance is negative + in_collision = (np.array(dist) < 0).any() + sphere_predictions.append(in_collision) +sphere_predictions = np.array(sphere_predictions) +sphere_accuracy = (sphere_predictions == ground_truth).mean() +print(f"Sphere model accuracy: {100*sphere_accuracy:.2f}%") + +# Check capsule model accuracy +capsule_predictions = [] +for q in q_batch: + dist = robot_coll.compute_world_collision_distance(robot, q, world_coll) + # Collision if any distance is negative + in_collision = (np.array(dist) < 0).any() + capsule_predictions.append(in_collision) +capsule_predictions = np.array(capsule_predictions) +capsule_accuracy = (capsule_predictions == ground_truth).mean() +print(f"Capsule model accuracy: {100*capsule_accuracy:.2f}%") + diff --git a/docs/Makefile b/docs/Makefile deleted file mode 100644 index aaecadc..0000000 --- a/docs/Makefile +++ /dev/null @@ -1,20 +0,0 @@ -# Minimal makefile for Sphinx documentation -# - -# You can set these variables from the command line. -SPHINXOPTS = -SPHINXBUILD = sphinx-build -SPHINXPROJ = viser -SOURCEDIR = source -BUILDDIR = ./build - -# Put it first so that "make" without argument is like "make help". -help: - @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) - -.PHONY: help Makefile - -# Catch-all target: route all unknown targets to Sphinx using the new -# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). -%: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/README.md b/docs/README.md deleted file mode 100644 index f039c3e..0000000 --- a/docs/README.md +++ /dev/null @@ -1,61 +0,0 @@ -# Viser Documentation - -This directory contains the documentation for Viser. - -## Building the Documentation - -To build the documentation: - -1. Install the documentation dependencies: - - ```bash - pip install -r docs/requirements.txt - ``` - -2. Build the documentation: - - ```bash - cd docs - make html - ``` - -3. View the documentation: - - ```bash - # On macOS - open build/html/index.html - - # On Linux - xdg-open build/html/index.html - ``` - -## Contributing Screenshots - -When adding new documentation, screenshots and visual examples significantly improve user understanding. - -We need screenshots for: - -- The Getting Started guide -- GUI element examples -- Scene API visualization examples -- Customization/theming examples - -See [Contributing Visuals](./source/contributing_visuals.md) for guidelines on capturing and adding images to the documentation. - -## Documentation Structure - -- `source/` - Source files for the documentation - - `_static/` - Static files (CSS, images, etc.) - - `images/` - Screenshots and other images - - `examples/` - Example code with documentation - - `*.md` - Markdown files for documentation pages - - `conf.py` - Sphinx configuration - -## Auto-Generated Example Documentation - -Example documentation is automatically generated from the examples in the `examples/` directory using the `update_example_docs.py` script. To update the example documentation after making changes to examples: - -```bash -cd docs -python update_example_docs.py -``` diff --git a/docs/requirements.txt b/docs/requirements.txt deleted file mode 100644 index 74185c7..0000000 --- a/docs/requirements.txt +++ /dev/null @@ -1,10 +0,0 @@ -sphinx==8.0.2 -furo==2024.8.6 -docutils==0.20.1 -toml==0.10.2 -sphinxcontrib-video==0.4.1 -git+https://github.com/brentyi/sphinxcontrib-programoutput.git -git+https://github.com/brentyi/ansi.git -git+https://github.com/sphinx-contrib/googleanalytics.git - -snowballstemmer==2.2.0 # https://github.com/snowballstem/snowball/issues/229 diff --git a/docs/source/_static/basic_ik.mov b/docs/source/_static/basic_ik.mov deleted file mode 100644 index 44c3bb0..0000000 Binary files a/docs/source/_static/basic_ik.mov and /dev/null differ diff --git a/docs/source/_static/css/custom.css b/docs/source/_static/css/custom.css deleted file mode 100644 index 3472b02..0000000 --- a/docs/source/_static/css/custom.css +++ /dev/null @@ -1,4 +0,0 @@ -img.sidebar-logo { - width: 10em; - margin: 1em 0 0 0; -} diff --git a/docs/source/_static/logo.svg b/docs/source/_static/logo.svg deleted file mode 100644 index 68c28ee..0000000 --- a/docs/source/_static/logo.svg +++ /dev/null @@ -1,34 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/docs/source/_static/logo_dark.svg b/docs/source/_static/logo_dark.svg deleted file mode 100644 index 2d42e27..0000000 --- a/docs/source/_static/logo_dark.svg +++ /dev/null @@ -1,40 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/docs/source/_templates/sidebar/brand.html b/docs/source/_templates/sidebar/brand.html deleted file mode 100644 index 88151c0..0000000 --- a/docs/source/_templates/sidebar/brand.html +++ /dev/null @@ -1,41 +0,0 @@ - - {%- endif %} {%- if theme_light_logo and theme_dark_logo %} - - {%- endif %} - - {% endblock brand_content %} - - -
- - - Github - -
\ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py deleted file mode 100644 index a76de8f..0000000 --- a/docs/source/conf.py +++ /dev/null @@ -1,266 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Configuration file for the Sphinx documentation builder. -# -# This file does only contain a selection of the most common options. For a -# full list see the documentation: -# http://www.sphinx-doc.org/en/stable/config - -import os -from typing import Dict, List - -import pyronot - -# -- Path setup -------------------------------------------------------------- - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# - -# -- Project information ----------------------------------------------------- - -project = "pyronot" # Change project name -copyright = "2025" # Update copyright year/holder if needed -author = "chungmin99" # Update author name - -version: str = os.environ.get( - "PYRONOT_VERSION_STR_OVERRIDE", pyronot.__version__ -) # Remove this - -# Formatting! -# 0.1.30 => v0.1.30 -# dev => dev -if not version.isalpha(): - version = "v" + version - -# The full version, including alpha/beta/rc tags -release = version # Use the same version for release for now - - -# -- General configuration --------------------------------------------------- - -# If your documentation needs a minimal Sphinx version, state it here. -# -# needs_sphinx = '1.0' - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -extensions = [ - "sphinx.ext.autodoc", - "sphinx.ext.todo", - "sphinx.ext.coverage", - "sphinx.ext.mathjax", - "sphinx.ext.githubpages", - "sphinx.ext.napoleon", - # "sphinx.ext.inheritance_diagram", - "sphinxcontrib.video", - "sphinx.ext.viewcode", - "sphinxcontrib.programoutput", - "sphinxcontrib.ansi", - # "sphinxcontrib.googleanalytics", # google analytics extension https://github.com/sphinx-contrib/googleanalytics/tree/master -] -programoutput_use_ansi = True -html_ansi_stylesheet = "black-on-white.css" -html_static_path = ["_static"] -html_theme_options = { - "light_css_variables": { - "color-code-background": "#f4f4f4", - "color-code-foreground": "#000", - }, - # Remove viser-specific footer icon - "footer_icons": [ - { - "name": "GitHub", - "url": "https://github.com/chungmin99/pyronot-dev", - "html": """ - - - - """, - "class": "", - }, - ], - # Remove viser-specific logos - "light_logo": "logo.svg", - "dark_logo": "logo_dark.svg", -} - -# Pull documentation types from hints -autodoc_typehints = "both" -autodoc_class_signature = "separated" -autodoc_default_options = { - "members": True, - "member-order": "bysource", - "undoc-members": True, - "inherited-members": True, - "exclude-members": "__init__, __post_init__", - "imported-members": True, -} - -# Add any paths that contain templates here, relative to this directory. -templates_path = ["_templates"] - -# The suffix(es) of source filenames. -# You can specify multiple suffix as a list of string: -# -source_suffix = ".rst" -# source_suffix = ".rst" - -# The master toctree document. -master_doc = "index" - -# The language for content autogenerated by Sphinx. Refer to documentation -# for a list of supported languages. -# -# This is also used if you do content translation via gettext catalogs. -# Usually you set "language" from the command line for these cases. -language: str = "en" - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This pattern also affects html_static_path and html_extra_path . -exclude_patterns: List[str] = [] - -# The name of the Pygments (syntax highlighting) style to use. -pygments_style = "default" - - -# -- Options for HTML output ------------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -html_theme = "furo" -html_title = "pyronot" # Update title - - -# Theme options are theme-specific and customize the look and feel of a theme -# further. For a list of options available for each theme, see the -# documentation. -# -# html_theme_options = {} - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -# html_static_path = ["_static"] - -# Custom sidebar templates, must be a dictionary that maps document names -# to template names. -# -# The default sidebars (for documents that don't match any pattern) are -# defined by theme itself. Builtin themes are using these templates by -# default: ``['localtoc.html', 'relations.html', 'sourcelink.html', -# 'searchbox.html']``. -# -# html_sidebars = {} - - -# -- Options for HTMLHelp output --------------------------------------------- - -# Output file base name for HTML help builder. -htmlhelp_basename = "pyronot_doc" # Update basename - - -# -- Options for Github output ------------------------------------------------ - -sphinx_to_github = True -sphinx_to_github_verbose = True -sphinx_to_github_encoding = "utf-8" - - -# -- Options for LaTeX output ------------------------------------------------ - -latex_elements: Dict[str, str] = { - # The paper size ('letterpaper' or 'a4paper'). - # - # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). - # - # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. - # - # 'preamble': '', - # Latex figure (float) alignment - # - # 'figure_align': 'htbp', -} - -# Grouping the document tree into LaTeX files. List of tuples -# (source start file, target name, title, -# author, documentclass [howto, manual, or own class]). -latex_documents = [ - ( - master_doc, - "pyronot.tex", # Update target name - "pyronot", # Update title - "Your Name", # Update author - "manual", - ), -] - - -# -- Options for manual page output ------------------------------------------ - -# One entry per manual page. List of tuples -# (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, "pyronot", "pyronot documentation", [author], 1) -] # Update name and description - - -# -- Options for Texinfo output ---------------------------------------------- - -# Grouping the document tree into Texinfo files. List of tuples -# (source start file, target name, title, author, -# dir menu entry, description, category) -texinfo_documents = [ - ( - master_doc, - "pyronot", # Update target name - "pyronot", # Update title - author, - "pyronot", # Update dir menu entry - "Python Robot Kinematics library", # Update description - "Miscellaneous", - ), -] - - -# -- Extension configuration -------------------------------------------------- - -# Google Analytics ID -# googleanalytics_id = "G-RRGY51J5ZH" # Remove this - -# -- Options for todo extension ---------------------------------------------- - -# If true, `todo` and `todoList` produce output, else they produce nothing. -todo_include_todos = True - -# -- Setup function ---------------------------------------- - - -def setup(app): - app.add_css_file("css/custom.css") - - -# -- Napoleon settings ------------------------------------------------------- - -# Settings for parsing non-sphinx style docstrings. We use Google style in this -# project. -napoleon_google_docstring = True -napoleon_numpy_docstring = False -napoleon_include_init_with_doc = False -napoleon_include_private_with_doc = False -napoleon_include_special_with_doc = True -napoleon_use_admonition_for_examples = False -napoleon_use_admonition_for_notes = False -napoleon_use_admonition_for_references = False -napoleon_use_ivar = False -napoleon_use_param = True -napoleon_use_rtype = True -napoleon_preprocess_types = True -napoleon_type_aliases = None -napoleon_attr_annotations = True diff --git a/docs/source/examples/01_basic_ik.rst b/docs/source/examples/01_basic_ik.rst deleted file mode 100644 index 660f6c9..0000000 --- a/docs/source/examples/01_basic_ik.rst +++ /dev/null @@ -1,68 +0,0 @@ -.. Comment: this file is automatically generated by `update_example_docs.py`. - It should not be modified manually. - -Basic IK -========================================== - - -Simplest Inverse Kinematics Example using PyRoNot. - -All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details. - - - -.. code-block:: python - :linenos: - - - import time - - import numpy as np - import pyronot as pk - import viser - from robot_descriptions.loaders.yourdfpy import load_robot_description - from viser.extras import ViserUrdf - - import pyronot_snippets as pks - - - def main(): - """Main function for basic IK.""" - - urdf = load_robot_description("panda_description") - target_link_name = "panda_hand" - - # Create robot. - robot = pk.Robot.from_urdf(urdf) - - # Set up visualizer. - server = viser.ViserServer() - server.scene.add_grid("/ground", width=2, height=2) - urdf_vis = ViserUrdf(server, urdf, root_node_name="/base") - - # Create interactive controller with initial position. - ik_target = server.scene.add_transform_controls( - "/ik_target", scale=0.2, position=(0.61, 0.0, 0.56), wxyz=(0, 0, 1, 0) - ) - timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True) - - while True: - # Solve IK. - start_time = time.time() - solution = pks.solve_ik( - robot=robot, - target_link_name=target_link_name, - target_position=np.array(ik_target.position), - target_wxyz=np.array(ik_target.wxyz), - ) - - # Update timing handle. - elapsed_time = time.time() - start_time - timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000) - - # Update visualizer. - urdf_vis.update_cfg(solution) - - - if __name__ == "__main__": - main() diff --git a/docs/source/examples/02_bimanual_ik.rst b/docs/source/examples/02_bimanual_ik.rst deleted file mode 100644 index 9bdc034..0000000 --- a/docs/source/examples/02_bimanual_ik.rst +++ /dev/null @@ -1,70 +0,0 @@ -.. Comment: this file is automatically generated by `update_example_docs.py`. - It should not be modified manually. - -Bimanual IK -========================================== - - -Same as 01_basic_ik.py, but with two end effectors! - -All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details. - - - -.. code-block:: python - :linenos: - - - import time - import viser - from robot_descriptions.loaders.yourdfpy import load_robot_description - import numpy as np - - import pyronot as pk - from viser.extras import ViserUrdf - import pyronot_snippets as pks - - - def main(): - """Main function for bimanual IK.""" - - urdf = load_robot_description("yumi_description") - target_link_names = ["yumi_link_7_r", "yumi_link_7_l"] - - # Create robot. - robot = pk.Robot.from_urdf(urdf) - - # Set up visualizer. - server = viser.ViserServer() - server.scene.add_grid("/ground", width=2, height=2) - urdf_vis = ViserUrdf(server, urdf, root_node_name="/base") - - # Create interactive controller with initial position. - ik_target_0 = server.scene.add_transform_controls( - "/ik_target_0", scale=0.2, position=(0.41, -0.3, 0.56), wxyz=(0, 0, 1, 0) - ) - ik_target_1 = server.scene.add_transform_controls( - "/ik_target_1", scale=0.2, position=(0.41, 0.3, 0.56), wxyz=(0, 0, 1, 0) - ) - timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True) - - while True: - # Solve IK. - start_time = time.time() - solution = pks.solve_ik_with_multiple_targets( - robot=robot, - target_link_names=target_link_names, - target_positions=np.array([ik_target_0.position, ik_target_1.position]), - target_wxyzs=np.array([ik_target_0.wxyz, ik_target_1.wxyz]), - ) - - # Update timing handle. - elapsed_time = time.time() - start_time - timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000) - - # Update visualizer. - urdf_vis.update_cfg(solution) - - - if __name__ == "__main__": - main() diff --git a/docs/source/examples/03_mobile_ik.rst b/docs/source/examples/03_mobile_ik.rst deleted file mode 100644 index 128bda2..0000000 --- a/docs/source/examples/03_mobile_ik.rst +++ /dev/null @@ -1,79 +0,0 @@ -.. Comment: this file is automatically generated by `update_example_docs.py`. - It should not be modified manually. - -Mobile IK -========================================== - - -Same as 01_basic_ik.py, but with a mobile base! - -All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details. - - - -.. code-block:: python - :linenos: - - - import time - import viser - from robot_descriptions.loaders.yourdfpy import load_robot_description - import numpy as np - - import pyronot as pk - from viser.extras import ViserUrdf - import pyronot_snippets as pks - - - def main(): - """Main function for IK with a mobile base. - The base is fixed along the xy plane, and is biased towards being at the origin. - """ - - urdf = load_robot_description("fetch_description") - target_link_name = "gripper_link" - - # Create robot. - robot = pk.Robot.from_urdf(urdf) - - # Set up visualizer. - server = viser.ViserServer() - server.scene.add_grid("/ground", width=2, height=2) - base_frame = server.scene.add_frame("/base", show_axes=False) - urdf_vis = ViserUrdf(server, urdf, root_node_name="/base") - - # Create interactive controller with initial position. - ik_target = server.scene.add_transform_controls( - "/ik_target", scale=0.2, position=(0.61, 0.0, 0.56), wxyz=(0, 0.707, 0, -0.707) - ) - timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True) - - cfg = np.array(robot.joint_var_cls(0).default_factory()) - - while True: - # Solve IK. - start_time = time.time() - base_pos, base_wxyz, cfg = pks.solve_ik_with_base( - robot=robot, - target_link_name=target_link_name, - target_position=np.array(ik_target.position), - target_wxyz=np.array(ik_target.wxyz), - fix_base_position=(False, False, True), # Only free along xy plane. - fix_base_orientation=(True, True, False), # Free along z-axis rotation. - prev_pos=base_frame.position, - prev_wxyz=base_frame.wxyz, - prev_cfg=cfg, - ) - - # Update timing handle. - elapsed_time = time.time() - start_time - timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000) - - # Update visualizer. - urdf_vis.update_cfg(cfg) - base_frame.position = np.array(base_pos) - base_frame.wxyz = np.array(base_wxyz) - - - if __name__ == "__main__": - main() diff --git a/docs/source/examples/04_ik_with_coll.rst b/docs/source/examples/04_ik_with_coll.rst deleted file mode 100644 index f879355..0000000 --- a/docs/source/examples/04_ik_with_coll.rst +++ /dev/null @@ -1,88 +0,0 @@ -.. Comment: this file is automatically generated by `update_example_docs.py`. - It should not be modified manually. - -IK with Collision -========================================== - - -Basic Inverse Kinematics with Collision Avoidance using PyRoNot. - -All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details. - - - -.. code-block:: python - :linenos: - - - import time - - import numpy as np - import pyronot as pk - import viser - from pyronot.collision import HalfSpace, RobotCollision, Sphere - from robot_descriptions.loaders.yourdfpy import load_robot_description - from viser.extras import ViserUrdf - - import pyronot_snippets as pks - - - def main(): - """Main function for basic IK with collision.""" - urdf = load_robot_description("panda_description") - target_link_name = "panda_hand" - robot = pk.Robot.from_urdf(urdf) - - robot_coll = RobotCollision.from_urdf(urdf) - plane_coll = HalfSpace.from_point_and_normal( - np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0]) - ) - sphere_coll = Sphere.from_center_and_radius( - np.array([0.0, 0.0, 0.0]), np.array([0.05]) - ) - - # Set up visualizer. - server = viser.ViserServer() - server.scene.add_grid("/ground", width=2, height=2, cell_size=0.1) - urdf_vis = ViserUrdf(server, urdf, root_node_name="/robot") - - # Create interactive controller for IK target. - ik_target_handle = server.scene.add_transform_controls( - "/ik_target", scale=0.2, position=(0.5, 0.0, 0.5), wxyz=(0, 0, 1, 0) - ) - - # Create interactive controller and mesh for the sphere obstacle. - sphere_handle = server.scene.add_transform_controls( - "/obstacle", scale=0.2, position=(0.4, 0.3, 0.4) - ) - server.scene.add_mesh_trimesh("/obstacle/mesh", mesh=sphere_coll.to_trimesh()) - - timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True) - - while True: - start_time = time.time() - - sphere_coll_world_current = sphere_coll.transform_from_wxyz_position( - wxyz=np.array(sphere_handle.wxyz), - position=np.array(sphere_handle.position), - ) - - world_coll_list = [plane_coll, sphere_coll_world_current] - solution = pks.solve_ik_with_collision( - robot=robot, - coll=robot_coll, - world_coll_list=world_coll_list, - target_link_name=target_link_name, - target_position=np.array(ik_target_handle.position), - target_wxyz=np.array(ik_target_handle.wxyz), - ) - - # Update timing handle. - timing_handle.value = (time.time() - start_time) * 1000 - - # Update visualizer. - urdf_vis.update_cfg(solution) - - - if __name__ == "__main__": - main() diff --git a/docs/source/examples/05_ik_with_manipulability.rst b/docs/source/examples/05_ik_with_manipulability.rst deleted file mode 100644 index f8bd331..0000000 --- a/docs/source/examples/05_ik_with_manipulability.rst +++ /dev/null @@ -1,81 +0,0 @@ -.. Comment: this file is automatically generated by `update_example_docs.py`. - It should not be modified manually. - -IK with Manipulability -========================================== - - -Inverse Kinematics with Manipulability using PyRoNot. - -All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details. - - - -.. code-block:: python - :linenos: - - - import time - import viser - from robot_descriptions.loaders.yourdfpy import load_robot_description - import numpy as np - - import pyronot as pk - from viser.extras import ViserUrdf - import pyronot_snippets as pks - - - def main(): - """Main function for basic IK.""" - - urdf = load_robot_description("panda_description") - target_link_name = "panda_hand" - - # Create robot. - robot = pk.Robot.from_urdf(urdf) - - # Set up visualizer. - server = viser.ViserServer() - server.scene.add_grid("/ground", width=2, height=2) - urdf_vis = ViserUrdf(server, urdf, root_node_name="/base") - - # Create interactive controller with initial position. - ik_target = server.scene.add_transform_controls( - "/ik_target", scale=0.2, position=(0.61, 0.0, 0.56), wxyz=(0, 0, 1, 0) - ) - timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True) - value_handle = server.gui.add_number("Yoshikawa Index", 0.001, disabled=True) - weight_handle = server.gui.add_slider( - "Manipulability Weight", 0.0, 10.0, 0.001, 0.0 - ) - manip_ellipse = pk.viewer.ManipulabilityEllipse( - server, - robot, - root_node_name="/manipulability", - target_link_name=target_link_name, - ) - - while True: - # Solve IK. - start_time = time.time() - solution = pks.solve_ik_with_manipulability( - robot=robot, - target_link_name=target_link_name, - target_position=np.array(ik_target.position), - target_wxyz=np.array(ik_target.wxyz), - manipulability_weight=weight_handle.value, - ) - - manip_ellipse.update(solution) - value_handle.value = manip_ellipse.manipulability - - # Update timing handle. - elapsed_time = time.time() - start_time - timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000) - - # Update visualizer. - urdf_vis.update_cfg(solution) - - - if __name__ == "__main__": - main() diff --git a/docs/source/examples/06_online_planning.rst b/docs/source/examples/06_online_planning.rst deleted file mode 100644 index 9177048..0000000 --- a/docs/source/examples/06_online_planning.rst +++ /dev/null @@ -1,119 +0,0 @@ -.. Comment: this file is automatically generated by `update_example_docs.py`. - It should not be modified manually. - -Online Planning -========================================== - - -Run online planning in collision aware environments. - -All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details. - - - -.. code-block:: python - :linenos: - - - import time - - import numpy as np - import pyronot as pk - import viser - from pyronot.collision import HalfSpace, RobotCollision, Sphere - from robot_descriptions.loaders.yourdfpy import load_robot_description - from viser.extras import ViserUrdf - - import pyronot_snippets as pks - - - def main(): - """Main function for online planning with collision.""" - urdf = load_robot_description("panda_description") - target_link_name = "panda_hand" - robot = pk.Robot.from_urdf(urdf) - - robot_coll = RobotCollision.from_urdf(urdf) - plane_coll = HalfSpace.from_point_and_normal( - np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0]) - ) - sphere_coll = Sphere.from_center_and_radius( - np.array([0.0, 0.0, 0.0]), np.array([0.05]) - ) - - # Define the online planning parameters. - len_traj, dt = 5, 0.1 - - # Set up visualizer. - server = viser.ViserServer() - server.scene.add_grid("/ground", width=2, height=2, cell_size=0.1) - urdf_vis = ViserUrdf(server, urdf, root_node_name="/robot") - - # Create interactive controller for IK target. - ik_target_handle = server.scene.add_transform_controls( - "/ik_target", scale=0.2, position=(0.3, 0.0, 0.5), wxyz=(0, 0, 1, 0) - ) - - # Create interactive controller and mesh for the sphere obstacle. - sphere_handle = server.scene.add_transform_controls( - "/obstacle", scale=0.2, position=(0.4, 0.3, 0.4) - ) - server.scene.add_mesh_trimesh("/obstacle/mesh", mesh=sphere_coll.to_trimesh()) - target_frame_handle = server.scene.add_batched_axes( - "target_frame", - axes_length=0.05, - axes_radius=0.005, - batched_positions=np.zeros((25, 3)), - batched_wxyzs=np.array([[1.0, 0.0, 0.0, 0.0]] * 25), - ) - - timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True) - - sol_pos, sol_wxyz = None, None - sol_traj = np.array( - robot.joint_var_cls.default_factory()[None].repeat(len_traj, axis=0) - ) - while True: - start_time = time.time() - - sphere_coll_world_current = sphere_coll.transform_from_wxyz_position( - wxyz=np.array(sphere_handle.wxyz), - position=np.array(sphere_handle.position), - ) - - world_coll_list = [plane_coll, sphere_coll_world_current] - sol_traj, sol_pos, sol_wxyz = pks.solve_online_planning( - robot=robot, - robot_coll=robot_coll, - world_coll=world_coll_list, - target_link_name=target_link_name, - target_position=np.array(ik_target_handle.position), - target_wxyz=np.array(ik_target_handle.wxyz), - timesteps=len_traj, - dt=dt, - start_cfg=sol_traj[0], - prev_sols=sol_traj, - ) - - # Update timing handle. - timing_handle.value = ( - 0.99 * timing_handle.value + 0.01 * (time.time() - start_time) * 1000 - ) - - # Update visualizer. - urdf_vis.update_cfg( - sol_traj[0] - ) # The first step of the online trajectory solution. - - # Update the planned trajectory visualization. - if hasattr(target_frame_handle, "batched_positions"): - target_frame_handle.batched_positions = np.array(sol_pos) # type: ignore[attr-defined] - target_frame_handle.batched_wxyzs = np.array(sol_wxyz) # type: ignore[attr-defined] - else: - # This is an older version of Viser. - target_frame_handle.positions_batched = np.array(sol_pos) # type: ignore[attr-defined] - target_frame_handle.wxyzs_batched = np.array(sol_wxyz) # type: ignore[attr-defined] - - - if __name__ == "__main__": - main() diff --git a/docs/source/examples/07_trajopt.rst b/docs/source/examples/07_trajopt.rst deleted file mode 100644 index 99e4348..0000000 --- a/docs/source/examples/07_trajopt.rst +++ /dev/null @@ -1,138 +0,0 @@ -.. Comment: this file is automatically generated by `update_example_docs.py`. - It should not be modified manually. - -Trajectory Optimization -========================================== - - -Basic Trajectory Optimization using PyRoNot. - -Robot going over a wall, while avoiding world-collisions. - -All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details. - - - -.. code-block:: python - :linenos: - - - import time - from typing import Literal - - import numpy as np - import pyronot as pk - import trimesh - import tyro - import viser - from viser.extras import ViserUrdf - from robot_descriptions.loaders.yourdfpy import load_robot_description - - import pyronot_snippets as pks - - - def main(robot_name: Literal["ur5", "panda"] = "panda"): - if robot_name == "ur5": - urdf = load_robot_description("ur5_description") - down_wxyz = np.array([0.707, 0, 0.707, 0]) - target_link_name = "ee_link" - - # For UR5 it's important to initialize the robot in a safe configuration; - # the zero-configuration puts the robot aligned with the wall obstacle. - default_cfg = np.zeros(6) - default_cfg[1] = -1.308 - robot = pk.Robot.from_urdf(urdf, default_joint_cfg=default_cfg) - - elif robot_name == "panda": - urdf = load_robot_description("panda_description") - target_link_name = "panda_hand" - down_wxyz = np.array([0, 0, 1, 0]) # for panda! - robot = pk.Robot.from_urdf(urdf) - - else: - raise ValueError(f"Invalid robot: {robot_name}") - - robot_coll = pk.collision.RobotCollision.from_urdf(urdf) - - # Define the trajectory problem: - # - number of timesteps, timestep size - timesteps, dt = 25, 0.02 - # - the start and end poses. - start_pos, end_pos = np.array([0.5, -0.3, 0.2]), np.array([0.5, 0.3, 0.2]) - - # Define the obstacles: - # - Ground - ground_coll = pk.collision.HalfSpace.from_point_and_normal( - np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0]) - ) - # - Wall - wall_height = 0.4 - wall_width = 0.1 - wall_length = 0.4 - wall_intervals = np.arange(start=0.3, stop=wall_length + 0.3, step=0.05) - translation = np.concatenate( - [ - wall_intervals.reshape(-1, 1), - np.full((wall_intervals.shape[0], 1), 0.0), - np.full((wall_intervals.shape[0], 1), wall_height / 2), - ], - axis=1, - ) - wall_coll = pk.collision.Capsule.from_radius_height( - position=translation, - radius=np.full((translation.shape[0], 1), wall_width / 2), - height=np.full((translation.shape[0], 1), wall_height), - ) - world_coll = [ground_coll, wall_coll] - - traj = pks.solve_trajopt( - robot, - robot_coll, - world_coll, - target_link_name, - start_pos, - down_wxyz, - end_pos, - down_wxyz, - timesteps, - dt, - ) - traj = np.array(traj) - - # Visualize! - server = viser.ViserServer() - urdf_vis = ViserUrdf(server, urdf) - server.scene.add_grid("/grid", width=2, height=2, cell_size=0.1) - server.scene.add_mesh_trimesh( - "wall_box", - trimesh.creation.box( - extents=(wall_length, wall_width, wall_height), - transform=trimesh.transformations.translation_matrix( - np.array([0.5, 0.0, wall_height / 2]) - ), - ), - ) - for name, pos in zip(["start", "end"], [start_pos, end_pos]): - server.scene.add_frame( - f"/{name}", - position=pos, - wxyz=down_wxyz, - axes_length=0.05, - axes_radius=0.01, - ) - - slider = server.gui.add_slider( - "Timestep", min=0, max=timesteps - 1, step=1, initial_value=0 - ) - playing = server.gui.add_checkbox("Playing", initial_value=True) - - while True: - if playing.value: - slider.value = (slider.value + 1) % timesteps - - urdf_vis.update_cfg(traj[slider.value]) - time.sleep(1.0 / 10.0) - - - if __name__ == "__main__": - tyro.cli(main) diff --git a/docs/source/examples/08_ik_with_mimic_joints.rst b/docs/source/examples/08_ik_with_mimic_joints.rst deleted file mode 100644 index 7760754..0000000 --- a/docs/source/examples/08_ik_with_mimic_joints.rst +++ /dev/null @@ -1,145 +0,0 @@ -.. Comment: this file is automatically generated by `update_example_docs.py`. - It should not be modified manually. - -IK with Mimic Joints -========================================== - - -This is a simple test to ensure that mimic joints are handled correctly in the IK solver. - -We procedurally generate a "zig-zag" chain of links with mimic joints, where: - - -* the first joint is driven directly, -* and the remaining joints are driven indirectly via mimic joints. - The multipliers alternate between -1 and 1, and the offsets are all 0. - -All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details. - - - -.. code-block:: python - :linenos: - - - import tempfile - import time - - import numpy as np - import pyronot as pk - import viser - import yourdfpy - from viser.extras import ViserUrdf - - import pyronot_snippets as pks - - - def create_chain_xml(length: float = 0.2, num_chains: int = 5) -> str: - def create_link(idx): - return f""" - - - - - - - - - - - - """ - - def create_joint(idx, multiplier=1.0, offset=0.0): - mimic = f'' - return f""" - - - - - - {mimic if idx != 0 else ""} - - - """ - - world_joint_origin_z = length / 2.0 - xml = f""" - - - - - - - - - - - """ - # Create the definition + first link. - xml += create_link(0) - xml += create_link(1) - xml += create_joint(0) - - # Procedurally add more links. - assert num_chains >= 2 - for idx in range(2, num_chains): - xml += create_link(idx) - current_offset = 0.0 - current_multiplier = 1.0 * ((-1) ** (idx % 2)) - xml += create_joint(idx - 1, current_multiplier, current_offset) - - xml += """ - - """ - return xml - - - def main(): - """Main function for basic IK.""" - - xml = create_chain_xml(num_chains=10, length=0.1) - with tempfile.NamedTemporaryFile(mode="w", suffix=".urdf") as f: - f.write(xml) - f.flush() - urdf = yourdfpy.URDF.load(f.name) - - # Create robot. - robot = pk.Robot.from_urdf(urdf) - - # Set up visualizer. - server = viser.ViserServer() - server.scene.add_grid("/ground", width=2, height=2) - urdf_vis = ViserUrdf(server, urdf, root_node_name="/base") - target_link_name_handle = server.gui.add_dropdown( - "Target Link", - robot.links.names, - initial_value=robot.links.names[-1], - ) - - # Create interactive controller with initial position. - ik_target = server.scene.add_transform_controls( - "/ik_target", scale=0.2, position=(0.0, 0.1, 0.1), wxyz=(0, 0, 1, 0) - ) - timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True) - - while True: - # Solve IK. - start_time = time.time() - solution = pks.solve_ik( - robot=robot, - target_link_name=target_link_name_handle.value, - target_position=np.array(ik_target.position), - target_wxyz=np.array(ik_target.wxyz), - ) - - # Update timing handle. - elapsed_time = time.time() - start_time - timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000) - - # Update visualizer. - urdf_vis.update_cfg(solution) - - - if __name__ == "__main__": - main() diff --git a/docs/source/examples/09_hand_retargeting.rst b/docs/source/examples/09_hand_retargeting.rst deleted file mode 100644 index a1a3c7f..0000000 --- a/docs/source/examples/09_hand_retargeting.rst +++ /dev/null @@ -1,376 +0,0 @@ -.. Comment: this file is automatically generated by `update_example_docs.py`. - It should not be modified manually. - -Hand Retargeting -========================================== - - -Simpler shadow hand retargeting example. -Find and unzip the shadowhand URDF at ``assets/hand_retargeting/shadowhand_urdf.zip``. - -All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details. - - - -.. code-block:: python - :linenos: - - - import pickle - import time - from pathlib import Path - from typing import Tuple, TypedDict - - import jax - import jax.numpy as jnp - import jax_dataclasses as jdc - import jaxlie - import jaxls - import numpy as onp - import pyronot as pk - import trimesh - import viser - import yourdfpy - from scipy.spatial.transform import Rotation as R - from viser.extras import ViserUrdf - - from retarget_helpers._utils import ( - MANO_TO_SHADOW_MAPPING, - create_conn_tree, - get_mapping_from_mano_to_shadow, - ) - - - class RetargetingWeights(TypedDict): - local_alignment: float - """Local alignment weight, by matching the relative joint/keypoint positions and angles.""" - global_alignment: float - """Global alignment weight, by matching the keypoint positions to the robot.""" - joint_smoothness: float - """Joint smoothness weight.""" - root_smoothness: float - """Root translation smoothness weight.""" - - - def main(): - """Main function for hand retargeting.""" - - asset_dir = Path(__file__).parent / "retarget_helpers" / "hand" - - robot_urdf_path = asset_dir / "shadowhand_urdf" / "shadow_hand_right.urdf" - - def filename_handler(fname: str) -> str: - base_path = robot_urdf_path.parent - return yourdfpy.filename_handler_magic(fname, dir=base_path) - - try: - urdf = yourdfpy.URDF.load(robot_urdf_path, filename_handler=filename_handler) - except FileNotFoundError: - raise FileNotFoundError( - "Please unzip the included URDF at `retarget_helpers/hand/shadowhand_urdf.zip`." - ) - - robot = pk.Robot.from_urdf(urdf) - - # Get the mapping from MANO to Shadow Hand joints. - shadow_link_idx, mano_joint_idx = get_mapping_from_mano_to_shadow(robot) - - # Create a mask for the MANO joints that are connected to the Shadow Hand. - mano_mask = create_conn_tree(robot, shadow_link_idx) - - # Load source motion data. - dexycb_motion_path = asset_dir / "dexycb_motion.pkl" - with open(dexycb_motion_path, "rb") as f: - dexycb_motion_data = pickle.load(f, encoding="latin1") - - # Load keypoints. - keypoints = dexycb_motion_data["world_hand_joints"] - assert not onp.isnan(keypoints).any() - num_timesteps = keypoints.shape[0] - num_mano_joints = len(MANO_TO_SHADOW_MAPPING) - - # Load mano hand contact information -- these are lists of lists, - # len(contact_points_per_frame) = num_timesteps, - # len(contact_points_per_frame[i]) = number of contacts in frame i, - contact_points_per_frame = dexycb_motion_data["contact_object_points"] - contact_indices_per_frame = dexycb_motion_data["contact_joint_indices"] - - # Now, we're going to pad this info + make a mask to indicate the padded regions. - # We will also track the shadowhand joint indices, NOT the MANO joint indices. - max_num_contacts = max(len(c) for c in contact_points_per_frame) - padded_contact_points_per_frame = onp.zeros((num_timesteps, max_num_contacts, 3)) - padded_contact_indices_per_frame = onp.zeros( - (num_timesteps, max_num_contacts), dtype=onp.int32 - ) - padded_contact_mask = onp.zeros((num_timesteps, max_num_contacts), dtype=onp.bool_) - for i in range(num_timesteps): - num_contacts = len(contact_points_per_frame[i]) - if num_contacts == 0: - continue - contact_shadowhand_indices = [ - robot.links.names.index(MANO_TO_SHADOW_MAPPING[j]) - for j in contact_indices_per_frame[i] - ] - padded_contact_points_per_frame[i, :num_contacts] = contact_points_per_frame[i] - padded_contact_indices_per_frame[i, :num_contacts] = contact_shadowhand_indices - padded_contact_mask[i, :num_contacts] = True - - # Load the object. - object_mesh_vertices = dexycb_motion_data["object_mesh_vertices"] - object_mesh_faces = dexycb_motion_data["object_mesh_faces"] - object_pose_list = dexycb_motion_data["object_poses"] # (N, 4, 4) - mesh = trimesh.Trimesh(object_mesh_vertices, object_mesh_faces) - - server = viser.ViserServer() - - # We will transform everything by the transform below, for aesthetics. - server.scene.add_frame( - "/scene_offset", - show_axes=False, - position=(-0.15415953, -0.73598871, 0.93434792), - wxyz=(-0.381870867, 0.92421569, 0.0, 2.0004992e-32), - ) - hand_mesh = server.scene.add_mesh_simple( - "/scene_offset/hand_mesh", - vertices=dexycb_motion_data["world_hand_vertices"][0, :, :], - faces=dexycb_motion_data["hand_mesh_faces"], - opacity=0.5, - ) - base_frame = server.scene.add_frame("/scene_offset/base", show_axes=False) - urdf_vis = ViserUrdf(server, urdf, root_node_name="/scene_offset/base") - playing = server.gui.add_checkbox("playing", True) - timestep_slider = server.gui.add_slider("timestep", 0, num_timesteps - 1, 1, 0) - object_handle = server.scene.add_mesh_trimesh("/scene_offset/object", mesh) - server.scene.add_grid("/grid", 2.0, 2.0) - - default_weights = RetargetingWeights( - local_alignment=10.0, - global_alignment=1.0, - joint_smoothness=2.0, - root_smoothness=2.0, - ) - - weights = pk.viewer.WeightTuner( - server, - default_weights, # type: ignore - ) - - Ts_world_root, joints = None, None - - def generate_trajectory(): - nonlocal Ts_world_root, joints - gen_button.disabled = True - Ts_world_root, joints = solve_retargeting( - robot=robot, - target_keypoints=keypoints, - shadow_hand_link_retarget_indices=shadow_link_idx, - mano_joint_retarget_indices=mano_joint_idx, - mano_mask=mano_mask, - weights=weights.get_weights(), # type: ignore - ) - gen_button.disabled = False - - gen_button = server.gui.add_button("Retarget!") - gen_button.on_click(lambda _: generate_trajectory()) - - generate_trajectory() - assert Ts_world_root is not None and joints is not None - - while True: - with server.atomic(): - if playing.value: - timestep_slider.value = (timestep_slider.value + 1) % num_timesteps - tstep = timestep_slider.value - base_frame.wxyz = onp.array(Ts_world_root.wxyz_xyz[tstep][:4]) - base_frame.position = onp.array(Ts_world_root.wxyz_xyz[tstep][4:]) - urdf_vis.update_cfg(onp.array(joints[tstep])) - - server.scene.add_point_cloud( - "/scene_offset/target_keypoints", - onp.array(keypoints[tstep]).reshape(-1, 3), - onp.array((0, 0, 255))[None] - .repeat(num_mano_joints, axis=0) - .reshape(-1, 3), - point_size=0.005, - point_shape="sparkle", - ) - server.scene.add_point_cloud( - "/scene_offset/contact_points", - onp.array(contact_points_per_frame[tstep]).reshape(-1, 3), - onp.array((255, 0, 0))[None] - .repeat(len(contact_points_per_frame[tstep]), axis=0) - .reshape(-1, 3), - point_size=0.005, - point_shape="circle", - ) - hand_mesh.vertices = dexycb_motion_data["world_hand_vertices"][tstep, :, :] - object_handle.position = object_pose_list[tstep][:3, 3] - object_handle.wxyz = R.from_matrix(object_pose_list[tstep][:3, :3]).as_quat( - scalar_first=True - ) - - time.sleep(0.05) - - - @jdc.jit - def solve_retargeting( - robot: pk.Robot, - target_keypoints: jnp.ndarray, - shadow_hand_link_retarget_indices: jnp.ndarray, - mano_joint_retarget_indices: jnp.ndarray, - mano_mask: jnp.ndarray, - weights: RetargetingWeights, - ) -> Tuple[jaxlie.SE3, jnp.ndarray]: - """Solve the retargeting problem.""" - - n_retarget = len(mano_joint_retarget_indices) - timesteps = target_keypoints.shape[0] - - # Variables. - class ManoJointsScaleVar( - jaxls.Var[jax.Array], default_factory=lambda: jnp.ones((n_retarget, n_retarget)) - ): ... - - class OffsetVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros((3,))): ... - - var_joints = robot.joint_var_cls(jnp.arange(timesteps)) - var_Ts_world_root = jaxls.SE3Var(jnp.arange(timesteps)) - var_smpl_joints_scale = ManoJointsScaleVar(jnp.zeros(timesteps)) - var_offset = OffsetVar(jnp.zeros(timesteps)) - - # Costs. - costs: list[jaxls.Cost] = [] - - @jaxls.Cost.create_factory - def retargeting_cost( - var_values: jaxls.VarValues, - var_Ts_world_root: jaxls.SE3Var, - var_robot_cfg: jaxls.Var[jnp.ndarray], - var_smpl_joints_scale: ManoJointsScaleVar, - keypoints: jnp.ndarray, - ) -> jax.Array: - """Retargeting factor, with a focus on: - - matching the relative joint/keypoint positions (vectors). - - and matching the relative angles between the vectors. - """ - robot_cfg = var_values[var_robot_cfg] - T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg)) - T_world_root = var_values[var_Ts_world_root] - T_world_link = T_world_root @ T_root_link - - mano_pos = keypoints[jnp.array(mano_joint_retarget_indices)] - robot_pos = T_world_link.translation()[ - jnp.array(shadow_hand_link_retarget_indices) - ] - - # NxN grid of relative positions. - delta_mano = mano_pos[:, None] - mano_pos[None, :] - delta_robot = robot_pos[:, None] - robot_pos[None, :] - - # Vector regularization. - position_scale = var_values[var_smpl_joints_scale][..., None] - residual_position_delta = ( - (delta_mano - delta_robot * position_scale) - * (1 - jnp.eye(delta_mano.shape[0])[..., None]) - * mano_mask[..., None] - ) - - # Vector angle regularization. - delta_mano_normalized = delta_mano / jnp.linalg.norm( - delta_mano + 1e-6, axis=-1, keepdims=True - ) - delta_robot_normalized = delta_robot / jnp.linalg.norm( - delta_robot + 1e-6, axis=-1, keepdims=True - ) - residual_angle_delta = 1 - (delta_mano_normalized * delta_robot_normalized).sum( - axis=-1 - ) - residual_angle_delta = ( - residual_angle_delta - * (1 - jnp.eye(residual_angle_delta.shape[0])) - * mano_mask - ) - - residual = ( - jnp.concatenate( - [ - residual_position_delta.flatten(), - residual_angle_delta.flatten(), - ], - axis=0, - ) - * weights["local_alignment"] - ) - return residual - - @jaxls.Cost.create_factory - def pc_alignment_cost( - var_values: jaxls.VarValues, - var_Ts_world_root: jaxls.SE3Var, - var_robot_cfg: jaxls.Var[jnp.ndarray], - keypoints: jnp.ndarray, - ) -> jax.Array: - """Soft cost to align the human keypoints to the robot, in the world frame.""" - T_world_root = var_values[var_Ts_world_root] - robot_cfg = var_values[var_robot_cfg] - T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg)) - T_world_link = T_world_root @ T_root_link - link_pos = T_world_link.translation()[shadow_hand_link_retarget_indices] - keypoint_pos = keypoints[mano_joint_retarget_indices] - return (link_pos - keypoint_pos).flatten() * weights["global_alignment"] - - @jaxls.Cost.create_factory - def root_smoothness( - var_values: jaxls.VarValues, - var_Ts_world_root: jaxls.SE3Var, - var_Ts_world_root_prev: jaxls.SE3Var, - ) -> jax.Array: - """Smoothness cost for the robot root translation.""" - return ( - var_values[var_Ts_world_root].translation() - - var_values[var_Ts_world_root_prev].translation() - ).flatten() * weights["root_smoothness"] - - costs = [ - retargeting_cost( - var_Ts_world_root, - var_joints, - var_smpl_joints_scale, - target_keypoints, - ), - pk.costs.limit_cost( - jax.tree.map(lambda x: x[None], robot), - var_joints, - 100.0, - ), - pk.costs.smoothness_cost( - robot.joint_var_cls(jnp.arange(1, timesteps)), - robot.joint_var_cls(jnp.arange(0, timesteps - 1)), - jnp.array([weights["joint_smoothness"]]), - ), - pc_alignment_cost( - var_Ts_world_root, - var_joints, - target_keypoints, - ), - root_smoothness( - jaxls.SE3Var(jnp.arange(1, timesteps)), - jaxls.SE3Var(jnp.arange(0, timesteps - 1)), - ), - ] - - solution = ( - jaxls.LeastSquaresProblem( - costs, [var_joints, var_Ts_world_root, var_smpl_joints_scale, var_offset] - ) - .analyze() - .solve() - ) - transform = solution[var_Ts_world_root] - offset = solution[var_offset] - transform = jaxlie.SE3.from_translation(offset) @ transform - return transform, solution[var_joints] - - - if __name__ == "__main__": - main() diff --git a/docs/source/examples/10_humanoid_retargeting.rst b/docs/source/examples/10_humanoid_retargeting.rst deleted file mode 100644 index 83b9f90..0000000 --- a/docs/source/examples/10_humanoid_retargeting.rst +++ /dev/null @@ -1,300 +0,0 @@ -.. Comment: this file is automatically generated by `update_example_docs.py`. - It should not be modified manually. - -Humanoid Retargeting -========================================== - - -Simpler motion retargeting to the G1 humanoid. - -All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details. - - - -.. code-block:: python - :linenos: - - - import time - from pathlib import Path - from typing import Tuple, TypedDict - - import jax - import jax.numpy as jnp - import jax_dataclasses as jdc - import jaxlie - import jaxls - import numpy as onp - import pyronot as pk - import viser - from robot_descriptions.loaders.yourdfpy import load_robot_description - from viser.extras import ViserUrdf - - from retarget_helpers._utils import ( - SMPL_JOINT_NAMES, - create_conn_tree, - get_humanoid_retarget_indices, - ) - - - class RetargetingWeights(TypedDict): - local_alignment: float - """Local alignment weight, by matching the relative joint/keypoint positions and angles.""" - global_alignment: float - """Global alignment weight, by matching the keypoint positions to the robot.""" - - - def main(): - """Main function for humanoid retargeting.""" - - urdf = load_robot_description("g1_description") - robot = pk.Robot.from_urdf(urdf) - - # Load source motion data: - # - keypoints [N, 45, 3], - # - left/right foot contact (boolean) 2 x [N], - # - heightmap [H, W]. - asset_dir = Path(__file__).parent / "retarget_helpers" / "humanoid" - smpl_keypoints = onp.load(asset_dir / "smpl_keypoints.npy") - is_left_foot_contact = onp.load(asset_dir / "left_foot_contact.npy") - is_right_foot_contact = onp.load(asset_dir / "right_foot_contact.npy") - heightmap = onp.load(asset_dir / "heightmap.npy") - - num_timesteps = smpl_keypoints.shape[0] - assert smpl_keypoints.shape == (num_timesteps, 45, 3) - assert is_left_foot_contact.shape == (num_timesteps,) - assert is_right_foot_contact.shape == (num_timesteps,) - - heightmap = pk.collision.Heightmap( - pose=jaxlie.SE3.identity(), - size=jnp.array([0.01, 0.01, 1.0]), - height_data=heightmap, - ) - - # Get the left and right foot keypoints, projected on the heightmap. - left_foot_keypoint_idx = SMPL_JOINT_NAMES.index("left_foot") - right_foot_keypoint_idx = SMPL_JOINT_NAMES.index("right_foot") - left_foot_keypoints = smpl_keypoints[..., left_foot_keypoint_idx, :].reshape(-1, 3) - right_foot_keypoints = smpl_keypoints[..., right_foot_keypoint_idx, :].reshape( - -1, 3 - ) - left_foot_keypoints = heightmap.project_points(left_foot_keypoints) - right_foot_keypoints = heightmap.project_points(right_foot_keypoints) - - smpl_joint_retarget_indices, g1_joint_retarget_indices = ( - get_humanoid_retarget_indices() - ) - smpl_mask = create_conn_tree(robot, g1_joint_retarget_indices) - - server = viser.ViserServer() - base_frame = server.scene.add_frame("/base", show_axes=False) - urdf_vis = ViserUrdf(server, urdf, root_node_name="/base") - playing = server.gui.add_checkbox("playing", True) - timestep_slider = server.gui.add_slider("timestep", 0, num_timesteps - 1, 1, 0) - server.scene.add_mesh_trimesh("/heightmap", heightmap.to_trimesh()) - - weights = pk.viewer.WeightTuner( - server, - RetargetingWeights( # type: ignore - local_alignment=2.0, - global_alignment=1.0, - ), - ) - - Ts_world_root, joints = None, None - - def generate_trajectory(): - nonlocal Ts_world_root, joints - gen_button.disabled = True - Ts_world_root, joints = solve_retargeting( - robot=robot, - target_keypoints=smpl_keypoints, - smpl_joint_retarget_indices=smpl_joint_retarget_indices, - g1_joint_retarget_indices=g1_joint_retarget_indices, - smpl_mask=smpl_mask, - weights=weights.get_weights(), # type: ignore - ) - gen_button.disabled = False - - gen_button = server.gui.add_button("Retarget!") - gen_button.on_click(lambda _: generate_trajectory()) - - generate_trajectory() - assert Ts_world_root is not None and joints is not None - - while True: - with server.atomic(): - if playing.value: - timestep_slider.value = (timestep_slider.value + 1) % num_timesteps - tstep = timestep_slider.value - base_frame.wxyz = onp.array(Ts_world_root.wxyz_xyz[tstep][:4]) - base_frame.position = onp.array(Ts_world_root.wxyz_xyz[tstep][4:]) - urdf_vis.update_cfg(onp.array(joints[tstep])) - server.scene.add_point_cloud( - "/target_keypoints", - onp.array(smpl_keypoints[tstep]), - onp.array((0, 0, 255))[None].repeat(45, axis=0), - point_size=0.01, - ) - - time.sleep(0.05) - - - @jdc.jit - def solve_retargeting( - robot: pk.Robot, - target_keypoints: jnp.ndarray, - smpl_joint_retarget_indices: jnp.ndarray, - g1_joint_retarget_indices: jnp.ndarray, - smpl_mask: jnp.ndarray, - weights: RetargetingWeights, - ) -> Tuple[jaxlie.SE3, jnp.ndarray]: - """Solve the retargeting problem.""" - - n_retarget = len(smpl_joint_retarget_indices) - timesteps = target_keypoints.shape[0] - - # Robot properties. - # - Joints that should move less for natural humanoid motion. - joints_to_move_less = jnp.array( - [ - robot.joints.actuated_names.index(name) - for name in ["left_hip_yaw_joint", "right_hip_yaw_joint", "torso_joint"] - ] - ) - - # Variables. - class SmplJointsScaleVarG1( - jaxls.Var[jax.Array], default_factory=lambda: jnp.ones((n_retarget, n_retarget)) - ): ... - - class OffsetVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros((3,))): ... - - var_joints = robot.joint_var_cls(jnp.arange(timesteps)) - var_Ts_world_root = jaxls.SE3Var(jnp.arange(timesteps)) - var_smpl_joints_scale = SmplJointsScaleVarG1(jnp.zeros(timesteps)) - var_offset = OffsetVar(jnp.zeros(timesteps)) - - # Costs. - costs: list[jaxls.Cost] = [] - - @jaxls.Cost.create_factory - def retargeting_cost( - var_values: jaxls.VarValues, - var_Ts_world_root: jaxls.SE3Var, - var_robot_cfg: jaxls.Var[jnp.ndarray], - var_smpl_joints_scale: SmplJointsScaleVarG1, - keypoints: jnp.ndarray, - ) -> jax.Array: - """Retargeting factor, with a focus on: - - matching the relative joint/keypoint positions (vectors). - - and matching the relative angles between the vectors. - """ - robot_cfg = var_values[var_robot_cfg] - T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg)) - T_world_root = var_values[var_Ts_world_root] - T_world_link = T_world_root @ T_root_link - - smpl_pos = keypoints[jnp.array(smpl_joint_retarget_indices)] - robot_pos = T_world_link.translation()[jnp.array(g1_joint_retarget_indices)] - - # NxN grid of relative positions. - delta_smpl = smpl_pos[:, None] - smpl_pos[None, :] - delta_robot = robot_pos[:, None] - robot_pos[None, :] - - # Vector regularization. - position_scale = var_values[var_smpl_joints_scale][..., None] - residual_position_delta = ( - (delta_smpl - delta_robot * position_scale) - * (1 - jnp.eye(delta_smpl.shape[0])[..., None]) - * smpl_mask[..., None] - ) - - # Vector angle regularization. - delta_smpl_normalized = delta_smpl / jnp.linalg.norm( - delta_smpl + 1e-6, axis=-1, keepdims=True - ) - delta_robot_normalized = delta_robot / jnp.linalg.norm( - delta_robot + 1e-6, axis=-1, keepdims=True - ) - residual_angle_delta = 1 - (delta_smpl_normalized * delta_robot_normalized).sum( - axis=-1 - ) - residual_angle_delta = ( - residual_angle_delta - * (1 - jnp.eye(residual_angle_delta.shape[0])) - * smpl_mask - ) - - residual = ( - jnp.concatenate( - [residual_position_delta.flatten(), residual_angle_delta.flatten()] - ) - * weights["local_alignment"] - ) - return residual - - @jaxls.Cost.create_factory - def pc_alignment_cost( - var_values: jaxls.VarValues, - var_Ts_world_root: jaxls.SE3Var, - var_robot_cfg: jaxls.Var[jnp.ndarray], - keypoints: jnp.ndarray, - ) -> jax.Array: - """Soft cost to align the human keypoints to the robot, in the world frame.""" - T_world_root = var_values[var_Ts_world_root] - robot_cfg = var_values[var_robot_cfg] - T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg)) - T_world_link = T_world_root @ T_root_link - link_pos = T_world_link.translation()[g1_joint_retarget_indices] - keypoint_pos = keypoints[smpl_joint_retarget_indices] - return (link_pos - keypoint_pos).flatten() * weights["global_alignment"] - - costs = [ - # Costs that are relatively self-contained to the robot. - retargeting_cost( - var_Ts_world_root, - var_joints, - var_smpl_joints_scale, - target_keypoints, - ), - pk.costs.limit_cost( - jax.tree.map(lambda x: x[None], robot), - var_joints, - 100.0, - ), - pk.costs.smoothness_cost( - robot.joint_var_cls(jnp.arange(1, timesteps)), - robot.joint_var_cls(jnp.arange(0, timesteps - 1)), - jnp.array([0.2]), - ), - pk.costs.rest_cost( - var_joints, - var_joints.default_factory()[None], - jnp.full(var_joints.default_factory().shape, 0.2) - .at[joints_to_move_less] - .set(2.0)[None], - ), - # Costs that are scene-centric. - pc_alignment_cost( - var_Ts_world_root, - var_joints, - target_keypoints, - ), - ] - - solution = ( - jaxls.LeastSquaresProblem( - costs, [var_joints, var_Ts_world_root, var_smpl_joints_scale, var_offset] - ) - .analyze() - .solve() - ) - transform = solution[var_Ts_world_root] - offset = solution[var_offset] - transform = jaxlie.SE3.from_translation(offset) @ transform - return transform, solution[var_joints] - - - if __name__ == "__main__": - main() diff --git a/docs/source/examples/11_hand_retargeting_fancy.rst b/docs/source/examples/11_hand_retargeting_fancy.rst deleted file mode 100644 index 9f5e77e..0000000 --- a/docs/source/examples/11_hand_retargeting_fancy.rst +++ /dev/null @@ -1,450 +0,0 @@ -.. Comment: this file is automatically generated by `update_example_docs.py`. - It should not be modified manually. - -Hand Retargeting (Fancy) -========================================== - - -Shadow Hand retargeting example, with costs to maintain contact with the object. -Find and unzip the shadowhand URDF at ``assets/hand_retargeting/shadowhand_urdf.zip``. - -All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details. - - - -.. code-block:: python - :linenos: - - - import time - from typing import Tuple, TypedDict - from pathlib import Path - import pickle - import trimesh - from scipy.spatial.transform import Rotation as R - - import jax - import jax.numpy as jnp - import jax_dataclasses as jdc - import jaxlie - import jaxls - import numpy as onp - import viser - from viser.extras import ViserUrdf - import yourdfpy - - import pyronot as pk - - from retarget_helpers._utils import ( - create_conn_tree, - get_mapping_from_mano_to_shadow, - MANO_TO_SHADOW_MAPPING, - ) - - - class RetargetingWeights(TypedDict): - local_alignment: float - """Local alignment weight, by matching the relative joint/keypoint positions and angles.""" - global_alignment: float - """Global alignment weight, by matching the keypoint positions to the robot.""" - contact: float - """Contact weight, to maintain contact between the robot and the object.""" - contact_margin: float - """Contact margin, to stop penalizing contact when the robot is already close to the object.""" - joint_smoothness: float - """Joint smoothness weight.""" - root_smoothness: float - """Root translation smoothness weight.""" - - - def main(): - """Main function for hand retargeting.""" - - asset_dir = Path(__file__).parent / "retarget_helpers" / "hand" - - robot_urdf_path = asset_dir / "shadowhand_urdf" / "shadow_hand_right.urdf" - - def filename_handler(fname: str) -> str: - base_path = robot_urdf_path.parent - return yourdfpy.filename_handler_magic(fname, dir=base_path) - - try: - urdf = yourdfpy.URDF.load(robot_urdf_path, filename_handler=filename_handler) - except FileNotFoundError: - raise FileNotFoundError( - "Please unzip the included URDF at `retarget_helpers/hand/shadowhand_urdf.zip`." - ) - - robot = pk.Robot.from_urdf(urdf) - - # Get the mapping from MANO to Shadow Hand joints. - shadow_link_idx, mano_joint_idx = get_mapping_from_mano_to_shadow(robot) - - # Create a mask for the MANO joints that are connected to the Shadow Hand. - mano_mask = create_conn_tree(robot, shadow_link_idx) - - # Load source motion data. - dexycb_motion_path = asset_dir / "dexycb_motion.pkl" - with open(dexycb_motion_path, "rb") as f: - dexycb_motion_data = pickle.load(f, encoding="latin1") - - # Load keypoints. - keypoints = dexycb_motion_data["world_hand_joints"] - assert not onp.isnan(keypoints).any() - num_timesteps = keypoints.shape[0] - num_mano_joints = len(MANO_TO_SHADOW_MAPPING) - - # Load mano hand contact information -- these are lists of lists, - # len(contact_points_per_frame) = num_timesteps, - # len(contact_points_per_frame[i]) = number of contacts in frame i, - contact_points_per_frame = dexycb_motion_data["contact_object_points"] - contact_indices_per_frame = dexycb_motion_data["contact_joint_indices"] - - # Now, we're going to pad this info + make a mask to indicate the padded regions. - # We will also track the shadowhand joint indices, NOT the MANO joint indices. - max_num_contacts = max(len(c) for c in contact_points_per_frame) - padded_contact_points_per_frame = onp.zeros((num_timesteps, max_num_contacts, 3)) - padded_contact_indices_per_frame = onp.zeros( - (num_timesteps, max_num_contacts), dtype=onp.int32 - ) - padded_contact_mask = onp.zeros((num_timesteps, max_num_contacts), dtype=onp.bool_) - for i in range(num_timesteps): - num_contacts = len(contact_points_per_frame[i]) - if num_contacts == 0: - continue - contact_shadowhand_indices = [ - robot.links.names.index(MANO_TO_SHADOW_MAPPING[j]) - for j in contact_indices_per_frame[i] - ] - padded_contact_points_per_frame[i, :num_contacts] = contact_points_per_frame[i] - padded_contact_indices_per_frame[i, :num_contacts] = contact_shadowhand_indices - padded_contact_mask[i, :num_contacts] = True - - # Load the object. - object_mesh_vertices = dexycb_motion_data["object_mesh_vertices"] - object_mesh_faces = dexycb_motion_data["object_mesh_faces"] - object_pose_list = dexycb_motion_data["object_poses"] # (N, 4, 4) - mesh = trimesh.Trimesh(object_mesh_vertices, object_mesh_faces) - - server = viser.ViserServer() - - # We will transform everything by the transform below, for aesthetics. - server.scene.add_frame( - "/scene_offset", - show_axes=False, - position=(-0.15415953, -0.73598871, 0.93434792), - wxyz=(-0.381870867, 0.92421569, 0.0, 2.0004992e-32), - ) - base_frame = server.scene.add_frame("/scene_offset/base", show_axes=False) - urdf_vis = ViserUrdf(server, urdf, root_node_name="/scene_offset/base") - playing = server.gui.add_checkbox("playing", True) - timestep_slider = server.gui.add_slider("timestep", 0, num_timesteps - 1, 1, 0) - object_handle = server.scene.add_mesh_trimesh("/scene_offset/object", mesh) - server.scene.add_grid("/grid", 2.0, 2.0) - - default_weights = RetargetingWeights( - local_alignment=10.0, - global_alignment=1.0, - contact=5.0, - contact_margin=0.01, - joint_smoothness=2.0, - root_smoothness=2.0, - ) - - weights = pk.viewer.WeightTuner( - server, - default_weights, # type: ignore - ) - - Ts_world_root, joints = None, None - - def generate_trajectory(): - nonlocal Ts_world_root, joints - gen_button.disabled = True - Ts_world_root, joints = solve_retargeting( - robot=robot, - target_keypoints=keypoints, - shadow_hand_link_retarget_indices=shadow_link_idx, - mano_joint_retarget_indices=mano_joint_idx, - mano_mask=mano_mask, - contact_points_per_frame=jnp.array(padded_contact_points_per_frame), - contact_indices_per_frame=jnp.array(padded_contact_indices_per_frame), - contact_mask=jnp.array(padded_contact_mask), - weights=weights.get_weights(), # type: ignore - ) - gen_button.disabled = False - - gen_button = server.gui.add_button("Retarget!") - gen_button.on_click(lambda _: generate_trajectory()) - - generate_trajectory() - assert Ts_world_root is not None and joints is not None - - while True: - with server.atomic(): - if playing.value: - timestep_slider.value = (timestep_slider.value + 1) % num_timesteps - tstep = timestep_slider.value - base_frame.wxyz = onp.array(Ts_world_root.wxyz_xyz[tstep][:4]) - base_frame.position = onp.array(Ts_world_root.wxyz_xyz[tstep][4:]) - urdf_vis.update_cfg(onp.array(joints[tstep])) - - server.scene.add_point_cloud( - "/scene_offset/target_keypoints", - onp.array(keypoints[tstep]).reshape(-1, 3), - onp.array((0, 0, 255))[None] - .repeat(num_mano_joints, axis=0) - .reshape(-1, 3), - point_size=0.005, - point_shape="sparkle", - ) - server.scene.add_point_cloud( - "/scene_offset/contact_points", - onp.array(contact_points_per_frame[tstep]).reshape(-1, 3), - onp.array((255, 0, 0))[None] - .repeat(len(contact_points_per_frame[tstep]), axis=0) - .reshape(-1, 3), - point_size=0.005, - point_shape="circle", - ) - object_handle.position = object_pose_list[tstep][:3, 3] - object_handle.wxyz = R.from_matrix(object_pose_list[tstep][:3, :3]).as_quat( - scalar_first=True - ) - - time.sleep(0.05) - - - @jdc.jit - def solve_retargeting( - robot: pk.Robot, - target_keypoints: jnp.ndarray, - shadow_hand_link_retarget_indices: jnp.ndarray, - mano_joint_retarget_indices: jnp.ndarray, - mano_mask: jnp.ndarray, - contact_points_per_frame: jnp.ndarray, - contact_indices_per_frame: jnp.ndarray, - contact_mask: jnp.ndarray, - weights: RetargetingWeights, - ) -> Tuple[jaxlie.SE3, jnp.ndarray]: - """Solve the retargeting problem.""" - - n_retarget = len(mano_joint_retarget_indices) - timesteps = target_keypoints.shape[0] - - # Variables. - class ManoJointsScaleVar( - jaxls.Var[jax.Array], default_factory=lambda: jnp.ones((n_retarget, n_retarget)) - ): ... - - class OffsetVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros((3,))): ... - - var_joints = robot.joint_var_cls(jnp.arange(timesteps)) - var_Ts_world_root = jaxls.SE3Var(jnp.arange(timesteps)) - var_smpl_joints_scale = ManoJointsScaleVar(jnp.zeros(timesteps)) - var_offset = OffsetVar(jnp.zeros(timesteps)) - - # Costs. - costs: list[jaxls.Cost] = [] - - @jaxls.Cost.create_factory - def retargeting_cost( - var_values: jaxls.VarValues, - var_Ts_world_root: jaxls.SE3Var, - var_robot_cfg: jaxls.Var[jnp.ndarray], - var_smpl_joints_scale: ManoJointsScaleVar, - keypoints: jnp.ndarray, - ) -> jax.Array: - """Retargeting factor, with a focus on: - - matching the relative joint/keypoint positions (vectors). - - and matching the relative angles between the vectors. - """ - robot_cfg = var_values[var_robot_cfg] - T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg)) - T_world_root = var_values[var_Ts_world_root] - T_world_link = T_world_root @ T_root_link - - mano_pos = keypoints[jnp.array(mano_joint_retarget_indices)] - robot_pos = T_world_link.translation()[ - jnp.array(shadow_hand_link_retarget_indices) - ] - - # NxN grid of relative positions. - delta_mano = mano_pos[:, None] - mano_pos[None, :] - delta_robot = robot_pos[:, None] - robot_pos[None, :] - - # Vector regularization. - position_scale = var_values[var_smpl_joints_scale][..., None] - residual_position_delta = ( - (delta_mano - delta_robot * position_scale) - * (1 - jnp.eye(delta_mano.shape[0])[..., None]) - * mano_mask[..., None] - ) - - # Vector angle regularization. - delta_mano_normalized = delta_mano / jnp.linalg.norm( - delta_mano + 1e-6, axis=-1, keepdims=True - ) - delta_robot_normalized = delta_robot / jnp.linalg.norm( - delta_robot + 1e-6, axis=-1, keepdims=True - ) - residual_angle_delta = 1 - (delta_mano_normalized * delta_robot_normalized).sum( - axis=-1 - ) - residual_angle_delta = ( - residual_angle_delta - * (1 - jnp.eye(residual_angle_delta.shape[0])) - * mano_mask - ) - - residual = ( - jnp.concatenate( - [ - residual_position_delta.flatten(), - residual_angle_delta.flatten(), - ], - axis=0, - ) - * weights["local_alignment"] - ) - return residual - - @jaxls.Cost.create_factory - def scale_regularization( - var_values: jaxls.VarValues, - var_smpl_joints_scale: ManoJointsScaleVar, - ) -> jax.Array: - """Regularize the scale of the retargeted joints.""" - # Close to 1. - res_0 = (var_values[var_smpl_joints_scale] - 1.0).flatten() * 1.0 - # Symmetric. - res_1 = ( - var_values[var_smpl_joints_scale] - var_values[var_smpl_joints_scale].T - ).flatten() * 100.0 - # Non-negative. - res_2 = jnp.clip(-var_values[var_smpl_joints_scale], min=0).flatten() * 100.0 - return jnp.concatenate([res_0, res_1, res_2]) - - @jaxls.Cost.create_factory - def pc_alignment_cost( - var_values: jaxls.VarValues, - var_Ts_world_root: jaxls.SE3Var, - var_robot_cfg: jaxls.Var[jnp.ndarray], - keypoints: jnp.ndarray, - ) -> jax.Array: - """Soft cost to align the human keypoints to the robot, in the world frame.""" - T_world_root = var_values[var_Ts_world_root] - robot_cfg = var_values[var_robot_cfg] - T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg)) - T_world_link = T_world_root @ T_root_link - link_pos = T_world_link.translation()[shadow_hand_link_retarget_indices] - keypoint_pos = keypoints[mano_joint_retarget_indices] - return (link_pos - keypoint_pos).flatten() * weights["global_alignment"] - - @jaxls.Cost.create_factory - def root_smoothness( - var_values: jaxls.VarValues, - var_Ts_world_root: jaxls.SE3Var, - var_Ts_world_root_prev: jaxls.SE3Var, - ) -> jax.Array: - """Smoothness cost for the robot root translation.""" - return ( - var_values[var_Ts_world_root].translation() - - var_values[var_Ts_world_root_prev].translation() - ).flatten() * weights["root_smoothness"] - - @jaxls.Cost.create_factory - def contact_cost( - var_values: jaxls.VarValues, - var_T_world_root: jaxls.SE3Var, - var_robot_cfg: jaxls.Var[jnp.ndarray], - contact_points: jax.Array, # (J, P, 3) - contact_indices: jax.Array, # (J,) - Actual robot joint indices. - contact_points_mask: jax.Array, # (J, P) - ) -> jax.Array: - """Cost for maintaining contact between specified robot joints and object points.""" - robot_cfg = var_values[var_robot_cfg] - T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg)) - T_world_root = var_values[var_T_world_root] - T_world_link = T_world_root @ T_root_link - - contact_joint_positions_world = T_world_link.translation()[contact_indices] - - # Contact points are already in world frame (as processed in dexycb). - # Calculate distances from each joint to its set of contact points - # Shape contact_points: (J, P, 3), contact_joint_positions_world: (J, 3) - # We want distance between joint J and points P for that joint. - # residual: (J, P, 3) - residual = contact_points - contact_joint_positions_world - - # Penalize distance beyond a margin. - residual_penalty = jnp.maximum( - jnp.abs(residual) - weights["contact_margin"], 0.0 - ) # (J, P, 3) - - # Apply mask. - residual_penalty = ( - residual_penalty * contact_points_mask[..., None] - ) # (J, P, 3) - residual = residual_penalty.flatten() * weights["contact"] - - return residual - - costs = [ - # Costs that are relatively self-contained to the robot. - retargeting_cost( - var_Ts_world_root, - var_joints, - var_smpl_joints_scale, - target_keypoints, - ), - scale_regularization(var_smpl_joints_scale), - pk.costs.limit_cost( - jax.tree.map(lambda x: x[None], robot), - var_joints, - 100.0, - ), - pk.costs.smoothness_cost( - robot.joint_var_cls(jnp.arange(1, timesteps)), - robot.joint_var_cls(jnp.arange(0, timesteps - 1)), - jnp.array([weights["joint_smoothness"]]), - ), - pk.costs.rest_cost( - var_joints, - var_joints.default_factory()[None], - jnp.array([0.2]), - ), - # Costs that are scene-centric. - pc_alignment_cost( - var_Ts_world_root, - var_joints, - target_keypoints, - ), - root_smoothness( - jaxls.SE3Var(jnp.arange(1, timesteps)), - jaxls.SE3Var(jnp.arange(0, timesteps - 1)), - ), - contact_cost( - var_T_world_root=var_Ts_world_root, - var_robot_cfg=var_joints, - contact_points=contact_points_per_frame, - contact_indices=contact_indices_per_frame, - contact_points_mask=contact_mask, - ), - ] - - solution = ( - jaxls.LeastSquaresProblem( - costs, [var_joints, var_Ts_world_root, var_smpl_joints_scale, var_offset] - ) - .analyze() - .solve() - ) - transform = solution[var_Ts_world_root] - offset = solution[var_offset] - transform = jaxlie.SE3.from_translation(offset) @ transform - return transform, solution[var_joints] - - - if __name__ == "__main__": - main() diff --git a/docs/source/examples/12_humanoid_retargeting_fancy.rst b/docs/source/examples/12_humanoid_retargeting_fancy.rst deleted file mode 100644 index f1929c4..0000000 --- a/docs/source/examples/12_humanoid_retargeting_fancy.rst +++ /dev/null @@ -1,519 +0,0 @@ -.. Comment: this file is automatically generated by `update_example_docs.py`. - It should not be modified manually. - -Humanoid Retargeting (Fancy) -========================================== - - -Retarget motion to G1 humanoid, with scene contacts (keep feet close to contact -points, while avoiding world-collisions). - -All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details. - - - -.. code-block:: python - :linenos: - - - import time - from typing import Tuple, TypedDict - from pathlib import Path - - import jax - import jax.numpy as jnp - import jax_dataclasses as jdc - import jaxlie - import jaxls - import numpy as onp - import pyronot as pk - import viser - from viser.extras import ViserUrdf - from pyronot.collision import colldist_from_sdf, collide - from robot_descriptions.loaders.yourdfpy import load_robot_description - - from retarget_helpers._utils import ( - SMPL_JOINT_NAMES, - create_conn_tree, - get_humanoid_retarget_indices, - ) - - - class RetargetingWeights(TypedDict): - local_alignment: float - """Local alignment weight, by matching the relative joint/keypoint positions and angles.""" - global_alignment: float - """Global alignment weight, by matching the keypoint positions to the robot.""" - floor_contact: float - """Floor contact weight, to place the robot's foot on the floor.""" - root_smoothness: float - """Root smoothness weight, to penalize the robot's root from jittering too much.""" - foot_skating: float - """Foot skating weight, to penalize the robot's foot from moving when it is in contact with the floor.""" - world_collision: float - """World collision weight, to penalize the robot from colliding with the world.""" - - - def main(): - """Main function for humanoid retargeting.""" - - urdf = load_robot_description("g1_description") - robot = pk.Robot.from_urdf(urdf) - robot_coll = pk.collision.RobotCollision.from_urdf(urdf) - - # Load source motion data: - # - keypoints [N, 45, 3], - # - left/right foot contact (boolean) 2 x [N], - # - heightmap [H, W]. - asset_dir = Path(__file__).parent / "retarget_helpers" / "humanoid" - smpl_keypoints = onp.load(asset_dir / "smpl_keypoints.npy") - is_left_foot_contact = onp.load(asset_dir / "left_foot_contact.npy") - is_right_foot_contact = onp.load(asset_dir / "right_foot_contact.npy") - heightmap = onp.load(asset_dir / "heightmap.npy") - - num_timesteps = smpl_keypoints.shape[0] - assert smpl_keypoints.shape == (num_timesteps, 45, 3) - assert is_left_foot_contact.shape == (num_timesteps,) - assert is_right_foot_contact.shape == (num_timesteps,) - - heightmap = pk.collision.Heightmap( - pose=jaxlie.SE3.identity(), - size=jnp.array([0.01, 0.01, 1.0]), - height_data=heightmap, - ) - - # Get the left and right foot keypoints, projected on the heightmap. - left_foot_keypoint_idx = SMPL_JOINT_NAMES.index("left_foot") - right_foot_keypoint_idx = SMPL_JOINT_NAMES.index("right_foot") - left_foot_keypoints = smpl_keypoints[..., left_foot_keypoint_idx, :].reshape(-1, 3) - right_foot_keypoints = smpl_keypoints[..., right_foot_keypoint_idx, :].reshape( - -1, 3 - ) - left_foot_keypoints = heightmap.project_points(left_foot_keypoints) - right_foot_keypoints = heightmap.project_points(right_foot_keypoints) - - smpl_joint_retarget_indices, g1_joint_retarget_indices = ( - get_humanoid_retarget_indices() - ) - smpl_mask = create_conn_tree(robot, g1_joint_retarget_indices) - - server = viser.ViserServer() - base_frame = server.scene.add_frame("/base", show_axes=False) - urdf_vis = ViserUrdf(server, urdf, root_node_name="/base") - playing = server.gui.add_checkbox("playing", True) - timestep_slider = server.gui.add_slider("timestep", 0, num_timesteps - 1, 1, 0) - server.scene.add_mesh_trimesh("/heightmap", heightmap.to_trimesh()) - - weights = pk.viewer.WeightTuner( - server, - RetargetingWeights( - local_alignment=2.0, - global_alignment=1.0, - floor_contact=1.0, - root_smoothness=1.0, - foot_skating=1.0, - world_collision=1.0, - ), # type: ignore - ) - - Ts_world_root, joints = None, None - - def generate_trajectory(): - nonlocal Ts_world_root, joints - gen_button.disabled = True - Ts_world_root, joints = solve_retargeting( - robot=robot, - robot_coll=robot_coll, - target_keypoints=smpl_keypoints, - is_left_foot_contact=is_left_foot_contact, - is_right_foot_contact=is_right_foot_contact, - left_foot_keypoints=left_foot_keypoints, - right_foot_keypoints=right_foot_keypoints, - smpl_joint_retarget_indices=smpl_joint_retarget_indices, - g1_joint_retarget_indices=g1_joint_retarget_indices, - smpl_mask=smpl_mask, - heightmap=heightmap, - weights=weights.get_weights(), # type: ignore - ) - gen_button.disabled = False - - gen_button = server.gui.add_button("Retarget!") - gen_button.on_click(lambda _: generate_trajectory()) - - generate_trajectory() - assert Ts_world_root is not None and joints is not None - - while True: - with server.atomic(): - if playing.value: - timestep_slider.value = (timestep_slider.value + 1) % num_timesteps - tstep = timestep_slider.value - base_frame.wxyz = onp.array(Ts_world_root.wxyz_xyz[tstep][:4]) - base_frame.position = onp.array(Ts_world_root.wxyz_xyz[tstep][4:]) - urdf_vis.update_cfg(onp.array(joints[tstep])) - server.scene.add_point_cloud( - "/target_keypoints", - onp.array(smpl_keypoints[tstep]), - onp.array((0, 0, 255))[None].repeat(45, axis=0), - point_size=0.01, - ) - - time.sleep(0.05) - - - @jdc.jit - def solve_retargeting( - robot: pk.Robot, - robot_coll: pk.collision.RobotCollision, - target_keypoints: jnp.ndarray, - is_left_foot_contact: jnp.ndarray, - is_right_foot_contact: jnp.ndarray, - left_foot_keypoints: jnp.ndarray, - right_foot_keypoints: jnp.ndarray, - smpl_joint_retarget_indices: jnp.ndarray, - g1_joint_retarget_indices: jnp.ndarray, - smpl_mask: jnp.ndarray, - heightmap: pk.collision.Heightmap, - weights: RetargetingWeights, - ) -> Tuple[jaxlie.SE3, jnp.ndarray]: - """Solve the retargeting problem.""" - - n_retarget = len(smpl_joint_retarget_indices) - timesteps = target_keypoints.shape[0] - - # Robot properties. - # - Joints that should move less for natural humanoid motion. - joints_to_move_less = jnp.array( - [ - robot.joints.actuated_names.index(name) - for name in ["left_hip_yaw_joint", "right_hip_yaw_joint", "torso_joint"] - ] - ) - # - Foot indices. - left_foot_idx = robot.links.names.index("left_ankle_roll_link") - right_foot_idx = robot.links.names.index("right_ankle_roll_link") - - # Variables. - class SmplJointsScaleVarG1( - jaxls.Var[jax.Array], default_factory=lambda: jnp.ones((n_retarget, n_retarget)) - ): ... - - class OffsetVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros((3,))): ... - - var_joints = robot.joint_var_cls(jnp.arange(timesteps)) - var_Ts_world_root = jaxls.SE3Var(jnp.arange(timesteps)) - var_smpl_joints_scale = SmplJointsScaleVarG1(jnp.zeros(timesteps)) - var_offset = OffsetVar(jnp.zeros(timesteps)) - - # Costs. - costs: list[jaxls.Cost] = [] - - @jaxls.Cost.create_factory - def retargeting_cost( - var_values: jaxls.VarValues, - var_Ts_world_root: jaxls.SE3Var, - var_robot_cfg: jaxls.Var[jnp.ndarray], - var_smpl_joints_scale: SmplJointsScaleVarG1, - keypoints: jnp.ndarray, - ) -> jax.Array: - """Retargeting factor, with a focus on: - - matching the relative joint/keypoint positions (vectors). - - and matching the relative angles between the vectors. - """ - robot_cfg = var_values[var_robot_cfg] - T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg)) - T_world_root = var_values[var_Ts_world_root] - T_world_link = T_world_root @ T_root_link - - smpl_pos = keypoints[jnp.array(smpl_joint_retarget_indices)] - robot_pos = T_world_link.translation()[jnp.array(g1_joint_retarget_indices)] - - # NxN grid of relative positions. - delta_smpl = smpl_pos[:, None] - smpl_pos[None, :] - delta_robot = robot_pos[:, None] - robot_pos[None, :] - - # Vector regularization. - position_scale = var_values[var_smpl_joints_scale][..., None] - residual_position_delta = ( - (delta_smpl - delta_robot * position_scale) - * (1 - jnp.eye(delta_smpl.shape[0])[..., None]) - * smpl_mask[..., None] - ) - - # Vector angle regularization. - delta_smpl_normalized = delta_smpl / jnp.linalg.norm( - delta_smpl + 1e-6, axis=-1, keepdims=True - ) - delta_robot_normalized = delta_robot / jnp.linalg.norm( - delta_robot + 1e-6, axis=-1, keepdims=True - ) - residual_angle_delta = 1 - (delta_smpl_normalized * delta_robot_normalized).sum( - axis=-1 - ) - residual_angle_delta = ( - residual_angle_delta - * (1 - jnp.eye(residual_angle_delta.shape[0])) - * smpl_mask - ) - - residual = ( - jnp.concatenate( - [residual_position_delta.flatten(), residual_angle_delta.flatten()] - ) - * weights["local_alignment"] - ) - return residual - - @jaxls.Cost.create_factory - def scale_regularization( - var_values: jaxls.VarValues, - var_smpl_joints_scale: SmplJointsScaleVarG1, - ) -> jax.Array: - """Regularize the scale of the retargeted joints.""" - # Close to 1. - res_0 = (var_values[var_smpl_joints_scale] - 1.0).flatten() * 1.0 - # Symmetric. - res_1 = ( - var_values[var_smpl_joints_scale] - var_values[var_smpl_joints_scale].T - ).flatten() * 100.0 - # Non-negative. - res_2 = jnp.clip(-var_values[var_smpl_joints_scale], min=0).flatten() * 100.0 - return jnp.concatenate([res_0, res_1, res_2]) - - @jaxls.Cost.create_factory - def pc_alignment_cost( - var_values: jaxls.VarValues, - var_Ts_world_root: jaxls.SE3Var, - var_robot_cfg: jaxls.Var[jnp.ndarray], - keypoints: jnp.ndarray, - ) -> jax.Array: - """Soft cost to align the human keypoints to the robot, in the world frame.""" - T_world_root = var_values[var_Ts_world_root] - robot_cfg = var_values[var_robot_cfg] - T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg)) - T_world_link = T_world_root @ T_root_link - link_pos = T_world_link.translation()[g1_joint_retarget_indices] - keypoint_pos = keypoints[smpl_joint_retarget_indices] - return (link_pos - keypoint_pos).flatten() * weights["global_alignment"] - - @jaxls.Cost.create_factory - def floor_contact_cost( - var_values: jaxls.VarValues, - var_Ts_world_root: jaxls.SE3Var, - var_robot_cfg: jaxls.Var[jnp.ndarray], - var_offset: OffsetVar, - is_left_foot_contact: jnp.ndarray, - is_right_foot_contact: jnp.ndarray, - left_foot_keypoints: jnp.ndarray, - right_foot_keypoints: jnp.ndarray, - ) -> jax.Array: - """Cost to place the robot on the floor: - - match foot keypoint positions, and - - penalize the foot from tilting too much. - """ - T_world_root = var_values[var_Ts_world_root] - T_root_link = jaxlie.SE3( - robot.forward_kinematics(cfg=var_values[var_robot_cfg]) - ) - - offset = var_values[var_offset] - left_foot_pos = (T_world_root @ T_root_link).translation()[ - left_foot_idx - ] + offset - right_foot_pos = (T_world_root @ T_root_link).translation()[ - right_foot_idx - ] + offset - left_foot_contact_cost = ( - is_left_foot_contact * (left_foot_pos - left_foot_keypoints) ** 2 - ) - right_foot_contact_cost = ( - is_right_foot_contact * (right_foot_pos - right_foot_keypoints) ** 2 - ) - - # Also penalize the foot from tilting too much -- keep z axis up! - left_foot_ori = ( - (T_world_root @ T_root_link).rotation().as_matrix()[left_foot_idx] - ) - right_foot_ori = ( - (T_world_root @ T_root_link).rotation().as_matrix()[right_foot_idx] - ) - left_foot_contact_residual_rot = jnp.where( - is_left_foot_contact, - left_foot_ori[2, 2] - 1, - 0.0, - ) - right_foot_contact_residual_rot = jnp.where( - is_right_foot_contact, - right_foot_ori[2, 2] - 1, - 0.0, - ) - - return ( - jnp.concatenate( - [ - left_foot_contact_cost.flatten(), - right_foot_contact_cost.flatten(), - left_foot_contact_residual_rot.flatten(), - right_foot_contact_residual_rot.flatten(), - ] - ) - * weights["floor_contact"] - ) - - @jaxls.Cost.create_factory - def root_smoothness( - var_values: jaxls.VarValues, - var_Ts_world_root: jaxls.SE3Var, - var_Ts_world_root_prev: jaxls.SE3Var, - ) -> jax.Array: - """Smoothness cost for the robot root pose.""" - return ( - var_values[var_Ts_world_root].inverse() @ var_values[var_Ts_world_root_prev] - ).log().flatten() * weights["root_smoothness"] - - @jaxls.Cost.create_factory - def skating_cost( - var_values: jaxls.VarValues, - var_Ts_world_root: jaxls.SE3Var, - var_robot_cfg: jaxls.Var[jnp.ndarray], - var_offset: OffsetVar, - var_Ts_world_root_prev: jaxls.SE3Var, - var_robot_cfg_prev: jaxls.Var[jnp.ndarray], - var_offset_prev: OffsetVar, - is_left_foot_contact: jnp.ndarray, - is_right_foot_contact: jnp.ndarray, - ) -> jax.Array: - """Cost to penalize the robot for skating.""" - T_world_root = var_values[var_Ts_world_root] - robot_cfg = var_values[var_robot_cfg] - T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg)) - offset = var_values[var_offset] - T_link = T_world_root @ T_root_link - left_foot_pos = T_link.translation()[left_foot_idx] + offset - right_foot_pos = T_link.translation()[right_foot_idx] + offset - - T_world_root_prev = var_values[var_Ts_world_root_prev] - robot_cfg_prev = var_values[var_robot_cfg_prev] - T_root_link_prev = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg_prev)) - offset_prev = var_values[var_offset_prev] - T_link_prev = T_world_root_prev @ T_root_link_prev - left_foot_pos_prev = T_link_prev.translation()[left_foot_idx] + offset_prev - right_foot_pos_prev = T_link_prev.translation()[right_foot_idx] + offset_prev - - skating_cost_left = is_left_foot_contact * (left_foot_pos - left_foot_pos_prev) - skating_cost_right = is_right_foot_contact * ( - right_foot_pos - right_foot_pos_prev - ) - - return ( - jnp.stack([skating_cost_left, skating_cost_right]) * weights["foot_skating"] - ) - - @jaxls.Cost.create_factory - def world_collision_cost( - var_values: jaxls.VarValues, - var_Ts_world_root: jaxls.SE3Var, - var_robot_cfg: jaxls.Var[jnp.ndarray], - var_offset: OffsetVar, - ) -> jax.Array: - """ - World collision; we intentionally use a low weight -- - high enough to lift the robot up from the ground, but - low enough to not interfere with the retargeting. - """ - Ts_world_root = var_values[var_Ts_world_root] - T_offset = jaxlie.SE3.from_translation(var_values[var_offset]) - transform = T_offset @ Ts_world_root - - robot_cfg = var_values[var_robot_cfg] - coll = robot_coll.at_config(robot, robot_cfg) - coll = coll.transform(transform) - - dist = collide(coll, heightmap) - act = colldist_from_sdf(dist, activation_dist=0.005) - return act.flatten() * weights["world_collision"] - - costs = [ - # Costs that are relatively self-contained to the robot. - retargeting_cost( - var_Ts_world_root, - var_joints, - var_smpl_joints_scale, - target_keypoints, - ), - scale_regularization(var_smpl_joints_scale), - pk.costs.limit_cost( - jax.tree.map(lambda x: x[None], robot), - var_joints, - 100.0, - ), - pk.costs.smoothness_cost( - robot.joint_var_cls(jnp.arange(1, timesteps)), - robot.joint_var_cls(jnp.arange(0, timesteps - 1)), - jnp.array([0.2]), - ), - pk.costs.rest_cost( - var_joints, - var_joints.default_factory()[None], - jnp.full(var_joints.default_factory().shape, 0.2) - .at[joints_to_move_less] - .set(2.0)[None], - ), - pk.costs.self_collision_cost( - jax.tree.map(lambda x: x[None], robot), - jax.tree.map(lambda x: x[None], robot_coll), - var_joints, - margin=0.05, - weight=2.0, - ), - # Costs that are scene-centric. - pc_alignment_cost( - var_Ts_world_root, - var_joints, - target_keypoints, - ), - floor_contact_cost( - var_Ts_world_root, - var_joints, - var_offset, - is_left_foot_contact, - is_right_foot_contact, - left_foot_keypoints, - right_foot_keypoints, - ), - root_smoothness( - jaxls.SE3Var(jnp.arange(1, timesteps)), - jaxls.SE3Var(jnp.arange(0, timesteps - 1)), - ), - skating_cost( - jaxls.SE3Var(jnp.arange(1, timesteps)), - robot.joint_var_cls(jnp.arange(1, timesteps)), - OffsetVar(jnp.arange(1, timesteps)), - jaxls.SE3Var(jnp.arange(0, timesteps - 1)), - robot.joint_var_cls(jnp.arange(0, timesteps - 1)), - OffsetVar(jnp.arange(0, timesteps - 1)), - is_left_foot_contact[:-1], - is_right_foot_contact[:-1], - ), - world_collision_cost( - var_Ts_world_root, - var_joints, - var_offset, - ), - ] - - solution = ( - jaxls.LeastSquaresProblem( - costs, [var_joints, var_Ts_world_root, var_smpl_joints_scale, var_offset] - ) - .analyze() - .solve() - ) - transform = solution[var_Ts_world_root] - offset = solution[var_offset] - transform = jaxlie.SE3.from_translation(offset) @ transform - return transform, solution[var_joints] - - - if __name__ == "__main__": - main() diff --git a/docs/source/examples/13_spherized_robot_ik.rst b/docs/source/examples/13_spherized_robot_ik.rst deleted file mode 100644 index e592167..0000000 --- a/docs/source/examples/13_spherized_robot_ik.rst +++ /dev/null @@ -1,80 +0,0 @@ -.. Comment: this file is automatically generated by `update_example_docs.py`. - It should not be modified manually. - -Spherized Robot IK -========================================== - - -All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details. - - - -.. code-block:: python - :linenos: - - - import time - - import numpy as np - import pyronot as pk - import viser - from viser.extras import ViserUrdf - - import pyronot_snippets as pks - import yourdfpy - - def main(): - - # Load the spherized panda urdf do not work!!! - urdf_path = "resources/ur5/ur5_spherized.urdf" - mesh_dir = "resources/ur5/meshes" - target_link_name = "tool0" - - # urdf_path = "resources/panda/panda_spherized.urdf" - # mesh_dir = "resources/panda/meshes" - # target_link_name = "panda_hand" - - urdf = yourdfpy.URDF.load(urdf_path, mesh_dir=mesh_dir) - - # urdf = load_robot_description("ur5_description") - # target_link_name = "ee_link" - - # Create robot. - robot = pk.Robot.from_urdf(urdf) - robot_coll = pk.collision.RobotCollisionSpherized.from_urdf(urdf) - # robot_coll = pk.collision.RobotCollision.from_urdf(urdf) - - # Set up visualizer. - server = viser.ViserServer() - server.scene.add_grid("/ground", width=2, height=2) - urdf_vis = ViserUrdf(server, urdf, root_node_name="/base") - - # Add the collision mesh to the visualizer. - - # Create interactive controller with initial position. - ik_target = server.scene.add_transform_controls( - "/ik_target", scale=0.2, position=(0.61, 0.0, 0.56), wxyz=(0, 0, 1, 0) - ) - timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True) - - while True: - # Solve IK. - start_time = time.time() - solution = pks.solve_ik( - robot=robot, - target_link_name=target_link_name, - target_position=np.array(ik_target.position), - target_wxyz=np.array(ik_target.wxyz), - ) - - # Update timing handle. - elapsed_time = time.time() - start_time - timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000) - - # Update visualizer. - urdf_vis.update_cfg(solution) - # Update the collision mesh. - robot_coll_mesh = robot_coll.at_config(robot, solution).to_trimesh() - server.scene.add_mesh_trimesh("/robot/collision", mesh=robot_coll_mesh) - if __name__ == "__main__": - main() diff --git a/docs/source/examples/14_spherized_ik_with_coll.rst b/docs/source/examples/14_spherized_ik_with_coll.rst deleted file mode 100644 index 396d446..0000000 --- a/docs/source/examples/14_spherized_ik_with_coll.rst +++ /dev/null @@ -1,95 +0,0 @@ -.. Comment: this file is automatically generated by `update_example_docs.py`. - It should not be modified manually. - -IK with Collision -========================================== - - -Basic Inverse Kinematics with Collision Avoidance using PyRoNot. - -All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details. - - - -.. code-block:: python - :linenos: - - - import time - - import numpy as np - import pyronot as pk - import viser - from pyronot.collision import HalfSpace, RobotCollision, RobotCollisionSpherized, Sphere - from robot_descriptions.loaders.yourdfpy import load_robot_description - from viser.extras import ViserUrdf - - import pyronot_snippets as pks - import yourdfpy - - - def main(): - """Main function for basic IK with collision.""" - urdf_path = "resources/ur5/ur5_spherized.urdf" - mesh_dir = "resources/ur5/meshes" - target_link_name = "tool0" - - # urdf_path = "resources/panda/panda_spherized.urdf" - # mesh_dir = "resources/panda/meshes" - # target_link_name = "panda_hand" - urdf = yourdfpy.URDF.load(urdf_path, mesh_dir=mesh_dir) - robot = pk.Robot.from_urdf(urdf) - - robot_coll = RobotCollisionSpherized.from_urdf(urdf) - plane_coll = HalfSpace.from_point_and_normal( - np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0]) - ) - sphere_coll = Sphere.from_center_and_radius( - np.array([0.0, 0.0, 0.0]), np.array([0.05]) - ) - - # Set up visualizer. - server = viser.ViserServer() - server.scene.add_grid("/ground", width=2, height=2, cell_size=0.1) - urdf_vis = ViserUrdf(server, urdf, root_node_name="/robot") - - # Create interactive controller for IK target. - ik_target_handle = server.scene.add_transform_controls( - "/ik_target", scale=0.2, position=(0.5, 0.0, 0.5), wxyz=(0, 0, 1, 0) - ) - - # Create interactive controller and mesh for the sphere obstacle. - sphere_handle = server.scene.add_transform_controls( - "/obstacle", scale=0.2, position=(0.4, 0.3, 0.4) - ) - server.scene.add_mesh_trimesh("/obstacle/mesh", mesh=sphere_coll.to_trimesh()) - - timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True) - - while True: - start_time = time.time() - - sphere_coll_world_current = sphere_coll.transform_from_wxyz_position( - wxyz=np.array(sphere_handle.wxyz), - position=np.array(sphere_handle.position), - ) - - world_coll_list = [plane_coll, sphere_coll_world_current] - solution = pks.solve_ik_with_collision( - robot=robot, - coll=robot_coll, - world_coll_list=world_coll_list, - target_link_name=target_link_name, - target_position=np.array(ik_target_handle.position), - target_wxyz=np.array(ik_target_handle.wxyz), - ) - - # Update timing handle. - timing_handle.value = (time.time() - start_time) * 1000 - - # Update visualizer. - urdf_vis.update_cfg(solution) - - - if __name__ == "__main__": - main() diff --git a/docs/source/examples/15_spherized_ik_then_coll.rst b/docs/source/examples/15_spherized_ik_then_coll.rst deleted file mode 100644 index ae43390..0000000 --- a/docs/source/examples/15_spherized_ik_then_coll.rst +++ /dev/null @@ -1,101 +0,0 @@ -.. Comment: this file is automatically generated by `update_example_docs.py`. - It should not be modified manually. - -IK with Collision -========================================== - - -Basic Inverse Kinematics with Collision Avoidance using PyRoNot. - -All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details. - - - -.. code-block:: python - :linenos: - - - import time - - import numpy as np - import pyronot as pk - import viser - from pyronot.collision import HalfSpace, RobotCollision, RobotCollisionSpherized, Sphere - from robot_descriptions.loaders.yourdfpy import load_robot_description - from viser.extras import ViserUrdf - - import pyronot_snippets as pks - import yourdfpy - - - def main(): - """Main function for basic IK with collision.""" - urdf_path = "resources/ur5/ur5_spherized.urdf" - mesh_dir = "resources/ur5/meshes" - target_link_name = "tool0" - - # urdf_path = "resources/panda/panda_spherized.urdf" - # mesh_dir = "resources/panda/meshes" - # target_link_name = "panda_hand" - urdf = yourdfpy.URDF.load(urdf_path, mesh_dir=mesh_dir) - robot = pk.Robot.from_urdf(urdf) - - robot_coll = RobotCollisionSpherized.from_urdf(urdf) - plane_coll = HalfSpace.from_point_and_normal( - np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0]) - ) - sphere_coll = Sphere.from_center_and_radius( - np.array([0.0, 0.0, 0.0]), np.array([0.05]) - ) - - # Set up visualizer. - server = viser.ViserServer() - server.scene.add_grid("/ground", width=2, height=2, cell_size=0.1) - urdf_vis = ViserUrdf(server, urdf, root_node_name="/robot") - - # Create interactive controller for IK target. - ik_target_handle = server.scene.add_transform_controls( - "/ik_target", scale=0.2, position=(0.5, 0.0, 0.5), wxyz=(0, 0, 1, 0) - ) - - # Create interactive controller and mesh for the sphere obstacle. - sphere_handle = server.scene.add_transform_controls( - "/obstacle", scale=0.2, position=(0.4, 0.3, 0.4) - ) - server.scene.add_mesh_trimesh("/obstacle/mesh", mesh=sphere_coll.to_trimesh()) - - just_ik_timing_handle = server.gui.add_number("just ik (ms)", 0.001, disabled=True) - coll_ik_timing_handle = server.gui.add_number("coll ik (ms)", 0.001, disabled=True) - while True: - - sphere_coll_world_current = sphere_coll.transform_from_wxyz_position( - wxyz=np.array(sphere_handle.wxyz), - position=np.array(sphere_handle.position), - ) - start_time = time.time() - just_ik = pks.solve_ik( - robot=robot, - target_link_name=target_link_name, - target_position=np.array(ik_target_handle.position), - target_wxyz=np.array(ik_target_handle.wxyz), - ) - just_ik_timing_handle.value = (time.time() - start_time) * 1000 - - world_coll_list = [plane_coll, sphere_coll_world_current] - start_time = time.time() - solution = pks.solve_collision_with_config( - robot=robot, - coll=robot_coll, - world_coll_list=world_coll_list, - cfg=just_ik, - ) - coll_ik_timing_handle.value = (time.time() - start_time) * 1000 - - # Update timing handle. - - # Update visualizer. - urdf_vis.update_cfg(solution) - - - if __name__ == "__main__": - main() diff --git a/docs/source/examples/16_spherized_online_planning.rst b/docs/source/examples/16_spherized_online_planning.rst deleted file mode 100644 index 813a845..0000000 --- a/docs/source/examples/16_spherized_online_planning.rst +++ /dev/null @@ -1,125 +0,0 @@ -.. Comment: this file is automatically generated by `update_example_docs.py`. - It should not be modified manually. - -Online Planning -========================================== - - -Run online planning in collision aware environments. - -All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details. - - - -.. code-block:: python - :linenos: - - - import time - - import numpy as np - import pyronot as pk - import viser - from pyronot.collision import HalfSpace, RobotCollision, RobotCollisionSpherized, Sphere - from robot_descriptions.loaders.yourdfpy import load_robot_description - from viser.extras import ViserUrdf - - import pyronot_snippets as pks - import yourdfpy - - - def main(): - """Main function for online planning with collision.""" - # urdf_path = "resources/ur5/ur5_spherized.urdf" - # mesh_dir = "resources/ur5/meshes" - # target_link_name = "robotiq_85_tool_link" - urdf_path = "resources/panda/panda_spherized.urdf" - mesh_dir = "resources/panda/meshes" - target_link_name = "panda_hand" - urdf = yourdfpy.URDF.load(urdf_path, mesh_dir=mesh_dir) - robot = pk.Robot.from_urdf(urdf) - - robot_coll = RobotCollisionSpherized.from_urdf(urdf) - plane_coll = HalfSpace.from_point_and_normal( - np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0]) - ) - sphere_coll = Sphere.from_center_and_radius( - np.array([0.0, 0.0, 0.0]), np.array([0.05]) - ) - - # Define the online planning parameters. - len_traj, dt = 5, 0.1 - - # Set up visualizer. - server = viser.ViserServer() - server.scene.add_grid("/ground", width=2, height=2, cell_size=0.1) - urdf_vis = ViserUrdf(server, urdf, root_node_name="/robot") - - # Create interactive controller for IK target. - ik_target_handle = server.scene.add_transform_controls( - "/ik_target", scale=0.2, position=(0.3, 0.0, 0.5), wxyz=(0, 0, 1, 0) - ) - - # Create interactive controller and mesh for the sphere obstacle. - sphere_handle = server.scene.add_transform_controls( - "/obstacle", scale=0.2, position=(0.4, 0.3, 0.4) - ) - server.scene.add_mesh_trimesh("/obstacle/mesh", mesh=sphere_coll.to_trimesh()) - target_frame_handle = server.scene.add_batched_axes( - "target_frame", - axes_length=0.05, - axes_radius=0.005, - batched_positions=np.zeros((25, 3)), - batched_wxyzs=np.array([[1.0, 0.0, 0.0, 0.0]] * 25), - ) - - timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True) - - sol_pos, sol_wxyz = None, None - sol_traj = np.array( - robot.joint_var_cls.default_factory()[None].repeat(len_traj, axis=0) - ) - while True: - start_time = time.time() - - sphere_coll_world_current = sphere_coll.transform_from_wxyz_position( - wxyz=np.array(sphere_handle.wxyz), - position=np.array(sphere_handle.position), - ) - - world_coll_list = [plane_coll, sphere_coll_world_current] - sol_traj, sol_pos, sol_wxyz = pks.solve_online_planning( - robot=robot, - robot_coll=robot_coll, - world_coll=world_coll_list, - target_link_name=target_link_name, - target_position=np.array(ik_target_handle.position), - target_wxyz=np.array(ik_target_handle.wxyz), - timesteps=len_traj, - dt=dt, - start_cfg=sol_traj[0], - prev_sols=sol_traj, - ) - - # Update timing handle. - timing_handle.value = ( - 0.99 * timing_handle.value + 0.01 * (time.time() - start_time) * 1000 - ) - - # Update visualizer. - urdf_vis.update_cfg( - sol_traj[0] - ) # The first step of the online trajectory solution. - - # Update the planned trajectory visualization. - if hasattr(target_frame_handle, "batched_positions"): - target_frame_handle.batched_positions = np.array(sol_pos) # type: ignore[attr-defined] - target_frame_handle.batched_wxyzs = np.array(sol_wxyz) # type: ignore[attr-defined] - else: - # This is an older version of Viser. - target_frame_handle.positions_batched = np.array(sol_pos) # type: ignore[attr-defined] - target_frame_handle.wxyzs_batched = np.array(sol_wxyz) # type: ignore[attr-defined] - - - if __name__ == "__main__": - main() diff --git a/docs/source/examples/17_geometries_example.rst b/docs/source/examples/17_geometries_example.rst deleted file mode 100644 index cbb852f..0000000 --- a/docs/source/examples/17_geometries_example.rst +++ /dev/null @@ -1,114 +0,0 @@ -.. Comment: this file is automatically generated by `update_example_docs.py`. - It should not be modified manually. - -Visualize Box and its six HalfSpace faces in viser. -========================================== - - -Run this example and use the transform controls to move/rotate the box -and the GUI numbers to change length/width/height interactively. - -All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details. - - - -.. code-block:: python - :linenos: - - - import time - - import numpy as np - import pyronot as pk - import viser - import trimesh - from pyronot.collision._geometry import Box, Sphere, Capsule - from pyronot.collision import collide - - - def main(): - """Start viser and visualize a Box and its HalfSpace faces.""" - - # Initial box parameters - center = np.array([0.0, 0.0, 0.5]) - length, width, height = 0.1, 0.1, 0.1 - - server = viser.ViserServer() - server.scene.add_grid("/ground", width=2, height=2, cell_size=0.1) - - box_handle = server.scene.add_transform_controls( - "/box_handle", scale=0.2, position=tuple(center), wxyz=(0, 0, 1, 0) - ) - - # Add transform controls for a sphere and a capsule - sphere_handle = server.scene.add_transform_controls( - "/sphere_handle", scale=0.15, position=(0.3, 0.0, 0.5), wxyz=(0, 0, 1, 0) - ) - sphere_radius_handle = server.gui.add_number("Sphere Radius", 0.05) - - cap_handle = server.scene.add_transform_controls( - "/cap_handle", scale=0.2, position=(-0.3, 0.0, 0.5), wxyz=(0, 0, 1, 0) - ) - cap_radius_handle = server.gui.add_number("Capsule Radius", 0.03) - cap_height_handle = server.gui.add_number("Capsule Height", 0.2) - - length_handle = server.gui.add_number("Length", length) - width_handle = server.gui.add_number("Width", width) - height_handle = server.gui.add_number("Height", height) - - server.scene.add_mesh_trimesh( - "/box/mesh", mesh=Box.from_center_and_dimensions(center, length, width, height).to_trimesh() - ) - - server.scene.add_mesh_trimesh("/box/polytope", mesh=trimesh.Trimesh()) - - while True: - pos = np.array(box_handle.position) - wxyz = np.array(box_handle.wxyz) - length = float(length_handle.value) if hasattr(length_handle, "value") else float(length_handle) - width = float(width_handle.value) if hasattr(width_handle, "value") else float(width_handle) - height = float(height_handle.value) if hasattr(height_handle, "value") else float(height_handle) - - box = Box.from_center_and_dimensions(center=pos, length=length, width=width, height=height, wxyz=wxyz) - - # Sphere - sph_pos = np.array(sphere_handle.position) - sph_wxyz = np.array(sphere_handle.wxyz) - sph_rad = float(sphere_radius_handle.value) if hasattr(sphere_radius_handle, "value") else float(sphere_radius_handle) - sphere = Sphere.from_center_and_radius(center=sph_pos, radius=sph_rad) - - # Capsule - cap_pos = np.array(cap_handle.position) - cap_wxyz = np.array(cap_handle.wxyz) - cap_rad = float(cap_radius_handle.value) if hasattr(cap_radius_handle, "value") else float(cap_radius_handle) - cap_h = float(cap_height_handle.value) if hasattr(cap_height_handle, "value") else float(cap_height_handle) - capsule = Capsule.from_radius_height(radius=cap_rad, height=cap_h, position=cap_pos, wxyz=cap_wxyz) - - server.scene.add_mesh_trimesh("/box/mesh", mesh=box.to_trimesh()) - server.scene.add_mesh_trimesh("/sphere/mesh", mesh=sphere.to_trimesh()) - server.scene.add_mesh_trimesh("/cap/mesh", mesh=capsule.to_trimesh()) - - poly_mesh = box.to_trimesh() - server.scene.add_mesh_trimesh("/box/polytope", mesh=poly_mesh) - - # Collision checks between all unique pairs - pairs = [ - ("Box", box, "Sphere", sphere), - ("Box", box, "Capsule", capsule), - ("Sphere", sphere, "Capsule", capsule), - ] - for name1, g1, name2, g2 in pairs: - try: - d = collide(g1, g2) - # d is a jax Array; convert to python float if scalar - d_val = float(d) - if d_val < 0.0: - print(f"Collision detected {name1} vs {name2}: distance={d_val:.6f}") - except Exception as e: - print(f"Error computing collision {name1} vs {name2}: {e}") - - time.sleep(1.0 / 60.0) - - - if __name__ == "__main__": - main() diff --git a/docs/source/examples/18_spherized_trajopt.rst b/docs/source/examples/18_spherized_trajopt.rst deleted file mode 100644 index 0e79fdc..0000000 --- a/docs/source/examples/18_spherized_trajopt.rst +++ /dev/null @@ -1,133 +0,0 @@ -.. Comment: this file is automatically generated by `update_example_docs.py`. - It should not be modified manually. - -Trajectory Optimization -========================================== - - -Basic Trajectory Optimization using PyRoNot. - -Robot going over a wall, while avoiding world-collisions. - -All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details. - - - -.. code-block:: python - :linenos: - - - import time - from typing import Literal - - import numpy as np - import pyronot as pk - import trimesh - import tyro - import viser - from viser.extras import ViserUrdf - from robot_descriptions.loaders.yourdfpy import load_robot_description - - import pyronot_snippets as pks - import yourdfpy - - def main(robot_name: Literal["ur5", "panda"] = "panda"): - if robot_name == "ur5": - raise ValueError("UR5 not supported yet") - - elif robot_name == "panda": - urdf_path = "resources/panda/panda_spherized.urdf" - mesh_dir = "resources/panda/meshes" - target_link_name = "panda_hand" - urdf = yourdfpy.URDF.load(urdf_path, mesh_dir=mesh_dir) - - down_wxyz = np.array([0, 0, 1, 0]) # for panda! - robot = pk.Robot.from_urdf(urdf) - - else: - raise ValueError(f"Invalid robot: {robot_name}") - - robot_coll = pk.collision.RobotCollisionSpherized.from_urdf(urdf) - - # Define the trajectory problem: - # - number of timesteps, timestep size - timesteps, dt = 25, 0.02 - # - the start and end poses. - start_pos, end_pos = np.array([0.5, -0.3, 0.2]), np.array([0.5, 0.3, 0.2]) - - # Define the obstacles: - # - Ground - ground_coll = pk.collision.HalfSpace.from_point_and_normal( - np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0]) - ) - # - Wall - wall_height = 0.4 - wall_width = 0.1 - wall_length = 0.4 - wall_intervals = np.arange(start=0.3, stop=wall_length + 0.3, step=0.05) - translation = np.concatenate( - [ - wall_intervals.reshape(-1, 1), - np.full((wall_intervals.shape[0], 1), 0.0), - np.full((wall_intervals.shape[0], 1), wall_height / 2), - ], - axis=1, - ) - wall_coll = pk.collision.Capsule.from_radius_height( - position=translation, - radius=np.full((translation.shape[0], 1), wall_width / 2), - height=np.full((translation.shape[0], 1), wall_height), - ) - world_coll = [ground_coll, wall_coll] - - traj = pks.solve_trajopt( - robot, - robot_coll, - world_coll, - target_link_name, - start_pos, - down_wxyz, - end_pos, - down_wxyz, - timesteps, - dt, - ) - traj = np.array(traj) - - # Visualize! - server = viser.ViserServer() - urdf_vis = ViserUrdf(server, urdf) - server.scene.add_grid("/grid", width=2, height=2, cell_size=0.1) - server.scene.add_mesh_trimesh( - "wall_box", - trimesh.creation.box( - extents=(wall_length, wall_width, wall_height), - transform=trimesh.transformations.translation_matrix( - np.array([0.5, 0.0, wall_height / 2]) - ), - ), - ) - for name, pos in zip(["start", "end"], [start_pos, end_pos]): - server.scene.add_frame( - f"/{name}", - position=pos, - wxyz=down_wxyz, - axes_length=0.05, - axes_radius=0.01, - ) - - slider = server.gui.add_slider( - "Timestep", min=0, max=timesteps - 1, step=1, initial_value=0 - ) - playing = server.gui.add_checkbox("Playing", initial_value=True) - - while True: - if playing.value: - slider.value = (slider.value + 1) % timesteps - - urdf_vis.update_cfg(traj[slider.value]) - time.sleep(1.0 / 10.0) - - - if __name__ == "__main__": - tyro.cli(main) diff --git a/docs/source/examples/19_spherized_ik_with_coll_exclude_link.rst b/docs/source/examples/19_spherized_ik_with_coll_exclude_link.rst deleted file mode 100644 index 3c214a7..0000000 --- a/docs/source/examples/19_spherized_ik_with_coll_exclude_link.rst +++ /dev/null @@ -1,123 +0,0 @@ -.. Comment: this file is automatically generated by `update_example_docs.py`. - It should not be modified manually. - -IK with Collision -========================================== - - -Basic Inverse Kinematics with Collision Avoidance using PyRoNot. - -All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details. - - - -.. code-block:: python - :linenos: - - - import time - - import numpy as np - import pyronot as pk - import viser - from pyronot.collision import HalfSpace, RobotCollision, RobotCollisionSpherized, Sphere - from robot_descriptions.loaders.yourdfpy import load_robot_description - from viser.extras import ViserUrdf - - import pyronot_snippets as pks - import yourdfpy - - - def main(): - """Main function for basic IK with collision.""" - urdf_path = "resources/ur5/ur5_spherized.urdf" - mesh_dir = "resources/ur5/meshes" - target_link_name = "robotiq_85_tool_link" - - # urdf_path = "resources/panda/panda_spherized.urdf" - # mesh_dir = "resources/panda/meshes" - # target_link_name = "panda_hand" - urdf = yourdfpy.URDF.load(urdf_path, mesh_dir=mesh_dir) - robot = pk.Robot.from_urdf(urdf, default_joint_cfg=[0, -1.57, 1.57, -1.57, -1.57, 0]) - - robot_coll = RobotCollisionSpherized.from_urdf(urdf) - plane_coll = HalfSpace.from_point_and_normal( - np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0]) - ) - sphere_coll = Sphere.from_center_and_radius( - np.array([0.0, 0.0, 0.0]), np.array([0.05]) - ) - - # Set up visualizer. - server = viser.ViserServer() - server.scene.add_grid("/ground", width=2, height=2, cell_size=0.1) - urdf_vis = ViserUrdf(server, urdf, root_node_name="/robot") - - # Create interactive controller for IK target. - ik_target_handle = server.scene.add_transform_controls( - "/ik_target", scale=0.2, position=(0.0, 0.6, 0.2), wxyz=(0, 0, 1, 0) - ) - - # Create interactive controller and mesh for the sphere obstacle. - sphere_handle = server.scene.add_transform_controls( - "/obstacle", scale=0.2, position=(0.4, 0.3, 0.4) - ) - server.scene.add_mesh_trimesh("/obstacle/mesh", mesh=sphere_coll.to_trimesh()) - - timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True) - exclude_links_from_cc = ["offset_link", "base_link"] - exclude_link_mask = robot.links.get_link_mask_from_names(exclude_links_from_cc) - print(exclude_link_mask) - while True: - start_time = time.time() - - sphere_coll_world_current = sphere_coll.transform_from_wxyz_position( - wxyz=np.array(sphere_handle.wxyz), - position=np.array(sphere_handle.position), - ) - - world_coll_list = [plane_coll, sphere_coll_world_current] - solution = pks.solve_ik_with_collision( - robot=robot, - coll=robot_coll, - world_coll_list=world_coll_list, - target_link_name=target_link_name, - target_position=np.array(ik_target_handle.position), - target_wxyz=np.array(ik_target_handle.wxyz), - ) - - # Update timing handle. - timing_handle.value = (time.time() - start_time) * 1000 - - # Update visualizer. - urdf_vis.update_cfg(solution) - # print(robot.links.names) - # Compute the collision of the solution - distance_link_to_plane = robot_coll.compute_world_collision_distance( - robot, - solution, - plane_coll - ) - distance_link_to_plane = RobotCollisionSpherized.mask_collision_distance(distance_link_to_plane, exclude_link_mask) - # print(distance_link_to_plane) - distance_link_to_sphere = robot_coll.compute_world_collision_distance( - robot, - solution, - sphere_coll - ) - distance_link_to_sphere = RobotCollisionSpherized.mask_collision_distance(distance_link_to_sphere, exclude_link_mask) - # print(distance_link_to_sphere) - # Visualize collision representation - robot_coll_config: Sphere = robot_coll.at_config(robot, solution) - # print(robot_coll_config.get_batch_axes()[-1]) - robot_coll_mesh = robot_coll_config.to_trimesh() - server.scene.add_mesh_trimesh( - "/robot_coll", - mesh=robot_coll_mesh, - wxyz=(1.0, 0.0, 0.0, 0.0), - position=(0.0, 0.0, 0.0), - ) - - - if __name__ == "__main__": - main() diff --git a/docs/source/examples/20_negative_distance.rst b/docs/source/examples/20_negative_distance.rst deleted file mode 100644 index ee45839..0000000 --- a/docs/source/examples/20_negative_distance.rst +++ /dev/null @@ -1,119 +0,0 @@ -.. Comment: this file is automatically generated by `update_example_docs.py`. - It should not be modified manually. - -IK with Collision -========================================== - - -Basic Inverse Kinematics with Collision Avoidance using PyRoNot. - -All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details. - - - -.. code-block:: python - :linenos: - - - import time - - import numpy as np - import pyronot as pk - import jax.numpy as jnp - import viser - from pyronot.collision import HalfSpace, RobotCollision, RobotCollisionSpherized, Sphere - from robot_descriptions.loaders.yourdfpy import load_robot_description - from viser.extras import ViserUrdf - - import pyronot_snippets as pks - import yourdfpy - - - def main(): - """Main function for basic IK with collision.""" - urdf_path = "resources/ur5/ur5_spherized.urdf" - mesh_dir = "resources/ur5/meshes" - target_link_name = "tool0" - srdf_path = "resources/ur5/ur5.srdf" - # urdf_path = "resources/panda/panda_spherized.urdf" - # mesh_dir = "resources/panda/meshes" - # target_link_name = "panda_hand" - urdf = yourdfpy.URDF.load(urdf_path, mesh_dir=mesh_dir) - robot = pk.Robot.from_urdf(urdf) - - robot_coll = RobotCollisionSpherized.from_urdf(urdf, srdf_path=srdf_path) - print(robot_coll.link_names) - plane_coll = HalfSpace.from_point_and_normal( - np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0]) - ) - sphere_coll = Sphere.from_center_and_radius( - np.array([0.0, 0.0, 0.0]), np.array([0.05]) - ) - - # Set up visualizer. - server = viser.ViserServer() - server.scene.add_grid("/ground", width=2, height=2, cell_size=0.1) - urdf_vis = ViserUrdf(server, urdf, root_node_name="/robot") - - # Create interactive controller for IK target. - ik_target_handle = server.scene.add_transform_controls( - "/ik_target", scale=0.2, position=(0.5, 0.0, 0.5), wxyz=(0, 0, 1, 0) - ) - - # Create interactive controller and mesh for the sphere obstacle. - sphere_handle = server.scene.add_transform_controls( - "/obstacle", scale=0.2, position=(0.4, 0.3, 0.4) - ) - server.scene.add_mesh_trimesh("/obstacle/mesh", mesh=sphere_coll.to_trimesh()) - - timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True) - distance_self_collision_handle = server.gui.add_number("Distance Self Collision", 0.001, disabled=True) - link1_handle = server.gui.add_text("Closest Link 1", initial_value="", disabled=True) - link2_handle = server.gui.add_text("Closest Link 2", initial_value="", disabled=True) - - - while True: - start_time = time.time() - - sphere_coll_world_current = sphere_coll.transform_from_wxyz_position( - wxyz=np.array(sphere_handle.wxyz), - position=np.array(sphere_handle.position), - ) - - world_coll_list = [plane_coll, sphere_coll_world_current] - solution = pks.solve_ik_with_collision( - robot=robot, - coll=robot_coll, - world_coll_list=world_coll_list, - target_link_name=target_link_name, - target_position=np.array(ik_target_handle.position), - target_wxyz=np.array(ik_target_handle.wxyz), - ) - # Compute self-collision distances - distance_self_collision = robot_coll.compute_self_collision_distance(robot, solution) - - # Find the closest pair - min_idx = int(jnp.argmin(distance_self_collision)) - min_distance = float(distance_self_collision[min_idx]) - - # Get the link names for this pair - link_i_idx = int(robot_coll.active_idx_i[min_idx]) - link_j_idx = int(robot_coll.active_idx_j[min_idx]) - link_i_name = robot_coll.link_names[link_i_idx] - link_j_name = robot_coll.link_names[link_j_idx] - - # Update GUI - distance_self_collision_handle.value = min_distance - link1_handle.value = link_i_name - link2_handle.value = link_j_name - print(f"link_i_name: {link_i_name}, link_j_name: {link_j_name}") - timing_handle.value = (time.time() - start_time) * 1000 - - # Update visualizer. - robot_coll_mesh = robot_coll.at_config(robot, solution).to_trimesh() - server.scene.add_mesh_trimesh("/robot/collision", mesh=robot_coll_mesh) - urdf_vis.update_cfg(solution) - - - if __name__ == "__main__": - main() diff --git a/docs/source/examples/21_quantize_collision.rst b/docs/source/examples/21_quantize_collision.rst deleted file mode 100644 index b6772f5..0000000 --- a/docs/source/examples/21_quantize_collision.rst +++ /dev/null @@ -1,131 +0,0 @@ -.. Comment: this file is automatically generated by `update_example_docs.py`. - It should not be modified manually. - -Quantize Collision Example -========================================== - - -Script to generate obstacles and robot configurations for quantization testing. - -All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details. - - - -.. code-block:: python - :linenos: - - - import jax - import jax.numpy as jnp - import numpy as np - import time - import pyronot as pk - from pyronot.collision import RobotCollisionSpherized, Sphere - from pyronot._robot_urdf_parser import RobotURDFParser - import yourdfpy - from pyronot.utils import quantize - - def generate_spheres(n_spheres): - print(f"Generating {n_spheres} random spheres...") - spheres = [] - for _ in range(n_spheres): - center = np.random.uniform(low=-1.0, high=1.0, size=(3,)) - radius = np.random.uniform(low=0.05, high=0.2) - sphere = Sphere.from_center_and_radius(center, np.array([radius])) - spheres.append(sphere) - - # Tree map them to create a batch of spheres - spheres_batch = jax.tree.map(lambda *args: jnp.stack(args), *spheres) - print(f"Generated {n_spheres} spheres.") - return spheres_batch - - def generate_configs(joints, n_configs): - print(f"Generating {n_configs} random robot configurations...") - q_batch = np.random.uniform( - low=joints.lower_limits, - high=joints.upper_limits, - size=(n_configs, joints.num_actuated_joints) - ) - print(f"Generated {n_configs} robot configurations.") - print(f"Configurations shape: {q_batch.shape}") - return q_batch - - def make_collision_checker(robot, robot_coll): - @jax.jit - def check_collisions(q_batch, obstacles): - # q_batch: (N_configs, dof) - # obstacles: Sphere batch (N_spheres) - - # Define single config check - def check_single(q, obs): - return robot_coll.compute_world_collision_distance(robot, q, obs) - - # Vmap over configs - # in_axes: q=(0), obs=(None) -> we want to check each q against ALL obs - return jax.vmap(check_single, in_axes=(0, None))(q_batch, obstacles) - - return check_collisions - - def run_benchmark(name, check_fn, q_batch, obstacles): - print(f"\n{name}:") - - # Metrics - q_size_mb = q_batch.nbytes / 1024 / 1024 - spheres_size_mb = sum(x.nbytes for x in jax.tree_util.tree_leaves(obstacles)) / 1024 / 1024 - - print(f"q_batch size: {q_size_mb:.2f} MB") - print(f"Obstacles (spheres) size: {spheres_size_mb:.2f} MB") - - # Warmup - print(f"Warming up JIT ({name})...") - _ = check_fn(q_batch, obstacles) - # _.block_until_ready() - - # Run collision checking - print(f"Executing collision checking ({name})...") - start_time = time.perf_counter() - dists = check_fn(q_batch, obstacles) - # dists.block_until_ready() - end_time = time.perf_counter() - - print(f"Time to compute: {end_time - start_time:.6f} seconds") - print(f"Collision distances shape: {dists.shape}") - print(f"Min distance: {jnp.min(dists)}") - - in_collision = dists < 0 - print(f"Number of collision pairs: {jnp.sum(in_collision)}") - - def main(): - global robot, robot_coll - # Load robot - urdf_path = "resources/ur5/ur5_spherized.urdf" - mesh_dir = "resources/ur5/meshes" - urdf = yourdfpy.URDF.load(urdf_path, mesh_dir=mesh_dir) - robot = pk.Robot.from_urdf(urdf) - joints, links = RobotURDFParser.parse(urdf) - - # Initialize collision model - print("Initializing collision model...") - robot_coll = RobotCollisionSpherized.from_urdf(urdf) - - # Create collision checker - check_collisions = make_collision_checker(robot, robot_coll) - - # Generate data - spheres_batch = generate_spheres(10000) - q_batch = generate_configs(joints, 10000) - - # Run benchmarks - run_benchmark("Default (float32)", check_collisions, q_batch, spheres_batch) - - # Quantized - q_batch_f16 = quantize(q_batch) - spheres_batch_f16 = quantize(spheres_batch) - run_benchmark("Quantized (float16)", check_collisions, q_batch_f16, spheres_batch_f16) - - q_batch_int8 = quantize(q_batch, jax.numpy.int8) - spheres_batch_int8 = quantize(spheres_batch, jax.numpy.int8) - run_benchmark("Quantized (int8)", check_collisions, q_batch_int8, spheres_batch_int8) - - if __name__ == "__main__": - main() diff --git a/docs/source/examples/22_neural_sdf.rst b/docs/source/examples/22_neural_sdf.rst deleted file mode 100644 index 8e6c9ff..0000000 --- a/docs/source/examples/22_neural_sdf.rst +++ /dev/null @@ -1,345 +0,0 @@ -.. Comment: this file is automatically generated by `update_example_docs.py`. - It should not be modified manually. - -Quantize Collision Example -========================================== - - -Script to generate obstacles and robot configurations for quantization testing. - -All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details. - - - -.. code-block:: python - :linenos: - - - import jax - import jax.numpy as jnp - import numpy as np - import time - import pyronot as pk - from pyronot.collision import RobotCollision, RobotCollisionSpherized, NeuralRobotCollision, Sphere - from pyronot._robot_urdf_parser import RobotURDFParser - import yourdfpy - from pyronot.utils import quantize - - # Global configuration: Set to True to run positional encoding comparison - RUN_POSITIONAL_ENCODING = False - - def generate_spheres(n_spheres): - print(f"Generating {n_spheres} random spheres...") - spheres = [] - for _ in range(n_spheres): - center = np.random.uniform(low=-1.0, high=1.0, size=(3,)) - radius = np.random.uniform(low=0.05, high=0.2) - sphere = Sphere.from_center_and_radius(center, np.array(radius)) - spheres.append(sphere) - - # Tree map them to create a batch of spheres - spheres_batch = jax.tree.map(lambda *args: jnp.stack(args), *spheres) - print(f"Generated {n_spheres} spheres.") - return spheres_batch - - def generate_configs(joints, n_configs): - print(f"Generating {n_configs} random robot configurations...") - q_batch = np.random.uniform( - low=joints.lower_limits, - high=joints.upper_limits, - size=(n_configs, joints.num_actuated_joints) - ) - print(f"Generated {n_configs} robot configurations.") - print(f"Configurations shape: {q_batch.shape}") - return q_batch - - def make_collision_checker(robot, robot_coll): - @jax.jit - def check_collisions(q_batch, obstacles): - # q_batch: (N_configs, dof) - # obstacles: Sphere batch (N_spheres) - - # Define single config check - def check_single(q, obs): - return robot_coll.compute_world_collision_distance(robot, q, obs) - - # Vmap over configs - # in_axes: q=(0), obs=(None) -> we want to check each q against ALL obs - return jax.vmap(check_single, in_axes=(0, None))(q_batch, obstacles) - - return check_collisions - - - def make_neural_collision_checker(robot, robot_coll, spheres_batch, use_positional_encoding=True, pe_min_deg=0, pe_max_deg=6): - """Train a neural collision model on the given static world and return its checker. - - This will: - - build a NeuralRobotCollision from the exact model - - train it on random configs for the provided world (spheres_batch) - - expose a vmap'ed collision function with the same signature as make_collision_checker's output - - Args: - robot: The Robot instance - robot_coll: The exact collision model (RobotCollisionSpherized) - spheres_batch: Batch of sphere obstacles - use_positional_encoding: If True, use iSDF-inspired positional encoding for - better capture of fine geometric details (default True) - pe_min_deg: Minimum frequency degree for positional encoding (default 0) - pe_max_deg: Maximum frequency degree for positional encoding (default 6) - """ - - # Wrap the world geometry in the same structure RobotCollisionSpherized expects - # RobotCollisionSpherized.from_urdf constructs a CollGeom internally when used in examples, - # so `spheres_batch` is already a valid batch of Sphere geometry. - - # Create neural collision model from existing exact model - # With positional encoding enabled for better accuracy near collision boundaries - neural_coll = NeuralRobotCollision.from_existing( - robot_coll, - use_positional_encoding=use_positional_encoding, - pe_min_deg=pe_min_deg, - pe_max_deg=pe_max_deg, - ) - - pe_status = f"with PE (deg {pe_min_deg}-{pe_max_deg})" if use_positional_encoding else "without PE" - print(f"Created neural collision model {pe_status}") - - # Train neural model on this specific world - neural_coll = neural_coll.train( - robot=robot, - world_geom=spheres_batch, - num_samples=10000, - batch_size=1000, # Smaller batch = more gradient updates per epoch - epochs=50, # More epochs for better convergence - learning_rate=1e-3, - ) - - # Now build a collision checker that calls the neural model - @jax.jit - def check_collisions(q_batch, obstacles): - # q_batch: (N_configs, dof) - # obstacles: same spheres_batch used for training - - def check_single(q, obs): - return neural_coll.compute_world_collision_distance(robot, q, obs) - - return jax.vmap(check_single, in_axes=(0, None))(q_batch, obstacles) - - return check_collisions - - def run_benchmark(name, check_fn, q_batch, obstacles): - print(f"\n{name}:") - - # Metrics - q_size_mb = q_batch.nbytes / 1024 / 1024 - spheres_size_mb = sum(x.nbytes for x in jax.tree_util.tree_leaves(obstacles)) / 1024 / 1024 - - print(f"q_batch size: {q_size_mb:.2f} MB") - print(f"Obstacles (spheres) size: {spheres_size_mb:.2f} MB") - - # Warmup - print(f"Warming up JIT ({name})...") - _ = check_fn(q_batch, obstacles) - - # Run collision checking - print(f"Executing collision checking ({name})...") - start_time = time.perf_counter() - dists = check_fn(q_batch, obstacles) - end_time = time.perf_counter() - elapsed_time = end_time - start_time - - print(f"Time to compute: {elapsed_time:.6f} seconds") - print(f"Collision distances shape: {dists.shape}") - print(f"Min distance: {jnp.min(dists):.6f}") - print(f"Max distance: {jnp.max(dists):.6f}") - print(f"Mean distance: {jnp.mean(dists):.6f}") - print(f"Std distance: {jnp.std(dists):.6f}") - - in_collision = dists < 0 - print(f"Number of collision pairs: {jnp.sum(in_collision)}") - - return dists, elapsed_time - - - def compare_results(name, neural_dists, exact_dists): - """Compare neural network predictions against exact results.""" - import gc - - print(f"\n {name} Comparison ") - - - # Compute metrics in a memory-efficient way - diff = neural_dists - exact_dists - mae = float(jnp.mean(jnp.abs(diff))) - max_ae = float(jnp.max(jnp.abs(diff))) - rmse = float(jnp.sqrt(jnp.mean(diff ** 2))) - bias = float(jnp.mean(diff)) - del diff # Free memory - gc.collect() - - print(f"Mean absolute error: {mae:.6f}") - print(f"Max absolute error: {max_ae:.6f}") - print(f"RMSE: {rmse:.6f}") - print(f"Mean error (bias): {bias:.6f}") - - # Check accuracy at collision boundary - exact_in_collision = exact_dists < 0.05 - neural_in_collision = neural_dists < 0.05 - - # Compute metrics and convert to Python ints immediately - true_positives = int(jnp.sum(exact_in_collision & neural_in_collision)) - false_positives = int(jnp.sum(~exact_in_collision & neural_in_collision)) - false_negatives = int(jnp.sum(exact_in_collision & ~neural_in_collision)) - true_negatives = int(jnp.sum(~exact_in_collision & ~neural_in_collision)) - - # Free the boolean arrays - del exact_in_collision, neural_in_collision - gc.collect() - - print(f"\nCollision Detection Accuracy (threshold=0.05):") - print(f" True Positives: {true_positives}") - print(f" False Positives: {false_positives}") - print(f" False Negatives: {false_negatives}") - print(f" True Negatives: {true_negatives}") - - precision = true_positives / (true_positives + false_positives + 1e-8) - recall = true_positives / (true_positives + false_negatives + 1e-8) - f1 = 2 * precision * recall / (precision + recall + 1e-8) - print(f" Precision: {precision:.4f}") - print(f" Recall: {recall:.4f}") - print(f" F1 Score: {f1:.4f}") - - return { - 'mae': mae, - 'max_ae': max_ae, - 'rmse': rmse, - 'bias': bias, - 'precision': precision, - 'recall': recall, - 'f1': f1, - } - - def main(): - import gc - - # Load robot - urdf_path = "resources/ur5/ur5_spherized.urdf" - mesh_dir = "resources/ur5/meshes" - urdf = yourdfpy.URDF.load(urdf_path, mesh_dir=mesh_dir) - robot = pk.Robot.from_urdf(urdf) - joints, links = RobotURDFParser.parse(urdf) - - # Initialize collision model - print("Initializing collision model...") - robot_coll = RobotCollisionSpherized.from_urdf(urdf) - - # Generate data (world is fixed for both exact and neural models) - spheres_batch = generate_spheres(100) - q_batch = generate_configs(joints, 50000) - - # Create collision checker using exact model - exact_check_collisions = make_collision_checker(robot, robot_coll) - - # Train neural model without positional encoding (default) - print("Training neural collision model WITHOUT positional encoding...") - print("(Raw link poses as input)") - neural_check_without_pe = make_neural_collision_checker( - robot, robot_coll, spheres_batch, - use_positional_encoding=False, - ) - - # Optionally train with positional encoding - neural_check_with_pe = None - if RUN_POSITIONAL_ENCODING: - print("\n" + "="*70) - print("Training neural collision model WITH positional encoding...") - print("(iSDF-inspired: projects onto icosahedron directions with frequency bands)") - print("="*70) - neural_check_with_pe = make_neural_collision_checker( - robot, robot_coll, spheres_batch, - use_positional_encoding=True, - pe_min_deg=0, - pe_max_deg=6, # 7 frequency bands: 2^0, 2^1, ..., 2^6 - ) - - # Run benchmarks - print("\n" + "="*70) - print("Running benchmarks...") - print("="*70) - - exact_dists, exact_time = run_benchmark( - "Exact (RobotCollisionSpherized)", - exact_check_collisions, q_batch, spheres_batch - ) - - neural_without_pe_dists, neural_without_pe_time = run_benchmark( - "Neural WITHOUT Positional Encoding", - neural_check_without_pe, q_batch, spheres_batch - ) - - neural_with_pe_dists = None - neural_with_pe_time = None - if RUN_POSITIONAL_ENCODING and neural_check_with_pe is not None: - neural_with_pe_dists, neural_with_pe_time = run_benchmark( - "Neural WITH Positional Encoding", - neural_check_with_pe, q_batch, spheres_batch - ) - - # Clear JAX caches and force garbage collection to free GPU memory - print("\nClearing memory before comparison...") - jax.clear_caches() - gc.collect() - - # Compare results - metrics_without_pe = compare_results( - "Neural WITHOUT Positional Encoding vs Exact", - neural_without_pe_dists, exact_dists - ) - - metrics_with_pe = None - if RUN_POSITIONAL_ENCODING and neural_with_pe_dists is not None: - metrics_with_pe = compare_results( - "Neural WITH Positional Encoding vs Exact", - neural_with_pe_dists, exact_dists - ) - - # Summary comparison (only if positional encoding was tested) - if RUN_POSITIONAL_ENCODING and metrics_with_pe is not None: - print("SUMMARY: Positional Encoding Impact") - print(f"\n{'Metric':<25} {'With PE':<15} {'Without PE':<15} {'Improvement':<15}") - - for metric in ['mae', 'rmse', 'max_ae', 'precision', 'recall', 'f1']: - with_pe = metrics_with_pe[metric] - without_pe = metrics_without_pe[metric] - - # For error metrics, lower is better; for accuracy metrics, higher is better - if metric in ['mae', 'rmse', 'max_ae', 'bias']: - improvement = (without_pe - with_pe) / (without_pe + 1e-8) * 100 - better = "↓" if with_pe < without_pe else "↑" - else: - improvement = (with_pe - without_pe) / (without_pe + 1e-8) * 100 - better = "↑" if with_pe > without_pe else "↓" - - print(f"{metric:<25} {with_pe:<15.6f} {without_pe:<15.6f} {improvement:+.1f}% {better}") - - print(f"\n{'Inference Time (s)':<25} {neural_with_pe_time:<15.6f} {neural_without_pe_time:<15.6f}") - print(f"{'Exact Time (s)':<25} {exact_time:<15.6f}") - else: - # Simple summary without PE comparison - print("SUMMARY: Neural vs Exact") - print(f"\n{'Metric':<25} {'Neural (no PE)':<15}") - for metric in ['mae', 'rmse', 'max_ae', 'precision', 'recall', 'f1']: - print(f"{metric:<25} {metrics_without_pe[metric]:<15.6f}") - print(f"\n{'Neural Inference Time (s)':<25} {neural_without_pe_time:<15.6f}") - print(f"{'Exact Time (s)':<25} {exact_time:<15.6f}") - print(f"{'Speedup':<25} {exact_time / neural_without_pe_time:<15.2f}x") - - # Cleanup - del exact_dists, neural_without_pe_dists - if neural_with_pe_dists is not None: - del neural_with_pe_dists - gc.collect() - - - if __name__ == "__main__": - main() diff --git a/docs/source/index.rst b/docs/source/index.rst deleted file mode 100644 index 8710800..0000000 --- a/docs/source/index.rst +++ /dev/null @@ -1,104 +0,0 @@ -PyRoNot -========== - -`Project page `_ `•` `arXiv `_ `•` `Code `_ - -**PyRoNot** is a library for robot kinematic optimization (Python Robot Kinematics). - -1. **Modular**: Optimization variables and cost functions are decoupled, enabling reusable components across tasks. Objectives like collision avoidance and pose matching can be applied to both IK and trajectory optimization without reimplementation. - -2. **Extensible**: ``PyRoNot`` supports automatic differentiation for user-defined costs with Jacobian computation, a real-time cost-weight tuning interface, and optional analytical Jacobians for performance-critical use cases. - -3. **Cross-Platform**: ``PyRoNot`` runs on CPU, GPU, and TPU, allowing efficient scaling from single-robot use cases to large-scale parallel processing for motion datasets or planning. - -We demonstrate how ``PyRoNot`` solves IK, trajectory optimization, and motion retargeting for robot hands and humanoids in a unified framework. It uses a Levenberg-Marquardt optimizer to efficiently solve these tasks, and we evaluate its performance on batched IK. - -Features include: - -- Differentiable robot forward kinematics model from a URDF. -- Automatic generation of robot collision primitives (e.g., capsules). -- Differentiable collision bodies with numpy broadcasting logic. -- Common cost factors (e.g., end effector pose, self/world-collision, manipulability). -- Arbitrary costs, getting Jacobians either calculated :doc:`through autodiff or defined manually`. -- Integration with a `Levenberg-Marquardt Solver `_ that supports optimization on manifolds (e.g., `lie groups `_). -- Cross-platform support (CPU, GPU, TPU) via JAX. - - - -Installation ------------- - -You can install ``pyronot`` with ``pip``, on Python 3.12+: - -.. code-block:: bash - - git clone https://github.com/chungmin99/pyronot.git - cd pyronot - pip install -e . - - -Python 3.10-3.11 should also work, but support may be dropped in the future. - -Limitations ------------ - -- **Soft constraints only**: We use a nonlinear least-squares formulation and model joint limits, collision avoidance, etc. as soft penalties with high weights rather than hard constraints. -- **Static shapes & JIT overhead**: JAX JIT compilation is triggered on first run and when input shapes change (e.g., number of targets, obstacles). Arrays can be pre-padded to vectorize over inputs with different shapes. -- **No sampling-based planners**: We don't include sampling-based planners (e.g., graphs, trees). -- **Collision performance**: Speed and accuracy comparisons against other robot toolkits such as CuRobo have not been extensively performed, and is likely slower than other toolkits for collision-heavy scenarios. - -The following are current implementation limitations that could potentially be addressed in future versions: - -- **Joint types**: We only support revolute, continuous, prismatic, and fixed joints. Other URDF joint types are treated as fixed joints. -- **Collision geometry**: We are limited to sphere, capsule, halfspace, and heightmap geometries. Mesh collision is approximated as capsules. -- **Kinematic structures**: We only support kinematic chains; no closed-loop mechanisms or parallel manipulators. - -Examples --------- - -.. toctree:: - :maxdepth: 1 - :caption: Examples - - examples/01_basic_ik - examples/02_bimanual_ik - examples/03_mobile_ik - examples/04_ik_with_coll - examples/05_ik_with_manipulability - examples/06_online_planning - examples/07_trajopt - examples/08_ik_with_mimic_joints - examples/09_hand_retargeting - examples/10_humanoid_retargeting - examples/11_hand_retargeting_fancy - examples/12_humanoid_retargeting_fancy - - -Acknowledgements ----------------- -``PyRoNot`` is heavily inspired by the prior work, including but not limited to -`Trac-IK `_, -`cuRobo `_, -`pink `_, -`mink `_, -`Drake `_, and -`Dex-Retargeting `_. -Thank you so much for your great work! - - -Citation --------- - -If you find this work useful, please cite it as follows: - -.. code-block:: bibtex - - @inproceedings{kim2025pyronot, - title={PyRoNot: A Modular Toolkit for Robot Kinematic Optimization}, - author={Kim*, Chung Min and Yi*, Brent and Choi, Hongsuk and Ma, Yi and Goldberg, Ken and Kanazawa, Angjoo}, - booktitle={2025 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)}, - year={2025}, - url={https://arxiv.org/abs/2505.03728}, - } - -Thanks for using ``PyRoNot``! diff --git a/docs/source/misc/writing_manual_jac.rst b/docs/source/misc/writing_manual_jac.rst deleted file mode 100644 index db2d8c6..0000000 --- a/docs/source/misc/writing_manual_jac.rst +++ /dev/null @@ -1,49 +0,0 @@ -:orphan: - -Defining Jacobians Manually -===================================== - -``pyronot`` supports both autodiff and manually defined Jacobians for computing cost gradients. - -For reference, this is the robot pose matching cost :math:`C_\text{pose}`: - -.. math:: - - \sum_{i} \left( w_{p,i} \left\| \mathbf{p}_{i}(q) - \mathbf{p}_{i}^* \right\|^2 + w_{R,i} \left\| \text{log}(\mathbf{R}_{i}(q)^{-1} \mathbf{R}_{i}^*) \right\|^2 \right) - - -where :math:`q` is the robot joint configuration, :math:`\mathbf{p}_{i}(q)` is the position of the :math:`i`-th link, :math:`\mathbf{R}_{i}(q)` is the rotation matrix of the :math:`i`-th link, and :math:`w_{p,i}` and :math:`w_{R,i}` are the position and orientation weights, respectively. - -The following is the most common way to define costs in ``pyronot`` -- with autodiff: - -.. code-block:: python - - @Cost.create_factory - def pose_cost( - vals: VarValues, - robot: Robot, - joint_var: Var[Array], - target_pose: jaxlie.SE3, - target_link_index: Array, - pos_weight: Array | float, - ori_weight: Array | float, - ) -> Array: - """Computes the residual for matching link poses to target poses.""" - assert target_link_index.dtype == jnp.int32 - joint_cfg = vals[joint_var] - Ts_link_world = robot.forward_kinematics(joint_cfg) - pose_actual = jaxlie.SE3(Ts_link_world[..., target_link_index, :]) - - # Position residual = position error * weight - pos_residual = (pose_actual.translation() - target_pose.translation()) * pos_weight - # Orientation residual = log(actual_inv * target) * weight - ori_residual = (pose_actual.rotation().inverse() @ target_pose.rotation()).log() * ori_weight - - return jnp.concatenate([pos_residual, ori_residual]).flatten() - -The alternative is to manually write out the Jacobian -- while automatic differentiation is convenient and works well for most use cases, analytical Jacobians can provide better performance, which we show in the `paper `_. - -We provide two implementations of pose matching cost with custom Jacobians: - -- an `analytically derived Jacobian `_ (~200 lines), or -- a `numerically approximated Jacobian `_ through finite differences (~50 lines). diff --git a/docs/update_example_docs.py b/docs/update_example_docs.py deleted file mode 100644 index f6e3ace..0000000 --- a/docs/update_example_docs.py +++ /dev/null @@ -1,104 +0,0 @@ -"""Helper script for updating the auto-generated examples pages in the documentation.""" - -from __future__ import annotations - -import dataclasses -import pathlib -import shutil -from typing import Iterable - -import m2r2 -import tyro - - -@dataclasses.dataclass -class ExampleMetadata: - index: str - index_with_zero: str - source: str - title: str - description: str - - @staticmethod - def from_path(path: pathlib.Path) -> ExampleMetadata: - # 01_functions -> 01, _, functions. - index, _, _ = path.stem.partition("_") - - # 01 -> 1. - index_with_zero = index - index = str(int(index)) - - print("Parsing", path) - source = path.read_text().strip() - docstring = source.split('"""')[1].strip() - - title, _, description = docstring.partition("\n") - - description = description.strip() - description += "\n" - description += "\n" - description += "All examples can be run by first cloning the PyRoNot repository, which includes the `pyronot_snippets` implementation details." - - return ExampleMetadata( - index=index, - index_with_zero=index_with_zero, - source=source.partition('"""')[2].partition('"""')[2].strip(), - title=title, - description=description, - ) - - -def get_example_paths(examples_dir: pathlib.Path) -> Iterable[pathlib.Path]: - return filter( - lambda p: not p.name.startswith("_"), sorted(examples_dir.glob("*.py")) - ) - - -REPO_ROOT = pathlib.Path(__file__).absolute().parent.parent - - -def main( - examples_dir: pathlib.Path = REPO_ROOT / "examples", - sphinx_source_dir: pathlib.Path = REPO_ROOT / "docs" / "source", -) -> None: - example_doc_dir = sphinx_source_dir / "examples" - shutil.rmtree(example_doc_dir) - example_doc_dir.mkdir() - - for path in get_example_paths(examples_dir): - ex = ExampleMetadata.from_path(path) - - relative_dir = path.parent.relative_to(examples_dir) - target_dir = example_doc_dir / relative_dir - target_dir.mkdir(exist_ok=True, parents=True) - - (target_dir / f"{path.stem}.rst").write_text( - "\n".join( - [ - ( - ".. Comment: this file is automatically generated by" - " `update_example_docs.py`." - ), - " It should not be modified manually.", - "", - f"{ex.title}", - "==========================================", - "", - m2r2.convert(ex.description), - "", - "", - ".. code-block:: python", - " :linenos:", - "", - "", - "\n".join( - f" {line}".rstrip() for line in ex.source.split("\n") - ), - "", - ] - ) - ) - - -if __name__ == "__main__": - tyro.cli(main, description=__doc__) diff --git a/examples/13_spherized_robot_ik.py b/examples/13_spherized_robot_ik.py index 1a47853..7667a9a 100644 --- a/examples/13_spherized_robot_ik.py +++ b/examples/13_spherized_robot_ik.py @@ -1,7 +1,4 @@ -"""Spherized Robot IK - - -""" +"""Spherized Robot IK""" import time @@ -11,10 +8,10 @@ from viser.extras import ViserUrdf import pyronot_snippets as pks -import yourdfpy +import yourdfpy -def main(): +def main(): # Load the spherized panda urdf do not work!!! urdf_path = "resources/ur5/ur5_spherized.urdf" mesh_dir = "resources/ur5/meshes" @@ -66,5 +63,7 @@ def main(): # Update the collision mesh. robot_coll_mesh = robot_coll.at_config(robot, solution).to_trimesh() server.scene.add_mesh_trimesh("/robot/collision", mesh=robot_coll_mesh) + + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/14_spherized_ik_with_coll.py b/examples/14_spherized_ik_with_coll.py index 6a39d24..38ed6de 100644 --- a/examples/14_spherized_ik_with_coll.py +++ b/examples/14_spherized_ik_with_coll.py @@ -13,7 +13,7 @@ from viser.extras import ViserUrdf import pyronot_snippets as pks -import yourdfpy +import yourdfpy def main(): diff --git a/examples/15_spherized_ik_then_coll.py b/examples/15_spherized_ik_then_coll.py index e3e7c4c..18400d1 100644 --- a/examples/15_spherized_ik_then_coll.py +++ b/examples/15_spherized_ik_then_coll.py @@ -13,7 +13,7 @@ from viser.extras import ViserUrdf import pyronot_snippets as pks -import yourdfpy +import yourdfpy def main(): @@ -55,15 +55,14 @@ def main(): just_ik_timing_handle = server.gui.add_number("just ik (ms)", 0.001, disabled=True) coll_ik_timing_handle = server.gui.add_number("coll ik (ms)", 0.001, disabled=True) while True: - sphere_coll_world_current = sphere_coll.transform_from_wxyz_position( wxyz=np.array(sphere_handle.wxyz), position=np.array(sphere_handle.position), ) start_time = time.time() just_ik = pks.solve_ik( - robot=robot, - target_link_name=target_link_name, + robot=robot, + target_link_name=target_link_name, target_position=np.array(ik_target_handle.position), target_wxyz=np.array(ik_target_handle.wxyz), ) diff --git a/examples/16_spherized_online_planning.py b/examples/16_spherized_online_planning.py index 6123eaf..ff0d65d 100644 --- a/examples/16_spherized_online_planning.py +++ b/examples/16_spherized_online_planning.py @@ -13,7 +13,7 @@ from viser.extras import ViserUrdf import pyronot_snippets as pks -import yourdfpy +import yourdfpy def main(): diff --git a/examples/17_geometries_example.py b/examples/17_geometries_example.py index a95e5ac..dc7ca97 100644 --- a/examples/17_geometries_example.py +++ b/examples/17_geometries_example.py @@ -45,7 +45,8 @@ def main(): height_handle = server.gui.add_number("Height", height) server.scene.add_mesh_trimesh( - "/box/mesh", mesh=Box.from_center_and_dimensions(center, length, width, height).to_trimesh() + "/box/mesh", + mesh=Box.from_center_and_dimensions(center, length, width, height).to_trimesh(), ) server.scene.add_mesh_trimesh("/box/polytope", mesh=trimesh.Trimesh()) @@ -53,24 +54,52 @@ def main(): while True: pos = np.array(box_handle.position) wxyz = np.array(box_handle.wxyz) - length = float(length_handle.value) if hasattr(length_handle, "value") else float(length_handle) - width = float(width_handle.value) if hasattr(width_handle, "value") else float(width_handle) - height = float(height_handle.value) if hasattr(height_handle, "value") else float(height_handle) - - box = Box.from_center_and_dimensions(center=pos, length=length, width=width, height=height, wxyz=wxyz) + length = ( + float(length_handle.value) + if hasattr(length_handle, "value") + else float(length_handle) + ) + width = ( + float(width_handle.value) + if hasattr(width_handle, "value") + else float(width_handle) + ) + height = ( + float(height_handle.value) + if hasattr(height_handle, "value") + else float(height_handle) + ) + + box = Box.from_center_and_dimensions( + center=pos, length=length, width=width, height=height, wxyz=wxyz + ) # Sphere sph_pos = np.array(sphere_handle.position) sph_wxyz = np.array(sphere_handle.wxyz) - sph_rad = float(sphere_radius_handle.value) if hasattr(sphere_radius_handle, "value") else float(sphere_radius_handle) + sph_rad = ( + float(sphere_radius_handle.value) + if hasattr(sphere_radius_handle, "value") + else float(sphere_radius_handle) + ) sphere = Sphere.from_center_and_radius(center=sph_pos, radius=sph_rad) # Capsule cap_pos = np.array(cap_handle.position) cap_wxyz = np.array(cap_handle.wxyz) - cap_rad = float(cap_radius_handle.value) if hasattr(cap_radius_handle, "value") else float(cap_radius_handle) - cap_h = float(cap_height_handle.value) if hasattr(cap_height_handle, "value") else float(cap_height_handle) - capsule = Capsule.from_radius_height(radius=cap_rad, height=cap_h, position=cap_pos, wxyz=cap_wxyz) + cap_rad = ( + float(cap_radius_handle.value) + if hasattr(cap_radius_handle, "value") + else float(cap_radius_handle) + ) + cap_h = ( + float(cap_height_handle.value) + if hasattr(cap_height_handle, "value") + else float(cap_height_handle) + ) + capsule = Capsule.from_radius_height( + radius=cap_rad, height=cap_h, position=cap_pos, wxyz=cap_wxyz + ) server.scene.add_mesh_trimesh("/box/mesh", mesh=box.to_trimesh()) server.scene.add_mesh_trimesh("/sphere/mesh", mesh=sphere.to_trimesh()) @@ -91,7 +120,9 @@ def main(): # d is a jax Array; convert to python float if scalar d_val = float(d) if d_val < 0.0: - print(f"Collision detected {name1} vs {name2}: distance={d_val:.6f}") + print( + f"Collision detected {name1} vs {name2}: distance={d_val:.6f}" + ) except Exception as e: print(f"Error computing collision {name1} vs {name2}: {e}") diff --git a/examples/18_spherized_trajopt.py b/examples/18_spherized_trajopt.py index 40917a9..970206f 100644 --- a/examples/18_spherized_trajopt.py +++ b/examples/18_spherized_trajopt.py @@ -17,7 +17,8 @@ from robot_descriptions.loaders.yourdfpy import load_robot_description import pyronot_snippets as pks -import yourdfpy +import yourdfpy + def main(robot_name: Literal["ur5", "panda"] = "panda"): if robot_name == "ur5": diff --git a/examples/19_spherized_ik_with_coll_exclude_link.py b/examples/19_spherized_ik_with_coll_exclude_link.py index da822b2..10ea766 100644 --- a/examples/19_spherized_ik_with_coll_exclude_link.py +++ b/examples/19_spherized_ik_with_coll_exclude_link.py @@ -13,8 +13,9 @@ from viser.extras import ViserUrdf import pyronot_snippets as pks -import yourdfpy +import yourdfpy +import jax.numpy as jnp def main(): """Main function for basic IK with collision.""" @@ -26,7 +27,9 @@ def main(): # mesh_dir = "resources/panda/meshes" # target_link_name = "panda_hand" urdf = yourdfpy.URDF.load(urdf_path, mesh_dir=mesh_dir) - robot = pk.Robot.from_urdf(urdf, default_joint_cfg=[0, -1.57, 1.57, -1.57, -1.57, 0]) + robot = pk.Robot.from_urdf( + urdf, default_joint_cfg=[0, -1.57, 1.57, -1.57, -1.57, 0] + ) robot_coll = RobotCollisionSpherized.from_urdf(urdf) plane_coll = HalfSpace.from_point_and_normal( @@ -80,27 +83,34 @@ def main(): # Update visualizer. urdf_vis.update_cfg(solution) # print(robot.links.names) - # Compute the collision of the solution + # Compute the collision of the solution + # Method 1: distance_link_to_plane = robot_coll.compute_world_collision_distance( - robot, - solution, - plane_coll + robot, solution, plane_coll + ) + distance_link_to_plane = RobotCollisionSpherized.mask_collision_distance( + distance_link_to_plane, exclude_link_mask + ) + # Method 2: + distance_link_to_plane_2 = robot_coll.compute_world_collision_distance_with_exclude_links( + robot, solution, plane_coll, exclude_link_mask ) - distance_link_to_plane = RobotCollisionSpherized.mask_collision_distance(distance_link_to_plane, exclude_link_mask) - # print(distance_link_to_plane) + assert jnp.allclose(distance_link_to_plane, distance_link_to_plane_2) distance_link_to_sphere = robot_coll.compute_world_collision_distance( - robot, - solution, - sphere_coll + robot, solution, sphere_coll ) - distance_link_to_sphere = RobotCollisionSpherized.mask_collision_distance(distance_link_to_sphere, exclude_link_mask) - # print(distance_link_to_sphere) - # Visualize collision representation + distance_link_to_sphere = RobotCollisionSpherized.mask_collision_distance( + distance_link_to_sphere, exclude_link_mask + ) + distance_link_to_sphere_2 = robot_coll.compute_world_collision_distance_with_exclude_links( + robot, solution, sphere_coll, exclude_link_mask + ) + assert jnp.allclose(distance_link_to_sphere, distance_link_to_sphere_2) + robot_coll_config: Sphere = robot_coll.at_config(robot, solution) - # print(robot_coll_config.get_batch_axes()[-1]) robot_coll_mesh = robot_coll_config.to_trimesh() server.scene.add_mesh_trimesh( - "/robot_coll", + "/robot_coll", mesh=robot_coll_mesh, wxyz=(1.0, 0.0, 0.0, 0.0), position=(0.0, 0.0, 0.0), diff --git a/examples/21_quantize_collision.py b/examples/21_quantize_collision.py index c12de17..c06b763 100644 --- a/examples/21_quantize_collision.py +++ b/examples/21_quantize_collision.py @@ -13,6 +13,7 @@ import yourdfpy from pyronot.utils import quantize + def generate_spheres(n_spheres): print(f"Generating {n_spheres} random spheres...") spheres = [] @@ -21,46 +22,51 @@ def generate_spheres(n_spheres): radius = np.random.uniform(low=0.05, high=0.2) sphere = Sphere.from_center_and_radius(center, np.array([radius])) spheres.append(sphere) - + # Tree map them to create a batch of spheres spheres_batch = jax.tree.map(lambda *args: jnp.stack(args), *spheres) print(f"Generated {n_spheres} spheres.") return spheres_batch + def generate_configs(joints, n_configs): print(f"Generating {n_configs} random robot configurations...") q_batch = np.random.uniform( - low=joints.lower_limits, - high=joints.upper_limits, - size=(n_configs, joints.num_actuated_joints) + low=joints.lower_limits, + high=joints.upper_limits, + size=(n_configs, joints.num_actuated_joints), ) print(f"Generated {n_configs} robot configurations.") print(f"Configurations shape: {q_batch.shape}") return q_batch + def make_collision_checker(robot, robot_coll): @jax.jit def check_collisions(q_batch, obstacles): # q_batch: (N_configs, dof) # obstacles: Sphere batch (N_spheres) - + # Define single config check def check_single(q, obs): return robot_coll.compute_world_collision_distance(robot, q, obs) - + # Vmap over configs # in_axes: q=(0), obs=(None) -> we want to check each q against ALL obs return jax.vmap(check_single, in_axes=(0, None))(q_batch, obstacles) - + return check_collisions + def run_benchmark(name, check_fn, q_batch, obstacles): print(f"\n{name}:") - + # Metrics q_size_mb = q_batch.nbytes / 1024 / 1024 - spheres_size_mb = sum(x.nbytes for x in jax.tree_util.tree_leaves(obstacles)) / 1024 / 1024 - + spheres_size_mb = ( + sum(x.nbytes for x in jax.tree_util.tree_leaves(obstacles)) / 1024 / 1024 + ) + print(f"q_batch size: {q_size_mb:.2f} MB") print(f"Obstacles (spheres) size: {spheres_size_mb:.2f} MB") @@ -79,10 +85,11 @@ def run_benchmark(name, check_fn, q_batch, obstacles): print(f"Time to compute: {end_time - start_time:.6f} seconds") print(f"Collision distances shape: {dists.shape}") print(f"Min distance: {jnp.min(dists)}") - + in_collision = dists < 0 print(f"Number of collision pairs: {jnp.sum(in_collision)}") + def main(): global robot, robot_coll # Load robot @@ -91,11 +98,11 @@ def main(): urdf = yourdfpy.URDF.load(urdf_path, mesh_dir=mesh_dir) robot = pk.Robot.from_urdf(urdf) joints, links = RobotURDFParser.parse(urdf) - + # Initialize collision model print("Initializing collision model...") robot_coll = RobotCollisionSpherized.from_urdf(urdf) - + # Create collision checker check_collisions = make_collision_checker(robot, robot_coll) @@ -105,15 +112,20 @@ def main(): # Run benchmarks run_benchmark("Default (float32)", check_collisions, q_batch, spheres_batch) - + # Quantized q_batch_f16 = quantize(q_batch) spheres_batch_f16 = quantize(spheres_batch) - run_benchmark("Quantized (float16)", check_collisions, q_batch_f16, spheres_batch_f16) + run_benchmark( + "Quantized (float16)", check_collisions, q_batch_f16, spheres_batch_f16 + ) q_batch_int8 = quantize(q_batch, jax.numpy.int8) spheres_batch_int8 = quantize(spheres_batch, jax.numpy.int8) - run_benchmark("Quantized (int8)", check_collisions, q_batch_int8, spheres_batch_int8) + run_benchmark( + "Quantized (int8)", check_collisions, q_batch_int8, spheres_batch_int8 + ) + if __name__ == "__main__": main() diff --git a/examples/22_neural_sdf.py b/examples/22_neural_sdf.py index f62ff54..6db1ecc 100644 --- a/examples/22_neural_sdf.py +++ b/examples/22_neural_sdf.py @@ -8,7 +8,12 @@ import numpy as np import time import pyronot as pk -from pyronot.collision import RobotCollision, RobotCollisionSpherized, NeuralRobotCollision, Sphere +from pyronot.collision import ( + RobotCollision, + RobotCollisionSpherized, + NeuralRobotCollision, + Sphere, +) from pyronot._robot_urdf_parser import RobotURDFParser import yourdfpy from pyronot.utils import quantize @@ -16,6 +21,7 @@ # Global configuration: Set to True to run positional encoding comparison RUN_POSITIONAL_ENCODING = False + def generate_spheres(n_spheres): print(f"Generating {n_spheres} random spheres...") spheres = [] @@ -24,48 +30,57 @@ def generate_spheres(n_spheres): radius = np.random.uniform(low=0.05, high=0.2) sphere = Sphere.from_center_and_radius(center, np.array(radius)) spheres.append(sphere) - + # Tree map them to create a batch of spheres spheres_batch = jax.tree.map(lambda *args: jnp.stack(args), *spheres) print(f"Generated {n_spheres} spheres.") return spheres_batch + def generate_configs(joints, n_configs): print(f"Generating {n_configs} random robot configurations...") q_batch = np.random.uniform( - low=joints.lower_limits, - high=joints.upper_limits, - size=(n_configs, joints.num_actuated_joints) + low=joints.lower_limits, + high=joints.upper_limits, + size=(n_configs, joints.num_actuated_joints), ) print(f"Generated {n_configs} robot configurations.") print(f"Configurations shape: {q_batch.shape}") return q_batch + def make_collision_checker(robot, robot_coll): @jax.jit def check_collisions(q_batch, obstacles): # q_batch: (N_configs, dof) # obstacles: Sphere batch (N_spheres) - + # Define single config check def check_single(q, obs): return robot_coll.compute_world_collision_distance(robot, q, obs) - + # Vmap over configs # in_axes: q=(0), obs=(None) -> we want to check each q against ALL obs return jax.vmap(check_single, in_axes=(0, None))(q_batch, obstacles) - + return check_collisions -def make_neural_collision_checker(robot, robot_coll, spheres_batch, use_positional_encoding=True, pe_min_deg=0, pe_max_deg=6): +def make_neural_collision_checker( + robot, + robot_coll, + spheres_batch, + use_positional_encoding=True, + pe_min_deg=0, + pe_max_deg=6, +): """Train a neural collision model on the given static world and return its checker. This will: - build a NeuralRobotCollision from the exact model - train it on random configs for the provided world (spheres_batch) - expose a vmap'ed collision function with the same signature as make_collision_checker's output - + Args: robot: The Robot instance robot_coll: The exact collision model (RobotCollisionSpherized) @@ -88,8 +103,12 @@ def make_neural_collision_checker(robot, robot_coll, spheres_batch, use_position pe_min_deg=pe_min_deg, pe_max_deg=pe_max_deg, ) - - pe_status = f"with PE (deg {pe_min_deg}-{pe_max_deg})" if use_positional_encoding else "without PE" + + pe_status = ( + f"with PE (deg {pe_min_deg}-{pe_max_deg})" + if use_positional_encoding + else "without PE" + ) print(f"Created neural collision model {pe_status}") # Train neural model on this specific world @@ -98,7 +117,7 @@ def make_neural_collision_checker(robot, robot_coll, spheres_batch, use_position world_geom=spheres_batch, num_samples=10000, batch_size=1000, # Smaller batch = more gradient updates per epoch - epochs=50, # More epochs for better convergence + epochs=50, # More epochs for better convergence learning_rate=1e-3, ) @@ -115,13 +134,16 @@ def check_single(q, obs): return check_collisions + def run_benchmark(name, check_fn, q_batch, obstacles): print(f"\n{name}:") - + # Metrics q_size_mb = q_batch.nbytes / 1024 / 1024 - spheres_size_mb = sum(x.nbytes for x in jax.tree_util.tree_leaves(obstacles)) / 1024 / 1024 - + spheres_size_mb = ( + sum(x.nbytes for x in jax.tree_util.tree_leaves(obstacles)) / 1024 / 1024 + ) + print(f"q_batch size: {q_size_mb:.2f} MB") print(f"Obstacles (spheres) size: {spheres_size_mb:.2f} MB") @@ -142,81 +164,81 @@ def run_benchmark(name, check_fn, q_batch, obstacles): print(f"Max distance: {jnp.max(dists):.6f}") print(f"Mean distance: {jnp.mean(dists):.6f}") print(f"Std distance: {jnp.std(dists):.6f}") - + in_collision = dists < 0 print(f"Number of collision pairs: {jnp.sum(in_collision)}") - + return dists, elapsed_time def compare_results(name, neural_dists, exact_dists): """Compare neural network predictions against exact results.""" import gc - + print(f"\n {name} Comparison ") - # Compute metrics in a memory-efficient way diff = neural_dists - exact_dists mae = float(jnp.mean(jnp.abs(diff))) max_ae = float(jnp.max(jnp.abs(diff))) - rmse = float(jnp.sqrt(jnp.mean(diff ** 2))) + rmse = float(jnp.sqrt(jnp.mean(diff**2))) bias = float(jnp.mean(diff)) del diff # Free memory gc.collect() - + print(f"Mean absolute error: {mae:.6f}") print(f"Max absolute error: {max_ae:.6f}") print(f"RMSE: {rmse:.6f}") print(f"Mean error (bias): {bias:.6f}") - + # Check accuracy at collision boundary exact_in_collision = exact_dists < 0.05 neural_in_collision = neural_dists < 0.05 - + # Compute metrics and convert to Python ints immediately true_positives = int(jnp.sum(exact_in_collision & neural_in_collision)) false_positives = int(jnp.sum(~exact_in_collision & neural_in_collision)) false_negatives = int(jnp.sum(exact_in_collision & ~neural_in_collision)) true_negatives = int(jnp.sum(~exact_in_collision & ~neural_in_collision)) - + # Free the boolean arrays del exact_in_collision, neural_in_collision gc.collect() - + print(f"\nCollision Detection Accuracy (threshold=0.05):") print(f" True Positives: {true_positives}") print(f" False Positives: {false_positives}") print(f" False Negatives: {false_negatives}") print(f" True Negatives: {true_negatives}") - + precision = true_positives / (true_positives + false_positives + 1e-8) recall = true_positives / (true_positives + false_negatives + 1e-8) f1 = 2 * precision * recall / (precision + recall + 1e-8) print(f" Precision: {precision:.4f}") print(f" Recall: {recall:.4f}") print(f" F1 Score: {f1:.4f}") - + return { - 'mae': mae, - 'max_ae': max_ae, - 'rmse': rmse, - 'bias': bias, - 'precision': precision, - 'recall': recall, - 'f1': f1, + "mae": mae, + "max_ae": max_ae, + "rmse": rmse, + "bias": bias, + "precision": precision, + "recall": recall, + "f1": f1, } + def main(): import gc - + # Load robot urdf_path = "resources/ur5/ur5_spherized.urdf" mesh_dir = "resources/ur5/meshes" urdf = yourdfpy.URDF.load(urdf_path, mesh_dir=mesh_dir) robot = pk.Robot.from_urdf(urdf) joints, links = RobotURDFParser.parse(urdf) - + # Initialize collision model print("Initializing collision model...") robot_coll = RobotCollisionSpherized.from_urdf(urdf) @@ -232,96 +254,116 @@ def main(): print("Training neural collision model WITHOUT positional encoding...") print("(Raw link poses as input)") neural_check_without_pe = make_neural_collision_checker( - robot, robot_coll, spheres_batch, + robot, + robot_coll, + spheres_batch, use_positional_encoding=False, ) # Optionally train with positional encoding neural_check_with_pe = None if RUN_POSITIONAL_ENCODING: - print("\n" + "="*70) + print("\n" + "=" * 70) print("Training neural collision model WITH positional encoding...") - print("(iSDF-inspired: projects onto icosahedron directions with frequency bands)") - print("="*70) + print( + "(iSDF-inspired: projects onto icosahedron directions with frequency bands)" + ) + print("=" * 70) neural_check_with_pe = make_neural_collision_checker( - robot, robot_coll, spheres_batch, + robot, + robot_coll, + spheres_batch, use_positional_encoding=True, pe_min_deg=0, pe_max_deg=6, # 7 frequency bands: 2^0, 2^1, ..., 2^6 ) # Run benchmarks - print("\n" + "="*70) + print("\n" + "=" * 70) print("Running benchmarks...") - print("="*70) - + print("=" * 70) + exact_dists, exact_time = run_benchmark( - "Exact (RobotCollisionSpherized)", - exact_check_collisions, q_batch, spheres_batch + "Exact (RobotCollisionSpherized)", + exact_check_collisions, + q_batch, + spheres_batch, ) - + neural_without_pe_dists, neural_without_pe_time = run_benchmark( - "Neural WITHOUT Positional Encoding", - neural_check_without_pe, q_batch, spheres_batch + "Neural WITHOUT Positional Encoding", + neural_check_without_pe, + q_batch, + spheres_batch, ) - + neural_with_pe_dists = None neural_with_pe_time = None if RUN_POSITIONAL_ENCODING and neural_check_with_pe is not None: neural_with_pe_dists, neural_with_pe_time = run_benchmark( - "Neural WITH Positional Encoding", - neural_check_with_pe, q_batch, spheres_batch + "Neural WITH Positional Encoding", + neural_check_with_pe, + q_batch, + spheres_batch, ) - + # Clear JAX caches and force garbage collection to free GPU memory print("\nClearing memory before comparison...") jax.clear_caches() gc.collect() - + # Compare results metrics_without_pe = compare_results( "Neural WITHOUT Positional Encoding vs Exact", - neural_without_pe_dists, exact_dists + neural_without_pe_dists, + exact_dists, ) - + metrics_with_pe = None if RUN_POSITIONAL_ENCODING and neural_with_pe_dists is not None: metrics_with_pe = compare_results( "Neural WITH Positional Encoding vs Exact", - neural_with_pe_dists, exact_dists + neural_with_pe_dists, + exact_dists, ) - + # Summary comparison (only if positional encoding was tested) if RUN_POSITIONAL_ENCODING and metrics_with_pe is not None: print("SUMMARY: Positional Encoding Impact") - print(f"\n{'Metric':<25} {'With PE':<15} {'Without PE':<15} {'Improvement':<15}") - - for metric in ['mae', 'rmse', 'max_ae', 'precision', 'recall', 'f1']: + print( + f"\n{'Metric':<25} {'With PE':<15} {'Without PE':<15} {'Improvement':<15}" + ) + + for metric in ["mae", "rmse", "max_ae", "precision", "recall", "f1"]: with_pe = metrics_with_pe[metric] without_pe = metrics_without_pe[metric] - + # For error metrics, lower is better; for accuracy metrics, higher is better - if metric in ['mae', 'rmse', 'max_ae', 'bias']: + if metric in ["mae", "rmse", "max_ae", "bias"]: improvement = (without_pe - with_pe) / (without_pe + 1e-8) * 100 better = "↓" if with_pe < without_pe else "↑" else: improvement = (with_pe - without_pe) / (without_pe + 1e-8) * 100 better = "↑" if with_pe > without_pe else "↓" - - print(f"{metric:<25} {with_pe:<15.6f} {without_pe:<15.6f} {improvement:+.1f}% {better}") - - print(f"\n{'Inference Time (s)':<25} {neural_with_pe_time:<15.6f} {neural_without_pe_time:<15.6f}") + + print( + f"{metric:<25} {with_pe:<15.6f} {without_pe:<15.6f} {improvement:+.1f}% {better}" + ) + + print( + f"\n{'Inference Time (s)':<25} {neural_with_pe_time:<15.6f} {neural_without_pe_time:<15.6f}" + ) print(f"{'Exact Time (s)':<25} {exact_time:<15.6f}") else: # Simple summary without PE comparison print("SUMMARY: Neural vs Exact") print(f"\n{'Metric':<25} {'Neural (no PE)':<15}") - for metric in ['mae', 'rmse', 'max_ae', 'precision', 'recall', 'f1']: + for metric in ["mae", "rmse", "max_ae", "precision", "recall", "f1"]: print(f"{metric:<25} {metrics_without_pe[metric]:<15.6f}") print(f"\n{'Neural Inference Time (s)':<25} {neural_without_pe_time:<15.6f}") print(f"{'Exact Time (s)':<25} {exact_time:<15.6f}") print(f"{'Speedup':<25} {exact_time / neural_without_pe_time:<15.2f}x") - + # Cleanup del exact_dists, neural_without_pe_dists if neural_with_pe_dists is not None: diff --git a/examples/pyronot_snippets/__init__.py b/examples/pyronot_snippets/__init__.py index f9a774b..f535768 100644 --- a/examples/pyronot_snippets/__init__.py +++ b/examples/pyronot_snippets/__init__.py @@ -9,4 +9,6 @@ from ._solve_ik_with_multiple_targets import ( solve_ik_with_multiple_targets as solve_ik_with_multiple_targets, ) -from ._solve_collision_with_config import solve_collision_with_config as solve_collision_with_config \ No newline at end of file +from ._solve_collision_with_config import ( + solve_collision_with_config as solve_collision_with_config, +) diff --git a/examples/pyronot_snippets/_solve_collision_with_config.py b/examples/pyronot_snippets/_solve_collision_with_config.py index 7a19797..78030ab 100644 --- a/examples/pyronot_snippets/_solve_collision_with_config.py +++ b/examples/pyronot_snippets/_solve_collision_with_config.py @@ -12,6 +12,7 @@ import numpy as onp import pyronot as pk + def solve_collision_with_config( robot: pk.Robot, coll: pk.collision.RobotCollision, diff --git a/src/pyronot/_robot_srdf_parser.py b/src/pyronot/_robot_srdf_parser.py index 52ef5b2..6e3e237 100644 --- a/src/pyronot/_robot_srdf_parser.py +++ b/src/pyronot/_robot_srdf_parser.py @@ -1,51 +1,50 @@ from xml.etree import ElementTree as ET + def read_disabled_collisions_from_srdf(srdf_path): """Read disabled collision pairs from SRDF file.""" - + tree = ET.parse(srdf_path) root = tree.getroot() - + disabled_pairs = [] - - for disable_elem in root.findall('disable_collisions'): - link1 = disable_elem.get('link1') - link2 = disable_elem.get('link2') - reason = disable_elem.get('reason', 'Unknown') - - disabled_pairs.append({ - 'link1': link1, - 'link2': link2, - 'reason': reason - }) - + + for disable_elem in root.findall("disable_collisions"): + link1 = disable_elem.get("link1") + link2 = disable_elem.get("link2") + reason = disable_elem.get("reason", "Unknown") + + disabled_pairs.append({"link1": link1, "link2": link2, "reason": reason}) + return disabled_pairs + def read_group_states_from_srdf(srdf_path): """Read named configurations from SRDF file.""" - + tree = ET.parse(srdf_path) root = tree.getroot() - + group_states = {} - - for state_elem in root.findall('group_state'): - group_name = state_elem.get('group') - state_name = state_elem.get('name') - + + for state_elem in root.findall("group_state"): + group_name = state_elem.get("group") + state_name = state_elem.get("name") + joints = {} - for joint_elem in state_elem.findall('joint'): - joint_name = joint_elem.get('name') - joint_value = float(joint_elem.get('value')) + for joint_elem in state_elem.findall("joint"): + joint_name = joint_elem.get("name") + joint_value = float(joint_elem.get("value")) joints[joint_name] = joint_value - + if group_name not in group_states: group_states[group_name] = {} - + group_states[group_name][state_name] = joints - + return group_states + # Usage if __name__ == "__main__": srdf_path = "resources/ur5/ur5.srdf" @@ -62,4 +61,4 @@ def read_group_states_from_srdf(srdf_path): for group_name, states in group_states.items(): print(f"\nGroup: {group_name}") for state_name, joints in states.items(): - print(f" State '{state_name}': {joints}") \ No newline at end of file + print(f" State '{state_name}': {joints}") diff --git a/src/pyronot/_robot_urdf_parser.py b/src/pyronot/_robot_urdf_parser.py index 49247b3..3ab451f 100644 --- a/src/pyronot/_robot_urdf_parser.py +++ b/src/pyronot/_robot_urdf_parser.py @@ -131,15 +131,23 @@ def get_link_indices(self, link_names: tuple[str, ...]) -> Int[Array, " n_matche """Get the indices of links by names.""" matches = jnp.array([name in link_names for name in self.names]) return jnp.where(matches)[0] - def get_link_mask_from_indices(self, link_indices: Int[Array, " n_matches"]) -> Int[Array, " num_links"]: + + def get_link_mask_from_indices( + self, link_indices: Int[Array, " n_matches"] + ) -> Int[Array, " num_links"]: """Get a mask of links by indices.""" mask = jnp.zeros(self.num_links, dtype=bool) mask = mask.at[link_indices].set(True) return mask - def get_link_mask_from_names(self, link_names: tuple[str, ...]) -> Int[Array, " num_links"]: + + def get_link_mask_from_names( + self, link_names: tuple[str, ...] + ) -> Int[Array, " num_links"]: """Get a mask of links by names.""" link_indices = self.get_link_indices(link_names) return self.get_link_mask_from_indices(link_indices) + + class RobotURDFParser: """Parser for creating Robot instances from URDF files.""" diff --git a/src/pyronot/collision/__init__.py b/src/pyronot/collision/__init__.py index 5a465b3..9c83bba 100644 --- a/src/pyronot/collision/__init__.py +++ b/src/pyronot/collision/__init__.py @@ -10,4 +10,6 @@ from ._robot_collision import RobotCollision as RobotCollision from ._robot_collision import RobotCollisionSpherized as RobotCollisionSpherized from ._neural_collision import NeuralRobotCollision as NeuralRobotCollision -from ._neural_collision import NeuralRobotCollisionSpherized as NeuralRobotCollisionSpherized \ No newline at end of file +from ._neural_collision import ( + NeuralRobotCollisionSpherized as NeuralRobotCollisionSpherized, +) diff --git a/src/pyronot/collision/_geometry.py b/src/pyronot/collision/_geometry.py index c1550ca..46595b4 100644 --- a/src/pyronot/collision/_geometry.py +++ b/src/pyronot/collision/_geometry.py @@ -57,22 +57,23 @@ def reshape(self, shape: tuple[int, ...]) -> Self: def __getitem__(self, index) -> Self: """Get a subset of the geometry by indexing into the batch dimensions. - + Args: index: Index or slice to apply to the batch dimensions - + Returns: New CollGeom object with indexed batch dimensions - + Example: >>> sphere[..., 0] # Get first element of last batch dimension >>> sphere[0:2] # Get first two elements of first batch dimension >>> sphere[..., exclude_indices] # Get elements at specific indices """ return jax.tree.map( - lambda x: x[index] if hasattr(x, '__getitem__') else x, + lambda x: x[index] if hasattr(x, "__getitem__") else x, self, ) + def transform(self, transform: jaxlie.SE3) -> Self: """Left-multiples geometry's pose with an SE(3) transformation.""" with jdc.copy_and_mutate(self) as out: @@ -211,8 +212,8 @@ def from_trimesh(mesh: trimesh.Trimesh) -> Sphere: Returns: Sphere: A Sphere geometry fit to the mesh. - - Author: + + Author: S """ if mesh.is_empty: @@ -223,7 +224,10 @@ def from_trimesh(mesh: trimesh.Trimesh) -> Sphere: # Compute the bounding sphere (center and radius) try: - center_np, radius_val = mesh.bounding_sphere.center, mesh.bounding_sphere.primitive.radius + center_np, radius_val = ( + mesh.bounding_sphere.center, + mesh.bounding_sphere.primitive.radius, + ) center = jnp.array(center_np, dtype=jnp.float32) radius = jnp.array(radius_val, dtype=jnp.float32) except Exception: @@ -235,9 +239,11 @@ def from_trimesh(mesh: trimesh.Trimesh) -> Sphere: return Sphere.from_center_and_radius(center=center, radius=radius) + @jdc.pytree_dataclass class Box(CollGeom): """Box (Rectangular Prism) geometry.""" + @property def half_lengths(self) -> Float[Array, "*batch 3"]: """Half-lengths along local X, Y, Z (size stores three values).""" @@ -261,7 +267,13 @@ def from_center_and_half_lengths( """ half_lengths = jnp.array(half_lengths) lengths = 2.0 * half_lengths - return Box.from_center_and_dimensions(center=center, length=lengths[..., 0], width=lengths[..., 1], height=lengths[..., 2], wxyz=wxyz) + return Box.from_center_and_dimensions( + center=center, + length=lengths[..., 0], + width=lengths[..., 1], + height=lengths[..., 2], + wxyz=wxyz, + ) @staticmethod def from_center_and_dimensions( @@ -365,6 +377,7 @@ def as_halfspaces(self) -> tuple[HalfSpace, ...]: return (hs_px, hs_nx, hs_py, hs_ny, hs_pz, hs_nz) + @jdc.pytree_dataclass class Capsule(CollGeom): """Capsule geometry.""" diff --git a/src/pyronot/collision/_geometry_pairs.py b/src/pyronot/collision/_geometry_pairs.py index 3228c83..f3c352f 100644 --- a/src/pyronot/collision/_geometry_pairs.py +++ b/src/pyronot/collision/_geometry_pairs.py @@ -62,10 +62,17 @@ def _sphere_sphere_dist( pos2: Float[Array, "*batch 3"], radius2: Float[Array, "*batch"], ) -> Float[Array, "*batch"]: - """Helper: Calculates distance between two spheres.""" - _, dist_center = _utils.normalize_with_norm(pos2 - pos1) - dist = dist_center - (radius1 + radius2) - return dist + """Helper: Calculates SQUARED distance between two sphere surfaces. + + Note: Returns (dist_center - radii_sum)², not dist². + For collision check: result < 0 means collision. + """ + diff = pos2 - pos1 + dist_center_sq = jnp.sum(diff * diff, axis=-1) # ||p2 - p1||², no sqrt + radii_sum = radius1 + radius2 + radii_sum_sq = radii_sum * radii_sum + # Return signed indicator: negative means collision + return dist_center_sq - radii_sum_sq def sphere_sphere(sphere1: Sphere, sphere2: Sphere) -> Float[Array, "*batch"]: @@ -217,6 +224,7 @@ def heightmap_halfspace( assert min_dist.shape == batch_axes return min_dist + def box_sphere(box: Box, sphere: Sphere) -> Float[Array, "*batch"]: """Compute signed distance between an oriented box and a sphere. @@ -247,22 +255,23 @@ def box_capsule(box: Box, capsule: Capsule) -> Float[Array, "*batch"]: """ cap_pos = capsule.pose.translation() - cap_axis = capsule.axis + cap_axis = capsule.axis half_h = capsule.height[..., None] * 0.5 - a_w = cap_pos - cap_axis * half_h - b_w = cap_pos + cap_axis * half_h + a_w = cap_pos - cap_axis * half_h + b_w = cap_pos + cap_axis * half_h a = box.pose.inverse().apply(a_w) b = box.pose.inverse().apply(b_w) - hl = box.half_lengths + hl = box.half_lengths ab = b - a ab_len2 = jnp.sum(ab * ab, axis=-1, keepdims=True) t = jnp.clip( jnp.sum((0.0 - a) * ab, axis=-1, keepdims=True) / (ab_len2 + 1e-12), - 0.0, 1.0, + 0.0, + 1.0, ) p = a + t * ab q = jnp.abs(p) - hl @@ -276,7 +285,6 @@ def box_capsule(box: Box, capsule: Capsule) -> Float[Array, "*batch"]: return sdist_box - capsule.radius - def box_halfspace(box: Box, halfspace: HalfSpace) -> Float[Array, "*batch"]: """Compute signed distance between box and a halfspace plane. @@ -301,7 +309,9 @@ def box_halfspace(box: Box, halfspace: HalfSpace) -> Float[Array, "*batch"]: hs_n_bc = jnp.broadcast_to(hs_n, verts_world.shape[:-1] + (3,))[..., None, :] hs_pt_bc = jnp.broadcast_to(hs_pt, verts_world.shape[:-1] + (3,))[..., None, :] - vertex_distances = jnp.einsum("...vi,...i->...v", verts_world - hs_pt_bc, hs_n_bc.squeeze(-2)) + vertex_distances = jnp.einsum( + "...vi,...i->...v", verts_world - hs_pt_bc, hs_n_bc.squeeze(-2) + ) min_dist = jnp.min(vertex_distances, axis=-1) return min_dist diff --git a/src/pyronot/collision/_neural_collision.py b/src/pyronot/collision/_neural_collision.py index 74a005d..0c38bf9 100644 --- a/src/pyronot/collision/_neural_collision.py +++ b/src/pyronot/collision/_neural_collision.py @@ -28,46 +28,48 @@ class NeuralRobotCollision: """ A wrapper class that adds neural network-based collision distance approximation to either RobotCollision or RobotCollisionSpherized. - + The network is trained to overfit to a specific scene, mapping robot link poses directly to collision distances between robot links and the static obstacles. - + Input: Flattened link poses (N links × 7 pose params = N*7 dimensions), optionally with positional encoding for capturing fine geometric details. Output: Flattened distance matrix (N links × M obstacles = N*M dimensions) - + This class uses composition to wrap either collision model type, delegating non-neural methods to the underlying collision model. - + Positional Encoding (inspired by iSDF): When enabled, the input is augmented with sinusoidal positional encodings at multiple frequency scales. This allows the network to learn high-frequency spatial features that are critical for accurate collision distance prediction, especially near obstacle boundaries where distances change rapidly. """ - + # The underlying collision model (either RobotCollision or RobotCollisionSpherized) _collision_model: Union[RobotCollision, RobotCollisionSpherized] - + # Neural network parameters (weights and biases for each layer) - nn_params: List[Tuple[Float[Array, "fan_in fan_out"], Float[Array, "fan_out"]]] = jdc.field(default_factory=list) - + nn_params: List[Tuple[Float[Array, "fan_in fan_out"], Float[Array, "fan_out"]]] = ( + jdc.field(default_factory=list) + ) + # Metadata about the training - these must be static for use in JIT conditionals is_trained: jdc.Static[bool] = False - + # We keep track of the number of obstacles this network was trained for (M) trained_num_obstacles: jdc.Static[int] = 0 - + # Input normalization parameters (computed during training) input_mean: jax.Array = jdc.field(default_factory=lambda: jnp.zeros(1)) input_std: jax.Array = jdc.field(default_factory=lambda: jnp.ones(1)) - + # Positional encoding parameters (static for JIT) use_positional_encoding: jdc.Static[bool] = False pe_min_deg: jdc.Static[int] = 0 pe_max_deg: jdc.Static[int] = 6 pe_scale: jdc.Static[float] = 1.0 - + # Computed PE scale (stored after training for use at inference) pe_scale_computed: jax.Array = jdc.field(default_factory=lambda: jnp.array(1.0)) @@ -75,23 +77,23 @@ class NeuralRobotCollision: @property def num_links(self) -> int: return self._collision_model.num_links - + @property def link_names(self) -> tuple[str, ...]: return self._collision_model.link_names - + @property def coll(self) -> CollGeom: return self._collision_model.coll - + @property def active_idx_i(self): return self._collision_model.active_idx_i - + @property def active_idx_j(self): return self._collision_model.active_idx_j - + @property def is_spherized(self) -> bool: """Returns True if the underlying model is RobotCollisionSpherized.""" @@ -110,7 +112,7 @@ def from_existing( """ Creates a NeuralRobotCollision instance from an existing collision model. Initializes the neural network with random weights. - + Args: original: The original collision model (RobotCollision or RobotCollisionSpherized). layer_sizes: List of hidden layer sizes. The input size is determined by robot DOF, @@ -125,15 +127,15 @@ def from_existing( """ if layer_sizes is None: layer_sizes = [256, 256, 256] - + if key is None: key = jax.random.PRNGKey(0) # We can't fully initialize the network structure until we know the output dimension (N*M), - # which depends on the number of obstacles M. + # which depends on the number of obstacles M. # For now, we just copy the fields and return an untrained instance. # The actual weights will be initialized/shaped during the training setup or first call. - + return NeuralRobotCollision( _collision_model=original, nn_params=[], @@ -223,23 +225,25 @@ def compute_world_collision_distance( ) -> Float[Array, "*batch_combined N M"]: """ Computes collision distances, using the trained neural network if available. - + This assumes that world_geom represents the SAME static obstacles that the network was trained on. The network uses link poses (from forward kinematics) as input and predicts distances based on those poses. - + If positional encoding is enabled, the input is augmented with sinusoidal embeddings at multiple frequency scales to capture fine geometric details. - + Falls back to the underlying collision model's exact computation if not trained. """ if not self.is_trained: # Fallback to the original exact computation if not trained - return self._collision_model.compute_world_collision_distance(robot, cfg, world_geom) + return self._collision_model.compute_world_collision_distance( + robot, cfg, world_geom + ) # Determine batch shapes batch_cfg_shape = cfg.shape[:-1] - + # Check world geom shape to ensure consistency with training (M) world_axes = world_geom.get_batch_axes() if len(world_axes) == 0: @@ -248,23 +252,25 @@ def compute_world_collision_distance( else: M = world_axes[-1] batch_world_shape = world_axes[:-1] - + if M != self.trained_num_obstacles: logger.warning( f"Neural network was trained for {self.trained_num_obstacles} obstacles, " f"but current world_geom has {M}. Falling back to exact computation." ) - return self._collision_model.compute_world_collision_distance(robot, cfg, world_geom) + return self._collision_model.compute_world_collision_distance( + robot, cfg, world_geom + ) # Compute link poses via forward kinematics # Shape: (*batch_cfg, num_links, 7) where 7 = wxyz (4) + xyz (3) link_poses = robot.forward_kinematics(cfg) N = self.num_links - + # Flatten link poses to use as network input # Shape: (*batch_cfg, num_links * 7) link_poses_flat = link_poses.reshape(*batch_cfg_shape, N * 7) - + # Apply positional encoding BEFORE normalization if enabled # (matching the training procedure) if self.use_positional_encoding: @@ -279,22 +285,24 @@ def compute_world_collision_distance( else: # Just normalize link_poses_normalized = (link_poses_flat - self.input_mean) / self.input_std - + # Flatten batch for inference input_dim = link_poses_normalized.shape[-1] input_flat = link_poses_normalized.reshape(-1, input_dim) - + # Run inference predict_fn = jax.vmap(self._forward_nn) dists_flat = predict_fn(input_flat) # Shape: (batch_size, N * M) - + # Reshape output to (*batch_cfg, N, M) dists = dists_flat.reshape(*batch_cfg_shape, N, M) - + # Handle broadcasting with world batch shape if necessary. if batch_world_shape: - expected_batch_combined = jnp.broadcast_shapes(batch_cfg_shape, batch_world_shape) - dists = jnp.broadcast_to(dists, (*expected_batch_combined, N, M)) + expected_batch_combined = jnp.broadcast_shapes( + batch_cfg_shape, batch_world_shape + ) + dists = jnp.broadcast_to(dists, (*expected_batch_combined, N, M)) return dists @@ -307,7 +315,7 @@ def train( epochs: int = 50, learning_rate: float = 1e-3, key: jax.Array = None, - layer_sizes: List[int] = [256, 256, 256, 256] + layer_sizes: List[int] = [256, 256, 256, 256], ) -> "NeuralRobotCollision": """ Trains the neural network to approximate the collision distances for the given world_geom. @@ -318,7 +326,7 @@ def train( where collision spheres end up in world space. """ logger.info("Starting neural collision training...") - + if key is None: key = jax.random.PRNGKey(0) @@ -329,30 +337,32 @@ def train( M = world_axes[-1] if len(world_axes) > 0 else 1 # 1. Generate training data with collision-aware sampling - logger.info(f"Generating {num_samples} samples with collision-aware sampling...") + logger.info( + f"Generating {num_samples} samples with collision-aware sampling..." + ) # Sample configurations using Halton sequence for better space coverage dof = robot.joints.num_actuated_joints lower_limits = robot.joints.lower_limits upper_limits = robot.joints.upper_limits - + # Generate initial Halton sequence samples (2x to have pool for filtering) initial_pool_size = num_samples * 2 halton_samples = halton_sequence(initial_pool_size, dof) q_pool = lower_limits + halton_samples * (upper_limits - lower_limits) - + # Compute distances for the pool to identify collision samples logger.info("Computing distances to identify collision samples...") - + def compute_min_dist(q): dists = self._collision_model.compute_world_collision_distance( robot, q, world_geom ) return jnp.min(dists) - + compute_all_min_dists = jax.vmap(compute_min_dist) min_dists = compute_all_min_dists(q_pool) # Shape: (initial_pool_size,) - + # Rebalance samples to have more collision and near-collision samples collision_threshold = 0.1 # Samples within 10cm of collision q_train = rebalance_samples( @@ -370,34 +380,38 @@ def compute_min_dist(q): # Shape: (num_samples, num_links, 7) where 7 = wxyz (4) + xyz (3) logger.info("Computing link poses via forward kinematics...") link_poses_all = robot.forward_kinematics(q_train) - + # Flatten link poses to (num_samples, num_links * 7) X_train_raw = link_poses_all.reshape(num_samples, N * 7) - + # Apply positional encoding BEFORE normalization if enabled # This is important because positional encoding should operate on the # original spatial scale of the data, not normalized values if self.use_positional_encoding: - # For positional encoding, we want the input scaled so that the + # For positional encoding, we want the input scaled so that the # lowest frequency (2^min_deg) captures the full data range, # and higher frequencies capture finer details. # A scale of 1.0 is typically fine when data is already in reasonable ranges. # We use pi as scale so that the range [-1, 1] maps to [-pi, pi] auto_scale = jnp.pi - logger.info(f"Applying positional encoding (min_deg={self.pe_min_deg}, max_deg={self.pe_max_deg}, scale={auto_scale:.4f})...") + logger.info( + f"Applying positional encoding (min_deg={self.pe_min_deg}, max_deg={self.pe_max_deg}, scale={auto_scale:.4f})..." + ) X_train_pe = positional_encoding( X_train_raw, min_deg=self.pe_min_deg, max_deg=self.pe_max_deg, scale=auto_scale, ) - logger.info(f"Input dimension after positional encoding: {X_train_pe.shape[-1]} (was {X_train_raw.shape[-1]})") - + logger.info( + f"Input dimension after positional encoding: {X_train_pe.shape[-1]} (was {X_train_raw.shape[-1]})" + ) + # Now normalize the positional-encoded features X_mean = jnp.mean(X_train_pe, axis=0, keepdims=True) X_std = jnp.std(X_train_pe, axis=0, keepdims=True) + 1e-8 X_train = (X_train_pe - X_mean) / X_std - + # Store the auto-computed scale for inference self_pe_scale_computed = auto_scale else: @@ -409,22 +423,22 @@ def compute_min_dist(q): # 2. Compute ground truth labels using vmap for acceleration logger.info("Computing ground truth distances (vectorized)...") - + # Use vmap to compute distances for all configurations in parallel def compute_single_dist(q): dists = self._collision_model.compute_world_collision_distance( robot, q, world_geom ) return dists.reshape(-1) # Flatten to (N*M,) - + # Vectorize over all training samples compute_all_dists = jax.vmap(compute_single_dist) Y_train = compute_all_dists(q_train) # Shape: (num_samples, N*M) - + # Compute sample weights based on minimum distance # Give higher weight to collision and near-collision samples Y_min_per_sample = jnp.min(Y_train, axis=1) # Shape: (num_samples,) - + # Weight function: higher weight for collision (dist <= 0) and near-collision # collision: weight = 3.0, near-collision: weight = 2.0, free: weight = 1.0 sample_weights = jnp.where( @@ -433,19 +447,23 @@ def compute_single_dist(q): jnp.where( Y_min_per_sample < collision_threshold, 2.0, # Near-collision samples get 2x weight - 1.0 # Free space samples get normal weight - ) + 1.0, # Free space samples get normal weight + ), ) # Normalize weights so they sum to num_samples (to maintain loss scale) sample_weights = sample_weights * (num_samples / jnp.sum(sample_weights)) - - logger.info(f"Sample weights - collision (3x): {jnp.sum(Y_min_per_sample <= 0)}, near-collision (2x): {jnp.sum((Y_min_per_sample > 0) & (Y_min_per_sample < collision_threshold))}") + + logger.info( + f"Sample weights - collision (3x): {jnp.sum(Y_min_per_sample <= 0)}, near-collision (2x): {jnp.sum((Y_min_per_sample > 0) & (Y_min_per_sample < collision_threshold))}" + ) # 3. Initialize Network # Input dimension depends on whether positional encoding is enabled raw_input_dim = N * 7 # num_links * 7 (wxyz_xyz pose representation) if self.use_positional_encoding: - input_dim = compute_positional_encoding_dim(raw_input_dim, self.pe_min_deg, self.pe_max_deg) + input_dim = compute_positional_encoding_dim( + raw_input_dim, self.pe_min_deg, self.pe_max_deg + ) else: input_dim = raw_input_dim output_dim = N * M # num_links * num_obstacles @@ -460,7 +478,11 @@ def compute_single_dist(q): b = jnp.zeros((fan_out,)) params.append((w, b)) - pe_info = f" with positional encoding (deg {self.pe_min_deg}-{self.pe_max_deg})" if self.use_positional_encoding else "" + pe_info = ( + f" with positional encoding (deg {self.pe_min_deg}-{self.pe_max_deg})" + if self.use_positional_encoding + else "" + ) logger.info( f"Training neural network{pe_info} (Input: {input_dim}, Output: {output_dim} [distances])..." ) @@ -488,50 +510,52 @@ def train_step(params, opt_state, x_batch, y_batch, w_batch, t): """Single training step with Adam optimizer.""" m, v = opt_state beta1, beta2, epsilon = 0.9, 0.999, 1e-8 - + # Compute gradients - loss_val, grads = jax.value_and_grad(loss_fn)(params, x_batch, y_batch, w_batch) - + loss_val, grads = jax.value_and_grad(loss_fn)( + params, x_batch, y_batch, w_batch + ) + # Adam update new_params = [] new_m = [] new_v = [] - + for i in range(len(params)): w, b = params[i] dw, db = grads[i] mw, mb = m[i] vw, vb = v[i] - + # Update biased first moment estimate mw = beta1 * mw + (1.0 - beta1) * dw mb = beta1 * mb + (1.0 - beta1) * db - + # Update biased second moment estimate - vw = beta2 * vw + (1.0 - beta2) * (dw ** 2) - vb = beta2 * vb + (1.0 - beta2) * (db ** 2) - + vw = beta2 * vw + (1.0 - beta2) * (dw**2) + vb = beta2 * vb + (1.0 - beta2) * (db**2) + # Bias correction - m_hat_w = mw / (1.0 - beta1 ** t) - m_hat_b = mb / (1.0 - beta1 ** t) - v_hat_w = vw / (1.0 - beta2 ** t) - v_hat_b = vb / (1.0 - beta2 ** t) - + m_hat_w = mw / (1.0 - beta1**t) + m_hat_b = mb / (1.0 - beta1**t) + v_hat_w = vw / (1.0 - beta2**t) + v_hat_b = vb / (1.0 - beta2**t) + # Update parameters w_new = w - learning_rate * m_hat_w / (jnp.sqrt(v_hat_w) + epsilon) b_new = b - learning_rate * m_hat_b / (jnp.sqrt(v_hat_b) + epsilon) - + new_params.append((w_new, b_new)) new_m.append((mw, mb)) new_v.append((vw, vb)) - + return new_params, (new_m, new_v), loss_val # Initialize Adam state m = [(jnp.zeros_like(w), jnp.zeros_like(b)) for w, b in params] v = [(jnp.zeros_like(w), jnp.zeros_like(b)) for w, b in params] opt_state = (m, v) - + params_state = params t = 0 num_batches = num_samples // batch_size @@ -560,9 +584,7 @@ def train_step(params, opt_state, x_batch, y_batch, w_batch, t): epoch_loss += loss_val if epoch % 10 == 0: - logger.info( - f"Epoch {epoch}: Loss = {epoch_loss / num_batches:.6f}" - ) + logger.info(f"Epoch {epoch}: Loss = {epoch_loss / num_batches:.6f}") logger.info("Training complete.") @@ -578,4 +600,4 @@ def train_step(params, opt_state, x_batch, y_batch, w_batch, t): # Backward compatibility alias -NeuralRobotCollisionSpherized = NeuralRobotCollision \ No newline at end of file +NeuralRobotCollisionSpherized = NeuralRobotCollision diff --git a/src/pyronot/collision/_robot_collision.py b/src/pyronot/collision/_robot_collision.py index c223dd6..a4b6eb4 100644 --- a/src/pyronot/collision/_robot_collision.py +++ b/src/pyronot/collision/_robot_collision.py @@ -1,7 +1,6 @@ from __future__ import annotations from typing import TYPE_CHECKING, Optional, Tuple, cast -import math import jax import jax.numpy as jnp @@ -10,7 +9,6 @@ import trimesh import yourdfpy from jaxtyping import Array, Float, Int -from jax import lax from loguru import logger if TYPE_CHECKING: @@ -153,8 +151,6 @@ def _compute_active_pair_indices( return active_i, active_j - - @staticmethod def _get_trimesh_collision_geometries( urdf: yourdfpy.URDF, link_name: str @@ -425,12 +421,15 @@ def compute_world_collision_distance( # 5. Return the distance matrix return dist_matrix + @jdc.pytree_dataclass class RobotCollisionSpherized: """Collision model for a robot, integrated with pyronot kinematics.""" num_links: jdc.Static[int] """Number of links in the model (matches kinematics links).""" + spheres_per_link: jdc.Static[tuple[int, ...]] + """Number of spheres for each link (tuple of length num_links).""" link_names: jdc.Static[tuple[str, ...]] """Names of the links corresponding to link indices.""" coll: CollGeom @@ -476,50 +475,42 @@ def from_urdf( # Gather all collision meshes. link_sphere_meshes: list[list[trimesh.Trimesh]] = [] for link_name in link_name_list: - spheres = RobotCollisionSpherized._get_trimesh_collision_spheres_for_link(urdf, link_name) + spheres = RobotCollisionSpherized._get_trimesh_collision_spheres_for_link( + urdf, link_name + ) link_sphere_meshes.append(spheres) - - sphere_list_per_link: list[list[Sphere]] = [] + # Build flat list of spheres and track count per link + sphere_list: list[Sphere] = [] + spheres_per_link: list[int] = [] for sphere_meshes in link_sphere_meshes: per_link_spheres = [ Sphere.from_trimesh(mesh) for mesh in sphere_meshes if mesh is not None ] - sphere_list_per_link.append(per_link_spheres) - - ############ Weihang: Please check this part ############# - # Add padding to the spheres list to make it a batched sphere object - max_spheres = max(len(spheres) for spheres in sphere_list_per_link) - padded_sphere_list: list[Sphere] = [] - for per_link_spheres in sphere_list_per_link: - if len(per_link_spheres) < max_spheres: - # Create dummy/invalid spheres for padding (e.g., zero radius) - dummy_sphere = Sphere.from_center_and_radius( - center=jnp.zeros(3), - radius=jnp.array(0.0) # or negative to mark as invalid - ) - padded = per_link_spheres + [dummy_sphere] * (max_spheres - len(per_link_spheres)) - padded_sphere_list.append(padded) - else: - padded_sphere_list.append(per_link_spheres) - spheres_2d = cast(Sphere, jax.tree.map(lambda *args: jnp.stack(args), *padded_sphere_list)) + sphere_list.extend(per_link_spheres) + spheres_per_link.append(len(per_link_spheres)) - ########################################################## + # Stack all spheres into a batched Sphere object + spheres = cast(Sphere, jax.tree.map(lambda *args: jnp.stack(args), *sphere_list)) # Directly compute active pair indices # Weihang: Have not checked this part yet!!! # Should be fine, generates active_indices per links - Sai if srdf_path: - active_idx_i, active_idx_j = RobotCollisionSpherized._compute_active_pair_indices_from_srdf( - link_names=link_name_list, - srdf_path=srdf_path, + active_idx_i, active_idx_j = ( + RobotCollisionSpherized._compute_active_pair_indices_from_srdf( + link_names=link_name_list, + srdf_path=srdf_path, + ) ) else: - active_idx_i, active_idx_j = RobotCollisionSpherized._compute_active_pair_indices( - link_names=link_name_list, - urdf=urdf, - user_ignore_pairs=user_ignore_pairs, - ignore_immediate_adjacents=ignore_immediate_adjacents, + active_idx_i, active_idx_j = ( + RobotCollisionSpherized._compute_active_pair_indices( + link_names=link_name_list, + urdf=urdf, + user_ignore_pairs=user_ignore_pairs, + ignore_immediate_adjacents=ignore_immediate_adjacents, + ) ) logger.info( @@ -530,18 +521,20 @@ def from_urdf( return RobotCollisionSpherized( num_links=link_info.num_links, + spheres_per_link=tuple(spheres_per_link), link_names=link_name_list, active_idx_i=active_idx_i, active_idx_j=active_idx_j, - coll=spheres_2d, # now stores lists of Sphere objects per link + coll=spheres, ) @staticmethod def _get_trimesh_collision_spheres_for_link( - urdf: yourdfpy.URDF, link_name: str) -> list[trimesh.Trimesh]: + urdf: yourdfpy.URDF, link_name: str + ) -> list[trimesh.Trimesh]: if link_name not in urdf.link_map: return [trimesh.Trimesh()] - + link = urdf.link_map[link_name] filename_handler = urdf._filename_handler coll_meshes = [] @@ -557,11 +550,11 @@ def _get_trimesh_collision_spheres_for_link( transform = jaxlie.SE3.identity().as_matrix() if geom.sphere is not None: mesh = trimesh.creation.icosphere(radius=geom.sphere.radius) - else: + else: logger.warning( f"Unsupported collision geometry type for link '{link_name}'." ) - if mesh is not None: + if mesh is not None: mesh.apply_transform(transform) coll_meshes.append(mesh) return coll_meshes @@ -583,11 +576,11 @@ def _compute_active_pair_indices_from_srdf( Tuple of (active_i, active_j) index arrays. """ from .._robot_srdf_parser import read_disabled_collisions_from_srdf - + num_links = len(link_names) link_name_to_idx = {name: i for i, name in enumerate(link_names)} ignore_matrix = jnp.zeros((num_links, num_links), dtype=bool) - + # Ignore self-collisions (diagonal) ignore_matrix = ignore_matrix.at[ jnp.arange(num_links), jnp.arange(num_links) @@ -596,32 +589,34 @@ def _compute_active_pair_indices_from_srdf( # Read disabled collision pairs from SRDF try: disabled_pairs = read_disabled_collisions_from_srdf(srdf_path) - + disabled_count = 0 for pair in disabled_pairs: - link1_name = pair['link1'] - link2_name = pair['link2'] - + link1_name = pair["link1"] + link2_name = pair["link2"] + # Only process if both links are in our link_names if link1_name in link_name_to_idx and link2_name in link_name_to_idx: idx1 = link_name_to_idx[link1_name] idx2 = link_name_to_idx[link2_name] - + # Set both directions as disabled (symmetric) ignore_matrix = ignore_matrix.at[idx1, idx2].set(True) ignore_matrix = ignore_matrix.at[idx2, idx1].set(True) disabled_count += 1 - + logger.info(f"Loaded {disabled_count} disabled collision pairs from SRDF") - + except FileNotFoundError: - logger.warning(f"SRDF file not found: {srdf_path}. Using all collision pairs.") + logger.warning( + f"SRDF file not found: {srdf_path}. Using all collision pairs." + ) except Exception as e: logger.warning(f"Error parsing SRDF file: {e}. Using all collision pairs.") # Get all lower triangular indices (i < j) idx_i, idx_j = jnp.tril_indices(num_links, k=-1) - + # Filter out ignored pairs should_check = ~ignore_matrix[idx_i, idx_j] active_i = idx_i[should_check] @@ -690,7 +685,7 @@ def _get_trimesh_collision_spheres(urdf: yourdfpy.URDF) -> list[trimesh.Trimesh] Returns: list[trimesh.Trimesh]: A list of trimesh sphere meshes (each transformed to link frame). - Author: + Author: Sai Coumar """ sphere_meshes = [] @@ -751,17 +746,20 @@ def at_config( ) # TODO: Override with passed in result of fk so i don't have to recompute Ts_link_world_wxyz_xyz = robot.forward_kinematics(cfg) - Ts_link_world = jaxlie.SE3(Ts_link_world_wxyz_xyz) - ############ Weihang: Please check this part ############# - coll_transformed = [] - for link in range(len(self.coll)): - coll_transformed.append(self.coll[link].transform(Ts_link_world)) - coll_transformed = cast(CollGeom, jax.tree.map(lambda *args: jnp.stack(args), *coll_transformed)) - ########################################################## - return coll_transformed - # return self.coll.transform(Ts_link_world) - + # Spheres are laid out as [link0_s0..sN0, link1_s0..sN1, ...] + # where Ni = spheres_per_link[i] + # Repeat each link's transform by its sphere count + total_spheres = sum(self.spheres_per_link) + expanded_transforms = jnp.repeat( + Ts_link_world_wxyz_xyz, + jnp.array(self.spheres_per_link), + axis=-2, + total_repeat_length=total_spheres + ) + + Ts_expanded = jaxlie.SE3(expanded_transforms) + return self.coll.transform(Ts_expanded) def compute_self_collision_distance( self, @@ -780,10 +778,10 @@ def compute_self_collision_distance( Signed distances for each active pair. Shape: (*batch, num_active_pairs). Positive distance means separation, negative means penetration. - + Author: Sai Coumar """ - + # 1. Transform all spheres to world frame coll = self.at_config(robot, cfg) # CollGeom: (*batch, n_spheres, num_links) @@ -796,15 +794,14 @@ def compute_self_collision_distance( # Return same format of active_distances as the capsule implementaiton active_distances = dist_matrix_links[..., self.active_idx_i, self.active_idx_j] return active_distances - @staticmethod def collide_link_vs_world(link_geom, world_geom): - # Map collide over spheres in this link (S) - # link_geom: (S, ...) - collide_spheres_vs_world = jax.vmap(collide, in_axes=(0, None), out_axes=0) - dist_spheres = collide_spheres_vs_world(link_geom, world_geom) # (S, M) - return dist_spheres.min(axis=0) # reduce over spheres → (M,) + # Map collide over spheres in this link (S) + # link_geom: (S, ...) + collide_spheres_vs_world = jax.vmap(collide, in_axes=(0, None), out_axes=0) + dist_spheres = collide_spheres_vs_world(link_geom, world_geom) # (S, M) + return dist_spheres.min(axis=0) # reduce over spheres → (M,) @jdc.jit def compute_world_collision_distance( @@ -814,144 +811,91 @@ def compute_world_collision_distance( world_geom: CollGeom, # Shape: (*batch_world, M, ...) ) -> Float[Array, "*batch_combined N M"]: """ - Computes signed distances between all robot links (N) and world obstacles (M), - accounting for multiple primitives (S) per link. + Computes the signed distances between all robot links (N) and all world obstacles (M). - The minimum distance over all primitives in each link is used as the link’s - representative distance to each world object. + Args: + robot_coll: The robot's collision model. + robot: The robot's kinematic model. + cfg: The robot configuration (actuated joints). + world_geom: Collision geometry representing world obstacles. If representing a + single obstacle, it should have batch shape (). If multiple, the last axis + is interpreted as the collection of world objects (M). + The batch dimensions (*batch_world) must be broadcast-compatible with cfg's + batch axes (*batch_cfg). + + Returns: + Matrix of signed distances between each robot link and each world object. + Shape: (*batch_combined, N, M), where N=num_links, M=num_world_objects. + Positive distance means separation, negative means penetration. """ - # 1. Get robot collision geometry at configuration - # Shape: (*batch_cfg, S, N, ...) + # 1. Get robot collision geometry at the current config + # Shape: (*batch_cfg, N, ...) coll_robot_world = self.at_config(robot, cfg) - batch_cfg_shape = coll_robot_world.get_batch_axes()[:-2] - S, N = coll_robot_world.get_batch_axes()[-2:] - + N = self.num_links + batch_cfg_shape = coll_robot_world.get_batch_axes()[:-1] # 2. Normalize world_geom shape and determine M world_axes = world_geom.get_batch_axes() - if len(world_axes) == 0: + if len(world_axes) == 0: # Single world object + # Use the object's broadcast_to method to add the M=1 axis correctly _world_geom = world_geom.broadcast_to((1,)) M = 1 batch_world_shape = () - else: + else: # Multiple world objects _world_geom = world_geom M = world_axes[-1] batch_world_shape = world_axes[:-1] - # 3. Define how to collide a single link (with S primitives) against the world - # Each link_geom has shape (S, ...). We map over the S primitives, then take min. - - - # 4. Now map that over links (N) - # coll_robot_world: (*batch_cfg, S, N, ...) - # We map over the link axis (-2 from end, N) - _collide_links_vs_world = jax.vmap( - self.collide_link_vs_world, in_axes=(-2, None), out_axes=-2 - ) - - # # 5. Compute final distance matrix - dist_matrix = _collide_links_vs_world(coll_robot_world, _world_geom) # (*batch, N, M) - + # 3. Compute distances: Map collide over robot links (axis -2) vs _world_geom (None) + # _world_geom is guaranteed to have the M axis now. + _collide_links_vs_world = jax.vmap(collide, in_axes=(-2, None), out_axes=(-2)) + dist_matrix = _collide_links_vs_world(coll_robot_world, _world_geom) + # 5. Return the distance matrix + return dist_matrix - # 6. Verify shape consistency - expected_batch_combined = jnp.broadcast_shapes(batch_cfg_shape, batch_world_shape) - expected_shape = (*expected_batch_combined, N, M) - assert dist_matrix.shape == expected_shape, ( - f"Output shape mismatch. Expected {expected_shape}, got {dist_matrix.shape}. " - f"Robot axes: {coll_robot_world.get_batch_axes()}, " - f"World axes: {world_geom.get_batch_axes()}" - ) + def compute_world_collision_distance_with_exclude_links( + self, + robot: Robot, + cfg: Float[Array, "*batch_cfg actuated_count"], + world_geom: CollGeom, # Shape: (*batch_world, M, ...) + exclude_link_mask: Int[Array, " num_links"], + ) -> Float[Array, "*batch_combined N M"]: + """ + Computes signed distances between all robot links (N) and world obstacles (M), + accounting for multiple primitives (S) per link. + The minimum distance over all primitives in each link is used as the link's + representative distance to each world object. + """ + dist_matrix = self.compute_world_collision_distance(robot, cfg, world_geom) + dist_matrix = RobotCollisionSpherized.mask_collision_distance(dist_matrix, exclude_link_mask) return dist_matrix - # @jdc.jit - # def compute_world_collision_distance( - # self, - # robot: Robot, - # cfg: Float[Array, "*batch_cfg actuated_count"], - # world_geom: CollGeom, # Shape: (*batch_world, M, ...) - # ) -> Float[Array, "*batch_combined N M"]: - # """ - # Computes signed distances between all robot links (N) and world obstacles (M), - # accounting for multiple primitives (S) per link. - - # The minimum distance over all primitives in each link is used as the link’s - # representative distance to each world object. - # """ - # CHUNK_SIZE = 1000 - - # # 1. Get robot collision geometry at configuration - # # Shape: (S, *batch_cfg, N, ...) - # coll_robot_world = self.at_config(robot, cfg) - - # # 2. Normalize world_geom shape and determine M - # world_axes = world_geom.get_batch_axes() - # if len(world_axes) == 0: - # _world_geom = world_geom.broadcast_to((1,)) - # M = 1 - # else: - # _world_geom = world_geom - # M = world_axes[-1] - - # # 3. Prepare for scan over links (N) - # # We want to iterate over the N axis. - # # coll_robot_world has N at the last batch axis. - # # We move N to the front (0) to scan over it. - # n_batch = len(coll_robot_world.get_batch_axes()) - # coll_robot_world_scannable = jax.tree.map( - # lambda x: jnp.moveaxis(x, -2 if x.ndim > n_batch else -1, 0), - # coll_robot_world, - # ) - - # # Pad to multiple of CHUNK_SIZE - # N = self.num_links - # pad_size = (CHUNK_SIZE - (N % CHUNK_SIZE)) % CHUNK_SIZE - - # @jdc.jit - # def pad_fn(x): - # # x shape: (N, ...) - # padding = [(0, 0)] * x.ndim - # padding[0] = (0, pad_size) - # return jnp.pad(x, padding, mode='edge') - - # coll_padded = jax.tree.map(pad_fn, coll_robot_world_scannable) - - # # Reshape to (num_chunks, CHUNK_SIZE, ...) - # num_chunks = (N + pad_size) // CHUNK_SIZE - - # @jdc.jit - # def reshape_fn(x): - # return x.reshape((num_chunks, CHUNK_SIZE) + x.shape[1:]) - - # coll_chunked = jax.tree.map(reshape_fn, coll_padded) - - # # 4. Define scan function - # @jdc.jit - # def scan_fn(carry, link_chunk_geom): - # # link_chunk_geom: (CHUNK_SIZE, S, *batch_cfg, ...) - # # _world_geom: (*batch_world, M, ...) - - # # vmap over the chunk (axis 0) - # _collide_chunk = jax.vmap( - # self.collide_link_vs_world, in_axes=(0, None), out_axes=0 - # ) - # d_chunk = _collide_chunk(link_chunk_geom, _world_geom) - # return carry, d_chunk - - # # 5. Run scan - # # dists_scanned: (num_chunks, CHUNK_SIZE, *batch_combined, M) - # _, dists_scanned = lax.scan(scan_fn, None, coll_chunked) - - # # 6. Restore shape - # # Flatten chunks - # dists_flattened = dists_scanned.reshape((-1,) + dists_scanned.shape[2:]) - # # Slice to original N - # dists_N = dists_flattened[:N] - # # Move N to -2 - # dist_matrix = jnp.moveaxis(dists_N, 0, -2) - - # return dist_matrix + def is_in_collision( + self, + robot: Robot, + cfg: Float[Array, "*batch_cfg actuated_count"], + world_geom: CollGeom, # Shape: (*batch_world, M, ...) + ) -> Bool[Array, "*batch_combined N M"]: + """ + Checks if the robot is in collision with the world obstacles. + """ + dist_matrix = self.compute_world_collision_distance(robot, cfg, world_geom) + return dist_matrix < 0 + + def is_in_collision_with_exclude_links( + self, + robot: Robot, + cfg: Float[Array, "*batch_cfg actuated_count"], + world_geom: CollGeom, # Shape: (*batch_world, M, ...) + exclude_link_mask: Int[Array, " num_links"], + ) -> Bool[Array, "*batch_combined N M"]: + """ + Checks if the robot is in collision with the world obstacles, excluding the specified links. + """ + dist_matrix = self.compute_world_collision_distance_with_exclude_links(robot, cfg, world_geom, exclude_link_mask) + return dist_matrix < 0 def get_swept_capsules( self, @@ -993,16 +937,20 @@ def get_swept_capsules( ) return swept_capsules - + @staticmethod - def mask_collision_distance(solution: Float[Array, "1D array = num_links"], exclude_link_mask: Int[Array, " num_links"], replace_value: float = 1e6) -> Float[Array, "1D array = num_links"]: + def mask_collision_distance( + solution: Float[Array, "1D array = num_links"], + exclude_link_mask: Int[Array, " num_links"], + replace_value: float = 1e6, + ) -> Float[Array, "1D array = num_links"]: """Mask collision distances at specified link indices by replacing them with a value. - + Args: solution: Collision distance array with shape (*batch, actuated_count) exclude_link_indices: Indices of links to exclude from collision checking replace_value: Value to replace at the excluded indices (default: 1e6 for large distance) - + Returns: Masked collision distance array with the same shape as solution """ diff --git a/src/pyronot/utils.py b/src/pyronot/utils.py index 3f7dd0d..3761b04 100644 --- a/src/pyronot/utils.py +++ b/src/pyronot/utils.py @@ -12,36 +12,87 @@ # Icosahedron vertices for positional encoding directions (from iSDF) # These 20 directions provide good coverage of the unit sphere -_ICOSAHEDRON_DIRS = jnp.array([ - [0.8506508, 0, 0.5257311], - [0.809017, 0.5, 0.309017], - [0.5257311, 0.8506508, 0], - [1, 0, 0], - [0.809017, 0.5, -0.309017], - [0.8506508, 0, -0.5257311], - [0.309017, 0.809017, -0.5], - [0, 0.5257311, -0.8506508], - [0.5, 0.309017, -0.809017], - [0, 1, 0], - [-0.5257311, 0.8506508, 0], - [-0.309017, 0.809017, -0.5], - [0, 0.5257311, 0.8506508], - [-0.309017, 0.809017, 0.5], - [0.309017, 0.809017, 0.5], - [0.5, 0.309017, 0.809017], - [0.5, -0.309017, 0.809017], - [0, 0, 1], - [-0.5, 0.309017, 0.809017], - [-0.809017, 0.5, 0.309017], -]).T # Shape: (3, 20) for efficient matmul +_ICOSAHEDRON_DIRS = jnp.array( + [ + [0.8506508, 0, 0.5257311], + [0.809017, 0.5, 0.309017], + [0.5257311, 0.8506508, 0], + [1, 0, 0], + [0.809017, 0.5, -0.309017], + [0.8506508, 0, -0.5257311], + [0.309017, 0.809017, -0.5], + [0, 0.5257311, -0.8506508], + [0.5, 0.309017, -0.809017], + [0, 1, 0], + [-0.5257311, 0.8506508, 0], + [-0.309017, 0.809017, -0.5], + [0, 0.5257311, 0.8506508], + [-0.309017, 0.809017, 0.5], + [0.309017, 0.809017, 0.5], + [0.5, 0.309017, 0.809017], + [0.5, -0.309017, 0.809017], + [0, 0, 1], + [-0.5, 0.309017, 0.809017], + [-0.809017, 0.5, 0.309017], + ] +).T # Shape: (3, 20) for efficient matmul # First 50 prime numbers for Halton sequence bases -_HALTON_PRIMES = jnp.array([ - 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, - 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, - 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229 -]) +_HALTON_PRIMES = jnp.array( + [ + 2, + 3, + 5, + 7, + 11, + 13, + 17, + 19, + 23, + 29, + 31, + 37, + 41, + 43, + 47, + 53, + 59, + 61, + 67, + 71, + 73, + 79, + 83, + 89, + 97, + 101, + 103, + 107, + 109, + 113, + 127, + 131, + 137, + 139, + 149, + 151, + 157, + 163, + 167, + 173, + 179, + 181, + 191, + 193, + 197, + 199, + 211, + 223, + 227, + 229, + ] +) def positional_encoding( @@ -52,47 +103,49 @@ def positional_encoding( ) -> Float[Array, "... embed_dim"]: """ Standard sinusoidal positional encoding (NeRF-style). - + Applies sinusoidal encoding at multiple frequency bands to each input dimension. This helps the network capture high-frequency spatial details. - + Args: x: Input tensor of shape (..., D). min_deg: Minimum frequency degree (default 0). max_deg: Maximum frequency degree (default 6). scale: Scale factor applied to input before encoding (default 1.0). - + Returns: Positional encoding of shape (..., D + D * 2 * num_freqs). """ n_freqs = max_deg - min_deg + 1 # Frequency bands: 2^min_deg, 2^(min_deg+1), ..., 2^max_deg frequency_bands = 2.0 ** jnp.arange(min_deg, max_deg + 1) - + # Scale input x_scaled = x * scale - + # Start with original features embeddings = [x_scaled] - + # Apply sin and cos at each frequency to each dimension for freq in frequency_bands: embeddings.append(jnp.sin(freq * x_scaled)) embeddings.append(jnp.cos(freq * x_scaled)) - + # Concatenate all embeddings along last axis return jnp.concatenate(embeddings, axis=-1) -def compute_positional_encoding_dim(input_dim: int, min_deg: int = 0, max_deg: int = 6) -> int: +def compute_positional_encoding_dim( + input_dim: int, min_deg: int = 0, max_deg: int = 6 +) -> int: """ Compute the output dimension of positional encoding. - + Args: input_dim: Input dimension D. min_deg: Minimum frequency degree. max_deg: Maximum frequency degree. - + Returns: Output dimension after positional encoding. """ @@ -104,47 +157,49 @@ def compute_positional_encoding_dim(input_dim: int, min_deg: int = 0, max_deg: i def halton_sequence(num_samples: int, dim: int, skip: int = 100) -> jax.Array: """ Generate a Halton sequence for quasi-random sampling. - + The Halton sequence provides better coverage of the sample space compared to uniform random sampling, which is beneficial for diverse sample collection. - + Args: num_samples: Number of samples to generate. dim: Dimensionality of each sample. skip: Number of initial samples to skip (improves uniformity). - + Returns: JAX array of shape (num_samples, dim) with values in [0, 1]. """ if dim > len(_HALTON_PRIMES): - raise ValueError(f"Halton sequence dimension {dim} exceeds available primes ({len(_HALTON_PRIMES)})") - + raise ValueError( + f"Halton sequence dimension {dim} exceeds available primes ({len(_HALTON_PRIMES)})" + ) + # Generate samples using vectorized operations where possible indices = jnp.arange(skip, skip + num_samples) bases = _HALTON_PRIMES[:dim] - + # Compute maximum number of digits needed for the largest index max_index = skip + num_samples max_digits = int(jnp.ceil(jnp.log(max_index + 1) / jnp.log(2))) + 1 - + def halton_for_base(base: int) -> jax.Array: """Vectorized Halton sequence for a single base.""" # Compute radical inverse for all indices at once result = jnp.zeros(num_samples) f = 1.0 / base current = indices.astype(jnp.float32) - + for _ in range(max_digits): digit = jnp.mod(current, base) result = result + f * digit current = jnp.floor(current / base) f = f / base - + return result - + # Stack results for all dimensions samples = jnp.stack([halton_for_base(int(b)) for b in bases], axis=1) - + return samples @@ -166,7 +221,7 @@ def rebalance_samples( """ Rebalance a sample pool to have a target distribution of collision, near-collision, and free-space samples. - + Args: samples: Initial sample pool of shape (pool_size, dof). distances: Minimum distances for each sample, shape (pool_size,). @@ -182,140 +237,184 @@ def rebalance_samples( max_augment_iterations: Maximum iterations for augmentation (default 10). perturbation_scale_collision: Perturbation scale for collision augmentation (default 0.05). perturbation_scale_near: Perturbation scale for near-collision augmentation (default 0.03). - + Returns: Rebalanced samples of shape (num_samples, dof). """ dof = samples.shape[1] - + # Separate samples into categories is_in_collision = distances <= 0 is_near_collision = (distances > 0) & (distances < collision_threshold) is_free_space = distances >= collision_threshold - + collision_samples = samples[is_in_collision] near_collision_samples = samples[is_near_collision] free_space_samples = samples[is_free_space] - + num_collision = collision_samples.shape[0] num_near_collision = near_collision_samples.shape[0] num_free = free_space_samples.shape[0] - - logger.info(f"Sample distribution from pool: collision={num_collision}, near-collision={num_near_collision}, free-space={num_free}") - + + logger.info( + f"Sample distribution from pool: collision={num_collision}, near-collision={num_near_collision}, free-space={num_free}" + ) + # Target distribution target_collision = int(num_samples * target_collision_ratio) target_near = int(num_samples * target_near_ratio) target_free = num_samples - target_collision - target_near - + key_augment = key - + # Augment collision samples if needed if num_collision < target_collision and num_collision > 0: - logger.info(f"Augmenting collision samples from {num_collision} to {target_collision}...") - + logger.info( + f"Augmenting collision samples from {num_collision} to {target_collision}..." + ) + samples_needed = target_collision - num_collision augmented_list = [] iteration = 0 - - while len(augmented_list) < samples_needed and iteration < max_augment_iterations: + + while ( + len(augmented_list) < samples_needed and iteration < max_augment_iterations + ): iteration += 1 batch_size_aug = min(samples_needed * 2, 5000) key_augment, subk1, subk2 = jax.random.split(key_augment, 3) indices = jax.random.randint(subk1, (batch_size_aug,), 0, num_collision) base_samples = collision_samples[indices] - - perturbation_range = perturbation_scale_collision * (upper_limits - lower_limits) - perturbations = jax.random.uniform(subk2, (batch_size_aug, dof), minval=-1, maxval=1) * perturbation_range - candidates = jnp.clip(base_samples + perturbations, lower_limits, upper_limits) - + + perturbation_range = perturbation_scale_collision * ( + upper_limits - lower_limits + ) + perturbations = ( + jax.random.uniform(subk2, (batch_size_aug, dof), minval=-1, maxval=1) + * perturbation_range + ) + candidates = jnp.clip( + base_samples + perturbations, lower_limits, upper_limits + ) + candidate_dists = distance_fn(candidates) valid_mask = candidate_dists <= 0 valid_candidates = candidates[valid_mask] - + if valid_candidates.shape[0] > 0: augmented_list.append(valid_candidates) - - logger.debug(f" Iteration {iteration}: {valid_candidates.shape[0]} valid collision samples generated") - + + logger.debug( + f" Iteration {iteration}: {valid_candidates.shape[0]} valid collision samples generated" + ) + if augmented_list: all_augmented = jnp.concatenate(augmented_list, axis=0)[:samples_needed] - collision_samples = jnp.concatenate([collision_samples, all_augmented], axis=0) + collision_samples = jnp.concatenate( + [collision_samples, all_augmented], axis=0 + ) num_collision = collision_samples.shape[0] logger.info(f" Final collision sample count: {num_collision}") - + # Augment near-collision samples if needed if num_near_collision < target_near and num_near_collision > 0: - logger.info(f"Augmenting near-collision samples from {num_near_collision} to {target_near}...") - + logger.info( + f"Augmenting near-collision samples from {num_near_collision} to {target_near}..." + ) + samples_needed = target_near - num_near_collision augmented_list = [] iteration = 0 - - while len(augmented_list) < samples_needed and iteration < max_augment_iterations: + + while ( + len(augmented_list) < samples_needed and iteration < max_augment_iterations + ): iteration += 1 batch_size_aug = min(samples_needed * 2, 5000) key_augment, subk1, subk2 = jax.random.split(key_augment, 3) - indices = jax.random.randint(subk1, (batch_size_aug,), 0, num_near_collision) + indices = jax.random.randint( + subk1, (batch_size_aug,), 0, num_near_collision + ) base_samples = near_collision_samples[indices] - + perturbation_range = perturbation_scale_near * (upper_limits - lower_limits) - perturbations = jax.random.uniform(subk2, (batch_size_aug, dof), minval=-1, maxval=1) * perturbation_range - candidates = jnp.clip(base_samples + perturbations, lower_limits, upper_limits) - + perturbations = ( + jax.random.uniform(subk2, (batch_size_aug, dof), minval=-1, maxval=1) + * perturbation_range + ) + candidates = jnp.clip( + base_samples + perturbations, lower_limits, upper_limits + ) + candidate_dists = distance_fn(candidates) valid_mask = (candidate_dists > 0) & (candidate_dists < collision_threshold) valid_candidates = candidates[valid_mask] - + if valid_candidates.shape[0] > 0: augmented_list.append(valid_candidates) - + if augmented_list: all_augmented = jnp.concatenate(augmented_list, axis=0)[:samples_needed] - near_collision_samples = jnp.concatenate([near_collision_samples, all_augmented], axis=0) + near_collision_samples = jnp.concatenate( + [near_collision_samples, all_augmented], axis=0 + ) num_near_collision = near_collision_samples.shape[0] logger.info(f" Final near-collision sample count: {num_near_collision}") - + # Construct final training set actual_collision = min(num_collision, target_collision) actual_near = min(num_near_collision, target_near) actual_free = max(0, num_samples - actual_collision - actual_near) actual_free = min(actual_free, num_free) - - logger.info(f"Assembling training set: collision={actual_collision}, near={actual_near}, free={actual_free}") - + + logger.info( + f"Assembling training set: collision={actual_collision}, near={actual_near}, free={actual_free}" + ) + # Select samples from each category key_augment, subk = jax.random.split(key_augment) - - selected_collision = collision_samples[:actual_collision] if actual_collision > 0 else jnp.empty((0, dof)) - selected_near = near_collision_samples[:actual_near] if actual_near > 0 else jnp.empty((0, dof)) - + + selected_collision = ( + collision_samples[:actual_collision] + if actual_collision > 0 + else jnp.empty((0, dof)) + ) + selected_near = ( + near_collision_samples[:actual_near] if actual_near > 0 else jnp.empty((0, dof)) + ) + if actual_free > 0 and num_free > 0: - free_indices = jax.random.choice(subk, num_free, shape=(actual_free,), replace=False) + free_indices = jax.random.choice( + subk, num_free, shape=(actual_free,), replace=False + ) selected_free = free_space_samples[free_indices] else: selected_free = jnp.empty((0, dof)) - + # Combine all samples - parts = [p for p in [selected_collision, selected_near, selected_free] if p.shape[0] > 0] + parts = [ + p for p in [selected_collision, selected_near, selected_free] if p.shape[0] > 0 + ] result = jnp.concatenate(parts, axis=0) if parts else jnp.empty((0, dof)) - + # Fill shortfall from original pool if needed if result.shape[0] < num_samples: shortfall = num_samples - result.shape[0] logger.info(f"Filling shortfall of {shortfall} samples from original pool...") key_augment, subk = jax.random.split(key_augment) - extra_indices = jax.random.choice(subk, samples.shape[0], shape=(shortfall,), replace=True) + extra_indices = jax.random.choice( + subk, samples.shape[0], shape=(shortfall,), replace=True + ) extra_samples = samples[extra_indices] result = jnp.concatenate([result, extra_samples], axis=0) - + # Shuffle the result key_augment, subk = jax.random.split(key_augment) shuffle_perm = jax.random.permutation(subk, result.shape[0]) result = result[shuffle_perm][:num_samples] - + logger.info(f"Final training set: {result.shape[0]} samples") - + return result @@ -338,6 +437,7 @@ def jax_log(fmt: str, *args, **kwargs) -> None: """Emit a loguru info message from a JITed JAX function.""" jax.debug.callback(partial(_log, fmt), *args, **kwargs) -@partial(jax.jit, static_argnames=['dtype']) + +@partial(jax.jit, static_argnames=["dtype"]) def quantize(tree, dtype=jax.numpy.float16): - return jax.tree.map(lambda x: x.astype(dtype) if hasattr(x, 'astype') else x, tree) \ No newline at end of file + return jax.tree.map(lambda x: x.astype(dtype) if hasattr(x, "astype") else x, tree)