diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml
deleted file mode 100644
index 00a16eb..0000000
--- a/.github/workflows/docs.yml
+++ /dev/null
@@ -1,42 +0,0 @@
-name: docs
-
-on:
- push:
- branches: [main]
-
-permissions:
- contents: write
-
-jobs:
- docs:
- runs-on: ubuntu-latest
- steps:
- # Check out source.
- - uses: actions/checkout@v2
- with:
- fetch-depth: 0 # This ensures the entire history is fetched so we can switch branches
-
- - name: Set up Python
- uses: actions/setup-python@v1
- with:
- python-version: "3.12"
-
- - name: Set up dependencies
- run: |
- sudo apt update
- sudo apt install -y libsuitesparse-dev
- pip install uv
- uv pip install --system -e ".[dev,examples]"
- uv pip install --system -r docs/requirements.txt
-
- # Build documentation.
- - name: Building documentation
- run: |
- sphinx-build docs/source docs/build -b dirhtml
-
- # Deploy to version-dependent subdirectory.
- - name: Deploy to GitHub Pages
- uses: peaceiris/actions-gh-pages@v4
- with:
- github_token: ${{ secrets.GITHUB_TOKEN }}
- publish_dir: ./docs/build
diff --git a/.github/workflows/formatting.yaml b/.github/workflows/formatting.yml
similarity index 100%
rename from .github/workflows/formatting.yaml
rename to .github/workflows/formatting.yml
diff --git a/.gitignore b/.gitignore
index 6ae5297..54063c7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -15,3 +15,4 @@ htmlcov
.envrc
.vite
build
+tmp/
\ No newline at end of file
diff --git a/README.md b/README.md
index 88864e5..e2789c2 100644
--- a/README.md
+++ b/README.md
@@ -1,32 +1,13 @@
-# `PyRoNot`: A Better Python Robot Kinematics Library
+# `PyRoNot`: A Python Library for Robot KinematicsUsing Spherical Approximations
-## Citation
+[](https://github.com/CoMMALab/pyronot/actions/workflows/formatting.yml)
+[](https://github.com/CoMMALab/pyronot/actions/workflows/pyright.yml)
+[](https://github.com/CoMMALab/pyronot/actions/workflows/pytest.yml)
+[](https://pypi.org/project/pyronot/)
-This repository is based on pyroki.
-
|
- Chung Min Kim*, Brent Yi*, Hongsuk Choi, Yi Ma, Ken Goldberg, Angjoo Kanazawa.
- PyRoki: A Modular Toolkit for Robot Kinematic Optimization
- arXiV, 2025.
- |
-
-
-\*Equal Contribution, UC Berkeley.
-
-Please cite PyRoki if you find this work useful for your research:
-
-```
-@inproceedings{kim2025pyroki,
- title={PyRoki: A Modular Toolkit for Robot Kinematic Optimization},
- author={Kim*, Chung Min and Yi*, Brent and Choi, Hongsuk and Ma, Yi and Goldberg, Ken and Kanazawa, Angjoo},
- booktitle={2025 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
- year={2025},
- url={https://arxiv.org/abs/2505.03728},
-}
-```
-
-Thanks!
+This repository is based on [pyroki](https://github.com/chungmin99/pyroki).
## Installation
```
diff --git a/benchmark/robot_coll_benchmark.py b/benchmark/robot_coll_benchmark.py
new file mode 100644
index 0000000..d3a7539
--- /dev/null
+++ b/benchmark/robot_coll_benchmark.py
@@ -0,0 +1,195 @@
+import jax
+jax.config.update('jax_platform_name', 'cpu')
+print(f"================================================")
+print(f"Jax platform: {jax.lib.xla_bridge.get_backend().platform}")
+print(f"================================================")
+import numpy as np
+import pyronot as prn
+import pyroki as pk
+
+from pyroki.collision import Sphere, RobotCollision
+from pyronot.collision import Sphere as SphereNot, RobotCollisionSpherized as RobotCollisionSpherizedNot
+import yourdfpy
+import pinocchio as pin
+import hppfcl
+import time
+
+np.random.seed(41)
+
+NUM_SAMPLES = 1000
+
+SPHERE_CENTERS = [
+ [0.55, 0, 0.25],
+ [0.35, 0.35, 0.25],
+ [0, 0.55, 0.25],
+ [-0.55, 0, 0.25],
+ [-0.35, -0.35, 0.25],
+ [0, -0.55, 0.25],
+ [0.35, -0.35, 0.25],
+ [0.35, 0.35, 0.8],
+ [0, 0.55, 0.8],
+ [-0.35, 0.35, 0.8],
+ [-0.55, 0, 0.8],
+ [-0.35, -0.35, 0.8],
+ [0, -0.55, 0.8],
+ [0.35, -0.35, 0.8],
+ ]
+
+SPHERE_R = [0.2] * len(SPHERE_CENTERS)
+
+urdf_path = "resources/ur5/ur5_spherized.urdf"
+mesh_dir = "resources/ur5/meshes"
+
+
+
+# ============ Pinocchio Ground Truth Setup ============
+def setup_pinocchio_collision(urdf_path, mesh_dir, sphere_centers, sphere_radii):
+ """Setup Pinocchio collision model with environment spheres."""
+ pin_model = pin.buildModelFromUrdf(urdf_path)
+ pin_geom_model = pin.buildGeomFromUrdf(pin_model, urdf_path, pin.COLLISION, mesh_dir)
+
+ num_robot_geoms = len(pin_geom_model.geometryObjects)
+
+ # Add obstacle spheres to the geometry model
+ for i, (center, radius) in enumerate(zip(sphere_centers, sphere_radii)):
+ sphere_shape = hppfcl.Sphere(radius)
+ placement = pin.SE3(np.eye(3), np.array(center))
+ geom_obj = pin.GeometryObject(
+ f"obstacle_sphere_{i}",
+ 0, # parent frame (universe)
+ 0, # parent joint (universe)
+ sphere_shape,
+ placement,
+ )
+ pin_geom_model.addGeometryObject(geom_obj)
+
+ # Add collision pairs: robot geometries vs obstacle spheres
+ num_obstacles = len(sphere_centers)
+ for robot_geom_id in range(num_robot_geoms):
+ for obs_idx in range(num_obstacles):
+ obs_geom_id = num_robot_geoms + obs_idx
+ pin_geom_model.addCollisionPair(
+ pin.CollisionPair(robot_geom_id, obs_geom_id)
+ )
+
+ pin_data = pin_model.createData()
+ pin_geom_data = pin.GeometryData(pin_geom_model)
+
+ return pin_model, pin_data, pin_geom_model, pin_geom_data
+
+
+def check_collision_pinocchio(pin_model, pin_data, pin_geom_model, pin_geom_data, q):
+ """Check if configuration q is in collision using Pinocchio (ground truth)."""
+ pin.updateGeometryPlacements(pin_model, pin_data, pin_geom_model, pin_geom_data, q)
+ return pin.computeCollisions(pin_geom_model, pin_geom_data, True)
+
+
+# Setup Pinocchio collision checker
+pin_model, pin_data, pin_geom_model, pin_geom_data = setup_pinocchio_collision(
+ urdf_path, mesh_dir, SPHERE_CENTERS, SPHERE_R
+)
+
+world_coll_not = SphereNot.from_center_and_radius(SPHERE_CENTERS, SPHERE_R)
+world_coll = Sphere.from_center_and_radius(SPHERE_CENTERS, SPHERE_R)
+
+urdf = yourdfpy.URDF.load(urdf_path, mesh_dir=mesh_dir)
+robot_not = prn.Robot.from_urdf(urdf)
+robot = pk.Robot.from_urdf(urdf)
+robot_coll = RobotCollision.from_urdf(urdf)
+robot_coll_not = RobotCollisionSpherizedNot.from_urdf(urdf)
+
+def generate_dataset(num_samples):
+ q_batch = []
+ robot_lower_limits = robot.joints.lower_limits
+ robot_upper_limits = robot.joints.upper_limits
+ for _ in range(num_samples):
+ q = np.random.uniform(robot_lower_limits, robot_upper_limits)
+ q_batch.append(q)
+ return np.array(q_batch)
+
+q_batch = generate_dataset(NUM_SAMPLES)
+print(f"Generated {q_batch.shape[0]} samples")
+
+# ============ Generate Ground Truth with Pinocchio ============
+print("\n=== Generating Pinocchio Ground Truth ===")
+start_time = time.time()
+ground_truth = []
+for q in q_batch:
+ collision = check_collision_pinocchio(pin_model, pin_data, pin_geom_model, pin_geom_data, q)
+ ground_truth.append(collision)
+ground_truth = np.array(ground_truth)
+end_time = time.time()
+time_taken_ms = (end_time - start_time) * 1000
+print(f"Time taken for Pinocchio ground truth (ms): {time_taken_ms:.2f}")
+print(f"Time per collision check (ms): {time_taken_ms/NUM_SAMPLES:.4f}")
+print(f"Collision rate: {ground_truth.sum()}/{NUM_SAMPLES} ({100*ground_truth.mean():.1f}%)")
+
+# ============ Benchmark pyronot collision methods ============
+print("\n=== Benchmarking pyronot Collision Methods ===")
+# Warmup for JIT
+q = q_batch[0]
+robot_coll.at_config(robot, q)
+jax.block_until_ready(robot_coll.compute_world_collision_distance(robot, q, world_coll))
+print(f"Type of robot_coll.compute_world_collision_distance: {type(robot_coll.compute_world_collision_distance)}")
+try:
+ print(f"Capsule cache: {robot_coll.compute_world_collision_distance._cache_size()}")
+except AttributeError:
+ print("Capsule method is NOT JIT compiled")
+
+robot_coll_not.at_config(robot_not, q)
+jax.block_until_ready(robot_coll_not.compute_world_collision_distance(robot_not, q, world_coll_not))
+print(f"Type of robot_coll_not.compute_world_collision_distance: {type(robot_coll_not.compute_world_collision_distance)}")
+try:
+ print(f"Sphere cache: {robot_coll_not.compute_world_collision_distance._cache_size()}")
+except AttributeError:
+ print("Sphere method is NOT JIT compiled")
+# End of warmup
+
+with jax.profiler.trace("./tmp/sphere_trace"):
+ q = q_batch[0]
+ result = robot_coll_not.compute_world_collision_distance(robot_not, q, world_coll_not)
+ jax.block_until_ready(result)
+
+print(f"=== Result ===")
+print(f"Time taken for pinocchio for a signle collision check (ms): \t{time_taken_ms/NUM_SAMPLES:.4f}")
+start_time = time.time()
+for q in q_batch:
+ # robot_coll_not.at_config(robot_not, q)
+ robot_coll_not.compute_world_collision_distance(robot_not, q, world_coll_not)
+end_time = time.time()
+time_taken_ms = (end_time - start_time) * 1000
+print(f"Time taken for sphere for single collision check (ms): \t\t{time_taken_ms/NUM_SAMPLES:.4f}")
+
+start_time = time.time()
+for q in q_batch:
+ # robot_coll.at_config(robot, q)
+ robot_coll.compute_world_collision_distance(robot, q, world_coll)
+end_time = time.time()
+time_taken_ms = (end_time - start_time) * 1000
+print(f"Time taken for capsule for single collision check (ms): \t{time_taken_ms/NUM_SAMPLES:.4f}")
+
+# ============ Compare with Ground Truth ============
+print("\n=== Accuracy Comparison with Ground Truth ===")
+
+# Check sphere model accuracy
+sphere_predictions = []
+for q in q_batch:
+ dist = robot_coll_not.compute_world_collision_distance(robot_not, q, world_coll_not)
+ # Collision if any distance is negative
+ in_collision = (np.array(dist) < 0).any()
+ sphere_predictions.append(in_collision)
+sphere_predictions = np.array(sphere_predictions)
+sphere_accuracy = (sphere_predictions == ground_truth).mean()
+print(f"Sphere model accuracy: {100*sphere_accuracy:.2f}%")
+
+# Check capsule model accuracy
+capsule_predictions = []
+for q in q_batch:
+ dist = robot_coll.compute_world_collision_distance(robot, q, world_coll)
+ # Collision if any distance is negative
+ in_collision = (np.array(dist) < 0).any()
+ capsule_predictions.append(in_collision)
+capsule_predictions = np.array(capsule_predictions)
+capsule_accuracy = (capsule_predictions == ground_truth).mean()
+print(f"Capsule model accuracy: {100*capsule_accuracy:.2f}%")
+
diff --git a/docs/Makefile b/docs/Makefile
deleted file mode 100644
index aaecadc..0000000
--- a/docs/Makefile
+++ /dev/null
@@ -1,20 +0,0 @@
-# Minimal makefile for Sphinx documentation
-#
-
-# You can set these variables from the command line.
-SPHINXOPTS =
-SPHINXBUILD = sphinx-build
-SPHINXPROJ = viser
-SOURCEDIR = source
-BUILDDIR = ./build
-
-# Put it first so that "make" without argument is like "make help".
-help:
- @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
-
-.PHONY: help Makefile
-
-# Catch-all target: route all unknown targets to Sphinx using the new
-# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
-%: Makefile
- @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/docs/README.md b/docs/README.md
deleted file mode 100644
index f039c3e..0000000
--- a/docs/README.md
+++ /dev/null
@@ -1,61 +0,0 @@
-# Viser Documentation
-
-This directory contains the documentation for Viser.
-
-## Building the Documentation
-
-To build the documentation:
-
-1. Install the documentation dependencies:
-
- ```bash
- pip install -r docs/requirements.txt
- ```
-
-2. Build the documentation:
-
- ```bash
- cd docs
- make html
- ```
-
-3. View the documentation:
-
- ```bash
- # On macOS
- open build/html/index.html
-
- # On Linux
- xdg-open build/html/index.html
- ```
-
-## Contributing Screenshots
-
-When adding new documentation, screenshots and visual examples significantly improve user understanding.
-
-We need screenshots for:
-
-- The Getting Started guide
-- GUI element examples
-- Scene API visualization examples
-- Customization/theming examples
-
-See [Contributing Visuals](./source/contributing_visuals.md) for guidelines on capturing and adding images to the documentation.
-
-## Documentation Structure
-
-- `source/` - Source files for the documentation
- - `_static/` - Static files (CSS, images, etc.)
- - `images/` - Screenshots and other images
- - `examples/` - Example code with documentation
- - `*.md` - Markdown files for documentation pages
- - `conf.py` - Sphinx configuration
-
-## Auto-Generated Example Documentation
-
-Example documentation is automatically generated from the examples in the `examples/` directory using the `update_example_docs.py` script. To update the example documentation after making changes to examples:
-
-```bash
-cd docs
-python update_example_docs.py
-```
diff --git a/docs/requirements.txt b/docs/requirements.txt
deleted file mode 100644
index 74185c7..0000000
--- a/docs/requirements.txt
+++ /dev/null
@@ -1,10 +0,0 @@
-sphinx==8.0.2
-furo==2024.8.6
-docutils==0.20.1
-toml==0.10.2
-sphinxcontrib-video==0.4.1
-git+https://github.com/brentyi/sphinxcontrib-programoutput.git
-git+https://github.com/brentyi/ansi.git
-git+https://github.com/sphinx-contrib/googleanalytics.git
-
-snowballstemmer==2.2.0 # https://github.com/snowballstem/snowball/issues/229
diff --git a/docs/source/_static/basic_ik.mov b/docs/source/_static/basic_ik.mov
deleted file mode 100644
index 44c3bb0..0000000
Binary files a/docs/source/_static/basic_ik.mov and /dev/null differ
diff --git a/docs/source/_static/css/custom.css b/docs/source/_static/css/custom.css
deleted file mode 100644
index 3472b02..0000000
--- a/docs/source/_static/css/custom.css
+++ /dev/null
@@ -1,4 +0,0 @@
-img.sidebar-logo {
- width: 10em;
- margin: 1em 0 0 0;
-}
diff --git a/docs/source/_static/logo.svg b/docs/source/_static/logo.svg
deleted file mode 100644
index 68c28ee..0000000
--- a/docs/source/_static/logo.svg
+++ /dev/null
@@ -1,34 +0,0 @@
-
-
\ No newline at end of file
diff --git a/docs/source/_static/logo_dark.svg b/docs/source/_static/logo_dark.svg
deleted file mode 100644
index 2d42e27..0000000
--- a/docs/source/_static/logo_dark.svg
+++ /dev/null
@@ -1,40 +0,0 @@
-
-
\ No newline at end of file
diff --git a/docs/source/_templates/sidebar/brand.html b/docs/source/_templates/sidebar/brand.html
deleted file mode 100644
index 88151c0..0000000
--- a/docs/source/_templates/sidebar/brand.html
+++ /dev/null
@@ -1,41 +0,0 @@
-
- {% block brand_content %} {%- if logo_url %}
-
- {%- endif %} {%- if theme_light_logo and theme_dark_logo %}
-
- {%- endif %}
-
- {% endblock brand_content %}
-
-
-
\ No newline at end of file
diff --git a/docs/source/conf.py b/docs/source/conf.py
deleted file mode 100644
index a76de8f..0000000
--- a/docs/source/conf.py
+++ /dev/null
@@ -1,266 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Configuration file for the Sphinx documentation builder.
-#
-# This file does only contain a selection of the most common options. For a
-# full list see the documentation:
-# http://www.sphinx-doc.org/en/stable/config
-
-import os
-from typing import Dict, List
-
-import pyronot
-
-# -- Path setup --------------------------------------------------------------
-
-# If extensions (or modules to document with autodoc) are in another directory,
-# add these directories to sys.path here. If the directory is relative to the
-# documentation root, use os.path.abspath to make it absolute, like shown here.
-#
-
-# -- Project information -----------------------------------------------------
-
-project = "pyronot" # Change project name
-copyright = "2025" # Update copyright year/holder if needed
-author = "chungmin99" # Update author name
-
-version: str = os.environ.get(
- "PYRONOT_VERSION_STR_OVERRIDE", pyronot.__version__
-) # Remove this
-
-# Formatting!
-# 0.1.30 => v0.1.30
-# dev => dev
-if not version.isalpha():
- version = "v" + version
-
-# The full version, including alpha/beta/rc tags
-release = version # Use the same version for release for now
-
-
-# -- General configuration ---------------------------------------------------
-
-# If your documentation needs a minimal Sphinx version, state it here.
-#
-# needs_sphinx = '1.0'
-
-# Add any Sphinx extension module names here, as strings. They can be
-# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
-# ones.
-extensions = [
- "sphinx.ext.autodoc",
- "sphinx.ext.todo",
- "sphinx.ext.coverage",
- "sphinx.ext.mathjax",
- "sphinx.ext.githubpages",
- "sphinx.ext.napoleon",
- # "sphinx.ext.inheritance_diagram",
- "sphinxcontrib.video",
- "sphinx.ext.viewcode",
- "sphinxcontrib.programoutput",
- "sphinxcontrib.ansi",
- # "sphinxcontrib.googleanalytics", # google analytics extension https://github.com/sphinx-contrib/googleanalytics/tree/master
-]
-programoutput_use_ansi = True
-html_ansi_stylesheet = "black-on-white.css"
-html_static_path = ["_static"]
-html_theme_options = {
- "light_css_variables": {
- "color-code-background": "#f4f4f4",
- "color-code-foreground": "#000",
- },
- # Remove viser-specific footer icon
- "footer_icons": [
- {
- "name": "GitHub",
- "url": "https://github.com/chungmin99/pyronot-dev",
- "html": """
-
- """,
- "class": "",
- },
- ],
- # Remove viser-specific logos
- "light_logo": "logo.svg",
- "dark_logo": "logo_dark.svg",
-}
-
-# Pull documentation types from hints
-autodoc_typehints = "both"
-autodoc_class_signature = "separated"
-autodoc_default_options = {
- "members": True,
- "member-order": "bysource",
- "undoc-members": True,
- "inherited-members": True,
- "exclude-members": "__init__, __post_init__",
- "imported-members": True,
-}
-
-# Add any paths that contain templates here, relative to this directory.
-templates_path = ["_templates"]
-
-# The suffix(es) of source filenames.
-# You can specify multiple suffix as a list of string:
-#
-source_suffix = ".rst"
-# source_suffix = ".rst"
-
-# The master toctree document.
-master_doc = "index"
-
-# The language for content autogenerated by Sphinx. Refer to documentation
-# for a list of supported languages.
-#
-# This is also used if you do content translation via gettext catalogs.
-# Usually you set "language" from the command line for these cases.
-language: str = "en"
-
-# List of patterns, relative to source directory, that match files and
-# directories to ignore when looking for source files.
-# This pattern also affects html_static_path and html_extra_path .
-exclude_patterns: List[str] = []
-
-# The name of the Pygments (syntax highlighting) style to use.
-pygments_style = "default"
-
-
-# -- Options for HTML output -------------------------------------------------
-
-# The theme to use for HTML and HTML Help pages. See the documentation for
-# a list of builtin themes.
-#
-html_theme = "furo"
-html_title = "pyronot" # Update title
-
-
-# Theme options are theme-specific and customize the look and feel of a theme
-# further. For a list of options available for each theme, see the
-# documentation.
-#
-# html_theme_options = {}
-
-# Add any paths that contain custom static files (such as style sheets) here,
-# relative to this directory. They are copied after the builtin static files,
-# so a file named "default.css" will overwrite the builtin "default.css".
-# html_static_path = ["_static"]
-
-# Custom sidebar templates, must be a dictionary that maps document names
-# to template names.
-#
-# The default sidebars (for documents that don't match any pattern) are
-# defined by theme itself. Builtin themes are using these templates by
-# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
-# 'searchbox.html']``.
-#
-# html_sidebars = {}
-
-
-# -- Options for HTMLHelp output ---------------------------------------------
-
-# Output file base name for HTML help builder.
-htmlhelp_basename = "pyronot_doc" # Update basename
-
-
-# -- Options for Github output ------------------------------------------------
-
-sphinx_to_github = True
-sphinx_to_github_verbose = True
-sphinx_to_github_encoding = "utf-8"
-
-
-# -- Options for LaTeX output ------------------------------------------------
-
-latex_elements: Dict[str, str] = {
- # The paper size ('letterpaper' or 'a4paper').
- #
- # 'papersize': 'letterpaper',
- # The font size ('10pt', '11pt' or '12pt').
- #
- # 'pointsize': '10pt',
- # Additional stuff for the LaTeX preamble.
- #
- # 'preamble': '',
- # Latex figure (float) alignment
- #
- # 'figure_align': 'htbp',
-}
-
-# Grouping the document tree into LaTeX files. List of tuples
-# (source start file, target name, title,
-# author, documentclass [howto, manual, or own class]).
-latex_documents = [
- (
- master_doc,
- "pyronot.tex", # Update target name
- "pyronot", # Update title
- "Your Name", # Update author
- "manual",
- ),
-]
-
-
-# -- Options for manual page output ------------------------------------------
-
-# One entry per manual page. List of tuples
-# (source start file, name, description, authors, manual section).
-man_pages = [
- (master_doc, "pyronot", "pyronot documentation", [author], 1)
-] # Update name and description
-
-
-# -- Options for Texinfo output ----------------------------------------------
-
-# Grouping the document tree into Texinfo files. List of tuples
-# (source start file, target name, title, author,
-# dir menu entry, description, category)
-texinfo_documents = [
- (
- master_doc,
- "pyronot", # Update target name
- "pyronot", # Update title
- author,
- "pyronot", # Update dir menu entry
- "Python Robot Kinematics library", # Update description
- "Miscellaneous",
- ),
-]
-
-
-# -- Extension configuration --------------------------------------------------
-
-# Google Analytics ID
-# googleanalytics_id = "G-RRGY51J5ZH" # Remove this
-
-# -- Options for todo extension ----------------------------------------------
-
-# If true, `todo` and `todoList` produce output, else they produce nothing.
-todo_include_todos = True
-
-# -- Setup function ----------------------------------------
-
-
-def setup(app):
- app.add_css_file("css/custom.css")
-
-
-# -- Napoleon settings -------------------------------------------------------
-
-# Settings for parsing non-sphinx style docstrings. We use Google style in this
-# project.
-napoleon_google_docstring = True
-napoleon_numpy_docstring = False
-napoleon_include_init_with_doc = False
-napoleon_include_private_with_doc = False
-napoleon_include_special_with_doc = True
-napoleon_use_admonition_for_examples = False
-napoleon_use_admonition_for_notes = False
-napoleon_use_admonition_for_references = False
-napoleon_use_ivar = False
-napoleon_use_param = True
-napoleon_use_rtype = True
-napoleon_preprocess_types = True
-napoleon_type_aliases = None
-napoleon_attr_annotations = True
diff --git a/docs/source/examples/01_basic_ik.rst b/docs/source/examples/01_basic_ik.rst
deleted file mode 100644
index 660f6c9..0000000
--- a/docs/source/examples/01_basic_ik.rst
+++ /dev/null
@@ -1,68 +0,0 @@
-.. Comment: this file is automatically generated by `update_example_docs.py`.
- It should not be modified manually.
-
-Basic IK
-==========================================
-
-
-Simplest Inverse Kinematics Example using PyRoNot.
-
-All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details.
-
-
-
-.. code-block:: python
- :linenos:
-
-
- import time
-
- import numpy as np
- import pyronot as pk
- import viser
- from robot_descriptions.loaders.yourdfpy import load_robot_description
- from viser.extras import ViserUrdf
-
- import pyronot_snippets as pks
-
-
- def main():
- """Main function for basic IK."""
-
- urdf = load_robot_description("panda_description")
- target_link_name = "panda_hand"
-
- # Create robot.
- robot = pk.Robot.from_urdf(urdf)
-
- # Set up visualizer.
- server = viser.ViserServer()
- server.scene.add_grid("/ground", width=2, height=2)
- urdf_vis = ViserUrdf(server, urdf, root_node_name="/base")
-
- # Create interactive controller with initial position.
- ik_target = server.scene.add_transform_controls(
- "/ik_target", scale=0.2, position=(0.61, 0.0, 0.56), wxyz=(0, 0, 1, 0)
- )
- timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True)
-
- while True:
- # Solve IK.
- start_time = time.time()
- solution = pks.solve_ik(
- robot=robot,
- target_link_name=target_link_name,
- target_position=np.array(ik_target.position),
- target_wxyz=np.array(ik_target.wxyz),
- )
-
- # Update timing handle.
- elapsed_time = time.time() - start_time
- timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000)
-
- # Update visualizer.
- urdf_vis.update_cfg(solution)
-
-
- if __name__ == "__main__":
- main()
diff --git a/docs/source/examples/02_bimanual_ik.rst b/docs/source/examples/02_bimanual_ik.rst
deleted file mode 100644
index 9bdc034..0000000
--- a/docs/source/examples/02_bimanual_ik.rst
+++ /dev/null
@@ -1,70 +0,0 @@
-.. Comment: this file is automatically generated by `update_example_docs.py`.
- It should not be modified manually.
-
-Bimanual IK
-==========================================
-
-
-Same as 01_basic_ik.py, but with two end effectors!
-
-All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details.
-
-
-
-.. code-block:: python
- :linenos:
-
-
- import time
- import viser
- from robot_descriptions.loaders.yourdfpy import load_robot_description
- import numpy as np
-
- import pyronot as pk
- from viser.extras import ViserUrdf
- import pyronot_snippets as pks
-
-
- def main():
- """Main function for bimanual IK."""
-
- urdf = load_robot_description("yumi_description")
- target_link_names = ["yumi_link_7_r", "yumi_link_7_l"]
-
- # Create robot.
- robot = pk.Robot.from_urdf(urdf)
-
- # Set up visualizer.
- server = viser.ViserServer()
- server.scene.add_grid("/ground", width=2, height=2)
- urdf_vis = ViserUrdf(server, urdf, root_node_name="/base")
-
- # Create interactive controller with initial position.
- ik_target_0 = server.scene.add_transform_controls(
- "/ik_target_0", scale=0.2, position=(0.41, -0.3, 0.56), wxyz=(0, 0, 1, 0)
- )
- ik_target_1 = server.scene.add_transform_controls(
- "/ik_target_1", scale=0.2, position=(0.41, 0.3, 0.56), wxyz=(0, 0, 1, 0)
- )
- timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True)
-
- while True:
- # Solve IK.
- start_time = time.time()
- solution = pks.solve_ik_with_multiple_targets(
- robot=robot,
- target_link_names=target_link_names,
- target_positions=np.array([ik_target_0.position, ik_target_1.position]),
- target_wxyzs=np.array([ik_target_0.wxyz, ik_target_1.wxyz]),
- )
-
- # Update timing handle.
- elapsed_time = time.time() - start_time
- timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000)
-
- # Update visualizer.
- urdf_vis.update_cfg(solution)
-
-
- if __name__ == "__main__":
- main()
diff --git a/docs/source/examples/03_mobile_ik.rst b/docs/source/examples/03_mobile_ik.rst
deleted file mode 100644
index 128bda2..0000000
--- a/docs/source/examples/03_mobile_ik.rst
+++ /dev/null
@@ -1,79 +0,0 @@
-.. Comment: this file is automatically generated by `update_example_docs.py`.
- It should not be modified manually.
-
-Mobile IK
-==========================================
-
-
-Same as 01_basic_ik.py, but with a mobile base!
-
-All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details.
-
-
-
-.. code-block:: python
- :linenos:
-
-
- import time
- import viser
- from robot_descriptions.loaders.yourdfpy import load_robot_description
- import numpy as np
-
- import pyronot as pk
- from viser.extras import ViserUrdf
- import pyronot_snippets as pks
-
-
- def main():
- """Main function for IK with a mobile base.
- The base is fixed along the xy plane, and is biased towards being at the origin.
- """
-
- urdf = load_robot_description("fetch_description")
- target_link_name = "gripper_link"
-
- # Create robot.
- robot = pk.Robot.from_urdf(urdf)
-
- # Set up visualizer.
- server = viser.ViserServer()
- server.scene.add_grid("/ground", width=2, height=2)
- base_frame = server.scene.add_frame("/base", show_axes=False)
- urdf_vis = ViserUrdf(server, urdf, root_node_name="/base")
-
- # Create interactive controller with initial position.
- ik_target = server.scene.add_transform_controls(
- "/ik_target", scale=0.2, position=(0.61, 0.0, 0.56), wxyz=(0, 0.707, 0, -0.707)
- )
- timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True)
-
- cfg = np.array(robot.joint_var_cls(0).default_factory())
-
- while True:
- # Solve IK.
- start_time = time.time()
- base_pos, base_wxyz, cfg = pks.solve_ik_with_base(
- robot=robot,
- target_link_name=target_link_name,
- target_position=np.array(ik_target.position),
- target_wxyz=np.array(ik_target.wxyz),
- fix_base_position=(False, False, True), # Only free along xy plane.
- fix_base_orientation=(True, True, False), # Free along z-axis rotation.
- prev_pos=base_frame.position,
- prev_wxyz=base_frame.wxyz,
- prev_cfg=cfg,
- )
-
- # Update timing handle.
- elapsed_time = time.time() - start_time
- timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000)
-
- # Update visualizer.
- urdf_vis.update_cfg(cfg)
- base_frame.position = np.array(base_pos)
- base_frame.wxyz = np.array(base_wxyz)
-
-
- if __name__ == "__main__":
- main()
diff --git a/docs/source/examples/04_ik_with_coll.rst b/docs/source/examples/04_ik_with_coll.rst
deleted file mode 100644
index f879355..0000000
--- a/docs/source/examples/04_ik_with_coll.rst
+++ /dev/null
@@ -1,88 +0,0 @@
-.. Comment: this file is automatically generated by `update_example_docs.py`.
- It should not be modified manually.
-
-IK with Collision
-==========================================
-
-
-Basic Inverse Kinematics with Collision Avoidance using PyRoNot.
-
-All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details.
-
-
-
-.. code-block:: python
- :linenos:
-
-
- import time
-
- import numpy as np
- import pyronot as pk
- import viser
- from pyronot.collision import HalfSpace, RobotCollision, Sphere
- from robot_descriptions.loaders.yourdfpy import load_robot_description
- from viser.extras import ViserUrdf
-
- import pyronot_snippets as pks
-
-
- def main():
- """Main function for basic IK with collision."""
- urdf = load_robot_description("panda_description")
- target_link_name = "panda_hand"
- robot = pk.Robot.from_urdf(urdf)
-
- robot_coll = RobotCollision.from_urdf(urdf)
- plane_coll = HalfSpace.from_point_and_normal(
- np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0])
- )
- sphere_coll = Sphere.from_center_and_radius(
- np.array([0.0, 0.0, 0.0]), np.array([0.05])
- )
-
- # Set up visualizer.
- server = viser.ViserServer()
- server.scene.add_grid("/ground", width=2, height=2, cell_size=0.1)
- urdf_vis = ViserUrdf(server, urdf, root_node_name="/robot")
-
- # Create interactive controller for IK target.
- ik_target_handle = server.scene.add_transform_controls(
- "/ik_target", scale=0.2, position=(0.5, 0.0, 0.5), wxyz=(0, 0, 1, 0)
- )
-
- # Create interactive controller and mesh for the sphere obstacle.
- sphere_handle = server.scene.add_transform_controls(
- "/obstacle", scale=0.2, position=(0.4, 0.3, 0.4)
- )
- server.scene.add_mesh_trimesh("/obstacle/mesh", mesh=sphere_coll.to_trimesh())
-
- timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True)
-
- while True:
- start_time = time.time()
-
- sphere_coll_world_current = sphere_coll.transform_from_wxyz_position(
- wxyz=np.array(sphere_handle.wxyz),
- position=np.array(sphere_handle.position),
- )
-
- world_coll_list = [plane_coll, sphere_coll_world_current]
- solution = pks.solve_ik_with_collision(
- robot=robot,
- coll=robot_coll,
- world_coll_list=world_coll_list,
- target_link_name=target_link_name,
- target_position=np.array(ik_target_handle.position),
- target_wxyz=np.array(ik_target_handle.wxyz),
- )
-
- # Update timing handle.
- timing_handle.value = (time.time() - start_time) * 1000
-
- # Update visualizer.
- urdf_vis.update_cfg(solution)
-
-
- if __name__ == "__main__":
- main()
diff --git a/docs/source/examples/05_ik_with_manipulability.rst b/docs/source/examples/05_ik_with_manipulability.rst
deleted file mode 100644
index f8bd331..0000000
--- a/docs/source/examples/05_ik_with_manipulability.rst
+++ /dev/null
@@ -1,81 +0,0 @@
-.. Comment: this file is automatically generated by `update_example_docs.py`.
- It should not be modified manually.
-
-IK with Manipulability
-==========================================
-
-
-Inverse Kinematics with Manipulability using PyRoNot.
-
-All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details.
-
-
-
-.. code-block:: python
- :linenos:
-
-
- import time
- import viser
- from robot_descriptions.loaders.yourdfpy import load_robot_description
- import numpy as np
-
- import pyronot as pk
- from viser.extras import ViserUrdf
- import pyronot_snippets as pks
-
-
- def main():
- """Main function for basic IK."""
-
- urdf = load_robot_description("panda_description")
- target_link_name = "panda_hand"
-
- # Create robot.
- robot = pk.Robot.from_urdf(urdf)
-
- # Set up visualizer.
- server = viser.ViserServer()
- server.scene.add_grid("/ground", width=2, height=2)
- urdf_vis = ViserUrdf(server, urdf, root_node_name="/base")
-
- # Create interactive controller with initial position.
- ik_target = server.scene.add_transform_controls(
- "/ik_target", scale=0.2, position=(0.61, 0.0, 0.56), wxyz=(0, 0, 1, 0)
- )
- timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True)
- value_handle = server.gui.add_number("Yoshikawa Index", 0.001, disabled=True)
- weight_handle = server.gui.add_slider(
- "Manipulability Weight", 0.0, 10.0, 0.001, 0.0
- )
- manip_ellipse = pk.viewer.ManipulabilityEllipse(
- server,
- robot,
- root_node_name="/manipulability",
- target_link_name=target_link_name,
- )
-
- while True:
- # Solve IK.
- start_time = time.time()
- solution = pks.solve_ik_with_manipulability(
- robot=robot,
- target_link_name=target_link_name,
- target_position=np.array(ik_target.position),
- target_wxyz=np.array(ik_target.wxyz),
- manipulability_weight=weight_handle.value,
- )
-
- manip_ellipse.update(solution)
- value_handle.value = manip_ellipse.manipulability
-
- # Update timing handle.
- elapsed_time = time.time() - start_time
- timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000)
-
- # Update visualizer.
- urdf_vis.update_cfg(solution)
-
-
- if __name__ == "__main__":
- main()
diff --git a/docs/source/examples/06_online_planning.rst b/docs/source/examples/06_online_planning.rst
deleted file mode 100644
index 9177048..0000000
--- a/docs/source/examples/06_online_planning.rst
+++ /dev/null
@@ -1,119 +0,0 @@
-.. Comment: this file is automatically generated by `update_example_docs.py`.
- It should not be modified manually.
-
-Online Planning
-==========================================
-
-
-Run online planning in collision aware environments.
-
-All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details.
-
-
-
-.. code-block:: python
- :linenos:
-
-
- import time
-
- import numpy as np
- import pyronot as pk
- import viser
- from pyronot.collision import HalfSpace, RobotCollision, Sphere
- from robot_descriptions.loaders.yourdfpy import load_robot_description
- from viser.extras import ViserUrdf
-
- import pyronot_snippets as pks
-
-
- def main():
- """Main function for online planning with collision."""
- urdf = load_robot_description("panda_description")
- target_link_name = "panda_hand"
- robot = pk.Robot.from_urdf(urdf)
-
- robot_coll = RobotCollision.from_urdf(urdf)
- plane_coll = HalfSpace.from_point_and_normal(
- np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0])
- )
- sphere_coll = Sphere.from_center_and_radius(
- np.array([0.0, 0.0, 0.0]), np.array([0.05])
- )
-
- # Define the online planning parameters.
- len_traj, dt = 5, 0.1
-
- # Set up visualizer.
- server = viser.ViserServer()
- server.scene.add_grid("/ground", width=2, height=2, cell_size=0.1)
- urdf_vis = ViserUrdf(server, urdf, root_node_name="/robot")
-
- # Create interactive controller for IK target.
- ik_target_handle = server.scene.add_transform_controls(
- "/ik_target", scale=0.2, position=(0.3, 0.0, 0.5), wxyz=(0, 0, 1, 0)
- )
-
- # Create interactive controller and mesh for the sphere obstacle.
- sphere_handle = server.scene.add_transform_controls(
- "/obstacle", scale=0.2, position=(0.4, 0.3, 0.4)
- )
- server.scene.add_mesh_trimesh("/obstacle/mesh", mesh=sphere_coll.to_trimesh())
- target_frame_handle = server.scene.add_batched_axes(
- "target_frame",
- axes_length=0.05,
- axes_radius=0.005,
- batched_positions=np.zeros((25, 3)),
- batched_wxyzs=np.array([[1.0, 0.0, 0.0, 0.0]] * 25),
- )
-
- timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True)
-
- sol_pos, sol_wxyz = None, None
- sol_traj = np.array(
- robot.joint_var_cls.default_factory()[None].repeat(len_traj, axis=0)
- )
- while True:
- start_time = time.time()
-
- sphere_coll_world_current = sphere_coll.transform_from_wxyz_position(
- wxyz=np.array(sphere_handle.wxyz),
- position=np.array(sphere_handle.position),
- )
-
- world_coll_list = [plane_coll, sphere_coll_world_current]
- sol_traj, sol_pos, sol_wxyz = pks.solve_online_planning(
- robot=robot,
- robot_coll=robot_coll,
- world_coll=world_coll_list,
- target_link_name=target_link_name,
- target_position=np.array(ik_target_handle.position),
- target_wxyz=np.array(ik_target_handle.wxyz),
- timesteps=len_traj,
- dt=dt,
- start_cfg=sol_traj[0],
- prev_sols=sol_traj,
- )
-
- # Update timing handle.
- timing_handle.value = (
- 0.99 * timing_handle.value + 0.01 * (time.time() - start_time) * 1000
- )
-
- # Update visualizer.
- urdf_vis.update_cfg(
- sol_traj[0]
- ) # The first step of the online trajectory solution.
-
- # Update the planned trajectory visualization.
- if hasattr(target_frame_handle, "batched_positions"):
- target_frame_handle.batched_positions = np.array(sol_pos) # type: ignore[attr-defined]
- target_frame_handle.batched_wxyzs = np.array(sol_wxyz) # type: ignore[attr-defined]
- else:
- # This is an older version of Viser.
- target_frame_handle.positions_batched = np.array(sol_pos) # type: ignore[attr-defined]
- target_frame_handle.wxyzs_batched = np.array(sol_wxyz) # type: ignore[attr-defined]
-
-
- if __name__ == "__main__":
- main()
diff --git a/docs/source/examples/07_trajopt.rst b/docs/source/examples/07_trajopt.rst
deleted file mode 100644
index 99e4348..0000000
--- a/docs/source/examples/07_trajopt.rst
+++ /dev/null
@@ -1,138 +0,0 @@
-.. Comment: this file is automatically generated by `update_example_docs.py`.
- It should not be modified manually.
-
-Trajectory Optimization
-==========================================
-
-
-Basic Trajectory Optimization using PyRoNot.
-
-Robot going over a wall, while avoiding world-collisions.
-
-All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details.
-
-
-
-.. code-block:: python
- :linenos:
-
-
- import time
- from typing import Literal
-
- import numpy as np
- import pyronot as pk
- import trimesh
- import tyro
- import viser
- from viser.extras import ViserUrdf
- from robot_descriptions.loaders.yourdfpy import load_robot_description
-
- import pyronot_snippets as pks
-
-
- def main(robot_name: Literal["ur5", "panda"] = "panda"):
- if robot_name == "ur5":
- urdf = load_robot_description("ur5_description")
- down_wxyz = np.array([0.707, 0, 0.707, 0])
- target_link_name = "ee_link"
-
- # For UR5 it's important to initialize the robot in a safe configuration;
- # the zero-configuration puts the robot aligned with the wall obstacle.
- default_cfg = np.zeros(6)
- default_cfg[1] = -1.308
- robot = pk.Robot.from_urdf(urdf, default_joint_cfg=default_cfg)
-
- elif robot_name == "panda":
- urdf = load_robot_description("panda_description")
- target_link_name = "panda_hand"
- down_wxyz = np.array([0, 0, 1, 0]) # for panda!
- robot = pk.Robot.from_urdf(urdf)
-
- else:
- raise ValueError(f"Invalid robot: {robot_name}")
-
- robot_coll = pk.collision.RobotCollision.from_urdf(urdf)
-
- # Define the trajectory problem:
- # - number of timesteps, timestep size
- timesteps, dt = 25, 0.02
- # - the start and end poses.
- start_pos, end_pos = np.array([0.5, -0.3, 0.2]), np.array([0.5, 0.3, 0.2])
-
- # Define the obstacles:
- # - Ground
- ground_coll = pk.collision.HalfSpace.from_point_and_normal(
- np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0])
- )
- # - Wall
- wall_height = 0.4
- wall_width = 0.1
- wall_length = 0.4
- wall_intervals = np.arange(start=0.3, stop=wall_length + 0.3, step=0.05)
- translation = np.concatenate(
- [
- wall_intervals.reshape(-1, 1),
- np.full((wall_intervals.shape[0], 1), 0.0),
- np.full((wall_intervals.shape[0], 1), wall_height / 2),
- ],
- axis=1,
- )
- wall_coll = pk.collision.Capsule.from_radius_height(
- position=translation,
- radius=np.full((translation.shape[0], 1), wall_width / 2),
- height=np.full((translation.shape[0], 1), wall_height),
- )
- world_coll = [ground_coll, wall_coll]
-
- traj = pks.solve_trajopt(
- robot,
- robot_coll,
- world_coll,
- target_link_name,
- start_pos,
- down_wxyz,
- end_pos,
- down_wxyz,
- timesteps,
- dt,
- )
- traj = np.array(traj)
-
- # Visualize!
- server = viser.ViserServer()
- urdf_vis = ViserUrdf(server, urdf)
- server.scene.add_grid("/grid", width=2, height=2, cell_size=0.1)
- server.scene.add_mesh_trimesh(
- "wall_box",
- trimesh.creation.box(
- extents=(wall_length, wall_width, wall_height),
- transform=trimesh.transformations.translation_matrix(
- np.array([0.5, 0.0, wall_height / 2])
- ),
- ),
- )
- for name, pos in zip(["start", "end"], [start_pos, end_pos]):
- server.scene.add_frame(
- f"/{name}",
- position=pos,
- wxyz=down_wxyz,
- axes_length=0.05,
- axes_radius=0.01,
- )
-
- slider = server.gui.add_slider(
- "Timestep", min=0, max=timesteps - 1, step=1, initial_value=0
- )
- playing = server.gui.add_checkbox("Playing", initial_value=True)
-
- while True:
- if playing.value:
- slider.value = (slider.value + 1) % timesteps
-
- urdf_vis.update_cfg(traj[slider.value])
- time.sleep(1.0 / 10.0)
-
-
- if __name__ == "__main__":
- tyro.cli(main)
diff --git a/docs/source/examples/08_ik_with_mimic_joints.rst b/docs/source/examples/08_ik_with_mimic_joints.rst
deleted file mode 100644
index 7760754..0000000
--- a/docs/source/examples/08_ik_with_mimic_joints.rst
+++ /dev/null
@@ -1,145 +0,0 @@
-.. Comment: this file is automatically generated by `update_example_docs.py`.
- It should not be modified manually.
-
-IK with Mimic Joints
-==========================================
-
-
-This is a simple test to ensure that mimic joints are handled correctly in the IK solver.
-
-We procedurally generate a "zig-zag" chain of links with mimic joints, where:
-
-
-* the first joint is driven directly,
-* and the remaining joints are driven indirectly via mimic joints.
- The multipliers alternate between -1 and 1, and the offsets are all 0.
-
-All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details.
-
-
-
-.. code-block:: python
- :linenos:
-
-
- import tempfile
- import time
-
- import numpy as np
- import pyronot as pk
- import viser
- import yourdfpy
- from viser.extras import ViserUrdf
-
- import pyronot_snippets as pks
-
-
- def create_chain_xml(length: float = 0.2, num_chains: int = 5) -> str:
- def create_link(idx):
- return f"""
-
-
-
-
-
-
-
-
-
-
-
- """
-
- def create_joint(idx, multiplier=1.0, offset=0.0):
- mimic = f''
- return f"""
-
-
-
-
-
- {mimic if idx != 0 else ""}
-
-
- """
-
- world_joint_origin_z = length / 2.0
- xml = f"""
-
-
-
-
-
-
-
-
-
-
- """
- # Create the definition + first link.
- xml += create_link(0)
- xml += create_link(1)
- xml += create_joint(0)
-
- # Procedurally add more links.
- assert num_chains >= 2
- for idx in range(2, num_chains):
- xml += create_link(idx)
- current_offset = 0.0
- current_multiplier = 1.0 * ((-1) ** (idx % 2))
- xml += create_joint(idx - 1, current_multiplier, current_offset)
-
- xml += """
-
- """
- return xml
-
-
- def main():
- """Main function for basic IK."""
-
- xml = create_chain_xml(num_chains=10, length=0.1)
- with tempfile.NamedTemporaryFile(mode="w", suffix=".urdf") as f:
- f.write(xml)
- f.flush()
- urdf = yourdfpy.URDF.load(f.name)
-
- # Create robot.
- robot = pk.Robot.from_urdf(urdf)
-
- # Set up visualizer.
- server = viser.ViserServer()
- server.scene.add_grid("/ground", width=2, height=2)
- urdf_vis = ViserUrdf(server, urdf, root_node_name="/base")
- target_link_name_handle = server.gui.add_dropdown(
- "Target Link",
- robot.links.names,
- initial_value=robot.links.names[-1],
- )
-
- # Create interactive controller with initial position.
- ik_target = server.scene.add_transform_controls(
- "/ik_target", scale=0.2, position=(0.0, 0.1, 0.1), wxyz=(0, 0, 1, 0)
- )
- timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True)
-
- while True:
- # Solve IK.
- start_time = time.time()
- solution = pks.solve_ik(
- robot=robot,
- target_link_name=target_link_name_handle.value,
- target_position=np.array(ik_target.position),
- target_wxyz=np.array(ik_target.wxyz),
- )
-
- # Update timing handle.
- elapsed_time = time.time() - start_time
- timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000)
-
- # Update visualizer.
- urdf_vis.update_cfg(solution)
-
-
- if __name__ == "__main__":
- main()
diff --git a/docs/source/examples/09_hand_retargeting.rst b/docs/source/examples/09_hand_retargeting.rst
deleted file mode 100644
index a1a3c7f..0000000
--- a/docs/source/examples/09_hand_retargeting.rst
+++ /dev/null
@@ -1,376 +0,0 @@
-.. Comment: this file is automatically generated by `update_example_docs.py`.
- It should not be modified manually.
-
-Hand Retargeting
-==========================================
-
-
-Simpler shadow hand retargeting example.
-Find and unzip the shadowhand URDF at ``assets/hand_retargeting/shadowhand_urdf.zip``.
-
-All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details.
-
-
-
-.. code-block:: python
- :linenos:
-
-
- import pickle
- import time
- from pathlib import Path
- from typing import Tuple, TypedDict
-
- import jax
- import jax.numpy as jnp
- import jax_dataclasses as jdc
- import jaxlie
- import jaxls
- import numpy as onp
- import pyronot as pk
- import trimesh
- import viser
- import yourdfpy
- from scipy.spatial.transform import Rotation as R
- from viser.extras import ViserUrdf
-
- from retarget_helpers._utils import (
- MANO_TO_SHADOW_MAPPING,
- create_conn_tree,
- get_mapping_from_mano_to_shadow,
- )
-
-
- class RetargetingWeights(TypedDict):
- local_alignment: float
- """Local alignment weight, by matching the relative joint/keypoint positions and angles."""
- global_alignment: float
- """Global alignment weight, by matching the keypoint positions to the robot."""
- joint_smoothness: float
- """Joint smoothness weight."""
- root_smoothness: float
- """Root translation smoothness weight."""
-
-
- def main():
- """Main function for hand retargeting."""
-
- asset_dir = Path(__file__).parent / "retarget_helpers" / "hand"
-
- robot_urdf_path = asset_dir / "shadowhand_urdf" / "shadow_hand_right.urdf"
-
- def filename_handler(fname: str) -> str:
- base_path = robot_urdf_path.parent
- return yourdfpy.filename_handler_magic(fname, dir=base_path)
-
- try:
- urdf = yourdfpy.URDF.load(robot_urdf_path, filename_handler=filename_handler)
- except FileNotFoundError:
- raise FileNotFoundError(
- "Please unzip the included URDF at `retarget_helpers/hand/shadowhand_urdf.zip`."
- )
-
- robot = pk.Robot.from_urdf(urdf)
-
- # Get the mapping from MANO to Shadow Hand joints.
- shadow_link_idx, mano_joint_idx = get_mapping_from_mano_to_shadow(robot)
-
- # Create a mask for the MANO joints that are connected to the Shadow Hand.
- mano_mask = create_conn_tree(robot, shadow_link_idx)
-
- # Load source motion data.
- dexycb_motion_path = asset_dir / "dexycb_motion.pkl"
- with open(dexycb_motion_path, "rb") as f:
- dexycb_motion_data = pickle.load(f, encoding="latin1")
-
- # Load keypoints.
- keypoints = dexycb_motion_data["world_hand_joints"]
- assert not onp.isnan(keypoints).any()
- num_timesteps = keypoints.shape[0]
- num_mano_joints = len(MANO_TO_SHADOW_MAPPING)
-
- # Load mano hand contact information -- these are lists of lists,
- # len(contact_points_per_frame) = num_timesteps,
- # len(contact_points_per_frame[i]) = number of contacts in frame i,
- contact_points_per_frame = dexycb_motion_data["contact_object_points"]
- contact_indices_per_frame = dexycb_motion_data["contact_joint_indices"]
-
- # Now, we're going to pad this info + make a mask to indicate the padded regions.
- # We will also track the shadowhand joint indices, NOT the MANO joint indices.
- max_num_contacts = max(len(c) for c in contact_points_per_frame)
- padded_contact_points_per_frame = onp.zeros((num_timesteps, max_num_contacts, 3))
- padded_contact_indices_per_frame = onp.zeros(
- (num_timesteps, max_num_contacts), dtype=onp.int32
- )
- padded_contact_mask = onp.zeros((num_timesteps, max_num_contacts), dtype=onp.bool_)
- for i in range(num_timesteps):
- num_contacts = len(contact_points_per_frame[i])
- if num_contacts == 0:
- continue
- contact_shadowhand_indices = [
- robot.links.names.index(MANO_TO_SHADOW_MAPPING[j])
- for j in contact_indices_per_frame[i]
- ]
- padded_contact_points_per_frame[i, :num_contacts] = contact_points_per_frame[i]
- padded_contact_indices_per_frame[i, :num_contacts] = contact_shadowhand_indices
- padded_contact_mask[i, :num_contacts] = True
-
- # Load the object.
- object_mesh_vertices = dexycb_motion_data["object_mesh_vertices"]
- object_mesh_faces = dexycb_motion_data["object_mesh_faces"]
- object_pose_list = dexycb_motion_data["object_poses"] # (N, 4, 4)
- mesh = trimesh.Trimesh(object_mesh_vertices, object_mesh_faces)
-
- server = viser.ViserServer()
-
- # We will transform everything by the transform below, for aesthetics.
- server.scene.add_frame(
- "/scene_offset",
- show_axes=False,
- position=(-0.15415953, -0.73598871, 0.93434792),
- wxyz=(-0.381870867, 0.92421569, 0.0, 2.0004992e-32),
- )
- hand_mesh = server.scene.add_mesh_simple(
- "/scene_offset/hand_mesh",
- vertices=dexycb_motion_data["world_hand_vertices"][0, :, :],
- faces=dexycb_motion_data["hand_mesh_faces"],
- opacity=0.5,
- )
- base_frame = server.scene.add_frame("/scene_offset/base", show_axes=False)
- urdf_vis = ViserUrdf(server, urdf, root_node_name="/scene_offset/base")
- playing = server.gui.add_checkbox("playing", True)
- timestep_slider = server.gui.add_slider("timestep", 0, num_timesteps - 1, 1, 0)
- object_handle = server.scene.add_mesh_trimesh("/scene_offset/object", mesh)
- server.scene.add_grid("/grid", 2.0, 2.0)
-
- default_weights = RetargetingWeights(
- local_alignment=10.0,
- global_alignment=1.0,
- joint_smoothness=2.0,
- root_smoothness=2.0,
- )
-
- weights = pk.viewer.WeightTuner(
- server,
- default_weights, # type: ignore
- )
-
- Ts_world_root, joints = None, None
-
- def generate_trajectory():
- nonlocal Ts_world_root, joints
- gen_button.disabled = True
- Ts_world_root, joints = solve_retargeting(
- robot=robot,
- target_keypoints=keypoints,
- shadow_hand_link_retarget_indices=shadow_link_idx,
- mano_joint_retarget_indices=mano_joint_idx,
- mano_mask=mano_mask,
- weights=weights.get_weights(), # type: ignore
- )
- gen_button.disabled = False
-
- gen_button = server.gui.add_button("Retarget!")
- gen_button.on_click(lambda _: generate_trajectory())
-
- generate_trajectory()
- assert Ts_world_root is not None and joints is not None
-
- while True:
- with server.atomic():
- if playing.value:
- timestep_slider.value = (timestep_slider.value + 1) % num_timesteps
- tstep = timestep_slider.value
- base_frame.wxyz = onp.array(Ts_world_root.wxyz_xyz[tstep][:4])
- base_frame.position = onp.array(Ts_world_root.wxyz_xyz[tstep][4:])
- urdf_vis.update_cfg(onp.array(joints[tstep]))
-
- server.scene.add_point_cloud(
- "/scene_offset/target_keypoints",
- onp.array(keypoints[tstep]).reshape(-1, 3),
- onp.array((0, 0, 255))[None]
- .repeat(num_mano_joints, axis=0)
- .reshape(-1, 3),
- point_size=0.005,
- point_shape="sparkle",
- )
- server.scene.add_point_cloud(
- "/scene_offset/contact_points",
- onp.array(contact_points_per_frame[tstep]).reshape(-1, 3),
- onp.array((255, 0, 0))[None]
- .repeat(len(contact_points_per_frame[tstep]), axis=0)
- .reshape(-1, 3),
- point_size=0.005,
- point_shape="circle",
- )
- hand_mesh.vertices = dexycb_motion_data["world_hand_vertices"][tstep, :, :]
- object_handle.position = object_pose_list[tstep][:3, 3]
- object_handle.wxyz = R.from_matrix(object_pose_list[tstep][:3, :3]).as_quat(
- scalar_first=True
- )
-
- time.sleep(0.05)
-
-
- @jdc.jit
- def solve_retargeting(
- robot: pk.Robot,
- target_keypoints: jnp.ndarray,
- shadow_hand_link_retarget_indices: jnp.ndarray,
- mano_joint_retarget_indices: jnp.ndarray,
- mano_mask: jnp.ndarray,
- weights: RetargetingWeights,
- ) -> Tuple[jaxlie.SE3, jnp.ndarray]:
- """Solve the retargeting problem."""
-
- n_retarget = len(mano_joint_retarget_indices)
- timesteps = target_keypoints.shape[0]
-
- # Variables.
- class ManoJointsScaleVar(
- jaxls.Var[jax.Array], default_factory=lambda: jnp.ones((n_retarget, n_retarget))
- ): ...
-
- class OffsetVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros((3,))): ...
-
- var_joints = robot.joint_var_cls(jnp.arange(timesteps))
- var_Ts_world_root = jaxls.SE3Var(jnp.arange(timesteps))
- var_smpl_joints_scale = ManoJointsScaleVar(jnp.zeros(timesteps))
- var_offset = OffsetVar(jnp.zeros(timesteps))
-
- # Costs.
- costs: list[jaxls.Cost] = []
-
- @jaxls.Cost.create_factory
- def retargeting_cost(
- var_values: jaxls.VarValues,
- var_Ts_world_root: jaxls.SE3Var,
- var_robot_cfg: jaxls.Var[jnp.ndarray],
- var_smpl_joints_scale: ManoJointsScaleVar,
- keypoints: jnp.ndarray,
- ) -> jax.Array:
- """Retargeting factor, with a focus on:
- - matching the relative joint/keypoint positions (vectors).
- - and matching the relative angles between the vectors.
- """
- robot_cfg = var_values[var_robot_cfg]
- T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
- T_world_root = var_values[var_Ts_world_root]
- T_world_link = T_world_root @ T_root_link
-
- mano_pos = keypoints[jnp.array(mano_joint_retarget_indices)]
- robot_pos = T_world_link.translation()[
- jnp.array(shadow_hand_link_retarget_indices)
- ]
-
- # NxN grid of relative positions.
- delta_mano = mano_pos[:, None] - mano_pos[None, :]
- delta_robot = robot_pos[:, None] - robot_pos[None, :]
-
- # Vector regularization.
- position_scale = var_values[var_smpl_joints_scale][..., None]
- residual_position_delta = (
- (delta_mano - delta_robot * position_scale)
- * (1 - jnp.eye(delta_mano.shape[0])[..., None])
- * mano_mask[..., None]
- )
-
- # Vector angle regularization.
- delta_mano_normalized = delta_mano / jnp.linalg.norm(
- delta_mano + 1e-6, axis=-1, keepdims=True
- )
- delta_robot_normalized = delta_robot / jnp.linalg.norm(
- delta_robot + 1e-6, axis=-1, keepdims=True
- )
- residual_angle_delta = 1 - (delta_mano_normalized * delta_robot_normalized).sum(
- axis=-1
- )
- residual_angle_delta = (
- residual_angle_delta
- * (1 - jnp.eye(residual_angle_delta.shape[0]))
- * mano_mask
- )
-
- residual = (
- jnp.concatenate(
- [
- residual_position_delta.flatten(),
- residual_angle_delta.flatten(),
- ],
- axis=0,
- )
- * weights["local_alignment"]
- )
- return residual
-
- @jaxls.Cost.create_factory
- def pc_alignment_cost(
- var_values: jaxls.VarValues,
- var_Ts_world_root: jaxls.SE3Var,
- var_robot_cfg: jaxls.Var[jnp.ndarray],
- keypoints: jnp.ndarray,
- ) -> jax.Array:
- """Soft cost to align the human keypoints to the robot, in the world frame."""
- T_world_root = var_values[var_Ts_world_root]
- robot_cfg = var_values[var_robot_cfg]
- T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
- T_world_link = T_world_root @ T_root_link
- link_pos = T_world_link.translation()[shadow_hand_link_retarget_indices]
- keypoint_pos = keypoints[mano_joint_retarget_indices]
- return (link_pos - keypoint_pos).flatten() * weights["global_alignment"]
-
- @jaxls.Cost.create_factory
- def root_smoothness(
- var_values: jaxls.VarValues,
- var_Ts_world_root: jaxls.SE3Var,
- var_Ts_world_root_prev: jaxls.SE3Var,
- ) -> jax.Array:
- """Smoothness cost for the robot root translation."""
- return (
- var_values[var_Ts_world_root].translation()
- - var_values[var_Ts_world_root_prev].translation()
- ).flatten() * weights["root_smoothness"]
-
- costs = [
- retargeting_cost(
- var_Ts_world_root,
- var_joints,
- var_smpl_joints_scale,
- target_keypoints,
- ),
- pk.costs.limit_cost(
- jax.tree.map(lambda x: x[None], robot),
- var_joints,
- 100.0,
- ),
- pk.costs.smoothness_cost(
- robot.joint_var_cls(jnp.arange(1, timesteps)),
- robot.joint_var_cls(jnp.arange(0, timesteps - 1)),
- jnp.array([weights["joint_smoothness"]]),
- ),
- pc_alignment_cost(
- var_Ts_world_root,
- var_joints,
- target_keypoints,
- ),
- root_smoothness(
- jaxls.SE3Var(jnp.arange(1, timesteps)),
- jaxls.SE3Var(jnp.arange(0, timesteps - 1)),
- ),
- ]
-
- solution = (
- jaxls.LeastSquaresProblem(
- costs, [var_joints, var_Ts_world_root, var_smpl_joints_scale, var_offset]
- )
- .analyze()
- .solve()
- )
- transform = solution[var_Ts_world_root]
- offset = solution[var_offset]
- transform = jaxlie.SE3.from_translation(offset) @ transform
- return transform, solution[var_joints]
-
-
- if __name__ == "__main__":
- main()
diff --git a/docs/source/examples/10_humanoid_retargeting.rst b/docs/source/examples/10_humanoid_retargeting.rst
deleted file mode 100644
index 83b9f90..0000000
--- a/docs/source/examples/10_humanoid_retargeting.rst
+++ /dev/null
@@ -1,300 +0,0 @@
-.. Comment: this file is automatically generated by `update_example_docs.py`.
- It should not be modified manually.
-
-Humanoid Retargeting
-==========================================
-
-
-Simpler motion retargeting to the G1 humanoid.
-
-All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details.
-
-
-
-.. code-block:: python
- :linenos:
-
-
- import time
- from pathlib import Path
- from typing import Tuple, TypedDict
-
- import jax
- import jax.numpy as jnp
- import jax_dataclasses as jdc
- import jaxlie
- import jaxls
- import numpy as onp
- import pyronot as pk
- import viser
- from robot_descriptions.loaders.yourdfpy import load_robot_description
- from viser.extras import ViserUrdf
-
- from retarget_helpers._utils import (
- SMPL_JOINT_NAMES,
- create_conn_tree,
- get_humanoid_retarget_indices,
- )
-
-
- class RetargetingWeights(TypedDict):
- local_alignment: float
- """Local alignment weight, by matching the relative joint/keypoint positions and angles."""
- global_alignment: float
- """Global alignment weight, by matching the keypoint positions to the robot."""
-
-
- def main():
- """Main function for humanoid retargeting."""
-
- urdf = load_robot_description("g1_description")
- robot = pk.Robot.from_urdf(urdf)
-
- # Load source motion data:
- # - keypoints [N, 45, 3],
- # - left/right foot contact (boolean) 2 x [N],
- # - heightmap [H, W].
- asset_dir = Path(__file__).parent / "retarget_helpers" / "humanoid"
- smpl_keypoints = onp.load(asset_dir / "smpl_keypoints.npy")
- is_left_foot_contact = onp.load(asset_dir / "left_foot_contact.npy")
- is_right_foot_contact = onp.load(asset_dir / "right_foot_contact.npy")
- heightmap = onp.load(asset_dir / "heightmap.npy")
-
- num_timesteps = smpl_keypoints.shape[0]
- assert smpl_keypoints.shape == (num_timesteps, 45, 3)
- assert is_left_foot_contact.shape == (num_timesteps,)
- assert is_right_foot_contact.shape == (num_timesteps,)
-
- heightmap = pk.collision.Heightmap(
- pose=jaxlie.SE3.identity(),
- size=jnp.array([0.01, 0.01, 1.0]),
- height_data=heightmap,
- )
-
- # Get the left and right foot keypoints, projected on the heightmap.
- left_foot_keypoint_idx = SMPL_JOINT_NAMES.index("left_foot")
- right_foot_keypoint_idx = SMPL_JOINT_NAMES.index("right_foot")
- left_foot_keypoints = smpl_keypoints[..., left_foot_keypoint_idx, :].reshape(-1, 3)
- right_foot_keypoints = smpl_keypoints[..., right_foot_keypoint_idx, :].reshape(
- -1, 3
- )
- left_foot_keypoints = heightmap.project_points(left_foot_keypoints)
- right_foot_keypoints = heightmap.project_points(right_foot_keypoints)
-
- smpl_joint_retarget_indices, g1_joint_retarget_indices = (
- get_humanoid_retarget_indices()
- )
- smpl_mask = create_conn_tree(robot, g1_joint_retarget_indices)
-
- server = viser.ViserServer()
- base_frame = server.scene.add_frame("/base", show_axes=False)
- urdf_vis = ViserUrdf(server, urdf, root_node_name="/base")
- playing = server.gui.add_checkbox("playing", True)
- timestep_slider = server.gui.add_slider("timestep", 0, num_timesteps - 1, 1, 0)
- server.scene.add_mesh_trimesh("/heightmap", heightmap.to_trimesh())
-
- weights = pk.viewer.WeightTuner(
- server,
- RetargetingWeights( # type: ignore
- local_alignment=2.0,
- global_alignment=1.0,
- ),
- )
-
- Ts_world_root, joints = None, None
-
- def generate_trajectory():
- nonlocal Ts_world_root, joints
- gen_button.disabled = True
- Ts_world_root, joints = solve_retargeting(
- robot=robot,
- target_keypoints=smpl_keypoints,
- smpl_joint_retarget_indices=smpl_joint_retarget_indices,
- g1_joint_retarget_indices=g1_joint_retarget_indices,
- smpl_mask=smpl_mask,
- weights=weights.get_weights(), # type: ignore
- )
- gen_button.disabled = False
-
- gen_button = server.gui.add_button("Retarget!")
- gen_button.on_click(lambda _: generate_trajectory())
-
- generate_trajectory()
- assert Ts_world_root is not None and joints is not None
-
- while True:
- with server.atomic():
- if playing.value:
- timestep_slider.value = (timestep_slider.value + 1) % num_timesteps
- tstep = timestep_slider.value
- base_frame.wxyz = onp.array(Ts_world_root.wxyz_xyz[tstep][:4])
- base_frame.position = onp.array(Ts_world_root.wxyz_xyz[tstep][4:])
- urdf_vis.update_cfg(onp.array(joints[tstep]))
- server.scene.add_point_cloud(
- "/target_keypoints",
- onp.array(smpl_keypoints[tstep]),
- onp.array((0, 0, 255))[None].repeat(45, axis=0),
- point_size=0.01,
- )
-
- time.sleep(0.05)
-
-
- @jdc.jit
- def solve_retargeting(
- robot: pk.Robot,
- target_keypoints: jnp.ndarray,
- smpl_joint_retarget_indices: jnp.ndarray,
- g1_joint_retarget_indices: jnp.ndarray,
- smpl_mask: jnp.ndarray,
- weights: RetargetingWeights,
- ) -> Tuple[jaxlie.SE3, jnp.ndarray]:
- """Solve the retargeting problem."""
-
- n_retarget = len(smpl_joint_retarget_indices)
- timesteps = target_keypoints.shape[0]
-
- # Robot properties.
- # - Joints that should move less for natural humanoid motion.
- joints_to_move_less = jnp.array(
- [
- robot.joints.actuated_names.index(name)
- for name in ["left_hip_yaw_joint", "right_hip_yaw_joint", "torso_joint"]
- ]
- )
-
- # Variables.
- class SmplJointsScaleVarG1(
- jaxls.Var[jax.Array], default_factory=lambda: jnp.ones((n_retarget, n_retarget))
- ): ...
-
- class OffsetVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros((3,))): ...
-
- var_joints = robot.joint_var_cls(jnp.arange(timesteps))
- var_Ts_world_root = jaxls.SE3Var(jnp.arange(timesteps))
- var_smpl_joints_scale = SmplJointsScaleVarG1(jnp.zeros(timesteps))
- var_offset = OffsetVar(jnp.zeros(timesteps))
-
- # Costs.
- costs: list[jaxls.Cost] = []
-
- @jaxls.Cost.create_factory
- def retargeting_cost(
- var_values: jaxls.VarValues,
- var_Ts_world_root: jaxls.SE3Var,
- var_robot_cfg: jaxls.Var[jnp.ndarray],
- var_smpl_joints_scale: SmplJointsScaleVarG1,
- keypoints: jnp.ndarray,
- ) -> jax.Array:
- """Retargeting factor, with a focus on:
- - matching the relative joint/keypoint positions (vectors).
- - and matching the relative angles between the vectors.
- """
- robot_cfg = var_values[var_robot_cfg]
- T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
- T_world_root = var_values[var_Ts_world_root]
- T_world_link = T_world_root @ T_root_link
-
- smpl_pos = keypoints[jnp.array(smpl_joint_retarget_indices)]
- robot_pos = T_world_link.translation()[jnp.array(g1_joint_retarget_indices)]
-
- # NxN grid of relative positions.
- delta_smpl = smpl_pos[:, None] - smpl_pos[None, :]
- delta_robot = robot_pos[:, None] - robot_pos[None, :]
-
- # Vector regularization.
- position_scale = var_values[var_smpl_joints_scale][..., None]
- residual_position_delta = (
- (delta_smpl - delta_robot * position_scale)
- * (1 - jnp.eye(delta_smpl.shape[0])[..., None])
- * smpl_mask[..., None]
- )
-
- # Vector angle regularization.
- delta_smpl_normalized = delta_smpl / jnp.linalg.norm(
- delta_smpl + 1e-6, axis=-1, keepdims=True
- )
- delta_robot_normalized = delta_robot / jnp.linalg.norm(
- delta_robot + 1e-6, axis=-1, keepdims=True
- )
- residual_angle_delta = 1 - (delta_smpl_normalized * delta_robot_normalized).sum(
- axis=-1
- )
- residual_angle_delta = (
- residual_angle_delta
- * (1 - jnp.eye(residual_angle_delta.shape[0]))
- * smpl_mask
- )
-
- residual = (
- jnp.concatenate(
- [residual_position_delta.flatten(), residual_angle_delta.flatten()]
- )
- * weights["local_alignment"]
- )
- return residual
-
- @jaxls.Cost.create_factory
- def pc_alignment_cost(
- var_values: jaxls.VarValues,
- var_Ts_world_root: jaxls.SE3Var,
- var_robot_cfg: jaxls.Var[jnp.ndarray],
- keypoints: jnp.ndarray,
- ) -> jax.Array:
- """Soft cost to align the human keypoints to the robot, in the world frame."""
- T_world_root = var_values[var_Ts_world_root]
- robot_cfg = var_values[var_robot_cfg]
- T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
- T_world_link = T_world_root @ T_root_link
- link_pos = T_world_link.translation()[g1_joint_retarget_indices]
- keypoint_pos = keypoints[smpl_joint_retarget_indices]
- return (link_pos - keypoint_pos).flatten() * weights["global_alignment"]
-
- costs = [
- # Costs that are relatively self-contained to the robot.
- retargeting_cost(
- var_Ts_world_root,
- var_joints,
- var_smpl_joints_scale,
- target_keypoints,
- ),
- pk.costs.limit_cost(
- jax.tree.map(lambda x: x[None], robot),
- var_joints,
- 100.0,
- ),
- pk.costs.smoothness_cost(
- robot.joint_var_cls(jnp.arange(1, timesteps)),
- robot.joint_var_cls(jnp.arange(0, timesteps - 1)),
- jnp.array([0.2]),
- ),
- pk.costs.rest_cost(
- var_joints,
- var_joints.default_factory()[None],
- jnp.full(var_joints.default_factory().shape, 0.2)
- .at[joints_to_move_less]
- .set(2.0)[None],
- ),
- # Costs that are scene-centric.
- pc_alignment_cost(
- var_Ts_world_root,
- var_joints,
- target_keypoints,
- ),
- ]
-
- solution = (
- jaxls.LeastSquaresProblem(
- costs, [var_joints, var_Ts_world_root, var_smpl_joints_scale, var_offset]
- )
- .analyze()
- .solve()
- )
- transform = solution[var_Ts_world_root]
- offset = solution[var_offset]
- transform = jaxlie.SE3.from_translation(offset) @ transform
- return transform, solution[var_joints]
-
-
- if __name__ == "__main__":
- main()
diff --git a/docs/source/examples/11_hand_retargeting_fancy.rst b/docs/source/examples/11_hand_retargeting_fancy.rst
deleted file mode 100644
index 9f5e77e..0000000
--- a/docs/source/examples/11_hand_retargeting_fancy.rst
+++ /dev/null
@@ -1,450 +0,0 @@
-.. Comment: this file is automatically generated by `update_example_docs.py`.
- It should not be modified manually.
-
-Hand Retargeting (Fancy)
-==========================================
-
-
-Shadow Hand retargeting example, with costs to maintain contact with the object.
-Find and unzip the shadowhand URDF at ``assets/hand_retargeting/shadowhand_urdf.zip``.
-
-All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details.
-
-
-
-.. code-block:: python
- :linenos:
-
-
- import time
- from typing import Tuple, TypedDict
- from pathlib import Path
- import pickle
- import trimesh
- from scipy.spatial.transform import Rotation as R
-
- import jax
- import jax.numpy as jnp
- import jax_dataclasses as jdc
- import jaxlie
- import jaxls
- import numpy as onp
- import viser
- from viser.extras import ViserUrdf
- import yourdfpy
-
- import pyronot as pk
-
- from retarget_helpers._utils import (
- create_conn_tree,
- get_mapping_from_mano_to_shadow,
- MANO_TO_SHADOW_MAPPING,
- )
-
-
- class RetargetingWeights(TypedDict):
- local_alignment: float
- """Local alignment weight, by matching the relative joint/keypoint positions and angles."""
- global_alignment: float
- """Global alignment weight, by matching the keypoint positions to the robot."""
- contact: float
- """Contact weight, to maintain contact between the robot and the object."""
- contact_margin: float
- """Contact margin, to stop penalizing contact when the robot is already close to the object."""
- joint_smoothness: float
- """Joint smoothness weight."""
- root_smoothness: float
- """Root translation smoothness weight."""
-
-
- def main():
- """Main function for hand retargeting."""
-
- asset_dir = Path(__file__).parent / "retarget_helpers" / "hand"
-
- robot_urdf_path = asset_dir / "shadowhand_urdf" / "shadow_hand_right.urdf"
-
- def filename_handler(fname: str) -> str:
- base_path = robot_urdf_path.parent
- return yourdfpy.filename_handler_magic(fname, dir=base_path)
-
- try:
- urdf = yourdfpy.URDF.load(robot_urdf_path, filename_handler=filename_handler)
- except FileNotFoundError:
- raise FileNotFoundError(
- "Please unzip the included URDF at `retarget_helpers/hand/shadowhand_urdf.zip`."
- )
-
- robot = pk.Robot.from_urdf(urdf)
-
- # Get the mapping from MANO to Shadow Hand joints.
- shadow_link_idx, mano_joint_idx = get_mapping_from_mano_to_shadow(robot)
-
- # Create a mask for the MANO joints that are connected to the Shadow Hand.
- mano_mask = create_conn_tree(robot, shadow_link_idx)
-
- # Load source motion data.
- dexycb_motion_path = asset_dir / "dexycb_motion.pkl"
- with open(dexycb_motion_path, "rb") as f:
- dexycb_motion_data = pickle.load(f, encoding="latin1")
-
- # Load keypoints.
- keypoints = dexycb_motion_data["world_hand_joints"]
- assert not onp.isnan(keypoints).any()
- num_timesteps = keypoints.shape[0]
- num_mano_joints = len(MANO_TO_SHADOW_MAPPING)
-
- # Load mano hand contact information -- these are lists of lists,
- # len(contact_points_per_frame) = num_timesteps,
- # len(contact_points_per_frame[i]) = number of contacts in frame i,
- contact_points_per_frame = dexycb_motion_data["contact_object_points"]
- contact_indices_per_frame = dexycb_motion_data["contact_joint_indices"]
-
- # Now, we're going to pad this info + make a mask to indicate the padded regions.
- # We will also track the shadowhand joint indices, NOT the MANO joint indices.
- max_num_contacts = max(len(c) for c in contact_points_per_frame)
- padded_contact_points_per_frame = onp.zeros((num_timesteps, max_num_contacts, 3))
- padded_contact_indices_per_frame = onp.zeros(
- (num_timesteps, max_num_contacts), dtype=onp.int32
- )
- padded_contact_mask = onp.zeros((num_timesteps, max_num_contacts), dtype=onp.bool_)
- for i in range(num_timesteps):
- num_contacts = len(contact_points_per_frame[i])
- if num_contacts == 0:
- continue
- contact_shadowhand_indices = [
- robot.links.names.index(MANO_TO_SHADOW_MAPPING[j])
- for j in contact_indices_per_frame[i]
- ]
- padded_contact_points_per_frame[i, :num_contacts] = contact_points_per_frame[i]
- padded_contact_indices_per_frame[i, :num_contacts] = contact_shadowhand_indices
- padded_contact_mask[i, :num_contacts] = True
-
- # Load the object.
- object_mesh_vertices = dexycb_motion_data["object_mesh_vertices"]
- object_mesh_faces = dexycb_motion_data["object_mesh_faces"]
- object_pose_list = dexycb_motion_data["object_poses"] # (N, 4, 4)
- mesh = trimesh.Trimesh(object_mesh_vertices, object_mesh_faces)
-
- server = viser.ViserServer()
-
- # We will transform everything by the transform below, for aesthetics.
- server.scene.add_frame(
- "/scene_offset",
- show_axes=False,
- position=(-0.15415953, -0.73598871, 0.93434792),
- wxyz=(-0.381870867, 0.92421569, 0.0, 2.0004992e-32),
- )
- base_frame = server.scene.add_frame("/scene_offset/base", show_axes=False)
- urdf_vis = ViserUrdf(server, urdf, root_node_name="/scene_offset/base")
- playing = server.gui.add_checkbox("playing", True)
- timestep_slider = server.gui.add_slider("timestep", 0, num_timesteps - 1, 1, 0)
- object_handle = server.scene.add_mesh_trimesh("/scene_offset/object", mesh)
- server.scene.add_grid("/grid", 2.0, 2.0)
-
- default_weights = RetargetingWeights(
- local_alignment=10.0,
- global_alignment=1.0,
- contact=5.0,
- contact_margin=0.01,
- joint_smoothness=2.0,
- root_smoothness=2.0,
- )
-
- weights = pk.viewer.WeightTuner(
- server,
- default_weights, # type: ignore
- )
-
- Ts_world_root, joints = None, None
-
- def generate_trajectory():
- nonlocal Ts_world_root, joints
- gen_button.disabled = True
- Ts_world_root, joints = solve_retargeting(
- robot=robot,
- target_keypoints=keypoints,
- shadow_hand_link_retarget_indices=shadow_link_idx,
- mano_joint_retarget_indices=mano_joint_idx,
- mano_mask=mano_mask,
- contact_points_per_frame=jnp.array(padded_contact_points_per_frame),
- contact_indices_per_frame=jnp.array(padded_contact_indices_per_frame),
- contact_mask=jnp.array(padded_contact_mask),
- weights=weights.get_weights(), # type: ignore
- )
- gen_button.disabled = False
-
- gen_button = server.gui.add_button("Retarget!")
- gen_button.on_click(lambda _: generate_trajectory())
-
- generate_trajectory()
- assert Ts_world_root is not None and joints is not None
-
- while True:
- with server.atomic():
- if playing.value:
- timestep_slider.value = (timestep_slider.value + 1) % num_timesteps
- tstep = timestep_slider.value
- base_frame.wxyz = onp.array(Ts_world_root.wxyz_xyz[tstep][:4])
- base_frame.position = onp.array(Ts_world_root.wxyz_xyz[tstep][4:])
- urdf_vis.update_cfg(onp.array(joints[tstep]))
-
- server.scene.add_point_cloud(
- "/scene_offset/target_keypoints",
- onp.array(keypoints[tstep]).reshape(-1, 3),
- onp.array((0, 0, 255))[None]
- .repeat(num_mano_joints, axis=0)
- .reshape(-1, 3),
- point_size=0.005,
- point_shape="sparkle",
- )
- server.scene.add_point_cloud(
- "/scene_offset/contact_points",
- onp.array(contact_points_per_frame[tstep]).reshape(-1, 3),
- onp.array((255, 0, 0))[None]
- .repeat(len(contact_points_per_frame[tstep]), axis=0)
- .reshape(-1, 3),
- point_size=0.005,
- point_shape="circle",
- )
- object_handle.position = object_pose_list[tstep][:3, 3]
- object_handle.wxyz = R.from_matrix(object_pose_list[tstep][:3, :3]).as_quat(
- scalar_first=True
- )
-
- time.sleep(0.05)
-
-
- @jdc.jit
- def solve_retargeting(
- robot: pk.Robot,
- target_keypoints: jnp.ndarray,
- shadow_hand_link_retarget_indices: jnp.ndarray,
- mano_joint_retarget_indices: jnp.ndarray,
- mano_mask: jnp.ndarray,
- contact_points_per_frame: jnp.ndarray,
- contact_indices_per_frame: jnp.ndarray,
- contact_mask: jnp.ndarray,
- weights: RetargetingWeights,
- ) -> Tuple[jaxlie.SE3, jnp.ndarray]:
- """Solve the retargeting problem."""
-
- n_retarget = len(mano_joint_retarget_indices)
- timesteps = target_keypoints.shape[0]
-
- # Variables.
- class ManoJointsScaleVar(
- jaxls.Var[jax.Array], default_factory=lambda: jnp.ones((n_retarget, n_retarget))
- ): ...
-
- class OffsetVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros((3,))): ...
-
- var_joints = robot.joint_var_cls(jnp.arange(timesteps))
- var_Ts_world_root = jaxls.SE3Var(jnp.arange(timesteps))
- var_smpl_joints_scale = ManoJointsScaleVar(jnp.zeros(timesteps))
- var_offset = OffsetVar(jnp.zeros(timesteps))
-
- # Costs.
- costs: list[jaxls.Cost] = []
-
- @jaxls.Cost.create_factory
- def retargeting_cost(
- var_values: jaxls.VarValues,
- var_Ts_world_root: jaxls.SE3Var,
- var_robot_cfg: jaxls.Var[jnp.ndarray],
- var_smpl_joints_scale: ManoJointsScaleVar,
- keypoints: jnp.ndarray,
- ) -> jax.Array:
- """Retargeting factor, with a focus on:
- - matching the relative joint/keypoint positions (vectors).
- - and matching the relative angles between the vectors.
- """
- robot_cfg = var_values[var_robot_cfg]
- T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
- T_world_root = var_values[var_Ts_world_root]
- T_world_link = T_world_root @ T_root_link
-
- mano_pos = keypoints[jnp.array(mano_joint_retarget_indices)]
- robot_pos = T_world_link.translation()[
- jnp.array(shadow_hand_link_retarget_indices)
- ]
-
- # NxN grid of relative positions.
- delta_mano = mano_pos[:, None] - mano_pos[None, :]
- delta_robot = robot_pos[:, None] - robot_pos[None, :]
-
- # Vector regularization.
- position_scale = var_values[var_smpl_joints_scale][..., None]
- residual_position_delta = (
- (delta_mano - delta_robot * position_scale)
- * (1 - jnp.eye(delta_mano.shape[0])[..., None])
- * mano_mask[..., None]
- )
-
- # Vector angle regularization.
- delta_mano_normalized = delta_mano / jnp.linalg.norm(
- delta_mano + 1e-6, axis=-1, keepdims=True
- )
- delta_robot_normalized = delta_robot / jnp.linalg.norm(
- delta_robot + 1e-6, axis=-1, keepdims=True
- )
- residual_angle_delta = 1 - (delta_mano_normalized * delta_robot_normalized).sum(
- axis=-1
- )
- residual_angle_delta = (
- residual_angle_delta
- * (1 - jnp.eye(residual_angle_delta.shape[0]))
- * mano_mask
- )
-
- residual = (
- jnp.concatenate(
- [
- residual_position_delta.flatten(),
- residual_angle_delta.flatten(),
- ],
- axis=0,
- )
- * weights["local_alignment"]
- )
- return residual
-
- @jaxls.Cost.create_factory
- def scale_regularization(
- var_values: jaxls.VarValues,
- var_smpl_joints_scale: ManoJointsScaleVar,
- ) -> jax.Array:
- """Regularize the scale of the retargeted joints."""
- # Close to 1.
- res_0 = (var_values[var_smpl_joints_scale] - 1.0).flatten() * 1.0
- # Symmetric.
- res_1 = (
- var_values[var_smpl_joints_scale] - var_values[var_smpl_joints_scale].T
- ).flatten() * 100.0
- # Non-negative.
- res_2 = jnp.clip(-var_values[var_smpl_joints_scale], min=0).flatten() * 100.0
- return jnp.concatenate([res_0, res_1, res_2])
-
- @jaxls.Cost.create_factory
- def pc_alignment_cost(
- var_values: jaxls.VarValues,
- var_Ts_world_root: jaxls.SE3Var,
- var_robot_cfg: jaxls.Var[jnp.ndarray],
- keypoints: jnp.ndarray,
- ) -> jax.Array:
- """Soft cost to align the human keypoints to the robot, in the world frame."""
- T_world_root = var_values[var_Ts_world_root]
- robot_cfg = var_values[var_robot_cfg]
- T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
- T_world_link = T_world_root @ T_root_link
- link_pos = T_world_link.translation()[shadow_hand_link_retarget_indices]
- keypoint_pos = keypoints[mano_joint_retarget_indices]
- return (link_pos - keypoint_pos).flatten() * weights["global_alignment"]
-
- @jaxls.Cost.create_factory
- def root_smoothness(
- var_values: jaxls.VarValues,
- var_Ts_world_root: jaxls.SE3Var,
- var_Ts_world_root_prev: jaxls.SE3Var,
- ) -> jax.Array:
- """Smoothness cost for the robot root translation."""
- return (
- var_values[var_Ts_world_root].translation()
- - var_values[var_Ts_world_root_prev].translation()
- ).flatten() * weights["root_smoothness"]
-
- @jaxls.Cost.create_factory
- def contact_cost(
- var_values: jaxls.VarValues,
- var_T_world_root: jaxls.SE3Var,
- var_robot_cfg: jaxls.Var[jnp.ndarray],
- contact_points: jax.Array, # (J, P, 3)
- contact_indices: jax.Array, # (J,) - Actual robot joint indices.
- contact_points_mask: jax.Array, # (J, P)
- ) -> jax.Array:
- """Cost for maintaining contact between specified robot joints and object points."""
- robot_cfg = var_values[var_robot_cfg]
- T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
- T_world_root = var_values[var_T_world_root]
- T_world_link = T_world_root @ T_root_link
-
- contact_joint_positions_world = T_world_link.translation()[contact_indices]
-
- # Contact points are already in world frame (as processed in dexycb).
- # Calculate distances from each joint to its set of contact points
- # Shape contact_points: (J, P, 3), contact_joint_positions_world: (J, 3)
- # We want distance between joint J and points P for that joint.
- # residual: (J, P, 3)
- residual = contact_points - contact_joint_positions_world
-
- # Penalize distance beyond a margin.
- residual_penalty = jnp.maximum(
- jnp.abs(residual) - weights["contact_margin"], 0.0
- ) # (J, P, 3)
-
- # Apply mask.
- residual_penalty = (
- residual_penalty * contact_points_mask[..., None]
- ) # (J, P, 3)
- residual = residual_penalty.flatten() * weights["contact"]
-
- return residual
-
- costs = [
- # Costs that are relatively self-contained to the robot.
- retargeting_cost(
- var_Ts_world_root,
- var_joints,
- var_smpl_joints_scale,
- target_keypoints,
- ),
- scale_regularization(var_smpl_joints_scale),
- pk.costs.limit_cost(
- jax.tree.map(lambda x: x[None], robot),
- var_joints,
- 100.0,
- ),
- pk.costs.smoothness_cost(
- robot.joint_var_cls(jnp.arange(1, timesteps)),
- robot.joint_var_cls(jnp.arange(0, timesteps - 1)),
- jnp.array([weights["joint_smoothness"]]),
- ),
- pk.costs.rest_cost(
- var_joints,
- var_joints.default_factory()[None],
- jnp.array([0.2]),
- ),
- # Costs that are scene-centric.
- pc_alignment_cost(
- var_Ts_world_root,
- var_joints,
- target_keypoints,
- ),
- root_smoothness(
- jaxls.SE3Var(jnp.arange(1, timesteps)),
- jaxls.SE3Var(jnp.arange(0, timesteps - 1)),
- ),
- contact_cost(
- var_T_world_root=var_Ts_world_root,
- var_robot_cfg=var_joints,
- contact_points=contact_points_per_frame,
- contact_indices=contact_indices_per_frame,
- contact_points_mask=contact_mask,
- ),
- ]
-
- solution = (
- jaxls.LeastSquaresProblem(
- costs, [var_joints, var_Ts_world_root, var_smpl_joints_scale, var_offset]
- )
- .analyze()
- .solve()
- )
- transform = solution[var_Ts_world_root]
- offset = solution[var_offset]
- transform = jaxlie.SE3.from_translation(offset) @ transform
- return transform, solution[var_joints]
-
-
- if __name__ == "__main__":
- main()
diff --git a/docs/source/examples/12_humanoid_retargeting_fancy.rst b/docs/source/examples/12_humanoid_retargeting_fancy.rst
deleted file mode 100644
index f1929c4..0000000
--- a/docs/source/examples/12_humanoid_retargeting_fancy.rst
+++ /dev/null
@@ -1,519 +0,0 @@
-.. Comment: this file is automatically generated by `update_example_docs.py`.
- It should not be modified manually.
-
-Humanoid Retargeting (Fancy)
-==========================================
-
-
-Retarget motion to G1 humanoid, with scene contacts (keep feet close to contact
-points, while avoiding world-collisions).
-
-All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details.
-
-
-
-.. code-block:: python
- :linenos:
-
-
- import time
- from typing import Tuple, TypedDict
- from pathlib import Path
-
- import jax
- import jax.numpy as jnp
- import jax_dataclasses as jdc
- import jaxlie
- import jaxls
- import numpy as onp
- import pyronot as pk
- import viser
- from viser.extras import ViserUrdf
- from pyronot.collision import colldist_from_sdf, collide
- from robot_descriptions.loaders.yourdfpy import load_robot_description
-
- from retarget_helpers._utils import (
- SMPL_JOINT_NAMES,
- create_conn_tree,
- get_humanoid_retarget_indices,
- )
-
-
- class RetargetingWeights(TypedDict):
- local_alignment: float
- """Local alignment weight, by matching the relative joint/keypoint positions and angles."""
- global_alignment: float
- """Global alignment weight, by matching the keypoint positions to the robot."""
- floor_contact: float
- """Floor contact weight, to place the robot's foot on the floor."""
- root_smoothness: float
- """Root smoothness weight, to penalize the robot's root from jittering too much."""
- foot_skating: float
- """Foot skating weight, to penalize the robot's foot from moving when it is in contact with the floor."""
- world_collision: float
- """World collision weight, to penalize the robot from colliding with the world."""
-
-
- def main():
- """Main function for humanoid retargeting."""
-
- urdf = load_robot_description("g1_description")
- robot = pk.Robot.from_urdf(urdf)
- robot_coll = pk.collision.RobotCollision.from_urdf(urdf)
-
- # Load source motion data:
- # - keypoints [N, 45, 3],
- # - left/right foot contact (boolean) 2 x [N],
- # - heightmap [H, W].
- asset_dir = Path(__file__).parent / "retarget_helpers" / "humanoid"
- smpl_keypoints = onp.load(asset_dir / "smpl_keypoints.npy")
- is_left_foot_contact = onp.load(asset_dir / "left_foot_contact.npy")
- is_right_foot_contact = onp.load(asset_dir / "right_foot_contact.npy")
- heightmap = onp.load(asset_dir / "heightmap.npy")
-
- num_timesteps = smpl_keypoints.shape[0]
- assert smpl_keypoints.shape == (num_timesteps, 45, 3)
- assert is_left_foot_contact.shape == (num_timesteps,)
- assert is_right_foot_contact.shape == (num_timesteps,)
-
- heightmap = pk.collision.Heightmap(
- pose=jaxlie.SE3.identity(),
- size=jnp.array([0.01, 0.01, 1.0]),
- height_data=heightmap,
- )
-
- # Get the left and right foot keypoints, projected on the heightmap.
- left_foot_keypoint_idx = SMPL_JOINT_NAMES.index("left_foot")
- right_foot_keypoint_idx = SMPL_JOINT_NAMES.index("right_foot")
- left_foot_keypoints = smpl_keypoints[..., left_foot_keypoint_idx, :].reshape(-1, 3)
- right_foot_keypoints = smpl_keypoints[..., right_foot_keypoint_idx, :].reshape(
- -1, 3
- )
- left_foot_keypoints = heightmap.project_points(left_foot_keypoints)
- right_foot_keypoints = heightmap.project_points(right_foot_keypoints)
-
- smpl_joint_retarget_indices, g1_joint_retarget_indices = (
- get_humanoid_retarget_indices()
- )
- smpl_mask = create_conn_tree(robot, g1_joint_retarget_indices)
-
- server = viser.ViserServer()
- base_frame = server.scene.add_frame("/base", show_axes=False)
- urdf_vis = ViserUrdf(server, urdf, root_node_name="/base")
- playing = server.gui.add_checkbox("playing", True)
- timestep_slider = server.gui.add_slider("timestep", 0, num_timesteps - 1, 1, 0)
- server.scene.add_mesh_trimesh("/heightmap", heightmap.to_trimesh())
-
- weights = pk.viewer.WeightTuner(
- server,
- RetargetingWeights(
- local_alignment=2.0,
- global_alignment=1.0,
- floor_contact=1.0,
- root_smoothness=1.0,
- foot_skating=1.0,
- world_collision=1.0,
- ), # type: ignore
- )
-
- Ts_world_root, joints = None, None
-
- def generate_trajectory():
- nonlocal Ts_world_root, joints
- gen_button.disabled = True
- Ts_world_root, joints = solve_retargeting(
- robot=robot,
- robot_coll=robot_coll,
- target_keypoints=smpl_keypoints,
- is_left_foot_contact=is_left_foot_contact,
- is_right_foot_contact=is_right_foot_contact,
- left_foot_keypoints=left_foot_keypoints,
- right_foot_keypoints=right_foot_keypoints,
- smpl_joint_retarget_indices=smpl_joint_retarget_indices,
- g1_joint_retarget_indices=g1_joint_retarget_indices,
- smpl_mask=smpl_mask,
- heightmap=heightmap,
- weights=weights.get_weights(), # type: ignore
- )
- gen_button.disabled = False
-
- gen_button = server.gui.add_button("Retarget!")
- gen_button.on_click(lambda _: generate_trajectory())
-
- generate_trajectory()
- assert Ts_world_root is not None and joints is not None
-
- while True:
- with server.atomic():
- if playing.value:
- timestep_slider.value = (timestep_slider.value + 1) % num_timesteps
- tstep = timestep_slider.value
- base_frame.wxyz = onp.array(Ts_world_root.wxyz_xyz[tstep][:4])
- base_frame.position = onp.array(Ts_world_root.wxyz_xyz[tstep][4:])
- urdf_vis.update_cfg(onp.array(joints[tstep]))
- server.scene.add_point_cloud(
- "/target_keypoints",
- onp.array(smpl_keypoints[tstep]),
- onp.array((0, 0, 255))[None].repeat(45, axis=0),
- point_size=0.01,
- )
-
- time.sleep(0.05)
-
-
- @jdc.jit
- def solve_retargeting(
- robot: pk.Robot,
- robot_coll: pk.collision.RobotCollision,
- target_keypoints: jnp.ndarray,
- is_left_foot_contact: jnp.ndarray,
- is_right_foot_contact: jnp.ndarray,
- left_foot_keypoints: jnp.ndarray,
- right_foot_keypoints: jnp.ndarray,
- smpl_joint_retarget_indices: jnp.ndarray,
- g1_joint_retarget_indices: jnp.ndarray,
- smpl_mask: jnp.ndarray,
- heightmap: pk.collision.Heightmap,
- weights: RetargetingWeights,
- ) -> Tuple[jaxlie.SE3, jnp.ndarray]:
- """Solve the retargeting problem."""
-
- n_retarget = len(smpl_joint_retarget_indices)
- timesteps = target_keypoints.shape[0]
-
- # Robot properties.
- # - Joints that should move less for natural humanoid motion.
- joints_to_move_less = jnp.array(
- [
- robot.joints.actuated_names.index(name)
- for name in ["left_hip_yaw_joint", "right_hip_yaw_joint", "torso_joint"]
- ]
- )
- # - Foot indices.
- left_foot_idx = robot.links.names.index("left_ankle_roll_link")
- right_foot_idx = robot.links.names.index("right_ankle_roll_link")
-
- # Variables.
- class SmplJointsScaleVarG1(
- jaxls.Var[jax.Array], default_factory=lambda: jnp.ones((n_retarget, n_retarget))
- ): ...
-
- class OffsetVar(jaxls.Var[jax.Array], default_factory=lambda: jnp.zeros((3,))): ...
-
- var_joints = robot.joint_var_cls(jnp.arange(timesteps))
- var_Ts_world_root = jaxls.SE3Var(jnp.arange(timesteps))
- var_smpl_joints_scale = SmplJointsScaleVarG1(jnp.zeros(timesteps))
- var_offset = OffsetVar(jnp.zeros(timesteps))
-
- # Costs.
- costs: list[jaxls.Cost] = []
-
- @jaxls.Cost.create_factory
- def retargeting_cost(
- var_values: jaxls.VarValues,
- var_Ts_world_root: jaxls.SE3Var,
- var_robot_cfg: jaxls.Var[jnp.ndarray],
- var_smpl_joints_scale: SmplJointsScaleVarG1,
- keypoints: jnp.ndarray,
- ) -> jax.Array:
- """Retargeting factor, with a focus on:
- - matching the relative joint/keypoint positions (vectors).
- - and matching the relative angles between the vectors.
- """
- robot_cfg = var_values[var_robot_cfg]
- T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
- T_world_root = var_values[var_Ts_world_root]
- T_world_link = T_world_root @ T_root_link
-
- smpl_pos = keypoints[jnp.array(smpl_joint_retarget_indices)]
- robot_pos = T_world_link.translation()[jnp.array(g1_joint_retarget_indices)]
-
- # NxN grid of relative positions.
- delta_smpl = smpl_pos[:, None] - smpl_pos[None, :]
- delta_robot = robot_pos[:, None] - robot_pos[None, :]
-
- # Vector regularization.
- position_scale = var_values[var_smpl_joints_scale][..., None]
- residual_position_delta = (
- (delta_smpl - delta_robot * position_scale)
- * (1 - jnp.eye(delta_smpl.shape[0])[..., None])
- * smpl_mask[..., None]
- )
-
- # Vector angle regularization.
- delta_smpl_normalized = delta_smpl / jnp.linalg.norm(
- delta_smpl + 1e-6, axis=-1, keepdims=True
- )
- delta_robot_normalized = delta_robot / jnp.linalg.norm(
- delta_robot + 1e-6, axis=-1, keepdims=True
- )
- residual_angle_delta = 1 - (delta_smpl_normalized * delta_robot_normalized).sum(
- axis=-1
- )
- residual_angle_delta = (
- residual_angle_delta
- * (1 - jnp.eye(residual_angle_delta.shape[0]))
- * smpl_mask
- )
-
- residual = (
- jnp.concatenate(
- [residual_position_delta.flatten(), residual_angle_delta.flatten()]
- )
- * weights["local_alignment"]
- )
- return residual
-
- @jaxls.Cost.create_factory
- def scale_regularization(
- var_values: jaxls.VarValues,
- var_smpl_joints_scale: SmplJointsScaleVarG1,
- ) -> jax.Array:
- """Regularize the scale of the retargeted joints."""
- # Close to 1.
- res_0 = (var_values[var_smpl_joints_scale] - 1.0).flatten() * 1.0
- # Symmetric.
- res_1 = (
- var_values[var_smpl_joints_scale] - var_values[var_smpl_joints_scale].T
- ).flatten() * 100.0
- # Non-negative.
- res_2 = jnp.clip(-var_values[var_smpl_joints_scale], min=0).flatten() * 100.0
- return jnp.concatenate([res_0, res_1, res_2])
-
- @jaxls.Cost.create_factory
- def pc_alignment_cost(
- var_values: jaxls.VarValues,
- var_Ts_world_root: jaxls.SE3Var,
- var_robot_cfg: jaxls.Var[jnp.ndarray],
- keypoints: jnp.ndarray,
- ) -> jax.Array:
- """Soft cost to align the human keypoints to the robot, in the world frame."""
- T_world_root = var_values[var_Ts_world_root]
- robot_cfg = var_values[var_robot_cfg]
- T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
- T_world_link = T_world_root @ T_root_link
- link_pos = T_world_link.translation()[g1_joint_retarget_indices]
- keypoint_pos = keypoints[smpl_joint_retarget_indices]
- return (link_pos - keypoint_pos).flatten() * weights["global_alignment"]
-
- @jaxls.Cost.create_factory
- def floor_contact_cost(
- var_values: jaxls.VarValues,
- var_Ts_world_root: jaxls.SE3Var,
- var_robot_cfg: jaxls.Var[jnp.ndarray],
- var_offset: OffsetVar,
- is_left_foot_contact: jnp.ndarray,
- is_right_foot_contact: jnp.ndarray,
- left_foot_keypoints: jnp.ndarray,
- right_foot_keypoints: jnp.ndarray,
- ) -> jax.Array:
- """Cost to place the robot on the floor:
- - match foot keypoint positions, and
- - penalize the foot from tilting too much.
- """
- T_world_root = var_values[var_Ts_world_root]
- T_root_link = jaxlie.SE3(
- robot.forward_kinematics(cfg=var_values[var_robot_cfg])
- )
-
- offset = var_values[var_offset]
- left_foot_pos = (T_world_root @ T_root_link).translation()[
- left_foot_idx
- ] + offset
- right_foot_pos = (T_world_root @ T_root_link).translation()[
- right_foot_idx
- ] + offset
- left_foot_contact_cost = (
- is_left_foot_contact * (left_foot_pos - left_foot_keypoints) ** 2
- )
- right_foot_contact_cost = (
- is_right_foot_contact * (right_foot_pos - right_foot_keypoints) ** 2
- )
-
- # Also penalize the foot from tilting too much -- keep z axis up!
- left_foot_ori = (
- (T_world_root @ T_root_link).rotation().as_matrix()[left_foot_idx]
- )
- right_foot_ori = (
- (T_world_root @ T_root_link).rotation().as_matrix()[right_foot_idx]
- )
- left_foot_contact_residual_rot = jnp.where(
- is_left_foot_contact,
- left_foot_ori[2, 2] - 1,
- 0.0,
- )
- right_foot_contact_residual_rot = jnp.where(
- is_right_foot_contact,
- right_foot_ori[2, 2] - 1,
- 0.0,
- )
-
- return (
- jnp.concatenate(
- [
- left_foot_contact_cost.flatten(),
- right_foot_contact_cost.flatten(),
- left_foot_contact_residual_rot.flatten(),
- right_foot_contact_residual_rot.flatten(),
- ]
- )
- * weights["floor_contact"]
- )
-
- @jaxls.Cost.create_factory
- def root_smoothness(
- var_values: jaxls.VarValues,
- var_Ts_world_root: jaxls.SE3Var,
- var_Ts_world_root_prev: jaxls.SE3Var,
- ) -> jax.Array:
- """Smoothness cost for the robot root pose."""
- return (
- var_values[var_Ts_world_root].inverse() @ var_values[var_Ts_world_root_prev]
- ).log().flatten() * weights["root_smoothness"]
-
- @jaxls.Cost.create_factory
- def skating_cost(
- var_values: jaxls.VarValues,
- var_Ts_world_root: jaxls.SE3Var,
- var_robot_cfg: jaxls.Var[jnp.ndarray],
- var_offset: OffsetVar,
- var_Ts_world_root_prev: jaxls.SE3Var,
- var_robot_cfg_prev: jaxls.Var[jnp.ndarray],
- var_offset_prev: OffsetVar,
- is_left_foot_contact: jnp.ndarray,
- is_right_foot_contact: jnp.ndarray,
- ) -> jax.Array:
- """Cost to penalize the robot for skating."""
- T_world_root = var_values[var_Ts_world_root]
- robot_cfg = var_values[var_robot_cfg]
- T_root_link = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg))
- offset = var_values[var_offset]
- T_link = T_world_root @ T_root_link
- left_foot_pos = T_link.translation()[left_foot_idx] + offset
- right_foot_pos = T_link.translation()[right_foot_idx] + offset
-
- T_world_root_prev = var_values[var_Ts_world_root_prev]
- robot_cfg_prev = var_values[var_robot_cfg_prev]
- T_root_link_prev = jaxlie.SE3(robot.forward_kinematics(cfg=robot_cfg_prev))
- offset_prev = var_values[var_offset_prev]
- T_link_prev = T_world_root_prev @ T_root_link_prev
- left_foot_pos_prev = T_link_prev.translation()[left_foot_idx] + offset_prev
- right_foot_pos_prev = T_link_prev.translation()[right_foot_idx] + offset_prev
-
- skating_cost_left = is_left_foot_contact * (left_foot_pos - left_foot_pos_prev)
- skating_cost_right = is_right_foot_contact * (
- right_foot_pos - right_foot_pos_prev
- )
-
- return (
- jnp.stack([skating_cost_left, skating_cost_right]) * weights["foot_skating"]
- )
-
- @jaxls.Cost.create_factory
- def world_collision_cost(
- var_values: jaxls.VarValues,
- var_Ts_world_root: jaxls.SE3Var,
- var_robot_cfg: jaxls.Var[jnp.ndarray],
- var_offset: OffsetVar,
- ) -> jax.Array:
- """
- World collision; we intentionally use a low weight --
- high enough to lift the robot up from the ground, but
- low enough to not interfere with the retargeting.
- """
- Ts_world_root = var_values[var_Ts_world_root]
- T_offset = jaxlie.SE3.from_translation(var_values[var_offset])
- transform = T_offset @ Ts_world_root
-
- robot_cfg = var_values[var_robot_cfg]
- coll = robot_coll.at_config(robot, robot_cfg)
- coll = coll.transform(transform)
-
- dist = collide(coll, heightmap)
- act = colldist_from_sdf(dist, activation_dist=0.005)
- return act.flatten() * weights["world_collision"]
-
- costs = [
- # Costs that are relatively self-contained to the robot.
- retargeting_cost(
- var_Ts_world_root,
- var_joints,
- var_smpl_joints_scale,
- target_keypoints,
- ),
- scale_regularization(var_smpl_joints_scale),
- pk.costs.limit_cost(
- jax.tree.map(lambda x: x[None], robot),
- var_joints,
- 100.0,
- ),
- pk.costs.smoothness_cost(
- robot.joint_var_cls(jnp.arange(1, timesteps)),
- robot.joint_var_cls(jnp.arange(0, timesteps - 1)),
- jnp.array([0.2]),
- ),
- pk.costs.rest_cost(
- var_joints,
- var_joints.default_factory()[None],
- jnp.full(var_joints.default_factory().shape, 0.2)
- .at[joints_to_move_less]
- .set(2.0)[None],
- ),
- pk.costs.self_collision_cost(
- jax.tree.map(lambda x: x[None], robot),
- jax.tree.map(lambda x: x[None], robot_coll),
- var_joints,
- margin=0.05,
- weight=2.0,
- ),
- # Costs that are scene-centric.
- pc_alignment_cost(
- var_Ts_world_root,
- var_joints,
- target_keypoints,
- ),
- floor_contact_cost(
- var_Ts_world_root,
- var_joints,
- var_offset,
- is_left_foot_contact,
- is_right_foot_contact,
- left_foot_keypoints,
- right_foot_keypoints,
- ),
- root_smoothness(
- jaxls.SE3Var(jnp.arange(1, timesteps)),
- jaxls.SE3Var(jnp.arange(0, timesteps - 1)),
- ),
- skating_cost(
- jaxls.SE3Var(jnp.arange(1, timesteps)),
- robot.joint_var_cls(jnp.arange(1, timesteps)),
- OffsetVar(jnp.arange(1, timesteps)),
- jaxls.SE3Var(jnp.arange(0, timesteps - 1)),
- robot.joint_var_cls(jnp.arange(0, timesteps - 1)),
- OffsetVar(jnp.arange(0, timesteps - 1)),
- is_left_foot_contact[:-1],
- is_right_foot_contact[:-1],
- ),
- world_collision_cost(
- var_Ts_world_root,
- var_joints,
- var_offset,
- ),
- ]
-
- solution = (
- jaxls.LeastSquaresProblem(
- costs, [var_joints, var_Ts_world_root, var_smpl_joints_scale, var_offset]
- )
- .analyze()
- .solve()
- )
- transform = solution[var_Ts_world_root]
- offset = solution[var_offset]
- transform = jaxlie.SE3.from_translation(offset) @ transform
- return transform, solution[var_joints]
-
-
- if __name__ == "__main__":
- main()
diff --git a/docs/source/examples/13_spherized_robot_ik.rst b/docs/source/examples/13_spherized_robot_ik.rst
deleted file mode 100644
index e592167..0000000
--- a/docs/source/examples/13_spherized_robot_ik.rst
+++ /dev/null
@@ -1,80 +0,0 @@
-.. Comment: this file is automatically generated by `update_example_docs.py`.
- It should not be modified manually.
-
-Spherized Robot IK
-==========================================
-
-
-All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details.
-
-
-
-.. code-block:: python
- :linenos:
-
-
- import time
-
- import numpy as np
- import pyronot as pk
- import viser
- from viser.extras import ViserUrdf
-
- import pyronot_snippets as pks
- import yourdfpy
-
- def main():
-
- # Load the spherized panda urdf do not work!!!
- urdf_path = "resources/ur5/ur5_spherized.urdf"
- mesh_dir = "resources/ur5/meshes"
- target_link_name = "tool0"
-
- # urdf_path = "resources/panda/panda_spherized.urdf"
- # mesh_dir = "resources/panda/meshes"
- # target_link_name = "panda_hand"
-
- urdf = yourdfpy.URDF.load(urdf_path, mesh_dir=mesh_dir)
-
- # urdf = load_robot_description("ur5_description")
- # target_link_name = "ee_link"
-
- # Create robot.
- robot = pk.Robot.from_urdf(urdf)
- robot_coll = pk.collision.RobotCollisionSpherized.from_urdf(urdf)
- # robot_coll = pk.collision.RobotCollision.from_urdf(urdf)
-
- # Set up visualizer.
- server = viser.ViserServer()
- server.scene.add_grid("/ground", width=2, height=2)
- urdf_vis = ViserUrdf(server, urdf, root_node_name="/base")
-
- # Add the collision mesh to the visualizer.
-
- # Create interactive controller with initial position.
- ik_target = server.scene.add_transform_controls(
- "/ik_target", scale=0.2, position=(0.61, 0.0, 0.56), wxyz=(0, 0, 1, 0)
- )
- timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True)
-
- while True:
- # Solve IK.
- start_time = time.time()
- solution = pks.solve_ik(
- robot=robot,
- target_link_name=target_link_name,
- target_position=np.array(ik_target.position),
- target_wxyz=np.array(ik_target.wxyz),
- )
-
- # Update timing handle.
- elapsed_time = time.time() - start_time
- timing_handle.value = 0.99 * timing_handle.value + 0.01 * (elapsed_time * 1000)
-
- # Update visualizer.
- urdf_vis.update_cfg(solution)
- # Update the collision mesh.
- robot_coll_mesh = robot_coll.at_config(robot, solution).to_trimesh()
- server.scene.add_mesh_trimesh("/robot/collision", mesh=robot_coll_mesh)
- if __name__ == "__main__":
- main()
diff --git a/docs/source/examples/14_spherized_ik_with_coll.rst b/docs/source/examples/14_spherized_ik_with_coll.rst
deleted file mode 100644
index 396d446..0000000
--- a/docs/source/examples/14_spherized_ik_with_coll.rst
+++ /dev/null
@@ -1,95 +0,0 @@
-.. Comment: this file is automatically generated by `update_example_docs.py`.
- It should not be modified manually.
-
-IK with Collision
-==========================================
-
-
-Basic Inverse Kinematics with Collision Avoidance using PyRoNot.
-
-All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details.
-
-
-
-.. code-block:: python
- :linenos:
-
-
- import time
-
- import numpy as np
- import pyronot as pk
- import viser
- from pyronot.collision import HalfSpace, RobotCollision, RobotCollisionSpherized, Sphere
- from robot_descriptions.loaders.yourdfpy import load_robot_description
- from viser.extras import ViserUrdf
-
- import pyronot_snippets as pks
- import yourdfpy
-
-
- def main():
- """Main function for basic IK with collision."""
- urdf_path = "resources/ur5/ur5_spherized.urdf"
- mesh_dir = "resources/ur5/meshes"
- target_link_name = "tool0"
-
- # urdf_path = "resources/panda/panda_spherized.urdf"
- # mesh_dir = "resources/panda/meshes"
- # target_link_name = "panda_hand"
- urdf = yourdfpy.URDF.load(urdf_path, mesh_dir=mesh_dir)
- robot = pk.Robot.from_urdf(urdf)
-
- robot_coll = RobotCollisionSpherized.from_urdf(urdf)
- plane_coll = HalfSpace.from_point_and_normal(
- np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0])
- )
- sphere_coll = Sphere.from_center_and_radius(
- np.array([0.0, 0.0, 0.0]), np.array([0.05])
- )
-
- # Set up visualizer.
- server = viser.ViserServer()
- server.scene.add_grid("/ground", width=2, height=2, cell_size=0.1)
- urdf_vis = ViserUrdf(server, urdf, root_node_name="/robot")
-
- # Create interactive controller for IK target.
- ik_target_handle = server.scene.add_transform_controls(
- "/ik_target", scale=0.2, position=(0.5, 0.0, 0.5), wxyz=(0, 0, 1, 0)
- )
-
- # Create interactive controller and mesh for the sphere obstacle.
- sphere_handle = server.scene.add_transform_controls(
- "/obstacle", scale=0.2, position=(0.4, 0.3, 0.4)
- )
- server.scene.add_mesh_trimesh("/obstacle/mesh", mesh=sphere_coll.to_trimesh())
-
- timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True)
-
- while True:
- start_time = time.time()
-
- sphere_coll_world_current = sphere_coll.transform_from_wxyz_position(
- wxyz=np.array(sphere_handle.wxyz),
- position=np.array(sphere_handle.position),
- )
-
- world_coll_list = [plane_coll, sphere_coll_world_current]
- solution = pks.solve_ik_with_collision(
- robot=robot,
- coll=robot_coll,
- world_coll_list=world_coll_list,
- target_link_name=target_link_name,
- target_position=np.array(ik_target_handle.position),
- target_wxyz=np.array(ik_target_handle.wxyz),
- )
-
- # Update timing handle.
- timing_handle.value = (time.time() - start_time) * 1000
-
- # Update visualizer.
- urdf_vis.update_cfg(solution)
-
-
- if __name__ == "__main__":
- main()
diff --git a/docs/source/examples/15_spherized_ik_then_coll.rst b/docs/source/examples/15_spherized_ik_then_coll.rst
deleted file mode 100644
index ae43390..0000000
--- a/docs/source/examples/15_spherized_ik_then_coll.rst
+++ /dev/null
@@ -1,101 +0,0 @@
-.. Comment: this file is automatically generated by `update_example_docs.py`.
- It should not be modified manually.
-
-IK with Collision
-==========================================
-
-
-Basic Inverse Kinematics with Collision Avoidance using PyRoNot.
-
-All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details.
-
-
-
-.. code-block:: python
- :linenos:
-
-
- import time
-
- import numpy as np
- import pyronot as pk
- import viser
- from pyronot.collision import HalfSpace, RobotCollision, RobotCollisionSpherized, Sphere
- from robot_descriptions.loaders.yourdfpy import load_robot_description
- from viser.extras import ViserUrdf
-
- import pyronot_snippets as pks
- import yourdfpy
-
-
- def main():
- """Main function for basic IK with collision."""
- urdf_path = "resources/ur5/ur5_spherized.urdf"
- mesh_dir = "resources/ur5/meshes"
- target_link_name = "tool0"
-
- # urdf_path = "resources/panda/panda_spherized.urdf"
- # mesh_dir = "resources/panda/meshes"
- # target_link_name = "panda_hand"
- urdf = yourdfpy.URDF.load(urdf_path, mesh_dir=mesh_dir)
- robot = pk.Robot.from_urdf(urdf)
-
- robot_coll = RobotCollisionSpherized.from_urdf(urdf)
- plane_coll = HalfSpace.from_point_and_normal(
- np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0])
- )
- sphere_coll = Sphere.from_center_and_radius(
- np.array([0.0, 0.0, 0.0]), np.array([0.05])
- )
-
- # Set up visualizer.
- server = viser.ViserServer()
- server.scene.add_grid("/ground", width=2, height=2, cell_size=0.1)
- urdf_vis = ViserUrdf(server, urdf, root_node_name="/robot")
-
- # Create interactive controller for IK target.
- ik_target_handle = server.scene.add_transform_controls(
- "/ik_target", scale=0.2, position=(0.5, 0.0, 0.5), wxyz=(0, 0, 1, 0)
- )
-
- # Create interactive controller and mesh for the sphere obstacle.
- sphere_handle = server.scene.add_transform_controls(
- "/obstacle", scale=0.2, position=(0.4, 0.3, 0.4)
- )
- server.scene.add_mesh_trimesh("/obstacle/mesh", mesh=sphere_coll.to_trimesh())
-
- just_ik_timing_handle = server.gui.add_number("just ik (ms)", 0.001, disabled=True)
- coll_ik_timing_handle = server.gui.add_number("coll ik (ms)", 0.001, disabled=True)
- while True:
-
- sphere_coll_world_current = sphere_coll.transform_from_wxyz_position(
- wxyz=np.array(sphere_handle.wxyz),
- position=np.array(sphere_handle.position),
- )
- start_time = time.time()
- just_ik = pks.solve_ik(
- robot=robot,
- target_link_name=target_link_name,
- target_position=np.array(ik_target_handle.position),
- target_wxyz=np.array(ik_target_handle.wxyz),
- )
- just_ik_timing_handle.value = (time.time() - start_time) * 1000
-
- world_coll_list = [plane_coll, sphere_coll_world_current]
- start_time = time.time()
- solution = pks.solve_collision_with_config(
- robot=robot,
- coll=robot_coll,
- world_coll_list=world_coll_list,
- cfg=just_ik,
- )
- coll_ik_timing_handle.value = (time.time() - start_time) * 1000
-
- # Update timing handle.
-
- # Update visualizer.
- urdf_vis.update_cfg(solution)
-
-
- if __name__ == "__main__":
- main()
diff --git a/docs/source/examples/16_spherized_online_planning.rst b/docs/source/examples/16_spherized_online_planning.rst
deleted file mode 100644
index 813a845..0000000
--- a/docs/source/examples/16_spherized_online_planning.rst
+++ /dev/null
@@ -1,125 +0,0 @@
-.. Comment: this file is automatically generated by `update_example_docs.py`.
- It should not be modified manually.
-
-Online Planning
-==========================================
-
-
-Run online planning in collision aware environments.
-
-All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details.
-
-
-
-.. code-block:: python
- :linenos:
-
-
- import time
-
- import numpy as np
- import pyronot as pk
- import viser
- from pyronot.collision import HalfSpace, RobotCollision, RobotCollisionSpherized, Sphere
- from robot_descriptions.loaders.yourdfpy import load_robot_description
- from viser.extras import ViserUrdf
-
- import pyronot_snippets as pks
- import yourdfpy
-
-
- def main():
- """Main function for online planning with collision."""
- # urdf_path = "resources/ur5/ur5_spherized.urdf"
- # mesh_dir = "resources/ur5/meshes"
- # target_link_name = "robotiq_85_tool_link"
- urdf_path = "resources/panda/panda_spherized.urdf"
- mesh_dir = "resources/panda/meshes"
- target_link_name = "panda_hand"
- urdf = yourdfpy.URDF.load(urdf_path, mesh_dir=mesh_dir)
- robot = pk.Robot.from_urdf(urdf)
-
- robot_coll = RobotCollisionSpherized.from_urdf(urdf)
- plane_coll = HalfSpace.from_point_and_normal(
- np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0])
- )
- sphere_coll = Sphere.from_center_and_radius(
- np.array([0.0, 0.0, 0.0]), np.array([0.05])
- )
-
- # Define the online planning parameters.
- len_traj, dt = 5, 0.1
-
- # Set up visualizer.
- server = viser.ViserServer()
- server.scene.add_grid("/ground", width=2, height=2, cell_size=0.1)
- urdf_vis = ViserUrdf(server, urdf, root_node_name="/robot")
-
- # Create interactive controller for IK target.
- ik_target_handle = server.scene.add_transform_controls(
- "/ik_target", scale=0.2, position=(0.3, 0.0, 0.5), wxyz=(0, 0, 1, 0)
- )
-
- # Create interactive controller and mesh for the sphere obstacle.
- sphere_handle = server.scene.add_transform_controls(
- "/obstacle", scale=0.2, position=(0.4, 0.3, 0.4)
- )
- server.scene.add_mesh_trimesh("/obstacle/mesh", mesh=sphere_coll.to_trimesh())
- target_frame_handle = server.scene.add_batched_axes(
- "target_frame",
- axes_length=0.05,
- axes_radius=0.005,
- batched_positions=np.zeros((25, 3)),
- batched_wxyzs=np.array([[1.0, 0.0, 0.0, 0.0]] * 25),
- )
-
- timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True)
-
- sol_pos, sol_wxyz = None, None
- sol_traj = np.array(
- robot.joint_var_cls.default_factory()[None].repeat(len_traj, axis=0)
- )
- while True:
- start_time = time.time()
-
- sphere_coll_world_current = sphere_coll.transform_from_wxyz_position(
- wxyz=np.array(sphere_handle.wxyz),
- position=np.array(sphere_handle.position),
- )
-
- world_coll_list = [plane_coll, sphere_coll_world_current]
- sol_traj, sol_pos, sol_wxyz = pks.solve_online_planning(
- robot=robot,
- robot_coll=robot_coll,
- world_coll=world_coll_list,
- target_link_name=target_link_name,
- target_position=np.array(ik_target_handle.position),
- target_wxyz=np.array(ik_target_handle.wxyz),
- timesteps=len_traj,
- dt=dt,
- start_cfg=sol_traj[0],
- prev_sols=sol_traj,
- )
-
- # Update timing handle.
- timing_handle.value = (
- 0.99 * timing_handle.value + 0.01 * (time.time() - start_time) * 1000
- )
-
- # Update visualizer.
- urdf_vis.update_cfg(
- sol_traj[0]
- ) # The first step of the online trajectory solution.
-
- # Update the planned trajectory visualization.
- if hasattr(target_frame_handle, "batched_positions"):
- target_frame_handle.batched_positions = np.array(sol_pos) # type: ignore[attr-defined]
- target_frame_handle.batched_wxyzs = np.array(sol_wxyz) # type: ignore[attr-defined]
- else:
- # This is an older version of Viser.
- target_frame_handle.positions_batched = np.array(sol_pos) # type: ignore[attr-defined]
- target_frame_handle.wxyzs_batched = np.array(sol_wxyz) # type: ignore[attr-defined]
-
-
- if __name__ == "__main__":
- main()
diff --git a/docs/source/examples/17_geometries_example.rst b/docs/source/examples/17_geometries_example.rst
deleted file mode 100644
index cbb852f..0000000
--- a/docs/source/examples/17_geometries_example.rst
+++ /dev/null
@@ -1,114 +0,0 @@
-.. Comment: this file is automatically generated by `update_example_docs.py`.
- It should not be modified manually.
-
-Visualize Box and its six HalfSpace faces in viser.
-==========================================
-
-
-Run this example and use the transform controls to move/rotate the box
-and the GUI numbers to change length/width/height interactively.
-
-All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details.
-
-
-
-.. code-block:: python
- :linenos:
-
-
- import time
-
- import numpy as np
- import pyronot as pk
- import viser
- import trimesh
- from pyronot.collision._geometry import Box, Sphere, Capsule
- from pyronot.collision import collide
-
-
- def main():
- """Start viser and visualize a Box and its HalfSpace faces."""
-
- # Initial box parameters
- center = np.array([0.0, 0.0, 0.5])
- length, width, height = 0.1, 0.1, 0.1
-
- server = viser.ViserServer()
- server.scene.add_grid("/ground", width=2, height=2, cell_size=0.1)
-
- box_handle = server.scene.add_transform_controls(
- "/box_handle", scale=0.2, position=tuple(center), wxyz=(0, 0, 1, 0)
- )
-
- # Add transform controls for a sphere and a capsule
- sphere_handle = server.scene.add_transform_controls(
- "/sphere_handle", scale=0.15, position=(0.3, 0.0, 0.5), wxyz=(0, 0, 1, 0)
- )
- sphere_radius_handle = server.gui.add_number("Sphere Radius", 0.05)
-
- cap_handle = server.scene.add_transform_controls(
- "/cap_handle", scale=0.2, position=(-0.3, 0.0, 0.5), wxyz=(0, 0, 1, 0)
- )
- cap_radius_handle = server.gui.add_number("Capsule Radius", 0.03)
- cap_height_handle = server.gui.add_number("Capsule Height", 0.2)
-
- length_handle = server.gui.add_number("Length", length)
- width_handle = server.gui.add_number("Width", width)
- height_handle = server.gui.add_number("Height", height)
-
- server.scene.add_mesh_trimesh(
- "/box/mesh", mesh=Box.from_center_and_dimensions(center, length, width, height).to_trimesh()
- )
-
- server.scene.add_mesh_trimesh("/box/polytope", mesh=trimesh.Trimesh())
-
- while True:
- pos = np.array(box_handle.position)
- wxyz = np.array(box_handle.wxyz)
- length = float(length_handle.value) if hasattr(length_handle, "value") else float(length_handle)
- width = float(width_handle.value) if hasattr(width_handle, "value") else float(width_handle)
- height = float(height_handle.value) if hasattr(height_handle, "value") else float(height_handle)
-
- box = Box.from_center_and_dimensions(center=pos, length=length, width=width, height=height, wxyz=wxyz)
-
- # Sphere
- sph_pos = np.array(sphere_handle.position)
- sph_wxyz = np.array(sphere_handle.wxyz)
- sph_rad = float(sphere_radius_handle.value) if hasattr(sphere_radius_handle, "value") else float(sphere_radius_handle)
- sphere = Sphere.from_center_and_radius(center=sph_pos, radius=sph_rad)
-
- # Capsule
- cap_pos = np.array(cap_handle.position)
- cap_wxyz = np.array(cap_handle.wxyz)
- cap_rad = float(cap_radius_handle.value) if hasattr(cap_radius_handle, "value") else float(cap_radius_handle)
- cap_h = float(cap_height_handle.value) if hasattr(cap_height_handle, "value") else float(cap_height_handle)
- capsule = Capsule.from_radius_height(radius=cap_rad, height=cap_h, position=cap_pos, wxyz=cap_wxyz)
-
- server.scene.add_mesh_trimesh("/box/mesh", mesh=box.to_trimesh())
- server.scene.add_mesh_trimesh("/sphere/mesh", mesh=sphere.to_trimesh())
- server.scene.add_mesh_trimesh("/cap/mesh", mesh=capsule.to_trimesh())
-
- poly_mesh = box.to_trimesh()
- server.scene.add_mesh_trimesh("/box/polytope", mesh=poly_mesh)
-
- # Collision checks between all unique pairs
- pairs = [
- ("Box", box, "Sphere", sphere),
- ("Box", box, "Capsule", capsule),
- ("Sphere", sphere, "Capsule", capsule),
- ]
- for name1, g1, name2, g2 in pairs:
- try:
- d = collide(g1, g2)
- # d is a jax Array; convert to python float if scalar
- d_val = float(d)
- if d_val < 0.0:
- print(f"Collision detected {name1} vs {name2}: distance={d_val:.6f}")
- except Exception as e:
- print(f"Error computing collision {name1} vs {name2}: {e}")
-
- time.sleep(1.0 / 60.0)
-
-
- if __name__ == "__main__":
- main()
diff --git a/docs/source/examples/18_spherized_trajopt.rst b/docs/source/examples/18_spherized_trajopt.rst
deleted file mode 100644
index 0e79fdc..0000000
--- a/docs/source/examples/18_spherized_trajopt.rst
+++ /dev/null
@@ -1,133 +0,0 @@
-.. Comment: this file is automatically generated by `update_example_docs.py`.
- It should not be modified manually.
-
-Trajectory Optimization
-==========================================
-
-
-Basic Trajectory Optimization using PyRoNot.
-
-Robot going over a wall, while avoiding world-collisions.
-
-All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details.
-
-
-
-.. code-block:: python
- :linenos:
-
-
- import time
- from typing import Literal
-
- import numpy as np
- import pyronot as pk
- import trimesh
- import tyro
- import viser
- from viser.extras import ViserUrdf
- from robot_descriptions.loaders.yourdfpy import load_robot_description
-
- import pyronot_snippets as pks
- import yourdfpy
-
- def main(robot_name: Literal["ur5", "panda"] = "panda"):
- if robot_name == "ur5":
- raise ValueError("UR5 not supported yet")
-
- elif robot_name == "panda":
- urdf_path = "resources/panda/panda_spherized.urdf"
- mesh_dir = "resources/panda/meshes"
- target_link_name = "panda_hand"
- urdf = yourdfpy.URDF.load(urdf_path, mesh_dir=mesh_dir)
-
- down_wxyz = np.array([0, 0, 1, 0]) # for panda!
- robot = pk.Robot.from_urdf(urdf)
-
- else:
- raise ValueError(f"Invalid robot: {robot_name}")
-
- robot_coll = pk.collision.RobotCollisionSpherized.from_urdf(urdf)
-
- # Define the trajectory problem:
- # - number of timesteps, timestep size
- timesteps, dt = 25, 0.02
- # - the start and end poses.
- start_pos, end_pos = np.array([0.5, -0.3, 0.2]), np.array([0.5, 0.3, 0.2])
-
- # Define the obstacles:
- # - Ground
- ground_coll = pk.collision.HalfSpace.from_point_and_normal(
- np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0])
- )
- # - Wall
- wall_height = 0.4
- wall_width = 0.1
- wall_length = 0.4
- wall_intervals = np.arange(start=0.3, stop=wall_length + 0.3, step=0.05)
- translation = np.concatenate(
- [
- wall_intervals.reshape(-1, 1),
- np.full((wall_intervals.shape[0], 1), 0.0),
- np.full((wall_intervals.shape[0], 1), wall_height / 2),
- ],
- axis=1,
- )
- wall_coll = pk.collision.Capsule.from_radius_height(
- position=translation,
- radius=np.full((translation.shape[0], 1), wall_width / 2),
- height=np.full((translation.shape[0], 1), wall_height),
- )
- world_coll = [ground_coll, wall_coll]
-
- traj = pks.solve_trajopt(
- robot,
- robot_coll,
- world_coll,
- target_link_name,
- start_pos,
- down_wxyz,
- end_pos,
- down_wxyz,
- timesteps,
- dt,
- )
- traj = np.array(traj)
-
- # Visualize!
- server = viser.ViserServer()
- urdf_vis = ViserUrdf(server, urdf)
- server.scene.add_grid("/grid", width=2, height=2, cell_size=0.1)
- server.scene.add_mesh_trimesh(
- "wall_box",
- trimesh.creation.box(
- extents=(wall_length, wall_width, wall_height),
- transform=trimesh.transformations.translation_matrix(
- np.array([0.5, 0.0, wall_height / 2])
- ),
- ),
- )
- for name, pos in zip(["start", "end"], [start_pos, end_pos]):
- server.scene.add_frame(
- f"/{name}",
- position=pos,
- wxyz=down_wxyz,
- axes_length=0.05,
- axes_radius=0.01,
- )
-
- slider = server.gui.add_slider(
- "Timestep", min=0, max=timesteps - 1, step=1, initial_value=0
- )
- playing = server.gui.add_checkbox("Playing", initial_value=True)
-
- while True:
- if playing.value:
- slider.value = (slider.value + 1) % timesteps
-
- urdf_vis.update_cfg(traj[slider.value])
- time.sleep(1.0 / 10.0)
-
-
- if __name__ == "__main__":
- tyro.cli(main)
diff --git a/docs/source/examples/19_spherized_ik_with_coll_exclude_link.rst b/docs/source/examples/19_spherized_ik_with_coll_exclude_link.rst
deleted file mode 100644
index 3c214a7..0000000
--- a/docs/source/examples/19_spherized_ik_with_coll_exclude_link.rst
+++ /dev/null
@@ -1,123 +0,0 @@
-.. Comment: this file is automatically generated by `update_example_docs.py`.
- It should not be modified manually.
-
-IK with Collision
-==========================================
-
-
-Basic Inverse Kinematics with Collision Avoidance using PyRoNot.
-
-All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details.
-
-
-
-.. code-block:: python
- :linenos:
-
-
- import time
-
- import numpy as np
- import pyronot as pk
- import viser
- from pyronot.collision import HalfSpace, RobotCollision, RobotCollisionSpherized, Sphere
- from robot_descriptions.loaders.yourdfpy import load_robot_description
- from viser.extras import ViserUrdf
-
- import pyronot_snippets as pks
- import yourdfpy
-
-
- def main():
- """Main function for basic IK with collision."""
- urdf_path = "resources/ur5/ur5_spherized.urdf"
- mesh_dir = "resources/ur5/meshes"
- target_link_name = "robotiq_85_tool_link"
-
- # urdf_path = "resources/panda/panda_spherized.urdf"
- # mesh_dir = "resources/panda/meshes"
- # target_link_name = "panda_hand"
- urdf = yourdfpy.URDF.load(urdf_path, mesh_dir=mesh_dir)
- robot = pk.Robot.from_urdf(urdf, default_joint_cfg=[0, -1.57, 1.57, -1.57, -1.57, 0])
-
- robot_coll = RobotCollisionSpherized.from_urdf(urdf)
- plane_coll = HalfSpace.from_point_and_normal(
- np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0])
- )
- sphere_coll = Sphere.from_center_and_radius(
- np.array([0.0, 0.0, 0.0]), np.array([0.05])
- )
-
- # Set up visualizer.
- server = viser.ViserServer()
- server.scene.add_grid("/ground", width=2, height=2, cell_size=0.1)
- urdf_vis = ViserUrdf(server, urdf, root_node_name="/robot")
-
- # Create interactive controller for IK target.
- ik_target_handle = server.scene.add_transform_controls(
- "/ik_target", scale=0.2, position=(0.0, 0.6, 0.2), wxyz=(0, 0, 1, 0)
- )
-
- # Create interactive controller and mesh for the sphere obstacle.
- sphere_handle = server.scene.add_transform_controls(
- "/obstacle", scale=0.2, position=(0.4, 0.3, 0.4)
- )
- server.scene.add_mesh_trimesh("/obstacle/mesh", mesh=sphere_coll.to_trimesh())
-
- timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True)
- exclude_links_from_cc = ["offset_link", "base_link"]
- exclude_link_mask = robot.links.get_link_mask_from_names(exclude_links_from_cc)
- print(exclude_link_mask)
- while True:
- start_time = time.time()
-
- sphere_coll_world_current = sphere_coll.transform_from_wxyz_position(
- wxyz=np.array(sphere_handle.wxyz),
- position=np.array(sphere_handle.position),
- )
-
- world_coll_list = [plane_coll, sphere_coll_world_current]
- solution = pks.solve_ik_with_collision(
- robot=robot,
- coll=robot_coll,
- world_coll_list=world_coll_list,
- target_link_name=target_link_name,
- target_position=np.array(ik_target_handle.position),
- target_wxyz=np.array(ik_target_handle.wxyz),
- )
-
- # Update timing handle.
- timing_handle.value = (time.time() - start_time) * 1000
-
- # Update visualizer.
- urdf_vis.update_cfg(solution)
- # print(robot.links.names)
- # Compute the collision of the solution
- distance_link_to_plane = robot_coll.compute_world_collision_distance(
- robot,
- solution,
- plane_coll
- )
- distance_link_to_plane = RobotCollisionSpherized.mask_collision_distance(distance_link_to_plane, exclude_link_mask)
- # print(distance_link_to_plane)
- distance_link_to_sphere = robot_coll.compute_world_collision_distance(
- robot,
- solution,
- sphere_coll
- )
- distance_link_to_sphere = RobotCollisionSpherized.mask_collision_distance(distance_link_to_sphere, exclude_link_mask)
- # print(distance_link_to_sphere)
- # Visualize collision representation
- robot_coll_config: Sphere = robot_coll.at_config(robot, solution)
- # print(robot_coll_config.get_batch_axes()[-1])
- robot_coll_mesh = robot_coll_config.to_trimesh()
- server.scene.add_mesh_trimesh(
- "/robot_coll",
- mesh=robot_coll_mesh,
- wxyz=(1.0, 0.0, 0.0, 0.0),
- position=(0.0, 0.0, 0.0),
- )
-
-
- if __name__ == "__main__":
- main()
diff --git a/docs/source/examples/20_negative_distance.rst b/docs/source/examples/20_negative_distance.rst
deleted file mode 100644
index ee45839..0000000
--- a/docs/source/examples/20_negative_distance.rst
+++ /dev/null
@@ -1,119 +0,0 @@
-.. Comment: this file is automatically generated by `update_example_docs.py`.
- It should not be modified manually.
-
-IK with Collision
-==========================================
-
-
-Basic Inverse Kinematics with Collision Avoidance using PyRoNot.
-
-All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details.
-
-
-
-.. code-block:: python
- :linenos:
-
-
- import time
-
- import numpy as np
- import pyronot as pk
- import jax.numpy as jnp
- import viser
- from pyronot.collision import HalfSpace, RobotCollision, RobotCollisionSpherized, Sphere
- from robot_descriptions.loaders.yourdfpy import load_robot_description
- from viser.extras import ViserUrdf
-
- import pyronot_snippets as pks
- import yourdfpy
-
-
- def main():
- """Main function for basic IK with collision."""
- urdf_path = "resources/ur5/ur5_spherized.urdf"
- mesh_dir = "resources/ur5/meshes"
- target_link_name = "tool0"
- srdf_path = "resources/ur5/ur5.srdf"
- # urdf_path = "resources/panda/panda_spherized.urdf"
- # mesh_dir = "resources/panda/meshes"
- # target_link_name = "panda_hand"
- urdf = yourdfpy.URDF.load(urdf_path, mesh_dir=mesh_dir)
- robot = pk.Robot.from_urdf(urdf)
-
- robot_coll = RobotCollisionSpherized.from_urdf(urdf, srdf_path=srdf_path)
- print(robot_coll.link_names)
- plane_coll = HalfSpace.from_point_and_normal(
- np.array([0.0, 0.0, 0.0]), np.array([0.0, 0.0, 1.0])
- )
- sphere_coll = Sphere.from_center_and_radius(
- np.array([0.0, 0.0, 0.0]), np.array([0.05])
- )
-
- # Set up visualizer.
- server = viser.ViserServer()
- server.scene.add_grid("/ground", width=2, height=2, cell_size=0.1)
- urdf_vis = ViserUrdf(server, urdf, root_node_name="/robot")
-
- # Create interactive controller for IK target.
- ik_target_handle = server.scene.add_transform_controls(
- "/ik_target", scale=0.2, position=(0.5, 0.0, 0.5), wxyz=(0, 0, 1, 0)
- )
-
- # Create interactive controller and mesh for the sphere obstacle.
- sphere_handle = server.scene.add_transform_controls(
- "/obstacle", scale=0.2, position=(0.4, 0.3, 0.4)
- )
- server.scene.add_mesh_trimesh("/obstacle/mesh", mesh=sphere_coll.to_trimesh())
-
- timing_handle = server.gui.add_number("Elapsed (ms)", 0.001, disabled=True)
- distance_self_collision_handle = server.gui.add_number("Distance Self Collision", 0.001, disabled=True)
- link1_handle = server.gui.add_text("Closest Link 1", initial_value="", disabled=True)
- link2_handle = server.gui.add_text("Closest Link 2", initial_value="", disabled=True)
-
-
- while True:
- start_time = time.time()
-
- sphere_coll_world_current = sphere_coll.transform_from_wxyz_position(
- wxyz=np.array(sphere_handle.wxyz),
- position=np.array(sphere_handle.position),
- )
-
- world_coll_list = [plane_coll, sphere_coll_world_current]
- solution = pks.solve_ik_with_collision(
- robot=robot,
- coll=robot_coll,
- world_coll_list=world_coll_list,
- target_link_name=target_link_name,
- target_position=np.array(ik_target_handle.position),
- target_wxyz=np.array(ik_target_handle.wxyz),
- )
- # Compute self-collision distances
- distance_self_collision = robot_coll.compute_self_collision_distance(robot, solution)
-
- # Find the closest pair
- min_idx = int(jnp.argmin(distance_self_collision))
- min_distance = float(distance_self_collision[min_idx])
-
- # Get the link names for this pair
- link_i_idx = int(robot_coll.active_idx_i[min_idx])
- link_j_idx = int(robot_coll.active_idx_j[min_idx])
- link_i_name = robot_coll.link_names[link_i_idx]
- link_j_name = robot_coll.link_names[link_j_idx]
-
- # Update GUI
- distance_self_collision_handle.value = min_distance
- link1_handle.value = link_i_name
- link2_handle.value = link_j_name
- print(f"link_i_name: {link_i_name}, link_j_name: {link_j_name}")
- timing_handle.value = (time.time() - start_time) * 1000
-
- # Update visualizer.
- robot_coll_mesh = robot_coll.at_config(robot, solution).to_trimesh()
- server.scene.add_mesh_trimesh("/robot/collision", mesh=robot_coll_mesh)
- urdf_vis.update_cfg(solution)
-
-
- if __name__ == "__main__":
- main()
diff --git a/docs/source/examples/21_quantize_collision.rst b/docs/source/examples/21_quantize_collision.rst
deleted file mode 100644
index b6772f5..0000000
--- a/docs/source/examples/21_quantize_collision.rst
+++ /dev/null
@@ -1,131 +0,0 @@
-.. Comment: this file is automatically generated by `update_example_docs.py`.
- It should not be modified manually.
-
-Quantize Collision Example
-==========================================
-
-
-Script to generate obstacles and robot configurations for quantization testing.
-
-All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details.
-
-
-
-.. code-block:: python
- :linenos:
-
-
- import jax
- import jax.numpy as jnp
- import numpy as np
- import time
- import pyronot as pk
- from pyronot.collision import RobotCollisionSpherized, Sphere
- from pyronot._robot_urdf_parser import RobotURDFParser
- import yourdfpy
- from pyronot.utils import quantize
-
- def generate_spheres(n_spheres):
- print(f"Generating {n_spheres} random spheres...")
- spheres = []
- for _ in range(n_spheres):
- center = np.random.uniform(low=-1.0, high=1.0, size=(3,))
- radius = np.random.uniform(low=0.05, high=0.2)
- sphere = Sphere.from_center_and_radius(center, np.array([radius]))
- spheres.append(sphere)
-
- # Tree map them to create a batch of spheres
- spheres_batch = jax.tree.map(lambda *args: jnp.stack(args), *spheres)
- print(f"Generated {n_spheres} spheres.")
- return spheres_batch
-
- def generate_configs(joints, n_configs):
- print(f"Generating {n_configs} random robot configurations...")
- q_batch = np.random.uniform(
- low=joints.lower_limits,
- high=joints.upper_limits,
- size=(n_configs, joints.num_actuated_joints)
- )
- print(f"Generated {n_configs} robot configurations.")
- print(f"Configurations shape: {q_batch.shape}")
- return q_batch
-
- def make_collision_checker(robot, robot_coll):
- @jax.jit
- def check_collisions(q_batch, obstacles):
- # q_batch: (N_configs, dof)
- # obstacles: Sphere batch (N_spheres)
-
- # Define single config check
- def check_single(q, obs):
- return robot_coll.compute_world_collision_distance(robot, q, obs)
-
- # Vmap over configs
- # in_axes: q=(0), obs=(None) -> we want to check each q against ALL obs
- return jax.vmap(check_single, in_axes=(0, None))(q_batch, obstacles)
-
- return check_collisions
-
- def run_benchmark(name, check_fn, q_batch, obstacles):
- print(f"\n{name}:")
-
- # Metrics
- q_size_mb = q_batch.nbytes / 1024 / 1024
- spheres_size_mb = sum(x.nbytes for x in jax.tree_util.tree_leaves(obstacles)) / 1024 / 1024
-
- print(f"q_batch size: {q_size_mb:.2f} MB")
- print(f"Obstacles (spheres) size: {spheres_size_mb:.2f} MB")
-
- # Warmup
- print(f"Warming up JIT ({name})...")
- _ = check_fn(q_batch, obstacles)
- # _.block_until_ready()
-
- # Run collision checking
- print(f"Executing collision checking ({name})...")
- start_time = time.perf_counter()
- dists = check_fn(q_batch, obstacles)
- # dists.block_until_ready()
- end_time = time.perf_counter()
-
- print(f"Time to compute: {end_time - start_time:.6f} seconds")
- print(f"Collision distances shape: {dists.shape}")
- print(f"Min distance: {jnp.min(dists)}")
-
- in_collision = dists < 0
- print(f"Number of collision pairs: {jnp.sum(in_collision)}")
-
- def main():
- global robot, robot_coll
- # Load robot
- urdf_path = "resources/ur5/ur5_spherized.urdf"
- mesh_dir = "resources/ur5/meshes"
- urdf = yourdfpy.URDF.load(urdf_path, mesh_dir=mesh_dir)
- robot = pk.Robot.from_urdf(urdf)
- joints, links = RobotURDFParser.parse(urdf)
-
- # Initialize collision model
- print("Initializing collision model...")
- robot_coll = RobotCollisionSpherized.from_urdf(urdf)
-
- # Create collision checker
- check_collisions = make_collision_checker(robot, robot_coll)
-
- # Generate data
- spheres_batch = generate_spheres(10000)
- q_batch = generate_configs(joints, 10000)
-
- # Run benchmarks
- run_benchmark("Default (float32)", check_collisions, q_batch, spheres_batch)
-
- # Quantized
- q_batch_f16 = quantize(q_batch)
- spheres_batch_f16 = quantize(spheres_batch)
- run_benchmark("Quantized (float16)", check_collisions, q_batch_f16, spheres_batch_f16)
-
- q_batch_int8 = quantize(q_batch, jax.numpy.int8)
- spheres_batch_int8 = quantize(spheres_batch, jax.numpy.int8)
- run_benchmark("Quantized (int8)", check_collisions, q_batch_int8, spheres_batch_int8)
-
- if __name__ == "__main__":
- main()
diff --git a/docs/source/examples/22_neural_sdf.rst b/docs/source/examples/22_neural_sdf.rst
deleted file mode 100644
index 8e6c9ff..0000000
--- a/docs/source/examples/22_neural_sdf.rst
+++ /dev/null
@@ -1,345 +0,0 @@
-.. Comment: this file is automatically generated by `update_example_docs.py`.
- It should not be modified manually.
-
-Quantize Collision Example
-==========================================
-
-
-Script to generate obstacles and robot configurations for quantization testing.
-
-All examples can be run by first cloning the PyRoNot repository, which includes the ``pyronot_snippets`` implementation details.
-
-
-
-.. code-block:: python
- :linenos:
-
-
- import jax
- import jax.numpy as jnp
- import numpy as np
- import time
- import pyronot as pk
- from pyronot.collision import RobotCollision, RobotCollisionSpherized, NeuralRobotCollision, Sphere
- from pyronot._robot_urdf_parser import RobotURDFParser
- import yourdfpy
- from pyronot.utils import quantize
-
- # Global configuration: Set to True to run positional encoding comparison
- RUN_POSITIONAL_ENCODING = False
-
- def generate_spheres(n_spheres):
- print(f"Generating {n_spheres} random spheres...")
- spheres = []
- for _ in range(n_spheres):
- center = np.random.uniform(low=-1.0, high=1.0, size=(3,))
- radius = np.random.uniform(low=0.05, high=0.2)
- sphere = Sphere.from_center_and_radius(center, np.array(radius))
- spheres.append(sphere)
-
- # Tree map them to create a batch of spheres
- spheres_batch = jax.tree.map(lambda *args: jnp.stack(args), *spheres)
- print(f"Generated {n_spheres} spheres.")
- return spheres_batch
-
- def generate_configs(joints, n_configs):
- print(f"Generating {n_configs} random robot configurations...")
- q_batch = np.random.uniform(
- low=joints.lower_limits,
- high=joints.upper_limits,
- size=(n_configs, joints.num_actuated_joints)
- )
- print(f"Generated {n_configs} robot configurations.")
- print(f"Configurations shape: {q_batch.shape}")
- return q_batch
-
- def make_collision_checker(robot, robot_coll):
- @jax.jit
- def check_collisions(q_batch, obstacles):
- # q_batch: (N_configs, dof)
- # obstacles: Sphere batch (N_spheres)
-
- # Define single config check
- def check_single(q, obs):
- return robot_coll.compute_world_collision_distance(robot, q, obs)
-
- # Vmap over configs
- # in_axes: q=(0), obs=(None) -> we want to check each q against ALL obs
- return jax.vmap(check_single, in_axes=(0, None))(q_batch, obstacles)
-
- return check_collisions
-
-
- def make_neural_collision_checker(robot, robot_coll, spheres_batch, use_positional_encoding=True, pe_min_deg=0, pe_max_deg=6):
- """Train a neural collision model on the given static world and return its checker.
-
- This will:
- - build a NeuralRobotCollision from the exact model
- - train it on random configs for the provided world (spheres_batch)
- - expose a vmap'ed collision function with the same signature as make_collision_checker's output
-
- Args:
- robot: The Robot instance
- robot_coll: The exact collision model (RobotCollisionSpherized)
- spheres_batch: Batch of sphere obstacles
- use_positional_encoding: If True, use iSDF-inspired positional encoding for
- better capture of fine geometric details (default True)
- pe_min_deg: Minimum frequency degree for positional encoding (default 0)
- pe_max_deg: Maximum frequency degree for positional encoding (default 6)
- """
-
- # Wrap the world geometry in the same structure RobotCollisionSpherized expects
- # RobotCollisionSpherized.from_urdf constructs a CollGeom internally when used in examples,
- # so `spheres_batch` is already a valid batch of Sphere geometry.
-
- # Create neural collision model from existing exact model
- # With positional encoding enabled for better accuracy near collision boundaries
- neural_coll = NeuralRobotCollision.from_existing(
- robot_coll,
- use_positional_encoding=use_positional_encoding,
- pe_min_deg=pe_min_deg,
- pe_max_deg=pe_max_deg,
- )
-
- pe_status = f"with PE (deg {pe_min_deg}-{pe_max_deg})" if use_positional_encoding else "without PE"
- print(f"Created neural collision model {pe_status}")
-
- # Train neural model on this specific world
- neural_coll = neural_coll.train(
- robot=robot,
- world_geom=spheres_batch,
- num_samples=10000,
- batch_size=1000, # Smaller batch = more gradient updates per epoch
- epochs=50, # More epochs for better convergence
- learning_rate=1e-3,
- )
-
- # Now build a collision checker that calls the neural model
- @jax.jit
- def check_collisions(q_batch, obstacles):
- # q_batch: (N_configs, dof)
- # obstacles: same spheres_batch used for training
-
- def check_single(q, obs):
- return neural_coll.compute_world_collision_distance(robot, q, obs)
-
- return jax.vmap(check_single, in_axes=(0, None))(q_batch, obstacles)
-
- return check_collisions
-
- def run_benchmark(name, check_fn, q_batch, obstacles):
- print(f"\n{name}:")
-
- # Metrics
- q_size_mb = q_batch.nbytes / 1024 / 1024
- spheres_size_mb = sum(x.nbytes for x in jax.tree_util.tree_leaves(obstacles)) / 1024 / 1024
-
- print(f"q_batch size: {q_size_mb:.2f} MB")
- print(f"Obstacles (spheres) size: {spheres_size_mb:.2f} MB")
-
- # Warmup
- print(f"Warming up JIT ({name})...")
- _ = check_fn(q_batch, obstacles)
-
- # Run collision checking
- print(f"Executing collision checking ({name})...")
- start_time = time.perf_counter()
- dists = check_fn(q_batch, obstacles)
- end_time = time.perf_counter()
- elapsed_time = end_time - start_time
-
- print(f"Time to compute: {elapsed_time:.6f} seconds")
- print(f"Collision distances shape: {dists.shape}")
- print(f"Min distance: {jnp.min(dists):.6f}")
- print(f"Max distance: {jnp.max(dists):.6f}")
- print(f"Mean distance: {jnp.mean(dists):.6f}")
- print(f"Std distance: {jnp.std(dists):.6f}")
-
- in_collision = dists < 0
- print(f"Number of collision pairs: {jnp.sum(in_collision)}")
-
- return dists, elapsed_time
-
-
- def compare_results(name, neural_dists, exact_dists):
- """Compare neural network predictions against exact results."""
- import gc
-
- print(f"\n {name} Comparison ")
-
-
- # Compute metrics in a memory-efficient way
- diff = neural_dists - exact_dists
- mae = float(jnp.mean(jnp.abs(diff)))
- max_ae = float(jnp.max(jnp.abs(diff)))
- rmse = float(jnp.sqrt(jnp.mean(diff ** 2)))
- bias = float(jnp.mean(diff))
- del diff # Free memory
- gc.collect()
-
- print(f"Mean absolute error: {mae:.6f}")
- print(f"Max absolute error: {max_ae:.6f}")
- print(f"RMSE: {rmse:.6f}")
- print(f"Mean error (bias): {bias:.6f}")
-
- # Check accuracy at collision boundary
- exact_in_collision = exact_dists < 0.05
- neural_in_collision = neural_dists < 0.05
-
- # Compute metrics and convert to Python ints immediately
- true_positives = int(jnp.sum(exact_in_collision & neural_in_collision))
- false_positives = int(jnp.sum(~exact_in_collision & neural_in_collision))
- false_negatives = int(jnp.sum(exact_in_collision & ~neural_in_collision))
- true_negatives = int(jnp.sum(~exact_in_collision & ~neural_in_collision))
-
- # Free the boolean arrays
- del exact_in_collision, neural_in_collision
- gc.collect()
-
- print(f"\nCollision Detection Accuracy (threshold=0.05):")
- print(f" True Positives: {true_positives}")
- print(f" False Positives: {false_positives}")
- print(f" False Negatives: {false_negatives}")
- print(f" True Negatives: {true_negatives}")
-
- precision = true_positives / (true_positives + false_positives + 1e-8)
- recall = true_positives / (true_positives + false_negatives + 1e-8)
- f1 = 2 * precision * recall / (precision + recall + 1e-8)
- print(f" Precision: {precision:.4f}")
- print(f" Recall: {recall:.4f}")
- print(f" F1 Score: {f1:.4f}")
-
- return {
- 'mae': mae,
- 'max_ae': max_ae,
- 'rmse': rmse,
- 'bias': bias,
- 'precision': precision,
- 'recall': recall,
- 'f1': f1,
- }
-
- def main():
- import gc
-
- # Load robot
- urdf_path = "resources/ur5/ur5_spherized.urdf"
- mesh_dir = "resources/ur5/meshes"
- urdf = yourdfpy.URDF.load(urdf_path, mesh_dir=mesh_dir)
- robot = pk.Robot.from_urdf(urdf)
- joints, links = RobotURDFParser.parse(urdf)
-
- # Initialize collision model
- print("Initializing collision model...")
- robot_coll = RobotCollisionSpherized.from_urdf(urdf)
-
- # Generate data (world is fixed for both exact and neural models)
- spheres_batch = generate_spheres(100)
- q_batch = generate_configs(joints, 50000)
-
- # Create collision checker using exact model
- exact_check_collisions = make_collision_checker(robot, robot_coll)
-
- # Train neural model without positional encoding (default)
- print("Training neural collision model WITHOUT positional encoding...")
- print("(Raw link poses as input)")
- neural_check_without_pe = make_neural_collision_checker(
- robot, robot_coll, spheres_batch,
- use_positional_encoding=False,
- )
-
- # Optionally train with positional encoding
- neural_check_with_pe = None
- if RUN_POSITIONAL_ENCODING:
- print("\n" + "="*70)
- print("Training neural collision model WITH positional encoding...")
- print("(iSDF-inspired: projects onto icosahedron directions with frequency bands)")
- print("="*70)
- neural_check_with_pe = make_neural_collision_checker(
- robot, robot_coll, spheres_batch,
- use_positional_encoding=True,
- pe_min_deg=0,
- pe_max_deg=6, # 7 frequency bands: 2^0, 2^1, ..., 2^6
- )
-
- # Run benchmarks
- print("\n" + "="*70)
- print("Running benchmarks...")
- print("="*70)
-
- exact_dists, exact_time = run_benchmark(
- "Exact (RobotCollisionSpherized)",
- exact_check_collisions, q_batch, spheres_batch
- )
-
- neural_without_pe_dists, neural_without_pe_time = run_benchmark(
- "Neural WITHOUT Positional Encoding",
- neural_check_without_pe, q_batch, spheres_batch
- )
-
- neural_with_pe_dists = None
- neural_with_pe_time = None
- if RUN_POSITIONAL_ENCODING and neural_check_with_pe is not None:
- neural_with_pe_dists, neural_with_pe_time = run_benchmark(
- "Neural WITH Positional Encoding",
- neural_check_with_pe, q_batch, spheres_batch
- )
-
- # Clear JAX caches and force garbage collection to free GPU memory
- print("\nClearing memory before comparison...")
- jax.clear_caches()
- gc.collect()
-
- # Compare results
- metrics_without_pe = compare_results(
- "Neural WITHOUT Positional Encoding vs Exact",
- neural_without_pe_dists, exact_dists
- )
-
- metrics_with_pe = None
- if RUN_POSITIONAL_ENCODING and neural_with_pe_dists is not None:
- metrics_with_pe = compare_results(
- "Neural WITH Positional Encoding vs Exact",
- neural_with_pe_dists, exact_dists
- )
-
- # Summary comparison (only if positional encoding was tested)
- if RUN_POSITIONAL_ENCODING and metrics_with_pe is not None:
- print("SUMMARY: Positional Encoding Impact")
- print(f"\n{'Metric':<25} {'With PE':<15} {'Without PE':<15} {'Improvement':<15}")
-
- for metric in ['mae', 'rmse', 'max_ae', 'precision', 'recall', 'f1']:
- with_pe = metrics_with_pe[metric]
- without_pe = metrics_without_pe[metric]
-
- # For error metrics, lower is better; for accuracy metrics, higher is better
- if metric in ['mae', 'rmse', 'max_ae', 'bias']:
- improvement = (without_pe - with_pe) / (without_pe + 1e-8) * 100
- better = "↓" if with_pe < without_pe else "↑"
- else:
- improvement = (with_pe - without_pe) / (without_pe + 1e-8) * 100
- better = "↑" if with_pe > without_pe else "↓"
-
- print(f"{metric:<25} {with_pe:<15.6f} {without_pe:<15.6f} {improvement:+.1f}% {better}")
-
- print(f"\n{'Inference Time (s)':<25} {neural_with_pe_time:<15.6f} {neural_without_pe_time:<15.6f}")
- print(f"{'Exact Time (s)':<25} {exact_time:<15.6f}")
- else:
- # Simple summary without PE comparison
- print("SUMMARY: Neural vs Exact")
- print(f"\n{'Metric':<25} {'Neural (no PE)':<15}")
- for metric in ['mae', 'rmse', 'max_ae', 'precision', 'recall', 'f1']:
- print(f"{metric:<25} {metrics_without_pe[metric]:<15.6f}")
- print(f"\n{'Neural Inference Time (s)':<25} {neural_without_pe_time:<15.6f}")
- print(f"{'Exact Time (s)':<25} {exact_time:<15.6f}")
- print(f"{'Speedup':<25} {exact_time / neural_without_pe_time:<15.2f}x")
-
- # Cleanup
- del exact_dists, neural_without_pe_dists
- if neural_with_pe_dists is not None:
- del neural_with_pe_dists
- gc.collect()
-
-
- if __name__ == "__main__":
- main()
diff --git a/docs/source/index.rst b/docs/source/index.rst
deleted file mode 100644
index 8710800..0000000
--- a/docs/source/index.rst
+++ /dev/null
@@ -1,104 +0,0 @@
-PyRoNot
-==========
-
-`Project page `_ `•` `arXiv `_ `•` `Code `_
-
-**PyRoNot** is a library for robot kinematic optimization (Python Robot Kinematics).
-
-1. **Modular**: Optimization variables and cost functions are decoupled, enabling reusable components across tasks. Objectives like collision avoidance and pose matching can be applied to both IK and trajectory optimization without reimplementation.
-
-2. **Extensible**: ``PyRoNot`` supports automatic differentiation for user-defined costs with Jacobian computation, a real-time cost-weight tuning interface, and optional analytical Jacobians for performance-critical use cases.
-
-3. **Cross-Platform**: ``PyRoNot`` runs on CPU, GPU, and TPU, allowing efficient scaling from single-robot use cases to large-scale parallel processing for motion datasets or planning.
-
-We demonstrate how ``PyRoNot`` solves IK, trajectory optimization, and motion retargeting for robot hands and humanoids in a unified framework. It uses a Levenberg-Marquardt optimizer to efficiently solve these tasks, and we evaluate its performance on batched IK.
-
-Features include:
-
-- Differentiable robot forward kinematics model from a URDF.
-- Automatic generation of robot collision primitives (e.g., capsules).
-- Differentiable collision bodies with numpy broadcasting logic.
-- Common cost factors (e.g., end effector pose, self/world-collision, manipulability).
-- Arbitrary costs, getting Jacobians either calculated :doc:`through autodiff or defined manually`.
-- Integration with a `Levenberg-Marquardt Solver `_ that supports optimization on manifolds (e.g., `lie groups `_).
-- Cross-platform support (CPU, GPU, TPU) via JAX.
-
-
-
-Installation
-------------
-
-You can install ``pyronot`` with ``pip``, on Python 3.12+:
-
-.. code-block:: bash
-
- git clone https://github.com/chungmin99/pyronot.git
- cd pyronot
- pip install -e .
-
-
-Python 3.10-3.11 should also work, but support may be dropped in the future.
-
-Limitations
------------
-
-- **Soft constraints only**: We use a nonlinear least-squares formulation and model joint limits, collision avoidance, etc. as soft penalties with high weights rather than hard constraints.
-- **Static shapes & JIT overhead**: JAX JIT compilation is triggered on first run and when input shapes change (e.g., number of targets, obstacles). Arrays can be pre-padded to vectorize over inputs with different shapes.
-- **No sampling-based planners**: We don't include sampling-based planners (e.g., graphs, trees).
-- **Collision performance**: Speed and accuracy comparisons against other robot toolkits such as CuRobo have not been extensively performed, and is likely slower than other toolkits for collision-heavy scenarios.
-
-The following are current implementation limitations that could potentially be addressed in future versions:
-
-- **Joint types**: We only support revolute, continuous, prismatic, and fixed joints. Other URDF joint types are treated as fixed joints.
-- **Collision geometry**: We are limited to sphere, capsule, halfspace, and heightmap geometries. Mesh collision is approximated as capsules.
-- **Kinematic structures**: We only support kinematic chains; no closed-loop mechanisms or parallel manipulators.
-
-Examples
---------
-
-.. toctree::
- :maxdepth: 1
- :caption: Examples
-
- examples/01_basic_ik
- examples/02_bimanual_ik
- examples/03_mobile_ik
- examples/04_ik_with_coll
- examples/05_ik_with_manipulability
- examples/06_online_planning
- examples/07_trajopt
- examples/08_ik_with_mimic_joints
- examples/09_hand_retargeting
- examples/10_humanoid_retargeting
- examples/11_hand_retargeting_fancy
- examples/12_humanoid_retargeting_fancy
-
-
-Acknowledgements
-----------------
-``PyRoNot`` is heavily inspired by the prior work, including but not limited to
-`Trac-IK `_,
-`cuRobo `_,
-`pink `_,
-`mink `_,
-`Drake `_, and
-`Dex-Retargeting `_.
-Thank you so much for your great work!
-
-
-Citation
---------
-
-If you find this work useful, please cite it as follows:
-
-.. code-block:: bibtex
-
- @inproceedings{kim2025pyronot,
- title={PyRoNot: A Modular Toolkit for Robot Kinematic Optimization},
- author={Kim*, Chung Min and Yi*, Brent and Choi, Hongsuk and Ma, Yi and Goldberg, Ken and Kanazawa, Angjoo},
- booktitle={2025 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
- year={2025},
- url={https://arxiv.org/abs/2505.03728},
- }
-
-Thanks for using ``PyRoNot``!
diff --git a/docs/source/misc/writing_manual_jac.rst b/docs/source/misc/writing_manual_jac.rst
deleted file mode 100644
index db2d8c6..0000000
--- a/docs/source/misc/writing_manual_jac.rst
+++ /dev/null
@@ -1,49 +0,0 @@
-:orphan:
-
-Defining Jacobians Manually
-=====================================
-
-``pyronot`` supports both autodiff and manually defined Jacobians for computing cost gradients.
-
-For reference, this is the robot pose matching cost :math:`C_\text{pose}`:
-
-.. math::
-
- \sum_{i} \left( w_{p,i} \left\| \mathbf{p}_{i}(q) - \mathbf{p}_{i}^* \right\|^2 + w_{R,i} \left\| \text{log}(\mathbf{R}_{i}(q)^{-1} \mathbf{R}_{i}^*) \right\|^2 \right)
-
-
-where :math:`q` is the robot joint configuration, :math:`\mathbf{p}_{i}(q)` is the position of the :math:`i`-th link, :math:`\mathbf{R}_{i}(q)` is the rotation matrix of the :math:`i`-th link, and :math:`w_{p,i}` and :math:`w_{R,i}` are the position and orientation weights, respectively.
-
-The following is the most common way to define costs in ``pyronot`` -- with autodiff:
-
-.. code-block:: python
-
- @Cost.create_factory
- def pose_cost(
- vals: VarValues,
- robot: Robot,
- joint_var: Var[Array],
- target_pose: jaxlie.SE3,
- target_link_index: Array,
- pos_weight: Array | float,
- ori_weight: Array | float,
- ) -> Array:
- """Computes the residual for matching link poses to target poses."""
- assert target_link_index.dtype == jnp.int32
- joint_cfg = vals[joint_var]
- Ts_link_world = robot.forward_kinematics(joint_cfg)
- pose_actual = jaxlie.SE3(Ts_link_world[..., target_link_index, :])
-
- # Position residual = position error * weight
- pos_residual = (pose_actual.translation() - target_pose.translation()) * pos_weight
- # Orientation residual = log(actual_inv * target) * weight
- ori_residual = (pose_actual.rotation().inverse() @ target_pose.rotation()).log() * ori_weight
-
- return jnp.concatenate([pos_residual, ori_residual]).flatten()
-
-The alternative is to manually write out the Jacobian -- while automatic differentiation is convenient and works well for most use cases, analytical Jacobians can provide better performance, which we show in the `paper `_.
-
-We provide two implementations of pose matching cost with custom Jacobians:
-
-- an `analytically derived Jacobian `_ (~200 lines), or
-- a `numerically approximated Jacobian `_ through finite differences (~50 lines).
diff --git a/docs/update_example_docs.py b/docs/update_example_docs.py
deleted file mode 100644
index f6e3ace..0000000
--- a/docs/update_example_docs.py
+++ /dev/null
@@ -1,104 +0,0 @@
-"""Helper script for updating the auto-generated examples pages in the documentation."""
-
-from __future__ import annotations
-
-import dataclasses
-import pathlib
-import shutil
-from typing import Iterable
-
-import m2r2
-import tyro
-
-
-@dataclasses.dataclass
-class ExampleMetadata:
- index: str
- index_with_zero: str
- source: str
- title: str
- description: str
-
- @staticmethod
- def from_path(path: pathlib.Path) -> ExampleMetadata:
- # 01_functions -> 01, _, functions.
- index, _, _ = path.stem.partition("_")
-
- # 01 -> 1.
- index_with_zero = index
- index = str(int(index))
-
- print("Parsing", path)
- source = path.read_text().strip()
- docstring = source.split('"""')[1].strip()
-
- title, _, description = docstring.partition("\n")
-
- description = description.strip()
- description += "\n"
- description += "\n"
- description += "All examples can be run by first cloning the PyRoNot repository, which includes the `pyronot_snippets` implementation details."
-
- return ExampleMetadata(
- index=index,
- index_with_zero=index_with_zero,
- source=source.partition('"""')[2].partition('"""')[2].strip(),
- title=title,
- description=description,
- )
-
-
-def get_example_paths(examples_dir: pathlib.Path) -> Iterable[pathlib.Path]:
- return filter(
- lambda p: not p.name.startswith("_"), sorted(examples_dir.glob("*.py"))
- )
-
-
-REPO_ROOT = pathlib.Path(__file__).absolute().parent.parent
-
-
-def main(
- examples_dir: pathlib.Path = REPO_ROOT / "examples",
- sphinx_source_dir: pathlib.Path = REPO_ROOT / "docs" / "source",
-) -> None:
- example_doc_dir = sphinx_source_dir / "examples"
- shutil.rmtree(example_doc_dir)
- example_doc_dir.mkdir()
-
- for path in get_example_paths(examples_dir):
- ex = ExampleMetadata.from_path(path)
-
- relative_dir = path.parent.relative_to(examples_dir)
- target_dir = example_doc_dir / relative_dir
- target_dir.mkdir(exist_ok=True, parents=True)
-
- (target_dir / f"{path.stem}.rst").write_text(
- "\n".join(
- [
- (
- ".. Comment: this file is automatically generated by"
- " `update_example_docs.py`."
- ),
- " It should not be modified manually.",
- "",
- f"{ex.title}",
- "==========================================",
- "",
- m2r2.convert(ex.description),
- "",
- "",
- ".. code-block:: python",
- " :linenos:",
- "",
- "",
- "\n".join(
- f" {line}".rstrip() for line in ex.source.split("\n")
- ),
- "",
- ]
- )
- )
-
-
-if __name__ == "__main__":
- tyro.cli(main, description=__doc__)
diff --git a/examples/13_spherized_robot_ik.py b/examples/13_spherized_robot_ik.py
index 1a47853..7667a9a 100644
--- a/examples/13_spherized_robot_ik.py
+++ b/examples/13_spherized_robot_ik.py
@@ -1,7 +1,4 @@
-"""Spherized Robot IK
-
-
-"""
+"""Spherized Robot IK"""
import time
@@ -11,10 +8,10 @@
from viser.extras import ViserUrdf
import pyronot_snippets as pks
-import yourdfpy
+import yourdfpy
-def main():
+def main():
# Load the spherized panda urdf do not work!!!
urdf_path = "resources/ur5/ur5_spherized.urdf"
mesh_dir = "resources/ur5/meshes"
@@ -66,5 +63,7 @@ def main():
# Update the collision mesh.
robot_coll_mesh = robot_coll.at_config(robot, solution).to_trimesh()
server.scene.add_mesh_trimesh("/robot/collision", mesh=robot_coll_mesh)
+
+
if __name__ == "__main__":
- main()
\ No newline at end of file
+ main()
diff --git a/examples/14_spherized_ik_with_coll.py b/examples/14_spherized_ik_with_coll.py
index 6a39d24..38ed6de 100644
--- a/examples/14_spherized_ik_with_coll.py
+++ b/examples/14_spherized_ik_with_coll.py
@@ -13,7 +13,7 @@
from viser.extras import ViserUrdf
import pyronot_snippets as pks
-import yourdfpy
+import yourdfpy
def main():
diff --git a/examples/15_spherized_ik_then_coll.py b/examples/15_spherized_ik_then_coll.py
index e3e7c4c..18400d1 100644
--- a/examples/15_spherized_ik_then_coll.py
+++ b/examples/15_spherized_ik_then_coll.py
@@ -13,7 +13,7 @@
from viser.extras import ViserUrdf
import pyronot_snippets as pks
-import yourdfpy
+import yourdfpy
def main():
@@ -55,15 +55,14 @@ def main():
just_ik_timing_handle = server.gui.add_number("just ik (ms)", 0.001, disabled=True)
coll_ik_timing_handle = server.gui.add_number("coll ik (ms)", 0.001, disabled=True)
while True:
-
sphere_coll_world_current = sphere_coll.transform_from_wxyz_position(
wxyz=np.array(sphere_handle.wxyz),
position=np.array(sphere_handle.position),
)
start_time = time.time()
just_ik = pks.solve_ik(
- robot=robot,
- target_link_name=target_link_name,
+ robot=robot,
+ target_link_name=target_link_name,
target_position=np.array(ik_target_handle.position),
target_wxyz=np.array(ik_target_handle.wxyz),
)
diff --git a/examples/16_spherized_online_planning.py b/examples/16_spherized_online_planning.py
index 6123eaf..ff0d65d 100644
--- a/examples/16_spherized_online_planning.py
+++ b/examples/16_spherized_online_planning.py
@@ -13,7 +13,7 @@
from viser.extras import ViserUrdf
import pyronot_snippets as pks
-import yourdfpy
+import yourdfpy
def main():
diff --git a/examples/17_geometries_example.py b/examples/17_geometries_example.py
index a95e5ac..dc7ca97 100644
--- a/examples/17_geometries_example.py
+++ b/examples/17_geometries_example.py
@@ -45,7 +45,8 @@ def main():
height_handle = server.gui.add_number("Height", height)
server.scene.add_mesh_trimesh(
- "/box/mesh", mesh=Box.from_center_and_dimensions(center, length, width, height).to_trimesh()
+ "/box/mesh",
+ mesh=Box.from_center_and_dimensions(center, length, width, height).to_trimesh(),
)
server.scene.add_mesh_trimesh("/box/polytope", mesh=trimesh.Trimesh())
@@ -53,24 +54,52 @@ def main():
while True:
pos = np.array(box_handle.position)
wxyz = np.array(box_handle.wxyz)
- length = float(length_handle.value) if hasattr(length_handle, "value") else float(length_handle)
- width = float(width_handle.value) if hasattr(width_handle, "value") else float(width_handle)
- height = float(height_handle.value) if hasattr(height_handle, "value") else float(height_handle)
-
- box = Box.from_center_and_dimensions(center=pos, length=length, width=width, height=height, wxyz=wxyz)
+ length = (
+ float(length_handle.value)
+ if hasattr(length_handle, "value")
+ else float(length_handle)
+ )
+ width = (
+ float(width_handle.value)
+ if hasattr(width_handle, "value")
+ else float(width_handle)
+ )
+ height = (
+ float(height_handle.value)
+ if hasattr(height_handle, "value")
+ else float(height_handle)
+ )
+
+ box = Box.from_center_and_dimensions(
+ center=pos, length=length, width=width, height=height, wxyz=wxyz
+ )
# Sphere
sph_pos = np.array(sphere_handle.position)
sph_wxyz = np.array(sphere_handle.wxyz)
- sph_rad = float(sphere_radius_handle.value) if hasattr(sphere_radius_handle, "value") else float(sphere_radius_handle)
+ sph_rad = (
+ float(sphere_radius_handle.value)
+ if hasattr(sphere_radius_handle, "value")
+ else float(sphere_radius_handle)
+ )
sphere = Sphere.from_center_and_radius(center=sph_pos, radius=sph_rad)
# Capsule
cap_pos = np.array(cap_handle.position)
cap_wxyz = np.array(cap_handle.wxyz)
- cap_rad = float(cap_radius_handle.value) if hasattr(cap_radius_handle, "value") else float(cap_radius_handle)
- cap_h = float(cap_height_handle.value) if hasattr(cap_height_handle, "value") else float(cap_height_handle)
- capsule = Capsule.from_radius_height(radius=cap_rad, height=cap_h, position=cap_pos, wxyz=cap_wxyz)
+ cap_rad = (
+ float(cap_radius_handle.value)
+ if hasattr(cap_radius_handle, "value")
+ else float(cap_radius_handle)
+ )
+ cap_h = (
+ float(cap_height_handle.value)
+ if hasattr(cap_height_handle, "value")
+ else float(cap_height_handle)
+ )
+ capsule = Capsule.from_radius_height(
+ radius=cap_rad, height=cap_h, position=cap_pos, wxyz=cap_wxyz
+ )
server.scene.add_mesh_trimesh("/box/mesh", mesh=box.to_trimesh())
server.scene.add_mesh_trimesh("/sphere/mesh", mesh=sphere.to_trimesh())
@@ -91,7 +120,9 @@ def main():
# d is a jax Array; convert to python float if scalar
d_val = float(d)
if d_val < 0.0:
- print(f"Collision detected {name1} vs {name2}: distance={d_val:.6f}")
+ print(
+ f"Collision detected {name1} vs {name2}: distance={d_val:.6f}"
+ )
except Exception as e:
print(f"Error computing collision {name1} vs {name2}: {e}")
diff --git a/examples/18_spherized_trajopt.py b/examples/18_spherized_trajopt.py
index 40917a9..970206f 100644
--- a/examples/18_spherized_trajopt.py
+++ b/examples/18_spherized_trajopt.py
@@ -17,7 +17,8 @@
from robot_descriptions.loaders.yourdfpy import load_robot_description
import pyronot_snippets as pks
-import yourdfpy
+import yourdfpy
+
def main(robot_name: Literal["ur5", "panda"] = "panda"):
if robot_name == "ur5":
diff --git a/examples/19_spherized_ik_with_coll_exclude_link.py b/examples/19_spherized_ik_with_coll_exclude_link.py
index da822b2..10ea766 100644
--- a/examples/19_spherized_ik_with_coll_exclude_link.py
+++ b/examples/19_spherized_ik_with_coll_exclude_link.py
@@ -13,8 +13,9 @@
from viser.extras import ViserUrdf
import pyronot_snippets as pks
-import yourdfpy
+import yourdfpy
+import jax.numpy as jnp
def main():
"""Main function for basic IK with collision."""
@@ -26,7 +27,9 @@ def main():
# mesh_dir = "resources/panda/meshes"
# target_link_name = "panda_hand"
urdf = yourdfpy.URDF.load(urdf_path, mesh_dir=mesh_dir)
- robot = pk.Robot.from_urdf(urdf, default_joint_cfg=[0, -1.57, 1.57, -1.57, -1.57, 0])
+ robot = pk.Robot.from_urdf(
+ urdf, default_joint_cfg=[0, -1.57, 1.57, -1.57, -1.57, 0]
+ )
robot_coll = RobotCollisionSpherized.from_urdf(urdf)
plane_coll = HalfSpace.from_point_and_normal(
@@ -80,27 +83,34 @@ def main():
# Update visualizer.
urdf_vis.update_cfg(solution)
# print(robot.links.names)
- # Compute the collision of the solution
+ # Compute the collision of the solution
+ # Method 1:
distance_link_to_plane = robot_coll.compute_world_collision_distance(
- robot,
- solution,
- plane_coll
+ robot, solution, plane_coll
+ )
+ distance_link_to_plane = RobotCollisionSpherized.mask_collision_distance(
+ distance_link_to_plane, exclude_link_mask
+ )
+ # Method 2:
+ distance_link_to_plane_2 = robot_coll.compute_world_collision_distance_with_exclude_links(
+ robot, solution, plane_coll, exclude_link_mask
)
- distance_link_to_plane = RobotCollisionSpherized.mask_collision_distance(distance_link_to_plane, exclude_link_mask)
- # print(distance_link_to_plane)
+ assert jnp.allclose(distance_link_to_plane, distance_link_to_plane_2)
distance_link_to_sphere = robot_coll.compute_world_collision_distance(
- robot,
- solution,
- sphere_coll
+ robot, solution, sphere_coll
)
- distance_link_to_sphere = RobotCollisionSpherized.mask_collision_distance(distance_link_to_sphere, exclude_link_mask)
- # print(distance_link_to_sphere)
- # Visualize collision representation
+ distance_link_to_sphere = RobotCollisionSpherized.mask_collision_distance(
+ distance_link_to_sphere, exclude_link_mask
+ )
+ distance_link_to_sphere_2 = robot_coll.compute_world_collision_distance_with_exclude_links(
+ robot, solution, sphere_coll, exclude_link_mask
+ )
+ assert jnp.allclose(distance_link_to_sphere, distance_link_to_sphere_2)
+
robot_coll_config: Sphere = robot_coll.at_config(robot, solution)
- # print(robot_coll_config.get_batch_axes()[-1])
robot_coll_mesh = robot_coll_config.to_trimesh()
server.scene.add_mesh_trimesh(
- "/robot_coll",
+ "/robot_coll",
mesh=robot_coll_mesh,
wxyz=(1.0, 0.0, 0.0, 0.0),
position=(0.0, 0.0, 0.0),
diff --git a/examples/21_quantize_collision.py b/examples/21_quantize_collision.py
index c12de17..c06b763 100644
--- a/examples/21_quantize_collision.py
+++ b/examples/21_quantize_collision.py
@@ -13,6 +13,7 @@
import yourdfpy
from pyronot.utils import quantize
+
def generate_spheres(n_spheres):
print(f"Generating {n_spheres} random spheres...")
spheres = []
@@ -21,46 +22,51 @@ def generate_spheres(n_spheres):
radius = np.random.uniform(low=0.05, high=0.2)
sphere = Sphere.from_center_and_radius(center, np.array([radius]))
spheres.append(sphere)
-
+
# Tree map them to create a batch of spheres
spheres_batch = jax.tree.map(lambda *args: jnp.stack(args), *spheres)
print(f"Generated {n_spheres} spheres.")
return spheres_batch
+
def generate_configs(joints, n_configs):
print(f"Generating {n_configs} random robot configurations...")
q_batch = np.random.uniform(
- low=joints.lower_limits,
- high=joints.upper_limits,
- size=(n_configs, joints.num_actuated_joints)
+ low=joints.lower_limits,
+ high=joints.upper_limits,
+ size=(n_configs, joints.num_actuated_joints),
)
print(f"Generated {n_configs} robot configurations.")
print(f"Configurations shape: {q_batch.shape}")
return q_batch
+
def make_collision_checker(robot, robot_coll):
@jax.jit
def check_collisions(q_batch, obstacles):
# q_batch: (N_configs, dof)
# obstacles: Sphere batch (N_spheres)
-
+
# Define single config check
def check_single(q, obs):
return robot_coll.compute_world_collision_distance(robot, q, obs)
-
+
# Vmap over configs
# in_axes: q=(0), obs=(None) -> we want to check each q against ALL obs
return jax.vmap(check_single, in_axes=(0, None))(q_batch, obstacles)
-
+
return check_collisions
+
def run_benchmark(name, check_fn, q_batch, obstacles):
print(f"\n{name}:")
-
+
# Metrics
q_size_mb = q_batch.nbytes / 1024 / 1024
- spheres_size_mb = sum(x.nbytes for x in jax.tree_util.tree_leaves(obstacles)) / 1024 / 1024
-
+ spheres_size_mb = (
+ sum(x.nbytes for x in jax.tree_util.tree_leaves(obstacles)) / 1024 / 1024
+ )
+
print(f"q_batch size: {q_size_mb:.2f} MB")
print(f"Obstacles (spheres) size: {spheres_size_mb:.2f} MB")
@@ -79,10 +85,11 @@ def run_benchmark(name, check_fn, q_batch, obstacles):
print(f"Time to compute: {end_time - start_time:.6f} seconds")
print(f"Collision distances shape: {dists.shape}")
print(f"Min distance: {jnp.min(dists)}")
-
+
in_collision = dists < 0
print(f"Number of collision pairs: {jnp.sum(in_collision)}")
+
def main():
global robot, robot_coll
# Load robot
@@ -91,11 +98,11 @@ def main():
urdf = yourdfpy.URDF.load(urdf_path, mesh_dir=mesh_dir)
robot = pk.Robot.from_urdf(urdf)
joints, links = RobotURDFParser.parse(urdf)
-
+
# Initialize collision model
print("Initializing collision model...")
robot_coll = RobotCollisionSpherized.from_urdf(urdf)
-
+
# Create collision checker
check_collisions = make_collision_checker(robot, robot_coll)
@@ -105,15 +112,20 @@ def main():
# Run benchmarks
run_benchmark("Default (float32)", check_collisions, q_batch, spheres_batch)
-
+
# Quantized
q_batch_f16 = quantize(q_batch)
spheres_batch_f16 = quantize(spheres_batch)
- run_benchmark("Quantized (float16)", check_collisions, q_batch_f16, spheres_batch_f16)
+ run_benchmark(
+ "Quantized (float16)", check_collisions, q_batch_f16, spheres_batch_f16
+ )
q_batch_int8 = quantize(q_batch, jax.numpy.int8)
spheres_batch_int8 = quantize(spheres_batch, jax.numpy.int8)
- run_benchmark("Quantized (int8)", check_collisions, q_batch_int8, spheres_batch_int8)
+ run_benchmark(
+ "Quantized (int8)", check_collisions, q_batch_int8, spheres_batch_int8
+ )
+
if __name__ == "__main__":
main()
diff --git a/examples/22_neural_sdf.py b/examples/22_neural_sdf.py
index f62ff54..6db1ecc 100644
--- a/examples/22_neural_sdf.py
+++ b/examples/22_neural_sdf.py
@@ -8,7 +8,12 @@
import numpy as np
import time
import pyronot as pk
-from pyronot.collision import RobotCollision, RobotCollisionSpherized, NeuralRobotCollision, Sphere
+from pyronot.collision import (
+ RobotCollision,
+ RobotCollisionSpherized,
+ NeuralRobotCollision,
+ Sphere,
+)
from pyronot._robot_urdf_parser import RobotURDFParser
import yourdfpy
from pyronot.utils import quantize
@@ -16,6 +21,7 @@
# Global configuration: Set to True to run positional encoding comparison
RUN_POSITIONAL_ENCODING = False
+
def generate_spheres(n_spheres):
print(f"Generating {n_spheres} random spheres...")
spheres = []
@@ -24,48 +30,57 @@ def generate_spheres(n_spheres):
radius = np.random.uniform(low=0.05, high=0.2)
sphere = Sphere.from_center_and_radius(center, np.array(radius))
spheres.append(sphere)
-
+
# Tree map them to create a batch of spheres
spheres_batch = jax.tree.map(lambda *args: jnp.stack(args), *spheres)
print(f"Generated {n_spheres} spheres.")
return spheres_batch
+
def generate_configs(joints, n_configs):
print(f"Generating {n_configs} random robot configurations...")
q_batch = np.random.uniform(
- low=joints.lower_limits,
- high=joints.upper_limits,
- size=(n_configs, joints.num_actuated_joints)
+ low=joints.lower_limits,
+ high=joints.upper_limits,
+ size=(n_configs, joints.num_actuated_joints),
)
print(f"Generated {n_configs} robot configurations.")
print(f"Configurations shape: {q_batch.shape}")
return q_batch
+
def make_collision_checker(robot, robot_coll):
@jax.jit
def check_collisions(q_batch, obstacles):
# q_batch: (N_configs, dof)
# obstacles: Sphere batch (N_spheres)
-
+
# Define single config check
def check_single(q, obs):
return robot_coll.compute_world_collision_distance(robot, q, obs)
-
+
# Vmap over configs
# in_axes: q=(0), obs=(None) -> we want to check each q against ALL obs
return jax.vmap(check_single, in_axes=(0, None))(q_batch, obstacles)
-
+
return check_collisions
-def make_neural_collision_checker(robot, robot_coll, spheres_batch, use_positional_encoding=True, pe_min_deg=0, pe_max_deg=6):
+def make_neural_collision_checker(
+ robot,
+ robot_coll,
+ spheres_batch,
+ use_positional_encoding=True,
+ pe_min_deg=0,
+ pe_max_deg=6,
+):
"""Train a neural collision model on the given static world and return its checker.
This will:
- build a NeuralRobotCollision from the exact model
- train it on random configs for the provided world (spheres_batch)
- expose a vmap'ed collision function with the same signature as make_collision_checker's output
-
+
Args:
robot: The Robot instance
robot_coll: The exact collision model (RobotCollisionSpherized)
@@ -88,8 +103,12 @@ def make_neural_collision_checker(robot, robot_coll, spheres_batch, use_position
pe_min_deg=pe_min_deg,
pe_max_deg=pe_max_deg,
)
-
- pe_status = f"with PE (deg {pe_min_deg}-{pe_max_deg})" if use_positional_encoding else "without PE"
+
+ pe_status = (
+ f"with PE (deg {pe_min_deg}-{pe_max_deg})"
+ if use_positional_encoding
+ else "without PE"
+ )
print(f"Created neural collision model {pe_status}")
# Train neural model on this specific world
@@ -98,7 +117,7 @@ def make_neural_collision_checker(robot, robot_coll, spheres_batch, use_position
world_geom=spheres_batch,
num_samples=10000,
batch_size=1000, # Smaller batch = more gradient updates per epoch
- epochs=50, # More epochs for better convergence
+ epochs=50, # More epochs for better convergence
learning_rate=1e-3,
)
@@ -115,13 +134,16 @@ def check_single(q, obs):
return check_collisions
+
def run_benchmark(name, check_fn, q_batch, obstacles):
print(f"\n{name}:")
-
+
# Metrics
q_size_mb = q_batch.nbytes / 1024 / 1024
- spheres_size_mb = sum(x.nbytes for x in jax.tree_util.tree_leaves(obstacles)) / 1024 / 1024
-
+ spheres_size_mb = (
+ sum(x.nbytes for x in jax.tree_util.tree_leaves(obstacles)) / 1024 / 1024
+ )
+
print(f"q_batch size: {q_size_mb:.2f} MB")
print(f"Obstacles (spheres) size: {spheres_size_mb:.2f} MB")
@@ -142,81 +164,81 @@ def run_benchmark(name, check_fn, q_batch, obstacles):
print(f"Max distance: {jnp.max(dists):.6f}")
print(f"Mean distance: {jnp.mean(dists):.6f}")
print(f"Std distance: {jnp.std(dists):.6f}")
-
+
in_collision = dists < 0
print(f"Number of collision pairs: {jnp.sum(in_collision)}")
-
+
return dists, elapsed_time
def compare_results(name, neural_dists, exact_dists):
"""Compare neural network predictions against exact results."""
import gc
-
+
print(f"\n {name} Comparison ")
-
# Compute metrics in a memory-efficient way
diff = neural_dists - exact_dists
mae = float(jnp.mean(jnp.abs(diff)))
max_ae = float(jnp.max(jnp.abs(diff)))
- rmse = float(jnp.sqrt(jnp.mean(diff ** 2)))
+ rmse = float(jnp.sqrt(jnp.mean(diff**2)))
bias = float(jnp.mean(diff))
del diff # Free memory
gc.collect()
-
+
print(f"Mean absolute error: {mae:.6f}")
print(f"Max absolute error: {max_ae:.6f}")
print(f"RMSE: {rmse:.6f}")
print(f"Mean error (bias): {bias:.6f}")
-
+
# Check accuracy at collision boundary
exact_in_collision = exact_dists < 0.05
neural_in_collision = neural_dists < 0.05
-
+
# Compute metrics and convert to Python ints immediately
true_positives = int(jnp.sum(exact_in_collision & neural_in_collision))
false_positives = int(jnp.sum(~exact_in_collision & neural_in_collision))
false_negatives = int(jnp.sum(exact_in_collision & ~neural_in_collision))
true_negatives = int(jnp.sum(~exact_in_collision & ~neural_in_collision))
-
+
# Free the boolean arrays
del exact_in_collision, neural_in_collision
gc.collect()
-
+
print(f"\nCollision Detection Accuracy (threshold=0.05):")
print(f" True Positives: {true_positives}")
print(f" False Positives: {false_positives}")
print(f" False Negatives: {false_negatives}")
print(f" True Negatives: {true_negatives}")
-
+
precision = true_positives / (true_positives + false_positives + 1e-8)
recall = true_positives / (true_positives + false_negatives + 1e-8)
f1 = 2 * precision * recall / (precision + recall + 1e-8)
print(f" Precision: {precision:.4f}")
print(f" Recall: {recall:.4f}")
print(f" F1 Score: {f1:.4f}")
-
+
return {
- 'mae': mae,
- 'max_ae': max_ae,
- 'rmse': rmse,
- 'bias': bias,
- 'precision': precision,
- 'recall': recall,
- 'f1': f1,
+ "mae": mae,
+ "max_ae": max_ae,
+ "rmse": rmse,
+ "bias": bias,
+ "precision": precision,
+ "recall": recall,
+ "f1": f1,
}
+
def main():
import gc
-
+
# Load robot
urdf_path = "resources/ur5/ur5_spherized.urdf"
mesh_dir = "resources/ur5/meshes"
urdf = yourdfpy.URDF.load(urdf_path, mesh_dir=mesh_dir)
robot = pk.Robot.from_urdf(urdf)
joints, links = RobotURDFParser.parse(urdf)
-
+
# Initialize collision model
print("Initializing collision model...")
robot_coll = RobotCollisionSpherized.from_urdf(urdf)
@@ -232,96 +254,116 @@ def main():
print("Training neural collision model WITHOUT positional encoding...")
print("(Raw link poses as input)")
neural_check_without_pe = make_neural_collision_checker(
- robot, robot_coll, spheres_batch,
+ robot,
+ robot_coll,
+ spheres_batch,
use_positional_encoding=False,
)
# Optionally train with positional encoding
neural_check_with_pe = None
if RUN_POSITIONAL_ENCODING:
- print("\n" + "="*70)
+ print("\n" + "=" * 70)
print("Training neural collision model WITH positional encoding...")
- print("(iSDF-inspired: projects onto icosahedron directions with frequency bands)")
- print("="*70)
+ print(
+ "(iSDF-inspired: projects onto icosahedron directions with frequency bands)"
+ )
+ print("=" * 70)
neural_check_with_pe = make_neural_collision_checker(
- robot, robot_coll, spheres_batch,
+ robot,
+ robot_coll,
+ spheres_batch,
use_positional_encoding=True,
pe_min_deg=0,
pe_max_deg=6, # 7 frequency bands: 2^0, 2^1, ..., 2^6
)
# Run benchmarks
- print("\n" + "="*70)
+ print("\n" + "=" * 70)
print("Running benchmarks...")
- print("="*70)
-
+ print("=" * 70)
+
exact_dists, exact_time = run_benchmark(
- "Exact (RobotCollisionSpherized)",
- exact_check_collisions, q_batch, spheres_batch
+ "Exact (RobotCollisionSpherized)",
+ exact_check_collisions,
+ q_batch,
+ spheres_batch,
)
-
+
neural_without_pe_dists, neural_without_pe_time = run_benchmark(
- "Neural WITHOUT Positional Encoding",
- neural_check_without_pe, q_batch, spheres_batch
+ "Neural WITHOUT Positional Encoding",
+ neural_check_without_pe,
+ q_batch,
+ spheres_batch,
)
-
+
neural_with_pe_dists = None
neural_with_pe_time = None
if RUN_POSITIONAL_ENCODING and neural_check_with_pe is not None:
neural_with_pe_dists, neural_with_pe_time = run_benchmark(
- "Neural WITH Positional Encoding",
- neural_check_with_pe, q_batch, spheres_batch
+ "Neural WITH Positional Encoding",
+ neural_check_with_pe,
+ q_batch,
+ spheres_batch,
)
-
+
# Clear JAX caches and force garbage collection to free GPU memory
print("\nClearing memory before comparison...")
jax.clear_caches()
gc.collect()
-
+
# Compare results
metrics_without_pe = compare_results(
"Neural WITHOUT Positional Encoding vs Exact",
- neural_without_pe_dists, exact_dists
+ neural_without_pe_dists,
+ exact_dists,
)
-
+
metrics_with_pe = None
if RUN_POSITIONAL_ENCODING and neural_with_pe_dists is not None:
metrics_with_pe = compare_results(
"Neural WITH Positional Encoding vs Exact",
- neural_with_pe_dists, exact_dists
+ neural_with_pe_dists,
+ exact_dists,
)
-
+
# Summary comparison (only if positional encoding was tested)
if RUN_POSITIONAL_ENCODING and metrics_with_pe is not None:
print("SUMMARY: Positional Encoding Impact")
- print(f"\n{'Metric':<25} {'With PE':<15} {'Without PE':<15} {'Improvement':<15}")
-
- for metric in ['mae', 'rmse', 'max_ae', 'precision', 'recall', 'f1']:
+ print(
+ f"\n{'Metric':<25} {'With PE':<15} {'Without PE':<15} {'Improvement':<15}"
+ )
+
+ for metric in ["mae", "rmse", "max_ae", "precision", "recall", "f1"]:
with_pe = metrics_with_pe[metric]
without_pe = metrics_without_pe[metric]
-
+
# For error metrics, lower is better; for accuracy metrics, higher is better
- if metric in ['mae', 'rmse', 'max_ae', 'bias']:
+ if metric in ["mae", "rmse", "max_ae", "bias"]:
improvement = (without_pe - with_pe) / (without_pe + 1e-8) * 100
better = "↓" if with_pe < without_pe else "↑"
else:
improvement = (with_pe - without_pe) / (without_pe + 1e-8) * 100
better = "↑" if with_pe > without_pe else "↓"
-
- print(f"{metric:<25} {with_pe:<15.6f} {without_pe:<15.6f} {improvement:+.1f}% {better}")
-
- print(f"\n{'Inference Time (s)':<25} {neural_with_pe_time:<15.6f} {neural_without_pe_time:<15.6f}")
+
+ print(
+ f"{metric:<25} {with_pe:<15.6f} {without_pe:<15.6f} {improvement:+.1f}% {better}"
+ )
+
+ print(
+ f"\n{'Inference Time (s)':<25} {neural_with_pe_time:<15.6f} {neural_without_pe_time:<15.6f}"
+ )
print(f"{'Exact Time (s)':<25} {exact_time:<15.6f}")
else:
# Simple summary without PE comparison
print("SUMMARY: Neural vs Exact")
print(f"\n{'Metric':<25} {'Neural (no PE)':<15}")
- for metric in ['mae', 'rmse', 'max_ae', 'precision', 'recall', 'f1']:
+ for metric in ["mae", "rmse", "max_ae", "precision", "recall", "f1"]:
print(f"{metric:<25} {metrics_without_pe[metric]:<15.6f}")
print(f"\n{'Neural Inference Time (s)':<25} {neural_without_pe_time:<15.6f}")
print(f"{'Exact Time (s)':<25} {exact_time:<15.6f}")
print(f"{'Speedup':<25} {exact_time / neural_without_pe_time:<15.2f}x")
-
+
# Cleanup
del exact_dists, neural_without_pe_dists
if neural_with_pe_dists is not None:
diff --git a/examples/pyronot_snippets/__init__.py b/examples/pyronot_snippets/__init__.py
index f9a774b..f535768 100644
--- a/examples/pyronot_snippets/__init__.py
+++ b/examples/pyronot_snippets/__init__.py
@@ -9,4 +9,6 @@
from ._solve_ik_with_multiple_targets import (
solve_ik_with_multiple_targets as solve_ik_with_multiple_targets,
)
-from ._solve_collision_with_config import solve_collision_with_config as solve_collision_with_config
\ No newline at end of file
+from ._solve_collision_with_config import (
+ solve_collision_with_config as solve_collision_with_config,
+)
diff --git a/examples/pyronot_snippets/_solve_collision_with_config.py b/examples/pyronot_snippets/_solve_collision_with_config.py
index 7a19797..78030ab 100644
--- a/examples/pyronot_snippets/_solve_collision_with_config.py
+++ b/examples/pyronot_snippets/_solve_collision_with_config.py
@@ -12,6 +12,7 @@
import numpy as onp
import pyronot as pk
+
def solve_collision_with_config(
robot: pk.Robot,
coll: pk.collision.RobotCollision,
diff --git a/src/pyronot/_robot_srdf_parser.py b/src/pyronot/_robot_srdf_parser.py
index 52ef5b2..6e3e237 100644
--- a/src/pyronot/_robot_srdf_parser.py
+++ b/src/pyronot/_robot_srdf_parser.py
@@ -1,51 +1,50 @@
from xml.etree import ElementTree as ET
+
def read_disabled_collisions_from_srdf(srdf_path):
"""Read disabled collision pairs from SRDF file."""
-
+
tree = ET.parse(srdf_path)
root = tree.getroot()
-
+
disabled_pairs = []
-
- for disable_elem in root.findall('disable_collisions'):
- link1 = disable_elem.get('link1')
- link2 = disable_elem.get('link2')
- reason = disable_elem.get('reason', 'Unknown')
-
- disabled_pairs.append({
- 'link1': link1,
- 'link2': link2,
- 'reason': reason
- })
-
+
+ for disable_elem in root.findall("disable_collisions"):
+ link1 = disable_elem.get("link1")
+ link2 = disable_elem.get("link2")
+ reason = disable_elem.get("reason", "Unknown")
+
+ disabled_pairs.append({"link1": link1, "link2": link2, "reason": reason})
+
return disabled_pairs
+
def read_group_states_from_srdf(srdf_path):
"""Read named configurations from SRDF file."""
-
+
tree = ET.parse(srdf_path)
root = tree.getroot()
-
+
group_states = {}
-
- for state_elem in root.findall('group_state'):
- group_name = state_elem.get('group')
- state_name = state_elem.get('name')
-
+
+ for state_elem in root.findall("group_state"):
+ group_name = state_elem.get("group")
+ state_name = state_elem.get("name")
+
joints = {}
- for joint_elem in state_elem.findall('joint'):
- joint_name = joint_elem.get('name')
- joint_value = float(joint_elem.get('value'))
+ for joint_elem in state_elem.findall("joint"):
+ joint_name = joint_elem.get("name")
+ joint_value = float(joint_elem.get("value"))
joints[joint_name] = joint_value
-
+
if group_name not in group_states:
group_states[group_name] = {}
-
+
group_states[group_name][state_name] = joints
-
+
return group_states
+
# Usage
if __name__ == "__main__":
srdf_path = "resources/ur5/ur5.srdf"
@@ -62,4 +61,4 @@ def read_group_states_from_srdf(srdf_path):
for group_name, states in group_states.items():
print(f"\nGroup: {group_name}")
for state_name, joints in states.items():
- print(f" State '{state_name}': {joints}")
\ No newline at end of file
+ print(f" State '{state_name}': {joints}")
diff --git a/src/pyronot/_robot_urdf_parser.py b/src/pyronot/_robot_urdf_parser.py
index 49247b3..3ab451f 100644
--- a/src/pyronot/_robot_urdf_parser.py
+++ b/src/pyronot/_robot_urdf_parser.py
@@ -131,15 +131,23 @@ def get_link_indices(self, link_names: tuple[str, ...]) -> Int[Array, " n_matche
"""Get the indices of links by names."""
matches = jnp.array([name in link_names for name in self.names])
return jnp.where(matches)[0]
- def get_link_mask_from_indices(self, link_indices: Int[Array, " n_matches"]) -> Int[Array, " num_links"]:
+
+ def get_link_mask_from_indices(
+ self, link_indices: Int[Array, " n_matches"]
+ ) -> Int[Array, " num_links"]:
"""Get a mask of links by indices."""
mask = jnp.zeros(self.num_links, dtype=bool)
mask = mask.at[link_indices].set(True)
return mask
- def get_link_mask_from_names(self, link_names: tuple[str, ...]) -> Int[Array, " num_links"]:
+
+ def get_link_mask_from_names(
+ self, link_names: tuple[str, ...]
+ ) -> Int[Array, " num_links"]:
"""Get a mask of links by names."""
link_indices = self.get_link_indices(link_names)
return self.get_link_mask_from_indices(link_indices)
+
+
class RobotURDFParser:
"""Parser for creating Robot instances from URDF files."""
diff --git a/src/pyronot/collision/__init__.py b/src/pyronot/collision/__init__.py
index 5a465b3..9c83bba 100644
--- a/src/pyronot/collision/__init__.py
+++ b/src/pyronot/collision/__init__.py
@@ -10,4 +10,6 @@
from ._robot_collision import RobotCollision as RobotCollision
from ._robot_collision import RobotCollisionSpherized as RobotCollisionSpherized
from ._neural_collision import NeuralRobotCollision as NeuralRobotCollision
-from ._neural_collision import NeuralRobotCollisionSpherized as NeuralRobotCollisionSpherized
\ No newline at end of file
+from ._neural_collision import (
+ NeuralRobotCollisionSpherized as NeuralRobotCollisionSpherized,
+)
diff --git a/src/pyronot/collision/_geometry.py b/src/pyronot/collision/_geometry.py
index c1550ca..46595b4 100644
--- a/src/pyronot/collision/_geometry.py
+++ b/src/pyronot/collision/_geometry.py
@@ -57,22 +57,23 @@ def reshape(self, shape: tuple[int, ...]) -> Self:
def __getitem__(self, index) -> Self:
"""Get a subset of the geometry by indexing into the batch dimensions.
-
+
Args:
index: Index or slice to apply to the batch dimensions
-
+
Returns:
New CollGeom object with indexed batch dimensions
-
+
Example:
>>> sphere[..., 0] # Get first element of last batch dimension
>>> sphere[0:2] # Get first two elements of first batch dimension
>>> sphere[..., exclude_indices] # Get elements at specific indices
"""
return jax.tree.map(
- lambda x: x[index] if hasattr(x, '__getitem__') else x,
+ lambda x: x[index] if hasattr(x, "__getitem__") else x,
self,
)
+
def transform(self, transform: jaxlie.SE3) -> Self:
"""Left-multiples geometry's pose with an SE(3) transformation."""
with jdc.copy_and_mutate(self) as out:
@@ -211,8 +212,8 @@ def from_trimesh(mesh: trimesh.Trimesh) -> Sphere:
Returns:
Sphere: A Sphere geometry fit to the mesh.
-
- Author:
+
+ Author:
S
"""
if mesh.is_empty:
@@ -223,7 +224,10 @@ def from_trimesh(mesh: trimesh.Trimesh) -> Sphere:
# Compute the bounding sphere (center and radius)
try:
- center_np, radius_val = mesh.bounding_sphere.center, mesh.bounding_sphere.primitive.radius
+ center_np, radius_val = (
+ mesh.bounding_sphere.center,
+ mesh.bounding_sphere.primitive.radius,
+ )
center = jnp.array(center_np, dtype=jnp.float32)
radius = jnp.array(radius_val, dtype=jnp.float32)
except Exception:
@@ -235,9 +239,11 @@ def from_trimesh(mesh: trimesh.Trimesh) -> Sphere:
return Sphere.from_center_and_radius(center=center, radius=radius)
+
@jdc.pytree_dataclass
class Box(CollGeom):
"""Box (Rectangular Prism) geometry."""
+
@property
def half_lengths(self) -> Float[Array, "*batch 3"]:
"""Half-lengths along local X, Y, Z (size stores three values)."""
@@ -261,7 +267,13 @@ def from_center_and_half_lengths(
"""
half_lengths = jnp.array(half_lengths)
lengths = 2.0 * half_lengths
- return Box.from_center_and_dimensions(center=center, length=lengths[..., 0], width=lengths[..., 1], height=lengths[..., 2], wxyz=wxyz)
+ return Box.from_center_and_dimensions(
+ center=center,
+ length=lengths[..., 0],
+ width=lengths[..., 1],
+ height=lengths[..., 2],
+ wxyz=wxyz,
+ )
@staticmethod
def from_center_and_dimensions(
@@ -365,6 +377,7 @@ def as_halfspaces(self) -> tuple[HalfSpace, ...]:
return (hs_px, hs_nx, hs_py, hs_ny, hs_pz, hs_nz)
+
@jdc.pytree_dataclass
class Capsule(CollGeom):
"""Capsule geometry."""
diff --git a/src/pyronot/collision/_geometry_pairs.py b/src/pyronot/collision/_geometry_pairs.py
index 3228c83..f3c352f 100644
--- a/src/pyronot/collision/_geometry_pairs.py
+++ b/src/pyronot/collision/_geometry_pairs.py
@@ -62,10 +62,17 @@ def _sphere_sphere_dist(
pos2: Float[Array, "*batch 3"],
radius2: Float[Array, "*batch"],
) -> Float[Array, "*batch"]:
- """Helper: Calculates distance between two spheres."""
- _, dist_center = _utils.normalize_with_norm(pos2 - pos1)
- dist = dist_center - (radius1 + radius2)
- return dist
+ """Helper: Calculates SQUARED distance between two sphere surfaces.
+
+ Note: Returns (dist_center - radii_sum)², not dist².
+ For collision check: result < 0 means collision.
+ """
+ diff = pos2 - pos1
+ dist_center_sq = jnp.sum(diff * diff, axis=-1) # ||p2 - p1||², no sqrt
+ radii_sum = radius1 + radius2
+ radii_sum_sq = radii_sum * radii_sum
+ # Return signed indicator: negative means collision
+ return dist_center_sq - radii_sum_sq
def sphere_sphere(sphere1: Sphere, sphere2: Sphere) -> Float[Array, "*batch"]:
@@ -217,6 +224,7 @@ def heightmap_halfspace(
assert min_dist.shape == batch_axes
return min_dist
+
def box_sphere(box: Box, sphere: Sphere) -> Float[Array, "*batch"]:
"""Compute signed distance between an oriented box and a sphere.
@@ -247,22 +255,23 @@ def box_capsule(box: Box, capsule: Capsule) -> Float[Array, "*batch"]:
"""
cap_pos = capsule.pose.translation()
- cap_axis = capsule.axis
+ cap_axis = capsule.axis
half_h = capsule.height[..., None] * 0.5
- a_w = cap_pos - cap_axis * half_h
- b_w = cap_pos + cap_axis * half_h
+ a_w = cap_pos - cap_axis * half_h
+ b_w = cap_pos + cap_axis * half_h
a = box.pose.inverse().apply(a_w)
b = box.pose.inverse().apply(b_w)
- hl = box.half_lengths
+ hl = box.half_lengths
ab = b - a
ab_len2 = jnp.sum(ab * ab, axis=-1, keepdims=True)
t = jnp.clip(
jnp.sum((0.0 - a) * ab, axis=-1, keepdims=True) / (ab_len2 + 1e-12),
- 0.0, 1.0,
+ 0.0,
+ 1.0,
)
p = a + t * ab
q = jnp.abs(p) - hl
@@ -276,7 +285,6 @@ def box_capsule(box: Box, capsule: Capsule) -> Float[Array, "*batch"]:
return sdist_box - capsule.radius
-
def box_halfspace(box: Box, halfspace: HalfSpace) -> Float[Array, "*batch"]:
"""Compute signed distance between box and a halfspace plane.
@@ -301,7 +309,9 @@ def box_halfspace(box: Box, halfspace: HalfSpace) -> Float[Array, "*batch"]:
hs_n_bc = jnp.broadcast_to(hs_n, verts_world.shape[:-1] + (3,))[..., None, :]
hs_pt_bc = jnp.broadcast_to(hs_pt, verts_world.shape[:-1] + (3,))[..., None, :]
- vertex_distances = jnp.einsum("...vi,...i->...v", verts_world - hs_pt_bc, hs_n_bc.squeeze(-2))
+ vertex_distances = jnp.einsum(
+ "...vi,...i->...v", verts_world - hs_pt_bc, hs_n_bc.squeeze(-2)
+ )
min_dist = jnp.min(vertex_distances, axis=-1)
return min_dist
diff --git a/src/pyronot/collision/_neural_collision.py b/src/pyronot/collision/_neural_collision.py
index 74a005d..0c38bf9 100644
--- a/src/pyronot/collision/_neural_collision.py
+++ b/src/pyronot/collision/_neural_collision.py
@@ -28,46 +28,48 @@ class NeuralRobotCollision:
"""
A wrapper class that adds neural network-based collision distance approximation
to either RobotCollision or RobotCollisionSpherized.
-
+
The network is trained to overfit to a specific scene, mapping robot link poses
directly to collision distances between robot links and the static obstacles.
-
+
Input: Flattened link poses (N links × 7 pose params = N*7 dimensions), optionally
with positional encoding for capturing fine geometric details.
Output: Flattened distance matrix (N links × M obstacles = N*M dimensions)
-
+
This class uses composition to wrap either collision model type, delegating
non-neural methods to the underlying collision model.
-
+
Positional Encoding (inspired by iSDF):
When enabled, the input is augmented with sinusoidal positional encodings
at multiple frequency scales. This allows the network to learn high-frequency
spatial features that are critical for accurate collision distance prediction,
especially near obstacle boundaries where distances change rapidly.
"""
-
+
# The underlying collision model (either RobotCollision or RobotCollisionSpherized)
_collision_model: Union[RobotCollision, RobotCollisionSpherized]
-
+
# Neural network parameters (weights and biases for each layer)
- nn_params: List[Tuple[Float[Array, "fan_in fan_out"], Float[Array, "fan_out"]]] = jdc.field(default_factory=list)
-
+ nn_params: List[Tuple[Float[Array, "fan_in fan_out"], Float[Array, "fan_out"]]] = (
+ jdc.field(default_factory=list)
+ )
+
# Metadata about the training - these must be static for use in JIT conditionals
is_trained: jdc.Static[bool] = False
-
+
# We keep track of the number of obstacles this network was trained for (M)
trained_num_obstacles: jdc.Static[int] = 0
-
+
# Input normalization parameters (computed during training)
input_mean: jax.Array = jdc.field(default_factory=lambda: jnp.zeros(1))
input_std: jax.Array = jdc.field(default_factory=lambda: jnp.ones(1))
-
+
# Positional encoding parameters (static for JIT)
use_positional_encoding: jdc.Static[bool] = False
pe_min_deg: jdc.Static[int] = 0
pe_max_deg: jdc.Static[int] = 6
pe_scale: jdc.Static[float] = 1.0
-
+
# Computed PE scale (stored after training for use at inference)
pe_scale_computed: jax.Array = jdc.field(default_factory=lambda: jnp.array(1.0))
@@ -75,23 +77,23 @@ class NeuralRobotCollision:
@property
def num_links(self) -> int:
return self._collision_model.num_links
-
+
@property
def link_names(self) -> tuple[str, ...]:
return self._collision_model.link_names
-
+
@property
def coll(self) -> CollGeom:
return self._collision_model.coll
-
+
@property
def active_idx_i(self):
return self._collision_model.active_idx_i
-
+
@property
def active_idx_j(self):
return self._collision_model.active_idx_j
-
+
@property
def is_spherized(self) -> bool:
"""Returns True if the underlying model is RobotCollisionSpherized."""
@@ -110,7 +112,7 @@ def from_existing(
"""
Creates a NeuralRobotCollision instance from an existing collision model.
Initializes the neural network with random weights.
-
+
Args:
original: The original collision model (RobotCollision or RobotCollisionSpherized).
layer_sizes: List of hidden layer sizes. The input size is determined by robot DOF,
@@ -125,15 +127,15 @@ def from_existing(
"""
if layer_sizes is None:
layer_sizes = [256, 256, 256]
-
+
if key is None:
key = jax.random.PRNGKey(0)
# We can't fully initialize the network structure until we know the output dimension (N*M),
- # which depends on the number of obstacles M.
+ # which depends on the number of obstacles M.
# For now, we just copy the fields and return an untrained instance.
# The actual weights will be initialized/shaped during the training setup or first call.
-
+
return NeuralRobotCollision(
_collision_model=original,
nn_params=[],
@@ -223,23 +225,25 @@ def compute_world_collision_distance(
) -> Float[Array, "*batch_combined N M"]:
"""
Computes collision distances, using the trained neural network if available.
-
+
This assumes that world_geom represents the SAME static obstacles that the network
was trained on. The network uses link poses (from forward kinematics) as input
and predicts distances based on those poses.
-
+
If positional encoding is enabled, the input is augmented with sinusoidal
embeddings at multiple frequency scales to capture fine geometric details.
-
+
Falls back to the underlying collision model's exact computation if not trained.
"""
if not self.is_trained:
# Fallback to the original exact computation if not trained
- return self._collision_model.compute_world_collision_distance(robot, cfg, world_geom)
+ return self._collision_model.compute_world_collision_distance(
+ robot, cfg, world_geom
+ )
# Determine batch shapes
batch_cfg_shape = cfg.shape[:-1]
-
+
# Check world geom shape to ensure consistency with training (M)
world_axes = world_geom.get_batch_axes()
if len(world_axes) == 0:
@@ -248,23 +252,25 @@ def compute_world_collision_distance(
else:
M = world_axes[-1]
batch_world_shape = world_axes[:-1]
-
+
if M != self.trained_num_obstacles:
logger.warning(
f"Neural network was trained for {self.trained_num_obstacles} obstacles, "
f"but current world_geom has {M}. Falling back to exact computation."
)
- return self._collision_model.compute_world_collision_distance(robot, cfg, world_geom)
+ return self._collision_model.compute_world_collision_distance(
+ robot, cfg, world_geom
+ )
# Compute link poses via forward kinematics
# Shape: (*batch_cfg, num_links, 7) where 7 = wxyz (4) + xyz (3)
link_poses = robot.forward_kinematics(cfg)
N = self.num_links
-
+
# Flatten link poses to use as network input
# Shape: (*batch_cfg, num_links * 7)
link_poses_flat = link_poses.reshape(*batch_cfg_shape, N * 7)
-
+
# Apply positional encoding BEFORE normalization if enabled
# (matching the training procedure)
if self.use_positional_encoding:
@@ -279,22 +285,24 @@ def compute_world_collision_distance(
else:
# Just normalize
link_poses_normalized = (link_poses_flat - self.input_mean) / self.input_std
-
+
# Flatten batch for inference
input_dim = link_poses_normalized.shape[-1]
input_flat = link_poses_normalized.reshape(-1, input_dim)
-
+
# Run inference
predict_fn = jax.vmap(self._forward_nn)
dists_flat = predict_fn(input_flat) # Shape: (batch_size, N * M)
-
+
# Reshape output to (*batch_cfg, N, M)
dists = dists_flat.reshape(*batch_cfg_shape, N, M)
-
+
# Handle broadcasting with world batch shape if necessary.
if batch_world_shape:
- expected_batch_combined = jnp.broadcast_shapes(batch_cfg_shape, batch_world_shape)
- dists = jnp.broadcast_to(dists, (*expected_batch_combined, N, M))
+ expected_batch_combined = jnp.broadcast_shapes(
+ batch_cfg_shape, batch_world_shape
+ )
+ dists = jnp.broadcast_to(dists, (*expected_batch_combined, N, M))
return dists
@@ -307,7 +315,7 @@ def train(
epochs: int = 50,
learning_rate: float = 1e-3,
key: jax.Array = None,
- layer_sizes: List[int] = [256, 256, 256, 256]
+ layer_sizes: List[int] = [256, 256, 256, 256],
) -> "NeuralRobotCollision":
"""
Trains the neural network to approximate the collision distances for the given world_geom.
@@ -318,7 +326,7 @@ def train(
where collision spheres end up in world space.
"""
logger.info("Starting neural collision training...")
-
+
if key is None:
key = jax.random.PRNGKey(0)
@@ -329,30 +337,32 @@ def train(
M = world_axes[-1] if len(world_axes) > 0 else 1
# 1. Generate training data with collision-aware sampling
- logger.info(f"Generating {num_samples} samples with collision-aware sampling...")
+ logger.info(
+ f"Generating {num_samples} samples with collision-aware sampling..."
+ )
# Sample configurations using Halton sequence for better space coverage
dof = robot.joints.num_actuated_joints
lower_limits = robot.joints.lower_limits
upper_limits = robot.joints.upper_limits
-
+
# Generate initial Halton sequence samples (2x to have pool for filtering)
initial_pool_size = num_samples * 2
halton_samples = halton_sequence(initial_pool_size, dof)
q_pool = lower_limits + halton_samples * (upper_limits - lower_limits)
-
+
# Compute distances for the pool to identify collision samples
logger.info("Computing distances to identify collision samples...")
-
+
def compute_min_dist(q):
dists = self._collision_model.compute_world_collision_distance(
robot, q, world_geom
)
return jnp.min(dists)
-
+
compute_all_min_dists = jax.vmap(compute_min_dist)
min_dists = compute_all_min_dists(q_pool) # Shape: (initial_pool_size,)
-
+
# Rebalance samples to have more collision and near-collision samples
collision_threshold = 0.1 # Samples within 10cm of collision
q_train = rebalance_samples(
@@ -370,34 +380,38 @@ def compute_min_dist(q):
# Shape: (num_samples, num_links, 7) where 7 = wxyz (4) + xyz (3)
logger.info("Computing link poses via forward kinematics...")
link_poses_all = robot.forward_kinematics(q_train)
-
+
# Flatten link poses to (num_samples, num_links * 7)
X_train_raw = link_poses_all.reshape(num_samples, N * 7)
-
+
# Apply positional encoding BEFORE normalization if enabled
# This is important because positional encoding should operate on the
# original spatial scale of the data, not normalized values
if self.use_positional_encoding:
- # For positional encoding, we want the input scaled so that the
+ # For positional encoding, we want the input scaled so that the
# lowest frequency (2^min_deg) captures the full data range,
# and higher frequencies capture finer details.
# A scale of 1.0 is typically fine when data is already in reasonable ranges.
# We use pi as scale so that the range [-1, 1] maps to [-pi, pi]
auto_scale = jnp.pi
- logger.info(f"Applying positional encoding (min_deg={self.pe_min_deg}, max_deg={self.pe_max_deg}, scale={auto_scale:.4f})...")
+ logger.info(
+ f"Applying positional encoding (min_deg={self.pe_min_deg}, max_deg={self.pe_max_deg}, scale={auto_scale:.4f})..."
+ )
X_train_pe = positional_encoding(
X_train_raw,
min_deg=self.pe_min_deg,
max_deg=self.pe_max_deg,
scale=auto_scale,
)
- logger.info(f"Input dimension after positional encoding: {X_train_pe.shape[-1]} (was {X_train_raw.shape[-1]})")
-
+ logger.info(
+ f"Input dimension after positional encoding: {X_train_pe.shape[-1]} (was {X_train_raw.shape[-1]})"
+ )
+
# Now normalize the positional-encoded features
X_mean = jnp.mean(X_train_pe, axis=0, keepdims=True)
X_std = jnp.std(X_train_pe, axis=0, keepdims=True) + 1e-8
X_train = (X_train_pe - X_mean) / X_std
-
+
# Store the auto-computed scale for inference
self_pe_scale_computed = auto_scale
else:
@@ -409,22 +423,22 @@ def compute_min_dist(q):
# 2. Compute ground truth labels using vmap for acceleration
logger.info("Computing ground truth distances (vectorized)...")
-
+
# Use vmap to compute distances for all configurations in parallel
def compute_single_dist(q):
dists = self._collision_model.compute_world_collision_distance(
robot, q, world_geom
)
return dists.reshape(-1) # Flatten to (N*M,)
-
+
# Vectorize over all training samples
compute_all_dists = jax.vmap(compute_single_dist)
Y_train = compute_all_dists(q_train) # Shape: (num_samples, N*M)
-
+
# Compute sample weights based on minimum distance
# Give higher weight to collision and near-collision samples
Y_min_per_sample = jnp.min(Y_train, axis=1) # Shape: (num_samples,)
-
+
# Weight function: higher weight for collision (dist <= 0) and near-collision
# collision: weight = 3.0, near-collision: weight = 2.0, free: weight = 1.0
sample_weights = jnp.where(
@@ -433,19 +447,23 @@ def compute_single_dist(q):
jnp.where(
Y_min_per_sample < collision_threshold,
2.0, # Near-collision samples get 2x weight
- 1.0 # Free space samples get normal weight
- )
+ 1.0, # Free space samples get normal weight
+ ),
)
# Normalize weights so they sum to num_samples (to maintain loss scale)
sample_weights = sample_weights * (num_samples / jnp.sum(sample_weights))
-
- logger.info(f"Sample weights - collision (3x): {jnp.sum(Y_min_per_sample <= 0)}, near-collision (2x): {jnp.sum((Y_min_per_sample > 0) & (Y_min_per_sample < collision_threshold))}")
+
+ logger.info(
+ f"Sample weights - collision (3x): {jnp.sum(Y_min_per_sample <= 0)}, near-collision (2x): {jnp.sum((Y_min_per_sample > 0) & (Y_min_per_sample < collision_threshold))}"
+ )
# 3. Initialize Network
# Input dimension depends on whether positional encoding is enabled
raw_input_dim = N * 7 # num_links * 7 (wxyz_xyz pose representation)
if self.use_positional_encoding:
- input_dim = compute_positional_encoding_dim(raw_input_dim, self.pe_min_deg, self.pe_max_deg)
+ input_dim = compute_positional_encoding_dim(
+ raw_input_dim, self.pe_min_deg, self.pe_max_deg
+ )
else:
input_dim = raw_input_dim
output_dim = N * M # num_links * num_obstacles
@@ -460,7 +478,11 @@ def compute_single_dist(q):
b = jnp.zeros((fan_out,))
params.append((w, b))
- pe_info = f" with positional encoding (deg {self.pe_min_deg}-{self.pe_max_deg})" if self.use_positional_encoding else ""
+ pe_info = (
+ f" with positional encoding (deg {self.pe_min_deg}-{self.pe_max_deg})"
+ if self.use_positional_encoding
+ else ""
+ )
logger.info(
f"Training neural network{pe_info} (Input: {input_dim}, Output: {output_dim} [distances])..."
)
@@ -488,50 +510,52 @@ def train_step(params, opt_state, x_batch, y_batch, w_batch, t):
"""Single training step with Adam optimizer."""
m, v = opt_state
beta1, beta2, epsilon = 0.9, 0.999, 1e-8
-
+
# Compute gradients
- loss_val, grads = jax.value_and_grad(loss_fn)(params, x_batch, y_batch, w_batch)
-
+ loss_val, grads = jax.value_and_grad(loss_fn)(
+ params, x_batch, y_batch, w_batch
+ )
+
# Adam update
new_params = []
new_m = []
new_v = []
-
+
for i in range(len(params)):
w, b = params[i]
dw, db = grads[i]
mw, mb = m[i]
vw, vb = v[i]
-
+
# Update biased first moment estimate
mw = beta1 * mw + (1.0 - beta1) * dw
mb = beta1 * mb + (1.0 - beta1) * db
-
+
# Update biased second moment estimate
- vw = beta2 * vw + (1.0 - beta2) * (dw ** 2)
- vb = beta2 * vb + (1.0 - beta2) * (db ** 2)
-
+ vw = beta2 * vw + (1.0 - beta2) * (dw**2)
+ vb = beta2 * vb + (1.0 - beta2) * (db**2)
+
# Bias correction
- m_hat_w = mw / (1.0 - beta1 ** t)
- m_hat_b = mb / (1.0 - beta1 ** t)
- v_hat_w = vw / (1.0 - beta2 ** t)
- v_hat_b = vb / (1.0 - beta2 ** t)
-
+ m_hat_w = mw / (1.0 - beta1**t)
+ m_hat_b = mb / (1.0 - beta1**t)
+ v_hat_w = vw / (1.0 - beta2**t)
+ v_hat_b = vb / (1.0 - beta2**t)
+
# Update parameters
w_new = w - learning_rate * m_hat_w / (jnp.sqrt(v_hat_w) + epsilon)
b_new = b - learning_rate * m_hat_b / (jnp.sqrt(v_hat_b) + epsilon)
-
+
new_params.append((w_new, b_new))
new_m.append((mw, mb))
new_v.append((vw, vb))
-
+
return new_params, (new_m, new_v), loss_val
# Initialize Adam state
m = [(jnp.zeros_like(w), jnp.zeros_like(b)) for w, b in params]
v = [(jnp.zeros_like(w), jnp.zeros_like(b)) for w, b in params]
opt_state = (m, v)
-
+
params_state = params
t = 0
num_batches = num_samples // batch_size
@@ -560,9 +584,7 @@ def train_step(params, opt_state, x_batch, y_batch, w_batch, t):
epoch_loss += loss_val
if epoch % 10 == 0:
- logger.info(
- f"Epoch {epoch}: Loss = {epoch_loss / num_batches:.6f}"
- )
+ logger.info(f"Epoch {epoch}: Loss = {epoch_loss / num_batches:.6f}")
logger.info("Training complete.")
@@ -578,4 +600,4 @@ def train_step(params, opt_state, x_batch, y_batch, w_batch, t):
# Backward compatibility alias
-NeuralRobotCollisionSpherized = NeuralRobotCollision
\ No newline at end of file
+NeuralRobotCollisionSpherized = NeuralRobotCollision
diff --git a/src/pyronot/collision/_robot_collision.py b/src/pyronot/collision/_robot_collision.py
index c223dd6..a4b6eb4 100644
--- a/src/pyronot/collision/_robot_collision.py
+++ b/src/pyronot/collision/_robot_collision.py
@@ -1,7 +1,6 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Optional, Tuple, cast
-import math
import jax
import jax.numpy as jnp
@@ -10,7 +9,6 @@
import trimesh
import yourdfpy
from jaxtyping import Array, Float, Int
-from jax import lax
from loguru import logger
if TYPE_CHECKING:
@@ -153,8 +151,6 @@ def _compute_active_pair_indices(
return active_i, active_j
-
-
@staticmethod
def _get_trimesh_collision_geometries(
urdf: yourdfpy.URDF, link_name: str
@@ -425,12 +421,15 @@ def compute_world_collision_distance(
# 5. Return the distance matrix
return dist_matrix
+
@jdc.pytree_dataclass
class RobotCollisionSpherized:
"""Collision model for a robot, integrated with pyronot kinematics."""
num_links: jdc.Static[int]
"""Number of links in the model (matches kinematics links)."""
+ spheres_per_link: jdc.Static[tuple[int, ...]]
+ """Number of spheres for each link (tuple of length num_links)."""
link_names: jdc.Static[tuple[str, ...]]
"""Names of the links corresponding to link indices."""
coll: CollGeom
@@ -476,50 +475,42 @@ def from_urdf(
# Gather all collision meshes.
link_sphere_meshes: list[list[trimesh.Trimesh]] = []
for link_name in link_name_list:
- spheres = RobotCollisionSpherized._get_trimesh_collision_spheres_for_link(urdf, link_name)
+ spheres = RobotCollisionSpherized._get_trimesh_collision_spheres_for_link(
+ urdf, link_name
+ )
link_sphere_meshes.append(spheres)
-
- sphere_list_per_link: list[list[Sphere]] = []
+ # Build flat list of spheres and track count per link
+ sphere_list: list[Sphere] = []
+ spheres_per_link: list[int] = []
for sphere_meshes in link_sphere_meshes:
per_link_spheres = [
Sphere.from_trimesh(mesh) for mesh in sphere_meshes if mesh is not None
]
- sphere_list_per_link.append(per_link_spheres)
-
- ############ Weihang: Please check this part #############
- # Add padding to the spheres list to make it a batched sphere object
- max_spheres = max(len(spheres) for spheres in sphere_list_per_link)
- padded_sphere_list: list[Sphere] = []
- for per_link_spheres in sphere_list_per_link:
- if len(per_link_spheres) < max_spheres:
- # Create dummy/invalid spheres for padding (e.g., zero radius)
- dummy_sphere = Sphere.from_center_and_radius(
- center=jnp.zeros(3),
- radius=jnp.array(0.0) # or negative to mark as invalid
- )
- padded = per_link_spheres + [dummy_sphere] * (max_spheres - len(per_link_spheres))
- padded_sphere_list.append(padded)
- else:
- padded_sphere_list.append(per_link_spheres)
- spheres_2d = cast(Sphere, jax.tree.map(lambda *args: jnp.stack(args), *padded_sphere_list))
+ sphere_list.extend(per_link_spheres)
+ spheres_per_link.append(len(per_link_spheres))
- ##########################################################
+ # Stack all spheres into a batched Sphere object
+ spheres = cast(Sphere, jax.tree.map(lambda *args: jnp.stack(args), *sphere_list))
# Directly compute active pair indices
# Weihang: Have not checked this part yet!!!
# Should be fine, generates active_indices per links - Sai
if srdf_path:
- active_idx_i, active_idx_j = RobotCollisionSpherized._compute_active_pair_indices_from_srdf(
- link_names=link_name_list,
- srdf_path=srdf_path,
+ active_idx_i, active_idx_j = (
+ RobotCollisionSpherized._compute_active_pair_indices_from_srdf(
+ link_names=link_name_list,
+ srdf_path=srdf_path,
+ )
)
else:
- active_idx_i, active_idx_j = RobotCollisionSpherized._compute_active_pair_indices(
- link_names=link_name_list,
- urdf=urdf,
- user_ignore_pairs=user_ignore_pairs,
- ignore_immediate_adjacents=ignore_immediate_adjacents,
+ active_idx_i, active_idx_j = (
+ RobotCollisionSpherized._compute_active_pair_indices(
+ link_names=link_name_list,
+ urdf=urdf,
+ user_ignore_pairs=user_ignore_pairs,
+ ignore_immediate_adjacents=ignore_immediate_adjacents,
+ )
)
logger.info(
@@ -530,18 +521,20 @@ def from_urdf(
return RobotCollisionSpherized(
num_links=link_info.num_links,
+ spheres_per_link=tuple(spheres_per_link),
link_names=link_name_list,
active_idx_i=active_idx_i,
active_idx_j=active_idx_j,
- coll=spheres_2d, # now stores lists of Sphere objects per link
+ coll=spheres,
)
@staticmethod
def _get_trimesh_collision_spheres_for_link(
- urdf: yourdfpy.URDF, link_name: str) -> list[trimesh.Trimesh]:
+ urdf: yourdfpy.URDF, link_name: str
+ ) -> list[trimesh.Trimesh]:
if link_name not in urdf.link_map:
return [trimesh.Trimesh()]
-
+
link = urdf.link_map[link_name]
filename_handler = urdf._filename_handler
coll_meshes = []
@@ -557,11 +550,11 @@ def _get_trimesh_collision_spheres_for_link(
transform = jaxlie.SE3.identity().as_matrix()
if geom.sphere is not None:
mesh = trimesh.creation.icosphere(radius=geom.sphere.radius)
- else:
+ else:
logger.warning(
f"Unsupported collision geometry type for link '{link_name}'."
)
- if mesh is not None:
+ if mesh is not None:
mesh.apply_transform(transform)
coll_meshes.append(mesh)
return coll_meshes
@@ -583,11 +576,11 @@ def _compute_active_pair_indices_from_srdf(
Tuple of (active_i, active_j) index arrays.
"""
from .._robot_srdf_parser import read_disabled_collisions_from_srdf
-
+
num_links = len(link_names)
link_name_to_idx = {name: i for i, name in enumerate(link_names)}
ignore_matrix = jnp.zeros((num_links, num_links), dtype=bool)
-
+
# Ignore self-collisions (diagonal)
ignore_matrix = ignore_matrix.at[
jnp.arange(num_links), jnp.arange(num_links)
@@ -596,32 +589,34 @@ def _compute_active_pair_indices_from_srdf(
# Read disabled collision pairs from SRDF
try:
disabled_pairs = read_disabled_collisions_from_srdf(srdf_path)
-
+
disabled_count = 0
for pair in disabled_pairs:
- link1_name = pair['link1']
- link2_name = pair['link2']
-
+ link1_name = pair["link1"]
+ link2_name = pair["link2"]
+
# Only process if both links are in our link_names
if link1_name in link_name_to_idx and link2_name in link_name_to_idx:
idx1 = link_name_to_idx[link1_name]
idx2 = link_name_to_idx[link2_name]
-
+
# Set both directions as disabled (symmetric)
ignore_matrix = ignore_matrix.at[idx1, idx2].set(True)
ignore_matrix = ignore_matrix.at[idx2, idx1].set(True)
disabled_count += 1
-
+
logger.info(f"Loaded {disabled_count} disabled collision pairs from SRDF")
-
+
except FileNotFoundError:
- logger.warning(f"SRDF file not found: {srdf_path}. Using all collision pairs.")
+ logger.warning(
+ f"SRDF file not found: {srdf_path}. Using all collision pairs."
+ )
except Exception as e:
logger.warning(f"Error parsing SRDF file: {e}. Using all collision pairs.")
# Get all lower triangular indices (i < j)
idx_i, idx_j = jnp.tril_indices(num_links, k=-1)
-
+
# Filter out ignored pairs
should_check = ~ignore_matrix[idx_i, idx_j]
active_i = idx_i[should_check]
@@ -690,7 +685,7 @@ def _get_trimesh_collision_spheres(urdf: yourdfpy.URDF) -> list[trimesh.Trimesh]
Returns:
list[trimesh.Trimesh]: A list of trimesh sphere meshes (each transformed to link frame).
- Author:
+ Author:
Sai Coumar
"""
sphere_meshes = []
@@ -751,17 +746,20 @@ def at_config(
)
# TODO: Override with passed in result of fk so i don't have to recompute
Ts_link_world_wxyz_xyz = robot.forward_kinematics(cfg)
- Ts_link_world = jaxlie.SE3(Ts_link_world_wxyz_xyz)
- ############ Weihang: Please check this part #############
- coll_transformed = []
- for link in range(len(self.coll)):
- coll_transformed.append(self.coll[link].transform(Ts_link_world))
- coll_transformed = cast(CollGeom, jax.tree.map(lambda *args: jnp.stack(args), *coll_transformed))
- ##########################################################
- return coll_transformed
- # return self.coll.transform(Ts_link_world)
-
+ # Spheres are laid out as [link0_s0..sN0, link1_s0..sN1, ...]
+ # where Ni = spheres_per_link[i]
+ # Repeat each link's transform by its sphere count
+ total_spheres = sum(self.spheres_per_link)
+ expanded_transforms = jnp.repeat(
+ Ts_link_world_wxyz_xyz,
+ jnp.array(self.spheres_per_link),
+ axis=-2,
+ total_repeat_length=total_spheres
+ )
+
+ Ts_expanded = jaxlie.SE3(expanded_transforms)
+ return self.coll.transform(Ts_expanded)
def compute_self_collision_distance(
self,
@@ -780,10 +778,10 @@ def compute_self_collision_distance(
Signed distances for each active pair.
Shape: (*batch, num_active_pairs).
Positive distance means separation, negative means penetration.
-
+
Author: Sai Coumar
"""
-
+
# 1. Transform all spheres to world frame
coll = self.at_config(robot, cfg) # CollGeom: (*batch, n_spheres, num_links)
@@ -796,15 +794,14 @@ def compute_self_collision_distance(
# Return same format of active_distances as the capsule implementaiton
active_distances = dist_matrix_links[..., self.active_idx_i, self.active_idx_j]
return active_distances
-
@staticmethod
def collide_link_vs_world(link_geom, world_geom):
- # Map collide over spheres in this link (S)
- # link_geom: (S, ...)
- collide_spheres_vs_world = jax.vmap(collide, in_axes=(0, None), out_axes=0)
- dist_spheres = collide_spheres_vs_world(link_geom, world_geom) # (S, M)
- return dist_spheres.min(axis=0) # reduce over spheres → (M,)
+ # Map collide over spheres in this link (S)
+ # link_geom: (S, ...)
+ collide_spheres_vs_world = jax.vmap(collide, in_axes=(0, None), out_axes=0)
+ dist_spheres = collide_spheres_vs_world(link_geom, world_geom) # (S, M)
+ return dist_spheres.min(axis=0) # reduce over spheres → (M,)
@jdc.jit
def compute_world_collision_distance(
@@ -814,144 +811,91 @@ def compute_world_collision_distance(
world_geom: CollGeom, # Shape: (*batch_world, M, ...)
) -> Float[Array, "*batch_combined N M"]:
"""
- Computes signed distances between all robot links (N) and world obstacles (M),
- accounting for multiple primitives (S) per link.
+ Computes the signed distances between all robot links (N) and all world obstacles (M).
- The minimum distance over all primitives in each link is used as the link’s
- representative distance to each world object.
+ Args:
+ robot_coll: The robot's collision model.
+ robot: The robot's kinematic model.
+ cfg: The robot configuration (actuated joints).
+ world_geom: Collision geometry representing world obstacles. If representing a
+ single obstacle, it should have batch shape (). If multiple, the last axis
+ is interpreted as the collection of world objects (M).
+ The batch dimensions (*batch_world) must be broadcast-compatible with cfg's
+ batch axes (*batch_cfg).
+
+ Returns:
+ Matrix of signed distances between each robot link and each world object.
+ Shape: (*batch_combined, N, M), where N=num_links, M=num_world_objects.
+ Positive distance means separation, negative means penetration.
"""
- # 1. Get robot collision geometry at configuration
- # Shape: (*batch_cfg, S, N, ...)
+ # 1. Get robot collision geometry at the current config
+ # Shape: (*batch_cfg, N, ...)
coll_robot_world = self.at_config(robot, cfg)
- batch_cfg_shape = coll_robot_world.get_batch_axes()[:-2]
- S, N = coll_robot_world.get_batch_axes()[-2:]
-
+ N = self.num_links
+ batch_cfg_shape = coll_robot_world.get_batch_axes()[:-1]
# 2. Normalize world_geom shape and determine M
world_axes = world_geom.get_batch_axes()
- if len(world_axes) == 0:
+ if len(world_axes) == 0: # Single world object
+ # Use the object's broadcast_to method to add the M=1 axis correctly
_world_geom = world_geom.broadcast_to((1,))
M = 1
batch_world_shape = ()
- else:
+ else: # Multiple world objects
_world_geom = world_geom
M = world_axes[-1]
batch_world_shape = world_axes[:-1]
- # 3. Define how to collide a single link (with S primitives) against the world
- # Each link_geom has shape (S, ...). We map over the S primitives, then take min.
-
-
- # 4. Now map that over links (N)
- # coll_robot_world: (*batch_cfg, S, N, ...)
- # We map over the link axis (-2 from end, N)
- _collide_links_vs_world = jax.vmap(
- self.collide_link_vs_world, in_axes=(-2, None), out_axes=-2
- )
-
- # # 5. Compute final distance matrix
- dist_matrix = _collide_links_vs_world(coll_robot_world, _world_geom) # (*batch, N, M)
-
+ # 3. Compute distances: Map collide over robot links (axis -2) vs _world_geom (None)
+ # _world_geom is guaranteed to have the M axis now.
+ _collide_links_vs_world = jax.vmap(collide, in_axes=(-2, None), out_axes=(-2))
+ dist_matrix = _collide_links_vs_world(coll_robot_world, _world_geom)
+ # 5. Return the distance matrix
+ return dist_matrix
- # 6. Verify shape consistency
- expected_batch_combined = jnp.broadcast_shapes(batch_cfg_shape, batch_world_shape)
- expected_shape = (*expected_batch_combined, N, M)
- assert dist_matrix.shape == expected_shape, (
- f"Output shape mismatch. Expected {expected_shape}, got {dist_matrix.shape}. "
- f"Robot axes: {coll_robot_world.get_batch_axes()}, "
- f"World axes: {world_geom.get_batch_axes()}"
- )
+ def compute_world_collision_distance_with_exclude_links(
+ self,
+ robot: Robot,
+ cfg: Float[Array, "*batch_cfg actuated_count"],
+ world_geom: CollGeom, # Shape: (*batch_world, M, ...)
+ exclude_link_mask: Int[Array, " num_links"],
+ ) -> Float[Array, "*batch_combined N M"]:
+ """
+ Computes signed distances between all robot links (N) and world obstacles (M),
+ accounting for multiple primitives (S) per link.
+ The minimum distance over all primitives in each link is used as the link's
+ representative distance to each world object.
+ """
+ dist_matrix = self.compute_world_collision_distance(robot, cfg, world_geom)
+ dist_matrix = RobotCollisionSpherized.mask_collision_distance(dist_matrix, exclude_link_mask)
return dist_matrix
- # @jdc.jit
- # def compute_world_collision_distance(
- # self,
- # robot: Robot,
- # cfg: Float[Array, "*batch_cfg actuated_count"],
- # world_geom: CollGeom, # Shape: (*batch_world, M, ...)
- # ) -> Float[Array, "*batch_combined N M"]:
- # """
- # Computes signed distances between all robot links (N) and world obstacles (M),
- # accounting for multiple primitives (S) per link.
-
- # The minimum distance over all primitives in each link is used as the link’s
- # representative distance to each world object.
- # """
- # CHUNK_SIZE = 1000
-
- # # 1. Get robot collision geometry at configuration
- # # Shape: (S, *batch_cfg, N, ...)
- # coll_robot_world = self.at_config(robot, cfg)
-
- # # 2. Normalize world_geom shape and determine M
- # world_axes = world_geom.get_batch_axes()
- # if len(world_axes) == 0:
- # _world_geom = world_geom.broadcast_to((1,))
- # M = 1
- # else:
- # _world_geom = world_geom
- # M = world_axes[-1]
-
- # # 3. Prepare for scan over links (N)
- # # We want to iterate over the N axis.
- # # coll_robot_world has N at the last batch axis.
- # # We move N to the front (0) to scan over it.
- # n_batch = len(coll_robot_world.get_batch_axes())
- # coll_robot_world_scannable = jax.tree.map(
- # lambda x: jnp.moveaxis(x, -2 if x.ndim > n_batch else -1, 0),
- # coll_robot_world,
- # )
-
- # # Pad to multiple of CHUNK_SIZE
- # N = self.num_links
- # pad_size = (CHUNK_SIZE - (N % CHUNK_SIZE)) % CHUNK_SIZE
-
- # @jdc.jit
- # def pad_fn(x):
- # # x shape: (N, ...)
- # padding = [(0, 0)] * x.ndim
- # padding[0] = (0, pad_size)
- # return jnp.pad(x, padding, mode='edge')
-
- # coll_padded = jax.tree.map(pad_fn, coll_robot_world_scannable)
-
- # # Reshape to (num_chunks, CHUNK_SIZE, ...)
- # num_chunks = (N + pad_size) // CHUNK_SIZE
-
- # @jdc.jit
- # def reshape_fn(x):
- # return x.reshape((num_chunks, CHUNK_SIZE) + x.shape[1:])
-
- # coll_chunked = jax.tree.map(reshape_fn, coll_padded)
-
- # # 4. Define scan function
- # @jdc.jit
- # def scan_fn(carry, link_chunk_geom):
- # # link_chunk_geom: (CHUNK_SIZE, S, *batch_cfg, ...)
- # # _world_geom: (*batch_world, M, ...)
-
- # # vmap over the chunk (axis 0)
- # _collide_chunk = jax.vmap(
- # self.collide_link_vs_world, in_axes=(0, None), out_axes=0
- # )
- # d_chunk = _collide_chunk(link_chunk_geom, _world_geom)
- # return carry, d_chunk
-
- # # 5. Run scan
- # # dists_scanned: (num_chunks, CHUNK_SIZE, *batch_combined, M)
- # _, dists_scanned = lax.scan(scan_fn, None, coll_chunked)
-
- # # 6. Restore shape
- # # Flatten chunks
- # dists_flattened = dists_scanned.reshape((-1,) + dists_scanned.shape[2:])
- # # Slice to original N
- # dists_N = dists_flattened[:N]
- # # Move N to -2
- # dist_matrix = jnp.moveaxis(dists_N, 0, -2)
-
- # return dist_matrix
+ def is_in_collision(
+ self,
+ robot: Robot,
+ cfg: Float[Array, "*batch_cfg actuated_count"],
+ world_geom: CollGeom, # Shape: (*batch_world, M, ...)
+ ) -> Bool[Array, "*batch_combined N M"]:
+ """
+ Checks if the robot is in collision with the world obstacles.
+ """
+ dist_matrix = self.compute_world_collision_distance(robot, cfg, world_geom)
+ return dist_matrix < 0
+
+ def is_in_collision_with_exclude_links(
+ self,
+ robot: Robot,
+ cfg: Float[Array, "*batch_cfg actuated_count"],
+ world_geom: CollGeom, # Shape: (*batch_world, M, ...)
+ exclude_link_mask: Int[Array, " num_links"],
+ ) -> Bool[Array, "*batch_combined N M"]:
+ """
+ Checks if the robot is in collision with the world obstacles, excluding the specified links.
+ """
+ dist_matrix = self.compute_world_collision_distance_with_exclude_links(robot, cfg, world_geom, exclude_link_mask)
+ return dist_matrix < 0
def get_swept_capsules(
self,
@@ -993,16 +937,20 @@ def get_swept_capsules(
)
return swept_capsules
-
+
@staticmethod
- def mask_collision_distance(solution: Float[Array, "1D array = num_links"], exclude_link_mask: Int[Array, " num_links"], replace_value: float = 1e6) -> Float[Array, "1D array = num_links"]:
+ def mask_collision_distance(
+ solution: Float[Array, "1D array = num_links"],
+ exclude_link_mask: Int[Array, " num_links"],
+ replace_value: float = 1e6,
+ ) -> Float[Array, "1D array = num_links"]:
"""Mask collision distances at specified link indices by replacing them with a value.
-
+
Args:
solution: Collision distance array with shape (*batch, actuated_count)
exclude_link_indices: Indices of links to exclude from collision checking
replace_value: Value to replace at the excluded indices (default: 1e6 for large distance)
-
+
Returns:
Masked collision distance array with the same shape as solution
"""
diff --git a/src/pyronot/utils.py b/src/pyronot/utils.py
index 3f7dd0d..3761b04 100644
--- a/src/pyronot/utils.py
+++ b/src/pyronot/utils.py
@@ -12,36 +12,87 @@
# Icosahedron vertices for positional encoding directions (from iSDF)
# These 20 directions provide good coverage of the unit sphere
-_ICOSAHEDRON_DIRS = jnp.array([
- [0.8506508, 0, 0.5257311],
- [0.809017, 0.5, 0.309017],
- [0.5257311, 0.8506508, 0],
- [1, 0, 0],
- [0.809017, 0.5, -0.309017],
- [0.8506508, 0, -0.5257311],
- [0.309017, 0.809017, -0.5],
- [0, 0.5257311, -0.8506508],
- [0.5, 0.309017, -0.809017],
- [0, 1, 0],
- [-0.5257311, 0.8506508, 0],
- [-0.309017, 0.809017, -0.5],
- [0, 0.5257311, 0.8506508],
- [-0.309017, 0.809017, 0.5],
- [0.309017, 0.809017, 0.5],
- [0.5, 0.309017, 0.809017],
- [0.5, -0.309017, 0.809017],
- [0, 0, 1],
- [-0.5, 0.309017, 0.809017],
- [-0.809017, 0.5, 0.309017],
-]).T # Shape: (3, 20) for efficient matmul
+_ICOSAHEDRON_DIRS = jnp.array(
+ [
+ [0.8506508, 0, 0.5257311],
+ [0.809017, 0.5, 0.309017],
+ [0.5257311, 0.8506508, 0],
+ [1, 0, 0],
+ [0.809017, 0.5, -0.309017],
+ [0.8506508, 0, -0.5257311],
+ [0.309017, 0.809017, -0.5],
+ [0, 0.5257311, -0.8506508],
+ [0.5, 0.309017, -0.809017],
+ [0, 1, 0],
+ [-0.5257311, 0.8506508, 0],
+ [-0.309017, 0.809017, -0.5],
+ [0, 0.5257311, 0.8506508],
+ [-0.309017, 0.809017, 0.5],
+ [0.309017, 0.809017, 0.5],
+ [0.5, 0.309017, 0.809017],
+ [0.5, -0.309017, 0.809017],
+ [0, 0, 1],
+ [-0.5, 0.309017, 0.809017],
+ [-0.809017, 0.5, 0.309017],
+ ]
+).T # Shape: (3, 20) for efficient matmul
# First 50 prime numbers for Halton sequence bases
-_HALTON_PRIMES = jnp.array([
- 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71,
- 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151,
- 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229
-])
+_HALTON_PRIMES = jnp.array(
+ [
+ 2,
+ 3,
+ 5,
+ 7,
+ 11,
+ 13,
+ 17,
+ 19,
+ 23,
+ 29,
+ 31,
+ 37,
+ 41,
+ 43,
+ 47,
+ 53,
+ 59,
+ 61,
+ 67,
+ 71,
+ 73,
+ 79,
+ 83,
+ 89,
+ 97,
+ 101,
+ 103,
+ 107,
+ 109,
+ 113,
+ 127,
+ 131,
+ 137,
+ 139,
+ 149,
+ 151,
+ 157,
+ 163,
+ 167,
+ 173,
+ 179,
+ 181,
+ 191,
+ 193,
+ 197,
+ 199,
+ 211,
+ 223,
+ 227,
+ 229,
+ ]
+)
def positional_encoding(
@@ -52,47 +103,49 @@ def positional_encoding(
) -> Float[Array, "... embed_dim"]:
"""
Standard sinusoidal positional encoding (NeRF-style).
-
+
Applies sinusoidal encoding at multiple frequency bands to each input dimension.
This helps the network capture high-frequency spatial details.
-
+
Args:
x: Input tensor of shape (..., D).
min_deg: Minimum frequency degree (default 0).
max_deg: Maximum frequency degree (default 6).
scale: Scale factor applied to input before encoding (default 1.0).
-
+
Returns:
Positional encoding of shape (..., D + D * 2 * num_freqs).
"""
n_freqs = max_deg - min_deg + 1
# Frequency bands: 2^min_deg, 2^(min_deg+1), ..., 2^max_deg
frequency_bands = 2.0 ** jnp.arange(min_deg, max_deg + 1)
-
+
# Scale input
x_scaled = x * scale
-
+
# Start with original features
embeddings = [x_scaled]
-
+
# Apply sin and cos at each frequency to each dimension
for freq in frequency_bands:
embeddings.append(jnp.sin(freq * x_scaled))
embeddings.append(jnp.cos(freq * x_scaled))
-
+
# Concatenate all embeddings along last axis
return jnp.concatenate(embeddings, axis=-1)
-def compute_positional_encoding_dim(input_dim: int, min_deg: int = 0, max_deg: int = 6) -> int:
+def compute_positional_encoding_dim(
+ input_dim: int, min_deg: int = 0, max_deg: int = 6
+) -> int:
"""
Compute the output dimension of positional encoding.
-
+
Args:
input_dim: Input dimension D.
min_deg: Minimum frequency degree.
max_deg: Maximum frequency degree.
-
+
Returns:
Output dimension after positional encoding.
"""
@@ -104,47 +157,49 @@ def compute_positional_encoding_dim(input_dim: int, min_deg: int = 0, max_deg: i
def halton_sequence(num_samples: int, dim: int, skip: int = 100) -> jax.Array:
"""
Generate a Halton sequence for quasi-random sampling.
-
+
The Halton sequence provides better coverage of the sample space compared to
uniform random sampling, which is beneficial for diverse sample collection.
-
+
Args:
num_samples: Number of samples to generate.
dim: Dimensionality of each sample.
skip: Number of initial samples to skip (improves uniformity).
-
+
Returns:
JAX array of shape (num_samples, dim) with values in [0, 1].
"""
if dim > len(_HALTON_PRIMES):
- raise ValueError(f"Halton sequence dimension {dim} exceeds available primes ({len(_HALTON_PRIMES)})")
-
+ raise ValueError(
+ f"Halton sequence dimension {dim} exceeds available primes ({len(_HALTON_PRIMES)})"
+ )
+
# Generate samples using vectorized operations where possible
indices = jnp.arange(skip, skip + num_samples)
bases = _HALTON_PRIMES[:dim]
-
+
# Compute maximum number of digits needed for the largest index
max_index = skip + num_samples
max_digits = int(jnp.ceil(jnp.log(max_index + 1) / jnp.log(2))) + 1
-
+
def halton_for_base(base: int) -> jax.Array:
"""Vectorized Halton sequence for a single base."""
# Compute radical inverse for all indices at once
result = jnp.zeros(num_samples)
f = 1.0 / base
current = indices.astype(jnp.float32)
-
+
for _ in range(max_digits):
digit = jnp.mod(current, base)
result = result + f * digit
current = jnp.floor(current / base)
f = f / base
-
+
return result
-
+
# Stack results for all dimensions
samples = jnp.stack([halton_for_base(int(b)) for b in bases], axis=1)
-
+
return samples
@@ -166,7 +221,7 @@ def rebalance_samples(
"""
Rebalance a sample pool to have a target distribution of collision, near-collision,
and free-space samples.
-
+
Args:
samples: Initial sample pool of shape (pool_size, dof).
distances: Minimum distances for each sample, shape (pool_size,).
@@ -182,140 +237,184 @@ def rebalance_samples(
max_augment_iterations: Maximum iterations for augmentation (default 10).
perturbation_scale_collision: Perturbation scale for collision augmentation (default 0.05).
perturbation_scale_near: Perturbation scale for near-collision augmentation (default 0.03).
-
+
Returns:
Rebalanced samples of shape (num_samples, dof).
"""
dof = samples.shape[1]
-
+
# Separate samples into categories
is_in_collision = distances <= 0
is_near_collision = (distances > 0) & (distances < collision_threshold)
is_free_space = distances >= collision_threshold
-
+
collision_samples = samples[is_in_collision]
near_collision_samples = samples[is_near_collision]
free_space_samples = samples[is_free_space]
-
+
num_collision = collision_samples.shape[0]
num_near_collision = near_collision_samples.shape[0]
num_free = free_space_samples.shape[0]
-
- logger.info(f"Sample distribution from pool: collision={num_collision}, near-collision={num_near_collision}, free-space={num_free}")
-
+
+ logger.info(
+ f"Sample distribution from pool: collision={num_collision}, near-collision={num_near_collision}, free-space={num_free}"
+ )
+
# Target distribution
target_collision = int(num_samples * target_collision_ratio)
target_near = int(num_samples * target_near_ratio)
target_free = num_samples - target_collision - target_near
-
+
key_augment = key
-
+
# Augment collision samples if needed
if num_collision < target_collision and num_collision > 0:
- logger.info(f"Augmenting collision samples from {num_collision} to {target_collision}...")
-
+ logger.info(
+ f"Augmenting collision samples from {num_collision} to {target_collision}..."
+ )
+
samples_needed = target_collision - num_collision
augmented_list = []
iteration = 0
-
- while len(augmented_list) < samples_needed and iteration < max_augment_iterations:
+
+ while (
+ len(augmented_list) < samples_needed and iteration < max_augment_iterations
+ ):
iteration += 1
batch_size_aug = min(samples_needed * 2, 5000)
key_augment, subk1, subk2 = jax.random.split(key_augment, 3)
indices = jax.random.randint(subk1, (batch_size_aug,), 0, num_collision)
base_samples = collision_samples[indices]
-
- perturbation_range = perturbation_scale_collision * (upper_limits - lower_limits)
- perturbations = jax.random.uniform(subk2, (batch_size_aug, dof), minval=-1, maxval=1) * perturbation_range
- candidates = jnp.clip(base_samples + perturbations, lower_limits, upper_limits)
-
+
+ perturbation_range = perturbation_scale_collision * (
+ upper_limits - lower_limits
+ )
+ perturbations = (
+ jax.random.uniform(subk2, (batch_size_aug, dof), minval=-1, maxval=1)
+ * perturbation_range
+ )
+ candidates = jnp.clip(
+ base_samples + perturbations, lower_limits, upper_limits
+ )
+
candidate_dists = distance_fn(candidates)
valid_mask = candidate_dists <= 0
valid_candidates = candidates[valid_mask]
-
+
if valid_candidates.shape[0] > 0:
augmented_list.append(valid_candidates)
-
- logger.debug(f" Iteration {iteration}: {valid_candidates.shape[0]} valid collision samples generated")
-
+
+ logger.debug(
+ f" Iteration {iteration}: {valid_candidates.shape[0]} valid collision samples generated"
+ )
+
if augmented_list:
all_augmented = jnp.concatenate(augmented_list, axis=0)[:samples_needed]
- collision_samples = jnp.concatenate([collision_samples, all_augmented], axis=0)
+ collision_samples = jnp.concatenate(
+ [collision_samples, all_augmented], axis=0
+ )
num_collision = collision_samples.shape[0]
logger.info(f" Final collision sample count: {num_collision}")
-
+
# Augment near-collision samples if needed
if num_near_collision < target_near and num_near_collision > 0:
- logger.info(f"Augmenting near-collision samples from {num_near_collision} to {target_near}...")
-
+ logger.info(
+ f"Augmenting near-collision samples from {num_near_collision} to {target_near}..."
+ )
+
samples_needed = target_near - num_near_collision
augmented_list = []
iteration = 0
-
- while len(augmented_list) < samples_needed and iteration < max_augment_iterations:
+
+ while (
+ len(augmented_list) < samples_needed and iteration < max_augment_iterations
+ ):
iteration += 1
batch_size_aug = min(samples_needed * 2, 5000)
key_augment, subk1, subk2 = jax.random.split(key_augment, 3)
- indices = jax.random.randint(subk1, (batch_size_aug,), 0, num_near_collision)
+ indices = jax.random.randint(
+ subk1, (batch_size_aug,), 0, num_near_collision
+ )
base_samples = near_collision_samples[indices]
-
+
perturbation_range = perturbation_scale_near * (upper_limits - lower_limits)
- perturbations = jax.random.uniform(subk2, (batch_size_aug, dof), minval=-1, maxval=1) * perturbation_range
- candidates = jnp.clip(base_samples + perturbations, lower_limits, upper_limits)
-
+ perturbations = (
+ jax.random.uniform(subk2, (batch_size_aug, dof), minval=-1, maxval=1)
+ * perturbation_range
+ )
+ candidates = jnp.clip(
+ base_samples + perturbations, lower_limits, upper_limits
+ )
+
candidate_dists = distance_fn(candidates)
valid_mask = (candidate_dists > 0) & (candidate_dists < collision_threshold)
valid_candidates = candidates[valid_mask]
-
+
if valid_candidates.shape[0] > 0:
augmented_list.append(valid_candidates)
-
+
if augmented_list:
all_augmented = jnp.concatenate(augmented_list, axis=0)[:samples_needed]
- near_collision_samples = jnp.concatenate([near_collision_samples, all_augmented], axis=0)
+ near_collision_samples = jnp.concatenate(
+ [near_collision_samples, all_augmented], axis=0
+ )
num_near_collision = near_collision_samples.shape[0]
logger.info(f" Final near-collision sample count: {num_near_collision}")
-
+
# Construct final training set
actual_collision = min(num_collision, target_collision)
actual_near = min(num_near_collision, target_near)
actual_free = max(0, num_samples - actual_collision - actual_near)
actual_free = min(actual_free, num_free)
-
- logger.info(f"Assembling training set: collision={actual_collision}, near={actual_near}, free={actual_free}")
-
+
+ logger.info(
+ f"Assembling training set: collision={actual_collision}, near={actual_near}, free={actual_free}"
+ )
+
# Select samples from each category
key_augment, subk = jax.random.split(key_augment)
-
- selected_collision = collision_samples[:actual_collision] if actual_collision > 0 else jnp.empty((0, dof))
- selected_near = near_collision_samples[:actual_near] if actual_near > 0 else jnp.empty((0, dof))
-
+
+ selected_collision = (
+ collision_samples[:actual_collision]
+ if actual_collision > 0
+ else jnp.empty((0, dof))
+ )
+ selected_near = (
+ near_collision_samples[:actual_near] if actual_near > 0 else jnp.empty((0, dof))
+ )
+
if actual_free > 0 and num_free > 0:
- free_indices = jax.random.choice(subk, num_free, shape=(actual_free,), replace=False)
+ free_indices = jax.random.choice(
+ subk, num_free, shape=(actual_free,), replace=False
+ )
selected_free = free_space_samples[free_indices]
else:
selected_free = jnp.empty((0, dof))
-
+
# Combine all samples
- parts = [p for p in [selected_collision, selected_near, selected_free] if p.shape[0] > 0]
+ parts = [
+ p for p in [selected_collision, selected_near, selected_free] if p.shape[0] > 0
+ ]
result = jnp.concatenate(parts, axis=0) if parts else jnp.empty((0, dof))
-
+
# Fill shortfall from original pool if needed
if result.shape[0] < num_samples:
shortfall = num_samples - result.shape[0]
logger.info(f"Filling shortfall of {shortfall} samples from original pool...")
key_augment, subk = jax.random.split(key_augment)
- extra_indices = jax.random.choice(subk, samples.shape[0], shape=(shortfall,), replace=True)
+ extra_indices = jax.random.choice(
+ subk, samples.shape[0], shape=(shortfall,), replace=True
+ )
extra_samples = samples[extra_indices]
result = jnp.concatenate([result, extra_samples], axis=0)
-
+
# Shuffle the result
key_augment, subk = jax.random.split(key_augment)
shuffle_perm = jax.random.permutation(subk, result.shape[0])
result = result[shuffle_perm][:num_samples]
-
+
logger.info(f"Final training set: {result.shape[0]} samples")
-
+
return result
@@ -338,6 +437,7 @@ def jax_log(fmt: str, *args, **kwargs) -> None:
"""Emit a loguru info message from a JITed JAX function."""
jax.debug.callback(partial(_log, fmt), *args, **kwargs)
-@partial(jax.jit, static_argnames=['dtype'])
+
+@partial(jax.jit, static_argnames=["dtype"])
def quantize(tree, dtype=jax.numpy.float16):
- return jax.tree.map(lambda x: x.astype(dtype) if hasattr(x, 'astype') else x, tree)
\ No newline at end of file
+ return jax.tree.map(lambda x: x.astype(dtype) if hasattr(x, "astype") else x, tree)