diff --git a/.gitignore b/.gitignore index c14d2fd..e0d3934 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,5 @@ /output.data /output.json /*.hh +/*.rs /.cpm-cache/ diff --git a/resources/templates/ccfk_template.rs b/resources/templates/ccfk_template.rs new file mode 100644 index 0000000..f37caf9 --- /dev/null +++ b/resources/templates/ccfk_template.rs @@ -0,0 +1,81 @@ +{% for i in range(length(links_with_geometry)) %} +{% set array_index = length(links_with_geometry) - i - 1 %} +{% set link_index = at(links_with_geometry, array_index) %} +{% set link_spheres = at(per_link_spheres, link_index) %} +{% set bs_loc = (n_spheres + array_index) * 4 %} + +// +// environment vs. robot collisions +// + +// {{ at(link_names, link_index) }} +if sphere_environment_in_collision(environment, + y[{{bs_loc + 0}}], + y[{{bs_loc + 1}}], + y[{{bs_loc + 2}}], + y[{{bs_loc + 3}}]) +{ + {% for j in range(length(link_spheres)) %} + {% set sphere_loc = at(link_spheres, j) * 4 %} + if sphere_environment_in_collision(environment, + y[{{ sphere_loc + 0 }}], + y[{{ sphere_loc + 1 }}], + y[{{ sphere_loc + 2 }}], + y[{{ sphere_loc + 3 }}]) + { + return false; + } + {% endfor %} +} + +{% endfor %} + +// +// robot self-collisions +// + +{% for i in range(length(allowed_link_pairs)) %} +{% set pair = at(allowed_link_pairs, i) %} +{% set link_1_index = at(pair, 0) %} +{% set link_2_index = at(pair, 1) %} +{% set link_1_bs = at(bounding_sphere_index, link_1_index) %} +{% set link_2_bs = at(bounding_sphere_index, link_2_index) %} +{% set link_1_spheres = at(per_link_spheres, link_1_index) %} +{% set link_2_spheres = at(per_link_spheres, link_2_index) %} +{% set link_1_bs_loc = (n_spheres + link_1_bs) * 4 %} +{% set link_2_bs_loc = (n_spheres + link_2_bs) * 4 %} + +// {{ at(link_names, link_1_index) }} vs. {{ at(link_names, link_2_index) }} +if sphere_sphere_self_collision( + y[{{link_1_bs_loc + 0}}], + y[{{link_1_bs_loc + 1}}], + y[{{link_1_bs_loc + 2}}], + y[{{link_1_bs_loc + 3}}], + y[{{link_2_bs_loc + 0}}], + y[{{link_2_bs_loc + 1}}], + y[{{link_2_bs_loc + 2}}], + y[{{link_2_bs_loc + 3}}] +) { + {% for j in range(length(link_1_spheres)) %} + {% for k in range(length(link_2_spheres)) %} + + {% set sphere_1_loc = at(link_1_spheres, j) %} + {% set sphere_2_loc = at(link_2_spheres, k) %} + + if sphere_sphere_self_collision( + y[{{ sphere_1_loc * 4 + 0}} ], + y[{{ sphere_1_loc * 4 + 1}} ], + y[{{ sphere_1_loc * 4 + 2}} ], + y[{{ sphere_1_loc * 4 + 3}} ], + y[{{ sphere_2_loc * 4 + 0}} ], + y[{{ sphere_2_loc * 4 + 1}} ], + y[{{ sphere_2_loc * 4 + 2}} ], + y[{{ sphere_2_loc * 4 + 3}} ] + ) { + return false; + } + + {% endfor %} + {% endfor %} +} +{% endfor %} diff --git a/resources/templates/fk_template.rs b/resources/templates/fk_template.rs new file mode 100644 index 0000000..66a1245 --- /dev/null +++ b/resources/templates/fk_template.rs @@ -0,0 +1,29 @@ +use core::simd::Simd; + +use elain::{Align, Alignment}; + +use crate::{ + env::World3d, + robot::{sphere_environment_in_collision, sphere_sphere_self_collision}, + cos, sin, +}; + +#[expect( + non_snake_case, + clippy::too_many_lines, + clippy::cognitive_complexity, + clippy::unreadable_literal, + clippy::approx_constant, + clippy::collapsible_if +)] +pub fn fkcc(x: &super::ConfigurationBlock, environment: &World3d) -> bool +where + Align: Alignment, +{ + let mut v = [Simd::splat(0.0); {{ccfk_code_vars}}]; + let mut y = [Simd::splat(0.0); {{ccfk_code_output}}]; + + {{ccfk_code}} + {% include "ccfk" %} + true +} diff --git a/src/fkcc_gen.cc b/src/fkcc_gen.cc index a5ab49e..c17055e 100644 --- a/src/fkcc_gen.cc +++ b/src/fkcc_gen.cc @@ -25,7 +25,8 @@ #include #include -#include "lang_gen.hh" +#include "lang_cpp.hh" +#include "lang_rust.hh" using namespace pinocchio; using namespace CppAD; @@ -461,6 +462,7 @@ struct Traced auto trace_sphere_cc_fk( const RobotInfo &info, + const std::string &language, bool spheres = true, bool bounding_spheres = true, bool fk = true) -> Traced @@ -523,11 +525,23 @@ auto trace_sphere_cc_fk( CppAD::vector result = collision_sphere_func.Forward(0, ind_vars); - LanguageCCustom langC("double"); LangCDefaultVariableNameGenerator nameGen; - std::ostringstream function_code; - handler.generateCode(function_code, langC, result, nameGen); + + if (language == "c++") + { + LanguageCCustom langC("double"); + handler.generateCode(function_code, langC, result, nameGen); + } + else if (language == "rust") + { + LanguageRust langRust("double"); + handler.generateCode(function_code, langRust, result, nameGen); + } + else + { + throw std::runtime_error(fmt::format("unsupported language {}", language)); + } return Traced{function_code.str(), handler.getTemporaryVariableCount(), n_out}; } @@ -598,26 +612,32 @@ int main(int argc, char **argv) end_effector_name = data["end_effector"]; } + std::string language = "c++"; + if (data.contains("language")) + { + language = data["language"]; + } + RobotInfo robot(parent_path / data["urdf"], srdf_path, end_effector_name); data.update(robot.json()); - auto traced_eefk_code = trace_sphere_cc_fk(robot, false, false, true); + auto traced_eefk_code = trace_sphere_cc_fk(robot, language, false, false, true); data["eefk_code"] = traced_eefk_code.code; data["eefk_code_vars"] = traced_eefk_code.temp_variables; data["eefk_code_output"] = traced_eefk_code.outputs; - auto traced_spherefk_code = trace_sphere_cc_fk(robot, true, false, false); + auto traced_spherefk_code = trace_sphere_cc_fk(robot, language, true, false, false); data["spherefk_code"] = traced_spherefk_code.code; data["spherefk_code_vars"] = traced_spherefk_code.temp_variables; data["spherefk_code_output"] = traced_spherefk_code.outputs; - auto traced_ccfk_code = trace_sphere_cc_fk(robot, true, true, false); + auto traced_ccfk_code = trace_sphere_cc_fk(robot, language, true, true, false); data["ccfk_code"] = traced_ccfk_code.code; data["ccfk_code_vars"] = traced_ccfk_code.temp_variables; data["ccfk_code_output"] = traced_ccfk_code.outputs; - auto traced_ccfkee_code = trace_sphere_cc_fk(robot, true, true, true); + auto traced_ccfkee_code = trace_sphere_cc_fk(robot, language, true, true, true); data["ccfkee_code"] = traced_ccfkee_code.code; data["ccfkee_code_vars"] = traced_ccfkee_code.temp_variables; data["ccfkee_code_output"] = traced_ccfkee_code.outputs; diff --git a/src/lang_gen.hh b/src/lang_cpp.hh similarity index 100% rename from src/lang_gen.hh rename to src/lang_cpp.hh diff --git a/src/lang_rust.hh b/src/lang_rust.hh new file mode 100644 index 0000000..3ffd3c4 --- /dev/null +++ b/src/lang_rust.hh @@ -0,0 +1,51 @@ +#pragma once + +#include +#include "cppad/cg/lang/c/language_c.hpp" + +namespace CppAD +{ + namespace cg + { + + template + class LanguageRust : public LanguageC + { + public: + explicit LanguageRust(std::string varTypeName, size_t spaces = 3) + : LanguageC(varTypeName, spaces) + { + } + + virtual void printParameter(const Base &value) + { + writeParameter(value, LanguageRust::_code); + } + + virtual void pushParameter(const Base &value) + { + writeParameter(value, LanguageRust::_streamStack); + } + + template + void writeParameter(const Base &value, Output &output) + { + // make sure all digits of floating point values are printed + std::ostringstream os; + os << std::setprecision(LanguageRust::_parameterPrecision) << value; + + std::string number = os.str(); + output << "Simd::::splat("; + output << number; + + if (number.find('.') == std::string::npos && number.find('e') == std::string::npos) + { + // also make sure there is always a '.' after the number in + // order to avoid integer overflows + output << '.'; + } + output << ")"; + } + }; + } // namespace cg +} // namespace CppAD