From b6a5be037c9cae096c6c8ddeabbb1bef03472a04 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 22 Nov 2025 19:02:41 -0800 Subject: [PATCH 001/116] Bare bones of a JAX extension. --- openequivariance/extension/CMakeLists.txt | 38 +++++ openequivariance/extension/libjax_tp_jit.cpp | 146 +++++++++++++++++++ 2 files changed, 184 insertions(+) create mode 100644 openequivariance/extension/CMakeLists.txt create mode 100644 openequivariance/extension/libjax_tp_jit.cpp diff --git a/openequivariance/extension/CMakeLists.txt b/openequivariance/extension/CMakeLists.txt new file mode 100644 index 00000000..c28bc346 --- /dev/null +++ b/openequivariance/extension/CMakeLists.txt @@ -0,0 +1,38 @@ +cmake_minimum_required(VERSION 3.15...3.30) +project(OEQ_JAX LANGUAGES CXX CUDA) # TODO: Add HIP support + +find_package(Python 3.11 REQUIRED COMPONENTS Interpreter Development.Module) + +execute_process( + COMMAND "${Python_EXECUTABLE}" "-c" + "from jax import ffi; print(ffi.include_dir())" + OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR +) + +message(STATUS "XLA include directory: ${XLA_DIR}") + +execute_process( + COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir + OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE nanobind_ROOT +) + +find_package(nanobind CONFIG REQUIRED) + +set(OEQ_JAX_SOURCES + libjax_tp_jit.cpp +) + +set(OEQ_JAX_HEADERS + convolution.hpp + tensorproducts.hpp + util/backend_cuda.hpp + util/backend_hip.hpp + buffer.hpp +) + +nanobind_add_module(OEQ_JAX NB_STATIC ${OEQ_JAX_SOURCES} ${OEQ_JAX_HEADERS}) + +target_include_directories(OEQ_JAX PUBLIC ${XLA_DIR}) +set_target_properties(OEQ_JAX PROPERTIES CUDA_STANDARD 17 POSITION_INDEPENDENT_CODE ON) + +install(TARGETS OEQ_JAX LIBRARY DESTINATION lib) \ No newline at end of file diff --git a/openequivariance/extension/libjax_tp_jit.cpp b/openequivariance/extension/libjax_tp_jit.cpp new file mode 100644 index 00000000..cbd99db0 --- /dev/null +++ b/openequivariance/extension/libjax_tp_jit.cpp @@ -0,0 +1,146 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include "nanobind/nanobind.h" +#include "xla/ffi/api/ffi.h" + +namespace nb = nanobind; +namespace ffi = xla::ffi; + +// ---------- +// Attributes +// ---------- +// +// An example demonstrating the different ways that attributes can be passed to +// the FFI. +// +// For example, we can pass arrays, variadic attributes, and user-defined types. +// Full support of user-defined types isn't yet supported by XLA, so that +// example will be added in the future. + +ffi::Error ArrayAttrImpl(ffi::Span array, + ffi::ResultBufferR0 res) { + int64_t total = 0; + for (int32_t x : array) { + total += x; + } + res->typed_data()[0] = total; + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(ArrayAttr, ArrayAttrImpl, + ffi::Ffi::Bind() + .Attr>("array") + .Ret>()); + +ffi::Error DictionaryAttrImpl(ffi::Dictionary attrs, + ffi::ResultBufferR0 secret, + ffi::ResultBufferR0 count) { + auto maybe_secret = attrs.get("secret"); + if (maybe_secret.has_error()) { + return maybe_secret.error(); + } + secret->typed_data()[0] = maybe_secret.value(); + count->typed_data()[0] = attrs.size(); + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(DictionaryAttr, DictionaryAttrImpl, + ffi::Ffi::Bind() + .Attrs() + .Ret>() + .Ret>()); + +// ------- +// Counter +// ------- +// +// An example demonstrating how an FFI call can maintain "state" between calls +// +// In this case, the ``Counter`` call simply accumulates the number of times it +// was executed, but this pattern can also be used for more advanced use cases. +// For example, this pattern is used in jaxlib for: +// +// 1. The GPU solver linear algebra kernels which require an expensive "handler" +// initialization, and +// 2. The ``triton_call`` function which caches the compiled triton modules +// after their first use. + +ffi::Error CounterImpl(std::string_view key, ffi::ResultBufferR0 out) { + static std::mutex mutex; + static auto &cache = *new std::unordered_map(); + { + const std::lock_guard lock(mutex); + /*auto it = cache.find(key); + if (it != cache.end()) { + out->typed_data()[0] = ++it->second; + } else { + cache.insert({key, 0}); + out->typed_data()[0] = 0; + }*/ + } + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + Counter, CounterImpl, + ffi::Ffi::Bind().Attr("key").Ret>()); + +// -------- +// Aliasing +// -------- +// +// This example demonstrates how input-output aliasing works. The handler +// doesn't do anything except to check that the input and output pointers +// address the same data. + +ffi::Error AliasingImpl(ffi::AnyBuffer input, + ffi::Result output) { + if (input.element_type() != output->element_type() || + input.element_count() != output->element_count()) { + return ffi::Error::InvalidArgument( + "The input and output data types and sizes must match."); + } + if (input.untyped_data() != output->untyped_data()) { + return ffi::Error::InvalidArgument( + "When aliased, the input and output buffers should point to the same " + "data."); + } + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + Aliasing, AliasingImpl, + ffi::Ffi::Bind().Arg().Ret()); + +// Boilerplate for exposing handlers to Python +NB_MODULE(_cpu_examples, m) { + m.def("registrations", []() { + nb::dict registrations; + registrations["array_attr"] = + nb::capsule(reinterpret_cast(ArrayAttr)); + registrations["dictionary_attr"] = + nb::capsule(reinterpret_cast(DictionaryAttr)); + registrations["counter"] = nb::capsule(reinterpret_cast(Counter)); + registrations["aliasing"] = nb::capsule(reinterpret_cast(Aliasing)); + return registrations; + }); +} From 8efa3a354cb948bcd2e79a40f8eef9bbff66dd40 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 22 Nov 2025 22:02:47 -0800 Subject: [PATCH 002/116] Continued progress. --- openequivariance/extension/CMakeLists.txt | 13 ++-- openequivariance/extension/libjax_tp_jit.cpp | 81 ++++---------------- 2 files changed, 24 insertions(+), 70 deletions(-) diff --git a/openequivariance/extension/CMakeLists.txt b/openequivariance/extension/CMakeLists.txt index c28bc346..15f7fa48 100644 --- a/openequivariance/extension/CMakeLists.txt +++ b/openequivariance/extension/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required(VERSION 3.15...3.30) -project(OEQ_JAX LANGUAGES CXX CUDA) # TODO: Add HIP support +project(oeq_jax_extension LANGUAGES CXX CUDA) # TODO: Add HIP support find_package(Python 3.11 REQUIRED COMPONENTS Interpreter Development.Module) @@ -27,12 +27,13 @@ set(OEQ_JAX_HEADERS tensorproducts.hpp util/backend_cuda.hpp util/backend_hip.hpp - buffer.hpp + util/buffer.hpp ) -nanobind_add_module(OEQ_JAX NB_STATIC ${OEQ_JAX_SOURCES} ${OEQ_JAX_HEADERS}) +nanobind_add_module(oeq_jax_extension NB_STATIC ${OEQ_JAX_SOURCES} ${OEQ_JAX_HEADERS}) -target_include_directories(OEQ_JAX PUBLIC ${XLA_DIR}) -set_target_properties(OEQ_JAX PROPERTIES CUDA_STANDARD 17 POSITION_INDEPENDENT_CODE ON) +target_include_directories(oeq_jax_extension PUBLIC ${XLA_DIR}) +set_target_properties(oeq_jax_extension PROPERTIES CUDA_STANDARD 17 POSITION_INDEPENDENT_CODE ON) +target_compile_options(oeq_jax_extension PRIVATE -Wno-attributes -Wno-return-type) -install(TARGETS OEQ_JAX LIBRARY DESTINATION lib) \ No newline at end of file +install(TARGETS oeq_jax_extension LIBRARY DESTINATION lib) \ No newline at end of file diff --git a/openequivariance/extension/libjax_tp_jit.cpp b/openequivariance/extension/libjax_tp_jit.cpp index cbd99db0..38c8bc1f 100644 --- a/openequivariance/extension/libjax_tp_jit.cpp +++ b/openequivariance/extension/libjax_tp_jit.cpp @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include "nanobind/nanobind.h" #include "xla/ffi/api/ffi.h" @@ -25,17 +26,7 @@ limitations under the License. namespace nb = nanobind; namespace ffi = xla::ffi; -// ---------- -// Attributes -// ---------- -// -// An example demonstrating the different ways that attributes can be passed to -// the FFI. -// -// For example, we can pass arrays, variadic attributes, and user-defined types. -// Full support of user-defined types isn't yet supported by XLA, so that -// example will be added in the future. - +/* ffi::Error ArrayAttrImpl(ffi::Span array, ffi::ResultBufferR0 res) { int64_t total = 0; @@ -68,25 +59,17 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DictionaryAttr, DictionaryAttrImpl, .Attrs() .Ret>() .Ret>()); +*/ -// ------- -// Counter -// ------- -// -// An example demonstrating how an FFI call can maintain "state" between calls -// -// In this case, the ``Counter`` call simply accumulates the number of times it -// was executed, but this pattern can also be used for more advanced use cases. -// For example, this pattern is used in jaxlib for: -// -// 1. The GPU solver linear algebra kernels which require an expensive "handler" -// initialization, and -// 2. The ``triton_call`` function which caches the compiled triton modules -// after their first use. - -ffi::Error CounterImpl(std::string_view key, ffi::ResultBufferR0 out) { +ffi::Error tp_forward_impl(std::string_view kernel, + ffi::Dictionary forward_config, + ffi::ResultBufferR0 out) { static std::mutex mutex; static auto &cache = *new std::unordered_map(); + + auto value = forward_config.get("example_key").value(); + std::cout << value << std::endl; + { const std::lock_guard lock(mutex); /*auto it = cache.find(key); @@ -101,46 +84,16 @@ ffi::Error CounterImpl(std::string_view key, ffi::ResultBufferR0 out) } XLA_FFI_DEFINE_HANDLER_SYMBOL( - Counter, CounterImpl, - ffi::Ffi::Bind().Attr("key").Ret>()); - -// -------- -// Aliasing -// -------- -// -// This example demonstrates how input-output aliasing works. The handler -// doesn't do anything except to check that the input and output pointers -// address the same data. - -ffi::Error AliasingImpl(ffi::AnyBuffer input, - ffi::Result output) { - if (input.element_type() != output->element_type() || - input.element_count() != output->element_count()) { - return ffi::Error::InvalidArgument( - "The input and output data types and sizes must match."); - } - if (input.untyped_data() != output->untyped_data()) { - return ffi::Error::InvalidArgument( - "When aliased, the input and output buffers should point to the same " - "data."); - } - return ffi::Error::Success(); -} - -XLA_FFI_DEFINE_HANDLER_SYMBOL( - Aliasing, AliasingImpl, - ffi::Ffi::Bind().Arg().Ret()); + tp_forward, tp_forward_impl, + ffi::Ffi::Bind() + .Attr("kernel") + .Attr("forward_config") + .Ret>()); -// Boilerplate for exposing handlers to Python -NB_MODULE(_cpu_examples, m) { +NB_MODULE(oeq_jax_extension, m) { m.def("registrations", []() { nb::dict registrations; - registrations["array_attr"] = - nb::capsule(reinterpret_cast(ArrayAttr)); - registrations["dictionary_attr"] = - nb::capsule(reinterpret_cast(DictionaryAttr)); - registrations["counter"] = nb::capsule(reinterpret_cast(Counter)); - registrations["aliasing"] = nb::capsule(reinterpret_cast(Aliasing)); + registrations["tp_forward"] = nb::capsule(reinterpret_cast(tp_forward)); return registrations; }); } From 4be44d7499b34f59c20b732f50b27afc456d1130 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 22 Nov 2025 22:26:36 -0800 Subject: [PATCH 003/116] More changes. --- openequivariance/extension/libjax_tp_jit.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/openequivariance/extension/libjax_tp_jit.cpp b/openequivariance/extension/libjax_tp_jit.cpp index 38c8bc1f..b3033953 100644 --- a/openequivariance/extension/libjax_tp_jit.cpp +++ b/openequivariance/extension/libjax_tp_jit.cpp @@ -19,6 +19,8 @@ limitations under the License. #include #include #include +#include +#include #include "nanobind/nanobind.h" #include "xla/ffi/api/ffi.h" @@ -61,7 +63,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DictionaryAttr, DictionaryAttrImpl, .Ret>()); */ -ffi::Error tp_forward_impl(std::string_view kernel, +ffi::Error tp_forward_impl( + cudaStream_t stream, + std::string_view kernel, ffi::Dictionary forward_config, ffi::ResultBufferR0 out) { static std::mutex mutex; @@ -69,7 +73,6 @@ ffi::Error tp_forward_impl(std::string_view kernel, auto value = forward_config.get("example_key").value(); std::cout << value << std::endl; - { const std::lock_guard lock(mutex); /*auto it = cache.find(key); @@ -86,9 +89,11 @@ ffi::Error tp_forward_impl(std::string_view kernel, XLA_FFI_DEFINE_HANDLER_SYMBOL( tp_forward, tp_forward_impl, ffi::Ffi::Bind() + .Ctx>() .Attr("kernel") .Attr("forward_config") - .Ret>()); + .Ret>(), + {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled NB_MODULE(oeq_jax_extension, m) { m.def("registrations", []() { From b758e4bc52f3f2d2ba793d617a463c83fe08c3fb Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 22 Nov 2025 23:14:54 -0800 Subject: [PATCH 004/116] Added all relevant parameters. --- openequivariance/extension/libjax_tp_jit.cpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/openequivariance/extension/libjax_tp_jit.cpp b/openequivariance/extension/libjax_tp_jit.cpp index b3033953..7cddc203 100644 --- a/openequivariance/extension/libjax_tp_jit.cpp +++ b/openequivariance/extension/libjax_tp_jit.cpp @@ -65,9 +65,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DictionaryAttr, DictionaryAttrImpl, ffi::Error tp_forward_impl( cudaStream_t stream, - std::string_view kernel, - ffi::Dictionary forward_config, - ffi::ResultBufferR0 out) { + std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + int64_t hash, ffi::ResultBufferR0 out) { static std::mutex mutex; static auto &cache = *new std::unordered_map(); @@ -90,8 +89,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( tp_forward, tp_forward_impl, ffi::Ffi::Bind() .Ctx>() - .Attr("kernel") - .Attr("forward_config") + .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("hash") .Ret>(), {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled From b0fb34c11a27e266efcb0e28ffc9d34e2ea40fbf Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 23 Nov 2025 00:57:48 -0800 Subject: [PATCH 005/116] Should be able to compile a kernel. --- openequivariance/extension/libjax_tp_jit.cpp | 136 ++++++++++--------- 1 file changed, 73 insertions(+), 63 deletions(-) diff --git a/openequivariance/extension/libjax_tp_jit.cpp b/openequivariance/extension/libjax_tp_jit.cpp index 7cddc203..ebafe66e 100644 --- a/openequivariance/extension/libjax_tp_jit.cpp +++ b/openequivariance/extension/libjax_tp_jit.cpp @@ -1,21 +1,7 @@ -/* Copyright 2024 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - #include #include #include +#include #include #include #include @@ -25,64 +11,88 @@ limitations under the License. #include "nanobind/nanobind.h" #include "xla/ffi/api/ffi.h" +#define CUDA_BACKEND // Stick to CUDA for now + +#ifdef CUDA_BACKEND + #include "util/backend_cuda.hpp" + #include "group_mm_cuda.hpp" + using JITKernel = CUJITKernel; + using GPU_Allocator = CUDA_Allocator; + + template + using GroupMM = GroupMMCUDA; +#endif + +#include "tensorproducts.hpp" + namespace nb = nanobind; namespace ffi = xla::ffi; -/* -ffi::Error ArrayAttrImpl(ffi::Span array, - ffi::ResultBufferR0 res) { - int64_t total = 0; - for (int32_t x : array) { - total += x; - } - res->typed_data()[0] = total; - return ffi::Error::Success(); -} +std::unordered_map>> kernel_cache; +std::mutex mut; -XLA_FFI_DEFINE_HANDLER_SYMBOL(ArrayAttr, ArrayAttrImpl, - ffi::Ffi::Bind() - .Attr>("array") - .Ret>()); +std::unordered_map parse_launch_config(ffi::Dictionary dict) { + std::unordered_map result; + result["num_blocks"] = dict.get("num_blocks").value(); + result["num_threads"] = dict.get("num_threads").value(); + result["warp_size"] = dict.get("warp_size").value(); + result["smem"] = dict.get("smem").value(); + return result; +} -ffi::Error DictionaryAttrImpl(ffi::Dictionary attrs, - ffi::ResultBufferR0 secret, - ffi::ResultBufferR0 count) { - auto maybe_secret = attrs.get("secret"); - if (maybe_secret.has_error()) { - return maybe_secret.error(); - } - secret->typed_data()[0] = maybe_secret.value(); - count->typed_data()[0] = attrs.size(); - return ffi::Error::Success(); +std::unordered_map parse_kernel_prop(ffi::Dictionary dict) { + std::unordered_map result; + result["L1_dim"] = dict.get("L1_dim").value(); + result["L2_dim"] = dict.get("L2_dim").value(); + result["L3_dim"] = dict.get("L3_dim").value(); + result["weight_numel"] = dict.get("weight_numel").value(); + result["shared_weights"] = dict.get("shared_weights").value(); + result["opt_level"] = dict.get("opt_level").value(); + result["irrep_dtype"] = dict.get("irrep_dtype").value(); + result["weight_dtype"] = dict.get("weight_dtype").value(); + return result; } -XLA_FFI_DEFINE_HANDLER_SYMBOL(DictionaryAttr, DictionaryAttrImpl, - ffi::Ffi::Bind() - .Attrs() - .Ret>() - .Ret>()); -*/ +JITTPImpl* compile_kernel_with_caching(std::string_view kernel, + ffi::Dictionary forward_config, + ffi::Dictionary backward_config, + ffi::Dictionary double_backward_config, + ffi::Dictionary kernel_prop, + int64_t hash) { + + JITTPImpl* result = nullptr; + { + const std::lock_guard lock(mut); + auto it = kernel_cache.find(hash); + if (it != kernel_cache.end()) { + result = it->second.get(); + } + else { + auto jit_tp_impl = std::make_unique>( + std::string(kernel), + parse_launch_config(forward_config), + parse_launch_config(backward_config), + parse_launch_config(double_backward_config), + parse_kernel_prop(kernel_prop)); + result = jit_tp_impl.get(); + kernel_cache.insert({hash, std::move(jit_tp_impl)}); + } + } + return result; +} ffi::Error tp_forward_impl( - cudaStream_t stream, - std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, - int64_t hash, ffi::ResultBufferR0 out) { - static std::mutex mutex; - static auto &cache = *new std::unordered_map(); + cudaStream_t stream, + std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + int64_t hash, ffi::ResultBufferR0 out) { + + JITTPImpl* jit_kernel = compile_kernel_with_caching( + kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash); + + std::cout << "SUCCESSFULLY COMPILED KERNEL!" << std::endl; + // TODO: Launch the forward kernel here - auto value = forward_config.get("example_key").value(); - std::cout << value << std::endl; - { - const std::lock_guard lock(mutex); - /*auto it = cache.find(key); - if (it != cache.end()) { - out->typed_data()[0] = ++it->second; - } else { - cache.insert({key, 0}); - out->typed_data()[0] = 0; - }*/ - } - return ffi::Error::Success(); + return ffi::Error::Success(); } XLA_FFI_DEFINE_HANDLER_SYMBOL( From d932d21a0cdae21fff71b0f6ebbcc86e4d7fe9d0 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 23 Nov 2025 11:45:59 -0800 Subject: [PATCH 006/116] Cleaned up code a bit. --- openequivariance/extension/libjax_tp_jit.cpp | 46 ++++++++++---------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/openequivariance/extension/libjax_tp_jit.cpp b/openequivariance/extension/libjax_tp_jit.cpp index ebafe66e..3cf922ad 100644 --- a/openequivariance/extension/libjax_tp_jit.cpp +++ b/openequivariance/extension/libjax_tp_jit.cpp @@ -31,25 +31,25 @@ namespace ffi = xla::ffi; std::unordered_map>> kernel_cache; std::mutex mut; -std::unordered_map parse_launch_config(ffi::Dictionary dict) { - std::unordered_map result; - result["num_blocks"] = dict.get("num_blocks").value(); - result["num_threads"] = dict.get("num_threads").value(); - result["warp_size"] = dict.get("warp_size").value(); - result["smem"] = dict.get("smem").value(); - return result; -} +std::vector launch_config_keys = { + "num_blocks", + "num_threads", + "smem"}; +std::vector kernel_prop_keys = { + "L1_dim", + "L2_dim", + "L3_dim", + "weight_numel", + "shared_weights", + "opt_level", + "irrep_dtype", + "weight_dtype"}; -std::unordered_map parse_kernel_prop(ffi::Dictionary dict) { +std::unordered_map parse_ffi_dict(ffi::Dictionary &dict, const std::vector &keys) { std::unordered_map result; - result["L1_dim"] = dict.get("L1_dim").value(); - result["L2_dim"] = dict.get("L2_dim").value(); - result["L3_dim"] = dict.get("L3_dim").value(); - result["weight_numel"] = dict.get("weight_numel").value(); - result["shared_weights"] = dict.get("shared_weights").value(); - result["opt_level"] = dict.get("opt_level").value(); - result["irrep_dtype"] = dict.get("irrep_dtype").value(); - result["weight_dtype"] = dict.get("weight_dtype").value(); + for (const auto &key : keys) { + result[key] = dict.get(key).value(); + } return result; } @@ -67,13 +67,13 @@ JITTPImpl* compile_kernel_with_caching(std::string_view kernel, if (it != kernel_cache.end()) { result = it->second.get(); } - else { + else { auto jit_tp_impl = std::make_unique>( std::string(kernel), - parse_launch_config(forward_config), - parse_launch_config(backward_config), - parse_launch_config(double_backward_config), - parse_kernel_prop(kernel_prop)); + parse_ffi_dict(forward_config, launch_config_keys), + parse_ffi_dict(backward_config, launch_config_keys), + parse_ffi_dict(double_backward_config, launch_config_keys), + parse_ffi_dict(kernel_prop, kernel_prop_keys)); result = jit_tp_impl.get(); kernel_cache.insert({hash, std::move(jit_tp_impl)}); } @@ -86,7 +86,7 @@ ffi::Error tp_forward_impl( std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, int64_t hash, ffi::ResultBufferR0 out) { - JITTPImpl* jit_kernel = compile_kernel_with_caching( + auto jit_kernel = compile_kernel_with_caching( kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash); std::cout << "SUCCESSFULLY COMPILED KERNEL!" << std::endl; From 1069ae32cb7c33b3b23baf64b93a4831f4267160 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 23 Nov 2025 22:38:03 -0800 Subject: [PATCH 007/116] Reorganized the repo. --- openequivariance/benchmark/tpp_creation_utils.py | 4 +--- .../{implementations => core}/ComputationSchedule.py | 0 .../{implementations/convolution => core}/ConvolutionBase.py | 0 .../{implementations/convolution => core}/LoopUnrollConv.py | 0 openequivariance/{implementations => core}/LoopUnrollTP.py | 0 .../{implementations => core}/TensorProductBase.py | 0 openequivariance/{implementations => core}/dtype_enum.py | 0 openequivariance/{implementations => core}/e3nn_lite.py | 0 openequivariance/{implementations => core}/utils.py | 0 .../{implementations/convolution => torch}/CUEConv.py | 0 .../{implementations => torch}/CUETensorProduct.py | 0 .../{implementations/convolution => torch}/E3NNConv.py | 0 .../{implementations => torch}/E3NNTensorProduct.py | 0 .../{implementations/convolution => torch}/FlashTPConv.py | 0 openequivariance/{implementations => torch}/TensorProduct.py | 0 .../convolution => torch}/TensorProductConv.py | 0 openequivariance/{ => torch}/extlib/.empty | 0 openequivariance/{ => torch}/extlib/__init__.py | 0 .../symmetric_contraction/__init__.py | 0 .../symmetric_contraction/symmetric_contraction.py | 0 20 files changed, 1 insertion(+), 3 deletions(-) rename openequivariance/{implementations => core}/ComputationSchedule.py (100%) rename openequivariance/{implementations/convolution => core}/ConvolutionBase.py (100%) rename openequivariance/{implementations/convolution => core}/LoopUnrollConv.py (100%) rename openequivariance/{implementations => core}/LoopUnrollTP.py (100%) rename openequivariance/{implementations => core}/TensorProductBase.py (100%) rename openequivariance/{implementations => core}/dtype_enum.py (100%) rename openequivariance/{implementations => core}/e3nn_lite.py (100%) rename openequivariance/{implementations => core}/utils.py (100%) rename openequivariance/{implementations/convolution => torch}/CUEConv.py (100%) rename openequivariance/{implementations => torch}/CUETensorProduct.py (100%) rename openequivariance/{implementations/convolution => torch}/E3NNConv.py (100%) rename openequivariance/{implementations => torch}/E3NNTensorProduct.py (100%) rename openequivariance/{implementations/convolution => torch}/FlashTPConv.py (100%) rename openequivariance/{implementations => torch}/TensorProduct.py (100%) rename openequivariance/{implementations/convolution => torch}/TensorProductConv.py (100%) rename openequivariance/{ => torch}/extlib/.empty (100%) rename openequivariance/{ => torch}/extlib/__init__.py (100%) rename openequivariance/{implementations => torch}/symmetric_contraction/__init__.py (100%) rename openequivariance/{implementations => torch}/symmetric_contraction/symmetric_contraction.py (100%) diff --git a/openequivariance/benchmark/tpp_creation_utils.py b/openequivariance/benchmark/tpp_creation_utils.py index 18f3a84c..7421a71e 100644 --- a/openequivariance/benchmark/tpp_creation_utils.py +++ b/openequivariance/benchmark/tpp_creation_utils.py @@ -5,10 +5,8 @@ """ This was taken from - https://github.com/e3nn/e3nn/blob/0.5.4/e3nn/o3/_tensor_product/_sub.py - -And adopted to create TPP's to avoid torch dependence +Adapted to create TPPs to avoid torch dependence. """ diff --git a/openequivariance/implementations/ComputationSchedule.py b/openequivariance/core/ComputationSchedule.py similarity index 100% rename from openequivariance/implementations/ComputationSchedule.py rename to openequivariance/core/ComputationSchedule.py diff --git a/openequivariance/implementations/convolution/ConvolutionBase.py b/openequivariance/core/ConvolutionBase.py similarity index 100% rename from openequivariance/implementations/convolution/ConvolutionBase.py rename to openequivariance/core/ConvolutionBase.py diff --git a/openequivariance/implementations/convolution/LoopUnrollConv.py b/openequivariance/core/LoopUnrollConv.py similarity index 100% rename from openequivariance/implementations/convolution/LoopUnrollConv.py rename to openequivariance/core/LoopUnrollConv.py diff --git a/openequivariance/implementations/LoopUnrollTP.py b/openequivariance/core/LoopUnrollTP.py similarity index 100% rename from openequivariance/implementations/LoopUnrollTP.py rename to openequivariance/core/LoopUnrollTP.py diff --git a/openequivariance/implementations/TensorProductBase.py b/openequivariance/core/TensorProductBase.py similarity index 100% rename from openequivariance/implementations/TensorProductBase.py rename to openequivariance/core/TensorProductBase.py diff --git a/openequivariance/implementations/dtype_enum.py b/openequivariance/core/dtype_enum.py similarity index 100% rename from openequivariance/implementations/dtype_enum.py rename to openequivariance/core/dtype_enum.py diff --git a/openequivariance/implementations/e3nn_lite.py b/openequivariance/core/e3nn_lite.py similarity index 100% rename from openequivariance/implementations/e3nn_lite.py rename to openequivariance/core/e3nn_lite.py diff --git a/openequivariance/implementations/utils.py b/openequivariance/core/utils.py similarity index 100% rename from openequivariance/implementations/utils.py rename to openequivariance/core/utils.py diff --git a/openequivariance/implementations/convolution/CUEConv.py b/openequivariance/torch/CUEConv.py similarity index 100% rename from openequivariance/implementations/convolution/CUEConv.py rename to openequivariance/torch/CUEConv.py diff --git a/openequivariance/implementations/CUETensorProduct.py b/openequivariance/torch/CUETensorProduct.py similarity index 100% rename from openequivariance/implementations/CUETensorProduct.py rename to openequivariance/torch/CUETensorProduct.py diff --git a/openequivariance/implementations/convolution/E3NNConv.py b/openequivariance/torch/E3NNConv.py similarity index 100% rename from openequivariance/implementations/convolution/E3NNConv.py rename to openequivariance/torch/E3NNConv.py diff --git a/openequivariance/implementations/E3NNTensorProduct.py b/openequivariance/torch/E3NNTensorProduct.py similarity index 100% rename from openequivariance/implementations/E3NNTensorProduct.py rename to openequivariance/torch/E3NNTensorProduct.py diff --git a/openequivariance/implementations/convolution/FlashTPConv.py b/openequivariance/torch/FlashTPConv.py similarity index 100% rename from openequivariance/implementations/convolution/FlashTPConv.py rename to openequivariance/torch/FlashTPConv.py diff --git a/openequivariance/implementations/TensorProduct.py b/openequivariance/torch/TensorProduct.py similarity index 100% rename from openequivariance/implementations/TensorProduct.py rename to openequivariance/torch/TensorProduct.py diff --git a/openequivariance/implementations/convolution/TensorProductConv.py b/openequivariance/torch/TensorProductConv.py similarity index 100% rename from openequivariance/implementations/convolution/TensorProductConv.py rename to openequivariance/torch/TensorProductConv.py diff --git a/openequivariance/extlib/.empty b/openequivariance/torch/extlib/.empty similarity index 100% rename from openequivariance/extlib/.empty rename to openequivariance/torch/extlib/.empty diff --git a/openequivariance/extlib/__init__.py b/openequivariance/torch/extlib/__init__.py similarity index 100% rename from openequivariance/extlib/__init__.py rename to openequivariance/torch/extlib/__init__.py diff --git a/openequivariance/implementations/symmetric_contraction/__init__.py b/openequivariance/torch/symmetric_contraction/__init__.py similarity index 100% rename from openequivariance/implementations/symmetric_contraction/__init__.py rename to openequivariance/torch/symmetric_contraction/__init__.py diff --git a/openequivariance/implementations/symmetric_contraction/symmetric_contraction.py b/openequivariance/torch/symmetric_contraction/symmetric_contraction.py similarity index 100% rename from openequivariance/implementations/symmetric_contraction/symmetric_contraction.py rename to openequivariance/torch/symmetric_contraction/symmetric_contraction.py From 0525b347278fbb550834a90ace360783b41dd9be Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 24 Nov 2025 14:10:55 -0800 Subject: [PATCH 008/116] Refactored imports. --- docs/supported_ops.rst | 2 +- openequivariance/__init__.py | 15 ++++----------- openequivariance/benchmark/ConvBenchmarkSuite.py | 2 +- openequivariance/benchmark/TestBenchmarkSuite.py | 4 ++-- openequivariance/benchmark/benchmark_utils.py | 8 ++++---- openequivariance/benchmark/correctness_utils.py | 12 ++++++------ openequivariance/benchmark/perf_metrics_utils.py | 4 ++-- openequivariance/benchmark/random_buffer_utils.py | 2 +- openequivariance/benchmark/tpp_creation_utils.py | 2 +- openequivariance/core/ComputationSchedule.py | 2 +- openequivariance/core/ConvolutionBase.py | 11 +++++------ openequivariance/core/LoopUnrollConv.py | 8 ++++---- openequivariance/core/LoopUnrollTP.py | 8 ++++---- openequivariance/core/TensorProductBase.py | 4 ++-- openequivariance/core/utils.py | 2 +- openequivariance/torch/CUEConv.py | 5 ++--- openequivariance/torch/CUETensorProduct.py | 6 +++--- openequivariance/torch/E3NNConv.py | 4 ++-- openequivariance/torch/E3NNTensorProduct.py | 4 ++-- openequivariance/torch/FlashTPConv.py | 4 ++-- openequivariance/torch/TensorProduct.py | 4 ++-- openequivariance/torch/TensorProductConv.py | 8 ++++---- .../torch/symmetric_contraction/__init__.py | 2 +- tests/batch_test.py | 2 +- tests/benchmark.py | 12 ++++++------ tests/export_test.py | 2 +- 26 files changed, 65 insertions(+), 74 deletions(-) diff --git a/docs/supported_ops.rst b/docs/supported_ops.rst index 7f5ff78c..0cba136a 100644 --- a/docs/supported_ops.rst +++ b/docs/supported_ops.rst @@ -117,7 +117,7 @@ toplevel. You can use our implementation by running .. code-block:: - from openequivariance.implementations.symmetric_contraction import SymmetricContraction as OEQSymmetricContraction + from openequivariance.torch.symmetric_contraction import SymmetricContraction as OEQSymmetricContraction Some Github users report weak performance for the symmetric contraction backward pass; your mileage may vary. diff --git a/openequivariance/__init__.py b/openequivariance/__init__.py index 9fb67d03..40c1b7f7 100644 --- a/openequivariance/__init__.py +++ b/openequivariance/__init__.py @@ -3,25 +3,21 @@ import torch import numpy as np -try: - import openequivariance.extlib -except Exception as e: - raise ImportError(f"Unable to load OpenEquivariance extension library:\n{e}") from pathlib import Path from importlib.metadata import version -from openequivariance.implementations.e3nn_lite import ( +from openequivariance.core.e3nn_lite import ( TPProblem, Irrep, Irreps, _MulIr, Instruction, ) -from openequivariance.implementations.TensorProduct import TensorProduct -from openequivariance.implementations.convolution.TensorProductConv import ( +from openequivariance.torch.TensorProduct import TensorProduct +from openequivariance.torch.TensorProductConv import ( TensorProductConv, ) -from openequivariance.implementations.utils import torch_to_oeq_dtype +from openequivariance.core.utils import torch_to_oeq_dtype __version__ = None try: @@ -63,9 +59,6 @@ def torch_ext_so_path(): ] ) -LINKED_LIBPYTHON = openequivariance.extlib.LINKED_LIBPYTHON -LINKED_LIBPYTHON_ERROR = openequivariance.extlib.LINKED_LIBPYTHON_ERROR - __all__ = [ "TPProblem", "Irreps", diff --git a/openequivariance/benchmark/ConvBenchmarkSuite.py b/openequivariance/benchmark/ConvBenchmarkSuite.py index a4b7c982..499a33eb 100644 --- a/openequivariance/benchmark/ConvBenchmarkSuite.py +++ b/openequivariance/benchmark/ConvBenchmarkSuite.py @@ -7,7 +7,7 @@ import openequivariance as oeq from openequivariance.benchmark.logging_utils import getLogger -from openequivariance.implementations.convolution.ConvolutionBase import CoordGraph +from openequivariance.core.ConvolutionBase import CoordGraph logger = getLogger() diff --git a/openequivariance/benchmark/TestBenchmarkSuite.py b/openequivariance/benchmark/TestBenchmarkSuite.py index d764be77..155c2444 100644 --- a/openequivariance/benchmark/TestBenchmarkSuite.py +++ b/openequivariance/benchmark/TestBenchmarkSuite.py @@ -8,10 +8,10 @@ import openequivariance as oeq from openequivariance.extlib import DeviceProp -from openequivariance.implementations.TensorProductBase import TensorProductBase +from openequivariance.core.TensorProductBase import TensorProductBase from openequivariance.benchmark.logging_utils import getLogger, bcolors -from openequivariance.implementations.e3nn_lite import TPProblem +from openequivariance.core.e3nn_lite import TPProblem from openequivariance.benchmark.correctness_utils import ( correctness_forward, correctness_backward, diff --git a/openequivariance/benchmark/benchmark_utils.py b/openequivariance/benchmark/benchmark_utils.py index 4dfea422..2b9a902f 100644 --- a/openequivariance/benchmark/benchmark_utils.py +++ b/openequivariance/benchmark/benchmark_utils.py @@ -10,10 +10,10 @@ calculate_minimum_memory_streamed_forward, calculate_minimum_memory_streamed_backward, ) -from openequivariance.implementations.utils import calculate_total_nnz -from openequivariance.implementations.TensorProductBase import TensorProductBase -from openequivariance.implementations.e3nn_lite import TPProblem -from openequivariance.implementations.CUETensorProduct import CUETensorProduct +from openequivariance.core.utils import calculate_total_nnz +from openequivariance.core.TensorProductBase import TensorProductBase +from openequivariance.core.e3nn_lite import TPProblem +from openequivariance.torch.CUETensorProduct import CUETensorProduct from openequivariance.benchmark.logging_utils import getLogger, bcolors logger = getLogger() diff --git a/openequivariance/benchmark/correctness_utils.py b/openequivariance/benchmark/correctness_utils.py index e2cf414b..5d999290 100644 --- a/openequivariance/benchmark/correctness_utils.py +++ b/openequivariance/benchmark/correctness_utils.py @@ -1,8 +1,8 @@ from typing import Optional, Union -from openequivariance.implementations.TensorProductBase import TensorProductBase -from openequivariance.implementations.CUETensorProduct import CUETensorProduct -from openequivariance.implementations.e3nn_lite import TPProblem +from openequivariance.core.TensorProductBase import TensorProductBase +from openequivariance.core.e3nn_lite import TPProblem +from openequivariance.torch.CUETensorProduct import CUETensorProduct from openequivariance.benchmark.random_buffer_utils import ( get_random_buffers_forward, get_random_buffers_backward, @@ -71,7 +71,7 @@ def correctness_forward( prng_seed: int, ) -> dict: if reference_implementation is None: - from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct + from openequivariance.torch.E3NNTensorProduct import E3NNTensorProduct reference_implementation = E3NNTensorProduct @@ -115,7 +115,7 @@ def correctness_backward( prng_seed: int, ) -> dict: if reference_implementation is None: - from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct + from openequivariance.torch.E3NNTensorProduct import E3NNTensorProduct reference_implementation = E3NNTensorProduct @@ -201,7 +201,7 @@ def correctness_double_backward( dummy_grad = rng.standard_normal(1)[0] if reference_implementation is None: - from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct + from openequivariance.torch.E3NNTensorProduct import E3NNTensorProduct reference_implementation = E3NNTensorProduct diff --git a/openequivariance/benchmark/perf_metrics_utils.py b/openequivariance/benchmark/perf_metrics_utils.py index 88a903ab..212f05f4 100644 --- a/openequivariance/benchmark/perf_metrics_utils.py +++ b/openequivariance/benchmark/perf_metrics_utils.py @@ -1,11 +1,11 @@ import math -from openequivariance.implementations.utils import ( +from openequivariance.core.utils import ( count_cg_non_zero, sparse_outer_product_work, ) -from openequivariance.implementations.e3nn_lite import TPProblem, wigner_3j +from openequivariance.core.e3nn_lite import TPProblem, wigner_3j from openequivariance.benchmark.logging_utils import getLogger import numpy as np diff --git a/openequivariance/benchmark/random_buffer_utils.py b/openequivariance/benchmark/random_buffer_utils.py index 41fb7cb6..20e9ac72 100644 --- a/openequivariance/benchmark/random_buffer_utils.py +++ b/openequivariance/benchmark/random_buffer_utils.py @@ -1,6 +1,6 @@ import numpy as np -from openequivariance.implementations.e3nn_lite import TPProblem +from openequivariance.core.e3nn_lite import TPProblem def get_random_buffers_forward( diff --git a/openequivariance/benchmark/tpp_creation_utils.py b/openequivariance/benchmark/tpp_creation_utils.py index 7421a71e..7637f412 100644 --- a/openequivariance/benchmark/tpp_creation_utils.py +++ b/openequivariance/benchmark/tpp_creation_utils.py @@ -1,7 +1,7 @@ import numpy as np from typing import Iterator, Optional -from openequivariance.implementations.e3nn_lite import Irrep, Irreps, TPProblem +from openequivariance.core.e3nn_lite import Irrep, Irreps, TPProblem """ This was taken from diff --git a/openequivariance/core/ComputationSchedule.py b/openequivariance/core/ComputationSchedule.py index 6d3b7215..e52cb7d2 100644 --- a/openequivariance/core/ComputationSchedule.py +++ b/openequivariance/core/ComputationSchedule.py @@ -1,5 +1,5 @@ import numpy as np -from openequivariance.implementations.e3nn_lite import Irreps, TPProblem, wigner_3j +from openequivariance.core.e3nn_lite import Irreps, TPProblem, wigner_3j from itertools import accumulate from openequivariance.benchmark.logging_utils import getLogger diff --git a/openequivariance/core/ConvolutionBase.py b/openequivariance/core/ConvolutionBase.py index 7ed16571..12bc0096 100644 --- a/openequivariance/core/ConvolutionBase.py +++ b/openequivariance/core/ConvolutionBase.py @@ -8,8 +8,8 @@ from openequivariance.benchmark.logging_utils import getLogger, bcolors from openequivariance.benchmark.correctness_utils import check_similiarity -from openequivariance.implementations.e3nn_lite import wigner_3j -from openequivariance.implementations.utils import benchmark +from openequivariance.core.e3nn_lite import wigner_3j +from openequivariance.core.utils import benchmark logger = getLogger() @@ -240,7 +240,7 @@ def test_correctness_forward( high_precision_ref=False, ): if reference_implementation is None: - from openequivariance.implementations.convolution.E3NNConv import E3NNConv + from openequivariance.torch.E3NNConv import E3NNConv reference_implementation = E3NNConv @@ -484,7 +484,7 @@ def test_correctness_backward( high_precision_ref=False, ): if reference_implementation is None: - from openequivariance.implementations.convolution.E3NNConv import E3NNConv + from openequivariance.torch.E3NNConv import E3NNConv reference_implementation = E3NNConv @@ -572,8 +572,7 @@ def test_correctness_double_backward( dummy_grad_value = rng.standard_normal(1)[0] if reference_implementation is None: - from openequivariance.implementations.convolution.E3NNConv import E3NNConv - + from openequivariance.torch.E3NNConv import E3NNConv reference_implementation = E3NNConv reference_problem = self.config diff --git a/openequivariance/core/LoopUnrollConv.py b/openequivariance/core/LoopUnrollConv.py index a5e46ce3..edd346d9 100644 --- a/openequivariance/core/LoopUnrollConv.py +++ b/openequivariance/core/LoopUnrollConv.py @@ -1,12 +1,12 @@ import numpy as np -from openequivariance.implementations.convolution.ConvolutionBase import ConvolutionBase -from openequivariance.implementations.ComputationSchedule import ( +from openequivariance.core.ConvolutionBase import ConvolutionBase +from openequivariance.core.ComputationSchedule import ( ComputationSchedule, SMEMCapacityException, ) -from openequivariance.implementations.dtype_enum import ( +from openequivariance.core.dtype_enum import ( dtype_to_enum, enum_to_torch_dtype, ) @@ -14,7 +14,7 @@ from openequivariance import extlib from openequivariance.extlib import JITConvImpl, postprocess_kernel, DeviceProp -from openequivariance.implementations.utils import filter_and_analyze_problem +from openequivariance.core.utils import filter_and_analyze_problem from openequivariance.benchmark.logging_utils import getLogger logger = getLogger() diff --git a/openequivariance/core/LoopUnrollTP.py b/openequivariance/core/LoopUnrollTP.py index ed6a5395..b99499bb 100644 --- a/openequivariance/core/LoopUnrollTP.py +++ b/openequivariance/core/LoopUnrollTP.py @@ -2,12 +2,12 @@ import openequivariance.extlib as extlib from openequivariance.templates.jinja_utils import get_jinja_environment -from openequivariance.implementations.ComputationSchedule import ComputationSchedule +from openequivariance.core.ComputationSchedule import ComputationSchedule -from openequivariance.implementations.dtype_enum import dtype_to_enum -from openequivariance.implementations.TensorProductBase import TensorProductBase +from openequivariance.core.dtype_enum import dtype_to_enum +from openequivariance.core.TensorProductBase import TensorProductBase from openequivariance.benchmark.logging_utils import getLogger -from openequivariance.implementations.utils import ( +from openequivariance.core.utils import ( filter_and_analyze_problem, count_cg_non_zero, ) diff --git a/openequivariance/core/TensorProductBase.py b/openequivariance/core/TensorProductBase.py index 043c5b77..5e0d25e2 100644 --- a/openequivariance/core/TensorProductBase.py +++ b/openequivariance/core/TensorProductBase.py @@ -1,8 +1,8 @@ import numpy as np -from openequivariance.implementations.e3nn_lite import TPProblem +from openequivariance.core.e3nn_lite import TPProblem from openequivariance.benchmark.logging_utils import getLogger -from openequivariance.implementations.utils import benchmark +from openequivariance.core.utils import benchmark from openequivariance.extlib import DeviceBuffer logger = getLogger() diff --git a/openequivariance/core/utils.py b/openequivariance/core/utils.py index b90993c1..4e4365da 100644 --- a/openequivariance/core/utils.py +++ b/openequivariance/core/utils.py @@ -3,7 +3,7 @@ import numpy as np -from openequivariance.implementations.e3nn_lite import Instruction, TPProblem, wigner_3j +from openequivariance.core.e3nn_lite import Instruction, TPProblem, wigner_3j import json import tempfile diff --git a/openequivariance/torch/CUEConv.py b/openequivariance/torch/CUEConv.py index 9287abe8..83127f74 100644 --- a/openequivariance/torch/CUEConv.py +++ b/openequivariance/torch/CUEConv.py @@ -2,13 +2,12 @@ import itertools from typing import Iterator -from openequivariance.implementations.CUETensorProduct import CUETensorProduct -from openequivariance.implementations.convolution.ConvolutionBase import ( +from openequivariance.torch.CUETensorProduct import CUETensorProduct +from openequivariance.core.ConvolutionBase import ( ConvolutionBase, scatter_add_wrapper, ) - class CUEConv(ConvolutionBase): def __init__(self, config, *, idx_dtype=np.int64, torch_op=True): super().__init__(config, idx_dtype=idx_dtype, torch_op=torch_op) diff --git a/openequivariance/torch/CUETensorProduct.py b/openequivariance/torch/CUETensorProduct.py index a7d027f4..33b8db12 100644 --- a/openequivariance/torch/CUETensorProduct.py +++ b/openequivariance/torch/CUETensorProduct.py @@ -4,15 +4,15 @@ import itertools from typing import Iterator -from openequivariance.implementations.TensorProductBase import TensorProductBase -from openequivariance.implementations.e3nn_lite import TPProblem +from openequivariance.core.TensorProductBase import TensorProductBase +from openequivariance.core.e3nn_lite import TPProblem from openequivariance.benchmark.logging_utils import getLogger from openequivariance.benchmark.tpp_creation_utils import ( ChannelwiseTPP, FullyConnectedTPProblem, SingleInstruction, ) -from openequivariance.implementations.utils import count_cg_non_zero +from openequivariance.core.utils import count_cg_non_zero os.environ["CUEQUIVARIANCE_OPS_USE_JIT"] = "1" diff --git a/openequivariance/torch/E3NNConv.py b/openequivariance/torch/E3NNConv.py index 00b0faa8..618305fe 100644 --- a/openequivariance/torch/E3NNConv.py +++ b/openequivariance/torch/E3NNConv.py @@ -1,10 +1,10 @@ import numpy as np -from openequivariance.implementations.convolution.ConvolutionBase import ( +from openequivariance.core.ConvolutionBase import ( ConvolutionBase, scatter_add_wrapper, ) -from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct +from openequivariance.torch.E3NNTensorProduct import E3NNTensorProduct class E3NNConv(ConvolutionBase): diff --git a/openequivariance/torch/E3NNTensorProduct.py b/openequivariance/torch/E3NNTensorProduct.py index 334ba65c..32196235 100644 --- a/openequivariance/torch/E3NNTensorProduct.py +++ b/openequivariance/torch/E3NNTensorProduct.py @@ -9,8 +9,8 @@ import pathlib import numpy as np -from openequivariance.implementations.TensorProductBase import TensorProductBase -from openequivariance.implementations.e3nn_lite import TPProblem +from openequivariance.core.TensorProductBase import TensorProductBase +from openequivariance.core.e3nn_lite import TPProblem from openequivariance.benchmark.logging_utils import getLogger TORCH_COMPILE_AUTOTUNING_DIR = pathlib.Path("triton_autotuning") diff --git a/openequivariance/torch/FlashTPConv.py b/openequivariance/torch/FlashTPConv.py index 0302ef9c..9ec5c409 100644 --- a/openequivariance/torch/FlashTPConv.py +++ b/openequivariance/torch/FlashTPConv.py @@ -4,8 +4,8 @@ import torch import numpy as np -from openequivariance.implementations.convolution.ConvolutionBase import ConvolutionBase -from openequivariance.implementations.utils import oeq_to_torch_dtype +from openequivariance.core.ConvolutionBase import ConvolutionBase +from openequivariance.core.utils import oeq_to_torch_dtype class FlashTPConv(ConvolutionBase): diff --git a/openequivariance/torch/TensorProduct.py b/openequivariance/torch/TensorProduct.py index 54fc4307..a955694a 100644 --- a/openequivariance/torch/TensorProduct.py +++ b/openequivariance/torch/TensorProduct.py @@ -1,9 +1,9 @@ -from openequivariance.implementations.LoopUnrollTP import LoopUnrollTP +from openequivariance.core.LoopUnrollTP import LoopUnrollTP from openequivariance import TPProblem from openequivariance import extlib import torch import typing -from openequivariance.implementations.utils import torch_to_oeq_dtype +from openequivariance.core.utils import torch_to_oeq_dtype class TensorProduct(torch.nn.Module, LoopUnrollTP): diff --git a/openequivariance/torch/TensorProductConv.py b/openequivariance/torch/TensorProductConv.py index 7e860944..d4af06fa 100644 --- a/openequivariance/torch/TensorProductConv.py +++ b/openequivariance/torch/TensorProductConv.py @@ -4,14 +4,14 @@ import torch from openequivariance import extlib -from openequivariance.implementations.convolution.ConvolutionBase import ( +from openequivariance.core.ConvolutionBase import ( ConvolutionBase, scatter_add_wrapper, ) -from openequivariance.implementations.convolution.LoopUnrollConv import LoopUnrollConv -from openequivariance.implementations.TensorProduct import TensorProduct +from openequivariance.core.LoopUnrollConv import LoopUnrollConv +from openequivariance.torch.TensorProduct import TensorProduct from openequivariance import TPProblem -from openequivariance.implementations.utils import torch_to_oeq_dtype +from openequivariance.core.utils import torch_to_oeq_dtype class TensorProductConv(torch.nn.Module, LoopUnrollConv): diff --git a/openequivariance/torch/symmetric_contraction/__init__.py b/openequivariance/torch/symmetric_contraction/__init__.py index 75ac6cc8..43c9fdf4 100644 --- a/openequivariance/torch/symmetric_contraction/__init__.py +++ b/openequivariance/torch/symmetric_contraction/__init__.py @@ -1,4 +1,4 @@ -from openequivariance.implementations.symmetric_contraction.symmetric_contraction import ( +from openequivariance.torch.symmetric_contraction.symmetric_contraction import ( SymmetricContraction, ) diff --git a/tests/batch_test.py b/tests/batch_test.py index 3c7cdf27..b277c80b 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -3,7 +3,7 @@ import numpy as np import openequivariance as oeq -from openequivariance.implementations.TensorProduct import TensorProduct +from openequivariance.core.TensorProduct import TensorProduct from openequivariance.benchmark.correctness_utils import ( correctness_forward, correctness_backward, diff --git a/tests/benchmark.py b/tests/benchmark.py index ab005cdf..59f29300 100644 --- a/tests/benchmark.py +++ b/tests/benchmark.py @@ -12,13 +12,13 @@ from openequivariance.benchmark.logging_utils import getLogger from openequivariance.extlib import DeviceProp -from openequivariance.implementations.E3NNTensorProduct import ( +from openequivariance.torch.E3NNTensorProduct import ( E3NNTensorProduct, E3NNTensorProductCompiledCUDAGraphs, E3NNTensorProductCompiledMaxAutotuneCUDAGraphs, ) -from openequivariance.implementations.TensorProduct import TensorProduct -from openequivariance.implementations.CUETensorProduct import CUETensorProduct +from openequivariance.torch.TensorProduct import TensorProduct +from openequivariance.torch.CUETensorProduct import CUETensorProduct from openequivariance.benchmark.TestBenchmarkSuite import ( TestBenchmarkSuite, TestDefinition, @@ -30,15 +30,15 @@ SingleInstruction, ) -from openequivariance.implementations.convolution.TensorProductConv import ( +from openequivariance.torch.TensorProductConv import ( TensorProductConvAtomic, TensorProductConvDeterministic, TensorProductConvKahan, TensorProductConvScatterSum, ) -from openequivariance.implementations.convolution.CUEConv import CUEConv, CUEConvFused -from openequivariance.implementations.convolution.FlashTPConv import FlashTPConv +from openequivariance.torch.CUEConv import CUEConv, CUEConvFused +from openequivariance.torch.FlashTPConv import FlashTPConv from openequivariance.benchmark.ConvBenchmarkSuite import ConvBenchmarkSuite, load_graph from openequivariance.benchmark.problems import ( diff --git a/tests/export_test.py b/tests/export_test.py index e18b38b1..35e52df8 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -11,7 +11,7 @@ from torch_geometric import EdgeIndex import importlib.resources -from openequivariance.implementations.E3NNTensorProduct import E3NNTensorProduct +from openequivariance.torch.E3NNTensorProduct import E3NNTensorProduct @pytest.fixture(scope="session") From 33f823226239325a55e98d9cb14e4195ced1ee2e Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 24 Nov 2025 15:33:03 -0800 Subject: [PATCH 009/116] Tests are passing after the first refactor. --- .gitignore | 1 - MANIFEST.in | 1 - docs/api.rst | 4 ++-- docs/conf.py | 2 +- docs/supported_ops.rst | 2 +- openequivariance/__init__.py | 6 +++--- openequivariance/benchmark/TestBenchmarkSuite.py | 2 +- openequivariance/benchmark/benchmark_utils.py | 2 +- openequivariance/benchmark/correctness_utils.py | 8 ++++---- openequivariance/core/ConvolutionBase.py | 8 ++++---- openequivariance/core/LoopUnrollConv.py | 4 ++-- openequivariance/core/LoopUnrollTP.py | 2 +- openequivariance/core/TensorProductBase.py | 2 +- openequivariance/core/utils.py | 2 +- openequivariance/{torch => impl_torch}/CUEConv.py | 2 +- .../{torch => impl_torch}/CUETensorProduct.py | 0 openequivariance/{torch => impl_torch}/E3NNConv.py | 2 +- .../{torch => impl_torch}/E3NNTensorProduct.py | 0 .../{torch => impl_torch}/FlashTPConv.py | 0 .../{torch => impl_torch}/TensorProduct.py | 2 +- .../{torch => impl_torch}/TensorProductConv.py | 4 ++-- .../{torch => impl_torch}/extlib/.empty | 0 .../{torch => impl_torch}/extlib/__init__.py | 6 +++--- .../impl_torch/symmetric_contraction/__init__.py | 5 +++++ .../symmetric_contraction/symmetric_contraction.py | 2 +- .../torch/symmetric_contraction/__init__.py | 5 ----- tests/batch_test.py | 2 +- tests/benchmark.py | 14 +++++++------- tests/export_test.py | 2 +- 29 files changed, 45 insertions(+), 47 deletions(-) rename openequivariance/{torch => impl_torch}/CUEConv.py (97%) rename openequivariance/{torch => impl_torch}/CUETensorProduct.py (100%) rename openequivariance/{torch => impl_torch}/E3NNConv.py (96%) rename openequivariance/{torch => impl_torch}/E3NNTensorProduct.py (100%) rename openequivariance/{torch => impl_torch}/FlashTPConv.py (100%) rename openequivariance/{torch => impl_torch}/TensorProduct.py (99%) rename openequivariance/{torch => impl_torch}/TensorProductConv.py (99%) rename openequivariance/{torch => impl_torch}/extlib/.empty (100%) rename openequivariance/{torch => impl_torch}/extlib/__init__.py (96%) create mode 100644 openequivariance/impl_torch/symmetric_contraction/__init__.py rename openequivariance/{torch => impl_torch}/symmetric_contraction/symmetric_contraction.py (99%) delete mode 100644 openequivariance/torch/symmetric_contraction/__init__.py diff --git a/.gitignore b/.gitignore index 5c878b1a..64fcaa8d 100644 --- a/.gitignore +++ b/.gitignore @@ -38,7 +38,6 @@ triton_autotuning paper_benchmarks paper_benchmarks_v2 paper_benchmarks_v3 -openequivariance/extlib/*.so get_node.sh *.egg-info \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in index 7eaa4d91..f70b3e00 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,3 @@ -include openequivariance/extlib/*.so include openequivariance/extlib/*.empty include openequivariance/templates/*.cuh diff --git a/docs/api.rst b/docs/api.rst index 3fac1764..b8345747 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -31,9 +31,9 @@ trying our code. OpenEquivariance cannot accelerate all tensor products; see :members: :undoc-members: -.. autofunction:: openequivariance.torch_to_oeq_dtype +.. autofunction:: openequivariance.impl_torch_to_oeq_dtype -.. autofunction:: openequivariance.torch_ext_so_path +.. autofunction:: openequivariance.impl_torch_ext_so_path API Identical to e3nn --------------------- diff --git a/docs/conf.py b/docs/conf.py index 17707552..a360fb6d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -34,5 +34,5 @@ sys.path.insert(0, str(Path("..").resolve())) -autodoc_mock_imports = ["torch", "openequivariance.extlib", "jinja2", "numpy"] +autodoc_mock_imports = ["torch", "openequivariance.impl_torch.extlib", "jinja2", "numpy"] autodoc_typehints = "description" diff --git a/docs/supported_ops.rst b/docs/supported_ops.rst index 0cba136a..02a98282 100644 --- a/docs/supported_ops.rst +++ b/docs/supported_ops.rst @@ -117,7 +117,7 @@ toplevel. You can use our implementation by running .. code-block:: - from openequivariance.torch.symmetric_contraction import SymmetricContraction as OEQSymmetricContraction + from openequivariance.impl_torch.symmetric_contraction import SymmetricContraction as OEQSymmetricContraction Some Github users report weak performance for the symmetric contraction backward pass; your mileage may vary. diff --git a/openequivariance/__init__.py b/openequivariance/__init__.py index 40c1b7f7..04527921 100644 --- a/openequivariance/__init__.py +++ b/openequivariance/__init__.py @@ -13,8 +13,8 @@ _MulIr, Instruction, ) -from openequivariance.torch.TensorProduct import TensorProduct -from openequivariance.torch.TensorProductConv import ( +from openequivariance.impl_torch.TensorProduct import TensorProduct +from openequivariance.impl_torch.TensorProductConv import ( TensorProductConv, ) from openequivariance.core.utils import torch_to_oeq_dtype @@ -42,7 +42,7 @@ def torch_ext_so_path(): :returns: Path to a ``.so`` file that must be linked to use OpenEquivariance from the PyTorch C++ Interface. """ - return openequivariance.extlib.torch_module.__file__ + return openequivariance.impl_torch.extlib.torch_module.__file__ torch.serialization.add_safe_globals( diff --git a/openequivariance/benchmark/TestBenchmarkSuite.py b/openequivariance/benchmark/TestBenchmarkSuite.py index 155c2444..12cbee5e 100644 --- a/openequivariance/benchmark/TestBenchmarkSuite.py +++ b/openequivariance/benchmark/TestBenchmarkSuite.py @@ -7,7 +7,7 @@ from dataclasses import dataclass import openequivariance as oeq -from openequivariance.extlib import DeviceProp +from openequivariance.impl_torch.extlib import DeviceProp from openequivariance.core.TensorProductBase import TensorProductBase from openequivariance.benchmark.logging_utils import getLogger, bcolors diff --git a/openequivariance/benchmark/benchmark_utils.py b/openequivariance/benchmark/benchmark_utils.py index 2b9a902f..b7abaf77 100644 --- a/openequivariance/benchmark/benchmark_utils.py +++ b/openequivariance/benchmark/benchmark_utils.py @@ -13,7 +13,7 @@ from openequivariance.core.utils import calculate_total_nnz from openequivariance.core.TensorProductBase import TensorProductBase from openequivariance.core.e3nn_lite import TPProblem -from openequivariance.torch.CUETensorProduct import CUETensorProduct +from openequivariance.impl_torch.CUETensorProduct import CUETensorProduct from openequivariance.benchmark.logging_utils import getLogger, bcolors logger = getLogger() diff --git a/openequivariance/benchmark/correctness_utils.py b/openequivariance/benchmark/correctness_utils.py index 5d999290..5a3ad87c 100644 --- a/openequivariance/benchmark/correctness_utils.py +++ b/openequivariance/benchmark/correctness_utils.py @@ -2,7 +2,7 @@ from openequivariance.core.TensorProductBase import TensorProductBase from openequivariance.core.e3nn_lite import TPProblem -from openequivariance.torch.CUETensorProduct import CUETensorProduct +from openequivariance.impl_torch.CUETensorProduct import CUETensorProduct from openequivariance.benchmark.random_buffer_utils import ( get_random_buffers_forward, get_random_buffers_backward, @@ -71,7 +71,7 @@ def correctness_forward( prng_seed: int, ) -> dict: if reference_implementation is None: - from openequivariance.torch.E3NNTensorProduct import E3NNTensorProduct + from openequivariance.impl_torch.E3NNTensorProduct import E3NNTensorProduct reference_implementation = E3NNTensorProduct @@ -115,7 +115,7 @@ def correctness_backward( prng_seed: int, ) -> dict: if reference_implementation is None: - from openequivariance.torch.E3NNTensorProduct import E3NNTensorProduct + from openequivariance.impl_torch.E3NNTensorProduct import E3NNTensorProduct reference_implementation = E3NNTensorProduct @@ -201,7 +201,7 @@ def correctness_double_backward( dummy_grad = rng.standard_normal(1)[0] if reference_implementation is None: - from openequivariance.torch.E3NNTensorProduct import E3NNTensorProduct + from openequivariance.impl_torch.E3NNTensorProduct import E3NNTensorProduct reference_implementation = E3NNTensorProduct diff --git a/openequivariance/core/ConvolutionBase.py b/openequivariance/core/ConvolutionBase.py index 12bc0096..35c82307 100644 --- a/openequivariance/core/ConvolutionBase.py +++ b/openequivariance/core/ConvolutionBase.py @@ -1,6 +1,6 @@ import copy import numpy as np -from openequivariance.extlib import DeviceBuffer +from openequivariance.impl_torch.extlib import DeviceBuffer from openequivariance.benchmark.random_buffer_utils import ( get_random_buffers_forward_conv, get_random_buffers_backward_conv, @@ -240,7 +240,7 @@ def test_correctness_forward( high_precision_ref=False, ): if reference_implementation is None: - from openequivariance.torch.E3NNConv import E3NNConv + from openequivariance.impl_torch.E3NNConv import E3NNConv reference_implementation = E3NNConv @@ -484,7 +484,7 @@ def test_correctness_backward( high_precision_ref=False, ): if reference_implementation is None: - from openequivariance.torch.E3NNConv import E3NNConv + from openequivariance.impl_torch.E3NNConv import E3NNConv reference_implementation = E3NNConv @@ -572,7 +572,7 @@ def test_correctness_double_backward( dummy_grad_value = rng.standard_normal(1)[0] if reference_implementation is None: - from openequivariance.torch.E3NNConv import E3NNConv + from openequivariance.impl_torch.E3NNConv import E3NNConv reference_implementation = E3NNConv reference_problem = self.config diff --git a/openequivariance/core/LoopUnrollConv.py b/openequivariance/core/LoopUnrollConv.py index edd346d9..104230fc 100644 --- a/openequivariance/core/LoopUnrollConv.py +++ b/openequivariance/core/LoopUnrollConv.py @@ -11,8 +11,8 @@ enum_to_torch_dtype, ) from openequivariance.templates.jinja_utils import get_jinja_environment -from openequivariance import extlib -from openequivariance.extlib import JITConvImpl, postprocess_kernel, DeviceProp +import openequivariance.impl_torch.extlib as extlib +from openequivariance.impl_torch.extlib import JITConvImpl, postprocess_kernel, DeviceProp from openequivariance.core.utils import filter_and_analyze_problem from openequivariance.benchmark.logging_utils import getLogger diff --git a/openequivariance/core/LoopUnrollTP.py b/openequivariance/core/LoopUnrollTP.py index b99499bb..84bebc42 100644 --- a/openequivariance/core/LoopUnrollTP.py +++ b/openequivariance/core/LoopUnrollTP.py @@ -1,6 +1,6 @@ import numpy as np -import openequivariance.extlib as extlib +import openequivariance.impl_torch.extlib as extlib from openequivariance.templates.jinja_utils import get_jinja_environment from openequivariance.core.ComputationSchedule import ComputationSchedule diff --git a/openequivariance/core/TensorProductBase.py b/openequivariance/core/TensorProductBase.py index 5e0d25e2..f00dcc7c 100644 --- a/openequivariance/core/TensorProductBase.py +++ b/openequivariance/core/TensorProductBase.py @@ -3,7 +3,7 @@ from openequivariance.core.e3nn_lite import TPProblem from openequivariance.benchmark.logging_utils import getLogger from openequivariance.core.utils import benchmark -from openequivariance.extlib import DeviceBuffer +from openequivariance.impl_torch.extlib import DeviceBuffer logger = getLogger() diff --git a/openequivariance/core/utils.py b/openequivariance/core/utils.py index 4e4365da..cf9d038e 100644 --- a/openequivariance/core/utils.py +++ b/openequivariance/core/utils.py @@ -7,7 +7,7 @@ import json import tempfile -from openequivariance.extlib import GPUTimer +from openequivariance.impl_torch.extlib import GPUTimer def sparse_outer_product_work(cg: np.ndarray) -> int: diff --git a/openequivariance/torch/CUEConv.py b/openequivariance/impl_torch/CUEConv.py similarity index 97% rename from openequivariance/torch/CUEConv.py rename to openequivariance/impl_torch/CUEConv.py index 83127f74..a8877ab8 100644 --- a/openequivariance/torch/CUEConv.py +++ b/openequivariance/impl_torch/CUEConv.py @@ -2,7 +2,7 @@ import itertools from typing import Iterator -from openequivariance.torch.CUETensorProduct import CUETensorProduct +from openequivariance.impl_torch.CUETensorProduct import CUETensorProduct from openequivariance.core.ConvolutionBase import ( ConvolutionBase, scatter_add_wrapper, diff --git a/openequivariance/torch/CUETensorProduct.py b/openequivariance/impl_torch/CUETensorProduct.py similarity index 100% rename from openequivariance/torch/CUETensorProduct.py rename to openequivariance/impl_torch/CUETensorProduct.py diff --git a/openequivariance/torch/E3NNConv.py b/openequivariance/impl_torch/E3NNConv.py similarity index 96% rename from openequivariance/torch/E3NNConv.py rename to openequivariance/impl_torch/E3NNConv.py index 618305fe..29137819 100644 --- a/openequivariance/torch/E3NNConv.py +++ b/openequivariance/impl_torch/E3NNConv.py @@ -4,7 +4,7 @@ ConvolutionBase, scatter_add_wrapper, ) -from openequivariance.torch.E3NNTensorProduct import E3NNTensorProduct +from openequivariance.impl_torch.E3NNTensorProduct import E3NNTensorProduct class E3NNConv(ConvolutionBase): diff --git a/openequivariance/torch/E3NNTensorProduct.py b/openequivariance/impl_torch/E3NNTensorProduct.py similarity index 100% rename from openequivariance/torch/E3NNTensorProduct.py rename to openequivariance/impl_torch/E3NNTensorProduct.py diff --git a/openequivariance/torch/FlashTPConv.py b/openequivariance/impl_torch/FlashTPConv.py similarity index 100% rename from openequivariance/torch/FlashTPConv.py rename to openequivariance/impl_torch/FlashTPConv.py diff --git a/openequivariance/torch/TensorProduct.py b/openequivariance/impl_torch/TensorProduct.py similarity index 99% rename from openequivariance/torch/TensorProduct.py rename to openequivariance/impl_torch/TensorProduct.py index a955694a..08ce6cc8 100644 --- a/openequivariance/torch/TensorProduct.py +++ b/openequivariance/impl_torch/TensorProduct.py @@ -1,6 +1,6 @@ from openequivariance.core.LoopUnrollTP import LoopUnrollTP from openequivariance import TPProblem -from openequivariance import extlib +from openequivariance.impl_torch import extlib import torch import typing from openequivariance.core.utils import torch_to_oeq_dtype diff --git a/openequivariance/torch/TensorProductConv.py b/openequivariance/impl_torch/TensorProductConv.py similarity index 99% rename from openequivariance/torch/TensorProductConv.py rename to openequivariance/impl_torch/TensorProductConv.py index d4af06fa..3ee71108 100644 --- a/openequivariance/torch/TensorProductConv.py +++ b/openequivariance/impl_torch/TensorProductConv.py @@ -3,13 +3,13 @@ import numpy as np import torch -from openequivariance import extlib +from openequivariance.impl_torch import extlib from openequivariance.core.ConvolutionBase import ( ConvolutionBase, scatter_add_wrapper, ) from openequivariance.core.LoopUnrollConv import LoopUnrollConv -from openequivariance.torch.TensorProduct import TensorProduct +from openequivariance.impl_torch.TensorProduct import TensorProduct from openequivariance import TPProblem from openequivariance.core.utils import torch_to_oeq_dtype diff --git a/openequivariance/torch/extlib/.empty b/openequivariance/impl_torch/extlib/.empty similarity index 100% rename from openequivariance/torch/extlib/.empty rename to openequivariance/impl_torch/extlib/.empty diff --git a/openequivariance/torch/extlib/__init__.py b/openequivariance/impl_torch/extlib/__init__.py similarity index 96% rename from openequivariance/torch/extlib/__init__.py rename to openequivariance/impl_torch/extlib/__init__.py index 527c4ab3..3d8fd085 100644 --- a/openequivariance/torch/extlib/__init__.py +++ b/openequivariance/impl_torch/extlib/__init__.py @@ -10,7 +10,7 @@ from openequivariance.benchmark.logging_utils import getLogger -oeq_root = str(Path(__file__).parent.parent) +oeq_root = str(Path(__file__).parent.parent.parent) build_ext = True TORCH_COMPILE = True @@ -39,9 +39,9 @@ generic_module = None if not build_ext: - import openequivariance.extlib.generic_module + import openequivariance.impl_torch.extlib.generic_module + generic_module = openequivariance.impl_torch.extlib.generic_module - generic_module = openequivariance.extlib.generic_module elif TORCH_VERSION_CUDA_OR_HIP: from torch.utils.cpp_extension import library_paths, include_paths diff --git a/openequivariance/impl_torch/symmetric_contraction/__init__.py b/openequivariance/impl_torch/symmetric_contraction/__init__.py new file mode 100644 index 00000000..23d4b030 --- /dev/null +++ b/openequivariance/impl_torch/symmetric_contraction/__init__.py @@ -0,0 +1,5 @@ +from openequivariance.impl_torch.symmetric_contraction.symmetric_contraction import ( + SymmetricContraction, +) + +__all__ = ["SymmetricContraction"] diff --git a/openequivariance/torch/symmetric_contraction/symmetric_contraction.py b/openequivariance/impl_torch/symmetric_contraction/symmetric_contraction.py similarity index 99% rename from openequivariance/torch/symmetric_contraction/symmetric_contraction.py rename to openequivariance/impl_torch/symmetric_contraction/symmetric_contraction.py index 9790c2a2..d1f409d0 100644 --- a/openequivariance/torch/symmetric_contraction/symmetric_contraction.py +++ b/openequivariance/impl_torch/symmetric_contraction/symmetric_contraction.py @@ -1,7 +1,7 @@ # ruff: noqa : E402 import torch -from openequivariance.extlib import GroupMM_F32, GroupMM_F64 +from openequivariance.impl_torch.extlib import GroupMM_F32, GroupMM_F64 class GroupMM: diff --git a/openequivariance/torch/symmetric_contraction/__init__.py b/openequivariance/torch/symmetric_contraction/__init__.py deleted file mode 100644 index 43c9fdf4..00000000 --- a/openequivariance/torch/symmetric_contraction/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from openequivariance.torch.symmetric_contraction.symmetric_contraction import ( - SymmetricContraction, -) - -__all__ = ["SymmetricContraction"] diff --git a/tests/batch_test.py b/tests/batch_test.py index b277c80b..9cd032a3 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -3,7 +3,7 @@ import numpy as np import openequivariance as oeq -from openequivariance.core.TensorProduct import TensorProduct +from openequivariance.impl_torch.TensorProduct import TensorProduct from openequivariance.benchmark.correctness_utils import ( correctness_forward, correctness_backward, diff --git a/tests/benchmark.py b/tests/benchmark.py index 59f29300..7ef63b9c 100644 --- a/tests/benchmark.py +++ b/tests/benchmark.py @@ -11,14 +11,14 @@ import numpy as np from openequivariance.benchmark.logging_utils import getLogger -from openequivariance.extlib import DeviceProp -from openequivariance.torch.E3NNTensorProduct import ( +from openequivariance.impl_torch.extlib import DeviceProp +from openequivariance.impl_torch.E3NNTensorProduct import ( E3NNTensorProduct, E3NNTensorProductCompiledCUDAGraphs, E3NNTensorProductCompiledMaxAutotuneCUDAGraphs, ) -from openequivariance.torch.TensorProduct import TensorProduct -from openequivariance.torch.CUETensorProduct import CUETensorProduct +from openequivariance.impl_torch.TensorProduct import TensorProduct +from openequivariance.impl_torch.CUETensorProduct import CUETensorProduct from openequivariance.benchmark.TestBenchmarkSuite import ( TestBenchmarkSuite, TestDefinition, @@ -30,15 +30,15 @@ SingleInstruction, ) -from openequivariance.torch.TensorProductConv import ( +from openequivariance.impl_torch.TensorProductConv import ( TensorProductConvAtomic, TensorProductConvDeterministic, TensorProductConvKahan, TensorProductConvScatterSum, ) -from openequivariance.torch.CUEConv import CUEConv, CUEConvFused -from openequivariance.torch.FlashTPConv import FlashTPConv +from openequivariance.impl_torch.CUEConv import CUEConv, CUEConvFused +from openequivariance.impl_torch.FlashTPConv import FlashTPConv from openequivariance.benchmark.ConvBenchmarkSuite import ConvBenchmarkSuite, load_graph from openequivariance.benchmark.problems import ( diff --git a/tests/export_test.py b/tests/export_test.py index 35e52df8..9b64e2fe 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -11,7 +11,7 @@ from torch_geometric import EdgeIndex import importlib.resources -from openequivariance.torch.E3NNTensorProduct import E3NNTensorProduct +from openequivariance.impl_torch.E3NNTensorProduct import E3NNTensorProduct @pytest.fixture(scope="session") From 121cb42c2f1c8055082cd49adfff7790da66cec3 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 24 Nov 2025 16:59:55 -0800 Subject: [PATCH 010/116] Nested directory structure by one level. --- MANIFEST.in => openequivariance/MANIFEST.in | 4 +--- openequivariance/{ => openequivariance}/__init__.py | 0 .../{ => openequivariance}/benchmark/ConvBenchmarkSuite.py | 0 .../{ => openequivariance}/benchmark/TestBenchmarkSuite.py | 0 .../{ => openequivariance}/benchmark/benchmark_utils.py | 0 .../{ => openequivariance}/benchmark/correctness_utils.py | 0 .../{ => openequivariance}/benchmark/logging_utils.py | 0 .../{ => openequivariance}/benchmark/perf_metrics_utils.py | 0 .../{ => openequivariance}/benchmark/plotting/__init__.py | 0 .../benchmark/plotting/plot_convolution.py | 0 .../benchmark/plotting/plot_double_backward.py | 0 .../benchmark/plotting/plot_roofline.py | 0 .../{ => openequivariance}/benchmark/plotting/plot_uvu.py | 0 .../{ => openequivariance}/benchmark/plotting/plot_uvw.py | 0 .../benchmark/plotting/plotting_utils.py | 0 openequivariance/{ => openequivariance}/benchmark/problems.py | 0 .../{ => openequivariance}/benchmark/random_buffer_utils.py | 0 .../{ => openequivariance}/benchmark/tpp_creation_utils.py | 0 .../{ => openequivariance}/core/ComputationSchedule.py | 0 .../{ => openequivariance}/core/ConvolutionBase.py | 0 .../{ => openequivariance}/core/LoopUnrollConv.py | 0 openequivariance/{ => openequivariance}/core/LoopUnrollTP.py | 0 .../{ => openequivariance}/core/TensorProductBase.py | 0 openequivariance/{ => openequivariance}/core/dtype_enum.py | 0 openequivariance/{ => openequivariance}/core/e3nn_lite.py | 0 openequivariance/{ => openequivariance}/core/utils.py | 0 .../{ => openequivariance}/extension/CMakeLists.txt | 0 .../{ => openequivariance}/extension/convolution.hpp | 0 .../{ => openequivariance}/extension/generic_module.cpp | 0 .../{ => openequivariance}/extension/group_mm_cuda.hpp | 0 .../{ => openequivariance}/extension/group_mm_hip.hpp | 0 .../{ => openequivariance}/extension/libjax_tp_jit.cpp | 0 .../{ => openequivariance}/extension/libtorch_tp_jit.cpp | 0 .../{ => openequivariance}/extension/tensorproducts.hpp | 0 .../{ => openequivariance}/extension/test/CMakeLists.txt | 0 .../{ => openequivariance}/extension/test/load_jitscript.cpp | 0 .../{ => openequivariance}/extension/util/backend_cuda.hpp | 0 .../{ => openequivariance}/extension/util/backend_hip.hpp | 0 .../{ => openequivariance}/extension/util/buffer.hpp | 0 openequivariance/{ => openequivariance}/impl_torch/CUEConv.py | 0 .../{ => openequivariance}/impl_torch/CUETensorProduct.py | 0 .../{ => openequivariance}/impl_torch/E3NNConv.py | 0 .../{ => openequivariance}/impl_torch/E3NNTensorProduct.py | 0 .../{ => openequivariance}/impl_torch/FlashTPConv.py | 0 .../{ => openequivariance}/impl_torch/TensorProduct.py | 0 .../{ => openequivariance}/impl_torch/TensorProductConv.py | 0 .../{ => openequivariance}/impl_torch/extlib/.empty | 0 .../{ => openequivariance}/impl_torch/extlib/__init__.py | 0 .../impl_torch/symmetric_contraction/__init__.py | 0 .../impl_torch/symmetric_contraction/symmetric_contraction.py | 0 openequivariance/{ => openequivariance}/templates/common.cuh | 0 .../{ => openequivariance}/templates/jinja_utils.py | 0 .../{ => openequivariance}/templates/loop_unroll_batch.cuh | 0 .../templates/loop_unroll_conv_atomic.cuh | 0 .../{ => openequivariance}/templates/loop_unroll_conv_det.cuh | 0 .../{ => openequivariance}/templates/loop_unroll_tp.cuh | 0 .../{ => openequivariance}/templates/macros.jinja | 0 openequivariance/{ => openequivariance}/templates/wmm.cuh | 0 pyproject.toml => openequivariance/pyproject.toml | 4 ++-- {tests => openequivariance/tests}/batch_test.py | 0 {tests => openequivariance/tests}/benchmark.py | 0 {tests => openequivariance/tests}/conv_test.py | 0 {tests => openequivariance/tests}/examples_test.py | 0 {tests => openequivariance/tests}/export_test.py | 0 {tests => openequivariance/tests}/import_test.py | 0 {tests => openequivariance/tests}/input_validation_test.py | 0 {tests => openequivariance/tests}/mace_driver.py | 0 {tests => openequivariance/tests}/multidevice_test.py | 0 {tests => openequivariance/tests}/stream_test.py | 0 {tests => openequivariance/tests}/torch_determinism_test.py | 0 70 files changed, 3 insertions(+), 5 deletions(-) rename MANIFEST.in => openequivariance/MANIFEST.in (73%) rename openequivariance/{ => openequivariance}/__init__.py (100%) rename openequivariance/{ => openequivariance}/benchmark/ConvBenchmarkSuite.py (100%) rename openequivariance/{ => openequivariance}/benchmark/TestBenchmarkSuite.py (100%) rename openequivariance/{ => openequivariance}/benchmark/benchmark_utils.py (100%) rename openequivariance/{ => openequivariance}/benchmark/correctness_utils.py (100%) rename openequivariance/{ => openequivariance}/benchmark/logging_utils.py (100%) rename openequivariance/{ => openequivariance}/benchmark/perf_metrics_utils.py (100%) rename openequivariance/{ => openequivariance}/benchmark/plotting/__init__.py (100%) rename openequivariance/{ => openequivariance}/benchmark/plotting/plot_convolution.py (100%) rename openequivariance/{ => openequivariance}/benchmark/plotting/plot_double_backward.py (100%) rename openequivariance/{ => openequivariance}/benchmark/plotting/plot_roofline.py (100%) rename openequivariance/{ => openequivariance}/benchmark/plotting/plot_uvu.py (100%) rename openequivariance/{ => openequivariance}/benchmark/plotting/plot_uvw.py (100%) rename openequivariance/{ => openequivariance}/benchmark/plotting/plotting_utils.py (100%) rename openequivariance/{ => openequivariance}/benchmark/problems.py (100%) rename openequivariance/{ => openequivariance}/benchmark/random_buffer_utils.py (100%) rename openequivariance/{ => openequivariance}/benchmark/tpp_creation_utils.py (100%) rename openequivariance/{ => openequivariance}/core/ComputationSchedule.py (100%) rename openequivariance/{ => openequivariance}/core/ConvolutionBase.py (100%) rename openequivariance/{ => openequivariance}/core/LoopUnrollConv.py (100%) rename openequivariance/{ => openequivariance}/core/LoopUnrollTP.py (100%) rename openequivariance/{ => openequivariance}/core/TensorProductBase.py (100%) rename openequivariance/{ => openequivariance}/core/dtype_enum.py (100%) rename openequivariance/{ => openequivariance}/core/e3nn_lite.py (100%) rename openequivariance/{ => openequivariance}/core/utils.py (100%) rename openequivariance/{ => openequivariance}/extension/CMakeLists.txt (100%) rename openequivariance/{ => openequivariance}/extension/convolution.hpp (100%) rename openequivariance/{ => openequivariance}/extension/generic_module.cpp (100%) rename openequivariance/{ => openequivariance}/extension/group_mm_cuda.hpp (100%) rename openequivariance/{ => openequivariance}/extension/group_mm_hip.hpp (100%) rename openequivariance/{ => openequivariance}/extension/libjax_tp_jit.cpp (100%) rename openequivariance/{ => openequivariance}/extension/libtorch_tp_jit.cpp (100%) rename openequivariance/{ => openequivariance}/extension/tensorproducts.hpp (100%) rename openequivariance/{ => openequivariance}/extension/test/CMakeLists.txt (100%) rename openequivariance/{ => openequivariance}/extension/test/load_jitscript.cpp (100%) rename openequivariance/{ => openequivariance}/extension/util/backend_cuda.hpp (100%) rename openequivariance/{ => openequivariance}/extension/util/backend_hip.hpp (100%) rename openequivariance/{ => openequivariance}/extension/util/buffer.hpp (100%) rename openequivariance/{ => openequivariance}/impl_torch/CUEConv.py (100%) rename openequivariance/{ => openequivariance}/impl_torch/CUETensorProduct.py (100%) rename openequivariance/{ => openequivariance}/impl_torch/E3NNConv.py (100%) rename openequivariance/{ => openequivariance}/impl_torch/E3NNTensorProduct.py (100%) rename openequivariance/{ => openequivariance}/impl_torch/FlashTPConv.py (100%) rename openequivariance/{ => openequivariance}/impl_torch/TensorProduct.py (100%) rename openequivariance/{ => openequivariance}/impl_torch/TensorProductConv.py (100%) rename openequivariance/{ => openequivariance}/impl_torch/extlib/.empty (100%) rename openequivariance/{ => openequivariance}/impl_torch/extlib/__init__.py (100%) rename openequivariance/{ => openequivariance}/impl_torch/symmetric_contraction/__init__.py (100%) rename openequivariance/{ => openequivariance}/impl_torch/symmetric_contraction/symmetric_contraction.py (100%) rename openequivariance/{ => openequivariance}/templates/common.cuh (100%) rename openequivariance/{ => openequivariance}/templates/jinja_utils.py (100%) rename openequivariance/{ => openequivariance}/templates/loop_unroll_batch.cuh (100%) rename openequivariance/{ => openequivariance}/templates/loop_unroll_conv_atomic.cuh (100%) rename openequivariance/{ => openequivariance}/templates/loop_unroll_conv_det.cuh (100%) rename openequivariance/{ => openequivariance}/templates/loop_unroll_tp.cuh (100%) rename openequivariance/{ => openequivariance}/templates/macros.jinja (100%) rename openequivariance/{ => openequivariance}/templates/wmm.cuh (100%) rename pyproject.toml => openequivariance/pyproject.toml (94%) rename {tests => openequivariance/tests}/batch_test.py (100%) rename {tests => openequivariance/tests}/benchmark.py (100%) rename {tests => openequivariance/tests}/conv_test.py (100%) rename {tests => openequivariance/tests}/examples_test.py (100%) rename {tests => openequivariance/tests}/export_test.py (100%) rename {tests => openequivariance/tests}/import_test.py (100%) rename {tests => openequivariance/tests}/input_validation_test.py (100%) rename {tests => openequivariance/tests}/mace_driver.py (100%) rename {tests => openequivariance/tests}/multidevice_test.py (100%) rename {tests => openequivariance/tests}/stream_test.py (100%) rename {tests => openequivariance/tests}/torch_determinism_test.py (100%) diff --git a/MANIFEST.in b/openequivariance/MANIFEST.in similarity index 73% rename from MANIFEST.in rename to openequivariance/MANIFEST.in index f70b3e00..2632d44c 100644 --- a/MANIFEST.in +++ b/openequivariance/MANIFEST.in @@ -1,6 +1,4 @@ -include openequivariance/extlib/*.empty - -include openequivariance/templates/*.cuh +include templates/*.cuh include openequivariance/templates/*.jinja include openequivariance/extension/* diff --git a/openequivariance/__init__.py b/openequivariance/openequivariance/__init__.py similarity index 100% rename from openequivariance/__init__.py rename to openequivariance/openequivariance/__init__.py diff --git a/openequivariance/benchmark/ConvBenchmarkSuite.py b/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py similarity index 100% rename from openequivariance/benchmark/ConvBenchmarkSuite.py rename to openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py diff --git a/openequivariance/benchmark/TestBenchmarkSuite.py b/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py similarity index 100% rename from openequivariance/benchmark/TestBenchmarkSuite.py rename to openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py diff --git a/openequivariance/benchmark/benchmark_utils.py b/openequivariance/openequivariance/benchmark/benchmark_utils.py similarity index 100% rename from openequivariance/benchmark/benchmark_utils.py rename to openequivariance/openequivariance/benchmark/benchmark_utils.py diff --git a/openequivariance/benchmark/correctness_utils.py b/openequivariance/openequivariance/benchmark/correctness_utils.py similarity index 100% rename from openequivariance/benchmark/correctness_utils.py rename to openequivariance/openequivariance/benchmark/correctness_utils.py diff --git a/openequivariance/benchmark/logging_utils.py b/openequivariance/openequivariance/benchmark/logging_utils.py similarity index 100% rename from openequivariance/benchmark/logging_utils.py rename to openequivariance/openequivariance/benchmark/logging_utils.py diff --git a/openequivariance/benchmark/perf_metrics_utils.py b/openequivariance/openequivariance/benchmark/perf_metrics_utils.py similarity index 100% rename from openequivariance/benchmark/perf_metrics_utils.py rename to openequivariance/openequivariance/benchmark/perf_metrics_utils.py diff --git a/openequivariance/benchmark/plotting/__init__.py b/openequivariance/openequivariance/benchmark/plotting/__init__.py similarity index 100% rename from openequivariance/benchmark/plotting/__init__.py rename to openequivariance/openequivariance/benchmark/plotting/__init__.py diff --git a/openequivariance/benchmark/plotting/plot_convolution.py b/openequivariance/openequivariance/benchmark/plotting/plot_convolution.py similarity index 100% rename from openequivariance/benchmark/plotting/plot_convolution.py rename to openequivariance/openequivariance/benchmark/plotting/plot_convolution.py diff --git a/openequivariance/benchmark/plotting/plot_double_backward.py b/openequivariance/openequivariance/benchmark/plotting/plot_double_backward.py similarity index 100% rename from openequivariance/benchmark/plotting/plot_double_backward.py rename to openequivariance/openequivariance/benchmark/plotting/plot_double_backward.py diff --git a/openequivariance/benchmark/plotting/plot_roofline.py b/openequivariance/openequivariance/benchmark/plotting/plot_roofline.py similarity index 100% rename from openequivariance/benchmark/plotting/plot_roofline.py rename to openequivariance/openequivariance/benchmark/plotting/plot_roofline.py diff --git a/openequivariance/benchmark/plotting/plot_uvu.py b/openequivariance/openequivariance/benchmark/plotting/plot_uvu.py similarity index 100% rename from openequivariance/benchmark/plotting/plot_uvu.py rename to openequivariance/openequivariance/benchmark/plotting/plot_uvu.py diff --git a/openequivariance/benchmark/plotting/plot_uvw.py b/openequivariance/openequivariance/benchmark/plotting/plot_uvw.py similarity index 100% rename from openequivariance/benchmark/plotting/plot_uvw.py rename to openequivariance/openequivariance/benchmark/plotting/plot_uvw.py diff --git a/openequivariance/benchmark/plotting/plotting_utils.py b/openequivariance/openequivariance/benchmark/plotting/plotting_utils.py similarity index 100% rename from openequivariance/benchmark/plotting/plotting_utils.py rename to openequivariance/openequivariance/benchmark/plotting/plotting_utils.py diff --git a/openequivariance/benchmark/problems.py b/openequivariance/openequivariance/benchmark/problems.py similarity index 100% rename from openequivariance/benchmark/problems.py rename to openequivariance/openequivariance/benchmark/problems.py diff --git a/openequivariance/benchmark/random_buffer_utils.py b/openequivariance/openequivariance/benchmark/random_buffer_utils.py similarity index 100% rename from openequivariance/benchmark/random_buffer_utils.py rename to openequivariance/openequivariance/benchmark/random_buffer_utils.py diff --git a/openequivariance/benchmark/tpp_creation_utils.py b/openequivariance/openequivariance/benchmark/tpp_creation_utils.py similarity index 100% rename from openequivariance/benchmark/tpp_creation_utils.py rename to openequivariance/openequivariance/benchmark/tpp_creation_utils.py diff --git a/openequivariance/core/ComputationSchedule.py b/openequivariance/openequivariance/core/ComputationSchedule.py similarity index 100% rename from openequivariance/core/ComputationSchedule.py rename to openequivariance/openequivariance/core/ComputationSchedule.py diff --git a/openequivariance/core/ConvolutionBase.py b/openequivariance/openequivariance/core/ConvolutionBase.py similarity index 100% rename from openequivariance/core/ConvolutionBase.py rename to openequivariance/openequivariance/core/ConvolutionBase.py diff --git a/openequivariance/core/LoopUnrollConv.py b/openequivariance/openequivariance/core/LoopUnrollConv.py similarity index 100% rename from openequivariance/core/LoopUnrollConv.py rename to openequivariance/openequivariance/core/LoopUnrollConv.py diff --git a/openequivariance/core/LoopUnrollTP.py b/openequivariance/openequivariance/core/LoopUnrollTP.py similarity index 100% rename from openequivariance/core/LoopUnrollTP.py rename to openequivariance/openequivariance/core/LoopUnrollTP.py diff --git a/openequivariance/core/TensorProductBase.py b/openequivariance/openequivariance/core/TensorProductBase.py similarity index 100% rename from openequivariance/core/TensorProductBase.py rename to openequivariance/openequivariance/core/TensorProductBase.py diff --git a/openequivariance/core/dtype_enum.py b/openequivariance/openequivariance/core/dtype_enum.py similarity index 100% rename from openequivariance/core/dtype_enum.py rename to openequivariance/openequivariance/core/dtype_enum.py diff --git a/openequivariance/core/e3nn_lite.py b/openequivariance/openequivariance/core/e3nn_lite.py similarity index 100% rename from openequivariance/core/e3nn_lite.py rename to openequivariance/openequivariance/core/e3nn_lite.py diff --git a/openequivariance/core/utils.py b/openequivariance/openequivariance/core/utils.py similarity index 100% rename from openequivariance/core/utils.py rename to openequivariance/openequivariance/core/utils.py diff --git a/openequivariance/extension/CMakeLists.txt b/openequivariance/openequivariance/extension/CMakeLists.txt similarity index 100% rename from openequivariance/extension/CMakeLists.txt rename to openequivariance/openequivariance/extension/CMakeLists.txt diff --git a/openequivariance/extension/convolution.hpp b/openequivariance/openequivariance/extension/convolution.hpp similarity index 100% rename from openequivariance/extension/convolution.hpp rename to openequivariance/openequivariance/extension/convolution.hpp diff --git a/openequivariance/extension/generic_module.cpp b/openequivariance/openequivariance/extension/generic_module.cpp similarity index 100% rename from openequivariance/extension/generic_module.cpp rename to openequivariance/openequivariance/extension/generic_module.cpp diff --git a/openequivariance/extension/group_mm_cuda.hpp b/openequivariance/openequivariance/extension/group_mm_cuda.hpp similarity index 100% rename from openequivariance/extension/group_mm_cuda.hpp rename to openequivariance/openequivariance/extension/group_mm_cuda.hpp diff --git a/openequivariance/extension/group_mm_hip.hpp b/openequivariance/openequivariance/extension/group_mm_hip.hpp similarity index 100% rename from openequivariance/extension/group_mm_hip.hpp rename to openequivariance/openequivariance/extension/group_mm_hip.hpp diff --git a/openequivariance/extension/libjax_tp_jit.cpp b/openequivariance/openequivariance/extension/libjax_tp_jit.cpp similarity index 100% rename from openequivariance/extension/libjax_tp_jit.cpp rename to openequivariance/openequivariance/extension/libjax_tp_jit.cpp diff --git a/openequivariance/extension/libtorch_tp_jit.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp similarity index 100% rename from openequivariance/extension/libtorch_tp_jit.cpp rename to openequivariance/openequivariance/extension/libtorch_tp_jit.cpp diff --git a/openequivariance/extension/tensorproducts.hpp b/openequivariance/openequivariance/extension/tensorproducts.hpp similarity index 100% rename from openequivariance/extension/tensorproducts.hpp rename to openequivariance/openequivariance/extension/tensorproducts.hpp diff --git a/openequivariance/extension/test/CMakeLists.txt b/openequivariance/openequivariance/extension/test/CMakeLists.txt similarity index 100% rename from openequivariance/extension/test/CMakeLists.txt rename to openequivariance/openequivariance/extension/test/CMakeLists.txt diff --git a/openequivariance/extension/test/load_jitscript.cpp b/openequivariance/openequivariance/extension/test/load_jitscript.cpp similarity index 100% rename from openequivariance/extension/test/load_jitscript.cpp rename to openequivariance/openequivariance/extension/test/load_jitscript.cpp diff --git a/openequivariance/extension/util/backend_cuda.hpp b/openequivariance/openequivariance/extension/util/backend_cuda.hpp similarity index 100% rename from openequivariance/extension/util/backend_cuda.hpp rename to openequivariance/openequivariance/extension/util/backend_cuda.hpp diff --git a/openequivariance/extension/util/backend_hip.hpp b/openequivariance/openequivariance/extension/util/backend_hip.hpp similarity index 100% rename from openequivariance/extension/util/backend_hip.hpp rename to openequivariance/openequivariance/extension/util/backend_hip.hpp diff --git a/openequivariance/extension/util/buffer.hpp b/openequivariance/openequivariance/extension/util/buffer.hpp similarity index 100% rename from openequivariance/extension/util/buffer.hpp rename to openequivariance/openequivariance/extension/util/buffer.hpp diff --git a/openequivariance/impl_torch/CUEConv.py b/openequivariance/openequivariance/impl_torch/CUEConv.py similarity index 100% rename from openequivariance/impl_torch/CUEConv.py rename to openequivariance/openequivariance/impl_torch/CUEConv.py diff --git a/openequivariance/impl_torch/CUETensorProduct.py b/openequivariance/openequivariance/impl_torch/CUETensorProduct.py similarity index 100% rename from openequivariance/impl_torch/CUETensorProduct.py rename to openequivariance/openequivariance/impl_torch/CUETensorProduct.py diff --git a/openequivariance/impl_torch/E3NNConv.py b/openequivariance/openequivariance/impl_torch/E3NNConv.py similarity index 100% rename from openequivariance/impl_torch/E3NNConv.py rename to openequivariance/openequivariance/impl_torch/E3NNConv.py diff --git a/openequivariance/impl_torch/E3NNTensorProduct.py b/openequivariance/openequivariance/impl_torch/E3NNTensorProduct.py similarity index 100% rename from openequivariance/impl_torch/E3NNTensorProduct.py rename to openequivariance/openequivariance/impl_torch/E3NNTensorProduct.py diff --git a/openequivariance/impl_torch/FlashTPConv.py b/openequivariance/openequivariance/impl_torch/FlashTPConv.py similarity index 100% rename from openequivariance/impl_torch/FlashTPConv.py rename to openequivariance/openequivariance/impl_torch/FlashTPConv.py diff --git a/openequivariance/impl_torch/TensorProduct.py b/openequivariance/openequivariance/impl_torch/TensorProduct.py similarity index 100% rename from openequivariance/impl_torch/TensorProduct.py rename to openequivariance/openequivariance/impl_torch/TensorProduct.py diff --git a/openequivariance/impl_torch/TensorProductConv.py b/openequivariance/openequivariance/impl_torch/TensorProductConv.py similarity index 100% rename from openequivariance/impl_torch/TensorProductConv.py rename to openequivariance/openequivariance/impl_torch/TensorProductConv.py diff --git a/openequivariance/impl_torch/extlib/.empty b/openequivariance/openequivariance/impl_torch/extlib/.empty similarity index 100% rename from openequivariance/impl_torch/extlib/.empty rename to openequivariance/openequivariance/impl_torch/extlib/.empty diff --git a/openequivariance/impl_torch/extlib/__init__.py b/openequivariance/openequivariance/impl_torch/extlib/__init__.py similarity index 100% rename from openequivariance/impl_torch/extlib/__init__.py rename to openequivariance/openequivariance/impl_torch/extlib/__init__.py diff --git a/openequivariance/impl_torch/symmetric_contraction/__init__.py b/openequivariance/openequivariance/impl_torch/symmetric_contraction/__init__.py similarity index 100% rename from openequivariance/impl_torch/symmetric_contraction/__init__.py rename to openequivariance/openequivariance/impl_torch/symmetric_contraction/__init__.py diff --git a/openequivariance/impl_torch/symmetric_contraction/symmetric_contraction.py b/openequivariance/openequivariance/impl_torch/symmetric_contraction/symmetric_contraction.py similarity index 100% rename from openequivariance/impl_torch/symmetric_contraction/symmetric_contraction.py rename to openequivariance/openequivariance/impl_torch/symmetric_contraction/symmetric_contraction.py diff --git a/openequivariance/templates/common.cuh b/openequivariance/openequivariance/templates/common.cuh similarity index 100% rename from openequivariance/templates/common.cuh rename to openequivariance/openequivariance/templates/common.cuh diff --git a/openequivariance/templates/jinja_utils.py b/openequivariance/openequivariance/templates/jinja_utils.py similarity index 100% rename from openequivariance/templates/jinja_utils.py rename to openequivariance/openequivariance/templates/jinja_utils.py diff --git a/openequivariance/templates/loop_unroll_batch.cuh b/openequivariance/openequivariance/templates/loop_unroll_batch.cuh similarity index 100% rename from openequivariance/templates/loop_unroll_batch.cuh rename to openequivariance/openequivariance/templates/loop_unroll_batch.cuh diff --git a/openequivariance/templates/loop_unroll_conv_atomic.cuh b/openequivariance/openequivariance/templates/loop_unroll_conv_atomic.cuh similarity index 100% rename from openequivariance/templates/loop_unroll_conv_atomic.cuh rename to openequivariance/openequivariance/templates/loop_unroll_conv_atomic.cuh diff --git a/openequivariance/templates/loop_unroll_conv_det.cuh b/openequivariance/openequivariance/templates/loop_unroll_conv_det.cuh similarity index 100% rename from openequivariance/templates/loop_unroll_conv_det.cuh rename to openequivariance/openequivariance/templates/loop_unroll_conv_det.cuh diff --git a/openequivariance/templates/loop_unroll_tp.cuh b/openequivariance/openequivariance/templates/loop_unroll_tp.cuh similarity index 100% rename from openequivariance/templates/loop_unroll_tp.cuh rename to openequivariance/openequivariance/templates/loop_unroll_tp.cuh diff --git a/openequivariance/templates/macros.jinja b/openequivariance/openequivariance/templates/macros.jinja similarity index 100% rename from openequivariance/templates/macros.jinja rename to openequivariance/openequivariance/templates/macros.jinja diff --git a/openequivariance/templates/wmm.cuh b/openequivariance/openequivariance/templates/wmm.cuh similarity index 100% rename from openequivariance/templates/wmm.cuh rename to openequivariance/openequivariance/templates/wmm.cuh diff --git a/pyproject.toml b/openequivariance/pyproject.toml similarity index 94% rename from pyproject.toml rename to openequivariance/pyproject.toml index 6f7b8771..8551501e 100644 --- a/pyproject.toml +++ b/openequivariance/pyproject.toml @@ -64,10 +64,10 @@ dev = [ ] [tool.setuptools.packages.find] -include = ["openequivariance*"] +include = ["."] [tool.setuptools_scm] -# Presence of this section necessary, even if empty +root = ".." [tool.pytest.ini_options] addopts = [ diff --git a/tests/batch_test.py b/openequivariance/tests/batch_test.py similarity index 100% rename from tests/batch_test.py rename to openequivariance/tests/batch_test.py diff --git a/tests/benchmark.py b/openequivariance/tests/benchmark.py similarity index 100% rename from tests/benchmark.py rename to openequivariance/tests/benchmark.py diff --git a/tests/conv_test.py b/openequivariance/tests/conv_test.py similarity index 100% rename from tests/conv_test.py rename to openequivariance/tests/conv_test.py diff --git a/tests/examples_test.py b/openequivariance/tests/examples_test.py similarity index 100% rename from tests/examples_test.py rename to openequivariance/tests/examples_test.py diff --git a/tests/export_test.py b/openequivariance/tests/export_test.py similarity index 100% rename from tests/export_test.py rename to openequivariance/tests/export_test.py diff --git a/tests/import_test.py b/openequivariance/tests/import_test.py similarity index 100% rename from tests/import_test.py rename to openequivariance/tests/import_test.py diff --git a/tests/input_validation_test.py b/openequivariance/tests/input_validation_test.py similarity index 100% rename from tests/input_validation_test.py rename to openequivariance/tests/input_validation_test.py diff --git a/tests/mace_driver.py b/openequivariance/tests/mace_driver.py similarity index 100% rename from tests/mace_driver.py rename to openequivariance/tests/mace_driver.py diff --git a/tests/multidevice_test.py b/openequivariance/tests/multidevice_test.py similarity index 100% rename from tests/multidevice_test.py rename to openequivariance/tests/multidevice_test.py diff --git a/tests/stream_test.py b/openequivariance/tests/stream_test.py similarity index 100% rename from tests/stream_test.py rename to openequivariance/tests/stream_test.py diff --git a/tests/torch_determinism_test.py b/openequivariance/tests/torch_determinism_test.py similarity index 100% rename from tests/torch_determinism_test.py rename to openequivariance/tests/torch_determinism_test.py From 45348a20acf7a88fc368fd529bc1e60f7e907d53 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 24 Nov 2025 18:15:50 -0800 Subject: [PATCH 011/116] Temp commit. --- openequivariance/openequivariance/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openequivariance/openequivariance/__init__.py b/openequivariance/openequivariance/__init__.py index 04527921..b3bee873 100644 --- a/openequivariance/openequivariance/__init__.py +++ b/openequivariance/openequivariance/__init__.py @@ -34,7 +34,7 @@ def _check_package_editable(): return json.loads(direct_url).get("dir_info", {}).get("editable", False) -_editable_install_output_path = Path(__file__).parent.parent / "outputs" +_editable_install_output_path = Path(__file__).parent.parent.parent / "outputs" def torch_ext_so_path(): From 583c24936bed0e5308a6ce8217764058816492a3 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 24 Nov 2025 20:29:12 -0800 Subject: [PATCH 012/116] Got the editable install working again. --- openequivariance/MANIFEST.in | 7 +------ openequivariance/pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/openequivariance/MANIFEST.in b/openequivariance/MANIFEST.in index 2632d44c..ab5b72e7 100644 --- a/openequivariance/MANIFEST.in +++ b/openequivariance/MANIFEST.in @@ -1,7 +1,2 @@ include templates/*.cuh -include openequivariance/templates/*.jinja - -include openequivariance/extension/* -include openequivariance/extension/convolution/* -include openequivariance/extension/tensorproducts/* -include openequivariance/extension/util/* \ No newline at end of file +include openequivariance/templates/*.jinja \ No newline at end of file diff --git a/openequivariance/pyproject.toml b/openequivariance/pyproject.toml index 8551501e..56c49b54 100644 --- a/openequivariance/pyproject.toml +++ b/openequivariance/pyproject.toml @@ -64,7 +64,7 @@ dev = [ ] [tool.setuptools.packages.find] -include = ["."] +include = ["openequivariance*"] [tool.setuptools_scm] root = ".." From e1aa9518cb9b61de0b98204c7c64385b859d6f1f Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 24 Nov 2025 22:28:46 -0800 Subject: [PATCH 013/116] Extension module in progress. --- openequivariance/openequivariance/__init__.py | 6 +++ openequivariance/pyproject.toml | 11 ++-- .../CMakeLists.txt | 23 ++++++--- openequivariance_extjax/pyproject.toml | 50 +++++++++++++++++++ .../src}/libjax_tp_jit.cpp | 0 5 files changed, 79 insertions(+), 11 deletions(-) rename {openequivariance/openequivariance/extension => openequivariance_extjax}/CMakeLists.txt (61%) create mode 100644 openequivariance_extjax/pyproject.toml rename {openequivariance/openequivariance/extension => openequivariance_extjax/src}/libjax_tp_jit.cpp (100%) diff --git a/openequivariance/openequivariance/__init__.py b/openequivariance/openequivariance/__init__.py index b3bee873..2097a33b 100644 --- a/openequivariance/openequivariance/__init__.py +++ b/openequivariance/openequivariance/__init__.py @@ -45,6 +45,12 @@ def torch_ext_so_path(): return openequivariance.impl_torch.extlib.torch_module.__file__ +def extension_source_path(): + """ + :returns: Path to the source code of the C++ extension. + """ + return str(Path(__file__).parent / "extension") + torch.serialization.add_safe_globals( [ TensorProduct, diff --git a/openequivariance/pyproject.toml b/openequivariance/pyproject.toml index 56c49b54..5598dde5 100644 --- a/openequivariance/pyproject.toml +++ b/openequivariance/pyproject.toml @@ -17,10 +17,9 @@ dependencies = [ "setuptools", "ninja", "jinja2", - "numpy", - "torch >= 2.4", + "numpy" ] -readme = "README.md" +readme = "../README.md" license = "BSD-3-Clause" license-files = ["LICENSE"] @@ -49,6 +48,12 @@ bench = [ "cuequivariance-ops-torch-cu12", ] +jax = [ + "jax[cuda12]", + "nanobind", + "scikit-build-core" +] + dev = [ "e3nn", "pre-commit", diff --git a/openequivariance/openequivariance/extension/CMakeLists.txt b/openequivariance_extjax/CMakeLists.txt similarity index 61% rename from openequivariance/openequivariance/extension/CMakeLists.txt rename to openequivariance_extjax/CMakeLists.txt index 15f7fa48..f27cb350 100644 --- a/openequivariance/openequivariance/extension/CMakeLists.txt +++ b/openequivariance_extjax/CMakeLists.txt @@ -8,31 +8,38 @@ execute_process( "from jax import ffi; print(ffi.include_dir())" OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR ) - message(STATUS "XLA include directory: ${XLA_DIR}") execute_process( COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE nanobind_ROOT ) +message(STATUS "nanobind cmake directory: ${nanobind_ROOT}") + +execute_process( + COMMAND "${Python_EXECUTABLE}" "-c" + "import openequivariance; print(openequivariance.extension_source_path())" + OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE HEADER_DIR +) +message(STATUS "OpenEquivariance extension source directory: ${HEADER_DIR}") find_package(nanobind CONFIG REQUIRED) set(OEQ_JAX_SOURCES - libjax_tp_jit.cpp + src/libjax_tp_jit.cpp ) set(OEQ_JAX_HEADERS - convolution.hpp - tensorproducts.hpp - util/backend_cuda.hpp - util/backend_hip.hpp - util/buffer.hpp + ${HEADER_DIR}/convolution.hpp + ${HEADER_DIR}/tensorproducts.hpp + ${HEADER_DIR}/util/backend_cuda.hpp + ${HEADER_DIR}/util/backend_hip.hpp + ${HEADER_DIR}/util/buffer.hpp ) nanobind_add_module(oeq_jax_extension NB_STATIC ${OEQ_JAX_SOURCES} ${OEQ_JAX_HEADERS}) -target_include_directories(oeq_jax_extension PUBLIC ${XLA_DIR}) +target_include_directories(oeq_jax_extension PUBLIC ${XLA_DIR} ${HEADER_DIR}) set_target_properties(oeq_jax_extension PROPERTIES CUDA_STANDARD 17 POSITION_INDEPENDENT_CODE ON) target_compile_options(oeq_jax_extension PRIVATE -Wno-attributes -Wno-return-type) diff --git a/openequivariance_extjax/pyproject.toml b/openequivariance_extjax/pyproject.toml new file mode 100644 index 00000000..17165af7 --- /dev/null +++ b/openequivariance_extjax/pyproject.toml @@ -0,0 +1,50 @@ +[build-system] +requires = [ + "setuptools-scm", + "scikit-build-core", + "nanobind" +] +build-backend = "scikit_build_core.build" + +[project] +name = "openequivariance_extjax" +dynamic = ["version"] +authors = [ + { name="Austin Glover" }, + { name="Vivek Bharadwaj" }, + { name="Aydin Buluc" }, + { name="James Demmel" } +] +description = "JAX C++ Extension for OpenEquivariance" +requires-python = ">=3.10" + +dependencies = [] +readme = "../README.md" + +#license = "BSD-3-Clause" +#license-files = ["../LICENSE"] + +classifiers = [ + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", +] + +[project.urls] +homepage = "https://passionlab.github.io/OpenEquivariance/" +source = "https://github.com/PASSIONLab/OpenEquivariance" +issues = "https://github.com/PASSIONLab/OpenEquivariance/issues" + + +[tool.setuptools_scm] +root = ".." + +[tool.pytest.ini_options] +addopts = [ + "--import-mode=importlib", +] + +[tool.ruff] +lint.ignore = ["E741"] \ No newline at end of file diff --git a/openequivariance/openequivariance/extension/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp similarity index 100% rename from openequivariance/openequivariance/extension/libjax_tp_jit.cpp rename to openequivariance_extjax/src/libjax_tp_jit.cpp From d7aa1ea7f32bee95ef3acaeb8af6aeb69dfaeada Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 24 Nov 2025 23:39:39 -0800 Subject: [PATCH 014/116] More things working. --- openequivariance_extjax/CMakeLists.txt | 13 ++++++------- openequivariance_extjax/src/libjax_tp_jit.cpp | 2 +- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/openequivariance_extjax/CMakeLists.txt b/openequivariance_extjax/CMakeLists.txt index f27cb350..e1252f3d 100644 --- a/openequivariance_extjax/CMakeLists.txt +++ b/openequivariance_extjax/CMakeLists.txt @@ -1,5 +1,5 @@ cmake_minimum_required(VERSION 3.15...3.30) -project(oeq_jax_extension LANGUAGES CXX CUDA) # TODO: Add HIP support +project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX CUDA) # TODO: Add HIP support find_package(Python 3.11 REQUIRED COMPONENTS Interpreter Development.Module) @@ -37,10 +37,9 @@ set(OEQ_JAX_HEADERS ${HEADER_DIR}/util/buffer.hpp ) -nanobind_add_module(oeq_jax_extension NB_STATIC ${OEQ_JAX_SOURCES} ${OEQ_JAX_HEADERS}) +nanobind_add_module(openequivariance_extjax NB_STATIC ${OEQ_JAX_SOURCES} ${OEQ_JAX_HEADERS}) -target_include_directories(oeq_jax_extension PUBLIC ${XLA_DIR} ${HEADER_DIR}) -set_target_properties(oeq_jax_extension PROPERTIES CUDA_STANDARD 17 POSITION_INDEPENDENT_CODE ON) -target_compile_options(oeq_jax_extension PRIVATE -Wno-attributes -Wno-return-type) - -install(TARGETS oeq_jax_extension LIBRARY DESTINATION lib) \ No newline at end of file +target_include_directories(openequivariance_extjax PUBLIC ${XLA_DIR} ${HEADER_DIR}) +set_target_properties(openequivariance_extjax PROPERTIES CUDA_STANDARD 17 POSITION_INDEPENDENT_CODE ON) +target_compile_options(openequivariance_extjax PRIVATE -Wno-attributes -Wno-return-type) +install(TARGETS openequivariance_extjax LIBRARY DESTINATION .) \ No newline at end of file diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp index 3cf922ad..b48e3e77 100644 --- a/openequivariance_extjax/src/libjax_tp_jit.cpp +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -104,7 +104,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Ret>(), {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled -NB_MODULE(oeq_jax_extension, m) { +NB_MODULE(openequivariance_extjax, m) { m.def("registrations", []() { nb::dict registrations; registrations["tp_forward"] = nb::capsule(reinterpret_cast(tp_forward)); From d2cec505db96227fd79f7a610a92e7e4794daf2b Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Tue, 25 Nov 2025 00:09:29 -0800 Subject: [PATCH 015/116] Began putting together a test rig. --- .../impl_jax/TensorProduct.py | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 openequivariance/openequivariance/impl_jax/TensorProduct.py diff --git a/openequivariance/openequivariance/impl_jax/TensorProduct.py b/openequivariance/openequivariance/impl_jax/TensorProduct.py new file mode 100644 index 00000000..f572e71e --- /dev/null +++ b/openequivariance/openequivariance/impl_jax/TensorProduct.py @@ -0,0 +1,47 @@ +import numpy as np + +import jax +import openequivariance_extjax as oeq_extjax +import hashlib +from openequivariance.core.e3nn_lite import TPProblem + +for name, target in oeq_extjax.registrations().items(): + jax.ffi.register_ffi_target(name, target, platform="CUDA") + +def hash_attributes(attrs): + m = hashlib.sha256() + for key in sorted(attrs.keys()): + m.update(attrs[key].__repr__().encode("utf-8")) + + hash = int(m.hexdigest()[-16:], 16) + attrs["hash"] = hash + +class TensorProduct: + def __init__(self, problem: TPProblem): + self.problem = problem + + self.kernel = "BLAH" + self.forward_config = {"example_key": 42} + self.backward_config = {} + self.double_backward_config = {} + self.kernel_prop = {} + self.attrs = { + "kernel": self.kernel, + "forward_config": self.forward_config, + "backward_config": self.backward_config, + "double_backward_config": self.double_backward_config, + "kernel_prop": self.kernel_prop + } + hash_attributes(self.attrs) + + self.forward_call = jax.ffi.ffi_call( + "tp_forward", + jax.ShapeDtypeStruct((), jax.numpy.int32)) + + self.forward_call(**self.attrs) + + +if __name__ == "__main__": + tp_problem = None + tensor_product = TensorProduct(tp_problem) + print("COMPLETE!") \ No newline at end of file From 6a7571135ab9d78abab2f19945e7ca7f6bdb1fc2 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Tue, 25 Nov 2025 02:04:58 -0800 Subject: [PATCH 016/116] Things starting to work. --- .../openequivariance/impl_jax/TensorProduct.py | 10 ++++++---- openequivariance_extjax/CMakeLists.txt | 17 ++++++++++++++++- openequivariance_extjax/src/libjax_tp_jit.cpp | 7 +++++-- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/openequivariance/openequivariance/impl_jax/TensorProduct.py b/openequivariance/openequivariance/impl_jax/TensorProduct.py index f572e71e..67eb2a1d 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProduct.py +++ b/openequivariance/openequivariance/impl_jax/TensorProduct.py @@ -3,25 +3,27 @@ import jax import openequivariance_extjax as oeq_extjax import hashlib -from openequivariance.core.e3nn_lite import TPProblem +#from openequivariance.core.e3nn_lite import TPProblem for name, target in oeq_extjax.registrations().items(): + print(name, target) jax.ffi.register_ffi_target(name, target, platform="CUDA") def hash_attributes(attrs): m = hashlib.sha256() + for key in sorted(attrs.keys()): m.update(attrs[key].__repr__().encode("utf-8")) - hash = int(m.hexdigest()[-16:], 16) + hash = int(m.hexdigest()[:16], 16) >> 1 attrs["hash"] = hash class TensorProduct: - def __init__(self, problem: TPProblem): + def __init__(self, problem): self.problem = problem self.kernel = "BLAH" - self.forward_config = {"example_key": 42} + self.forward_config = {"num_blocks": 42, "num_threads": 256, "smem": 8192 } self.backward_config = {} self.double_backward_config = {} self.kernel_prop = {} diff --git a/openequivariance_extjax/CMakeLists.txt b/openequivariance_extjax/CMakeLists.txt index e1252f3d..f40272f2 100644 --- a/openequivariance_extjax/CMakeLists.txt +++ b/openequivariance_extjax/CMakeLists.txt @@ -1,7 +1,8 @@ cmake_minimum_required(VERSION 3.15...3.30) -project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX CUDA) # TODO: Add HIP support +project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX) # TODO: Add HIP support find_package(Python 3.11 REQUIRED COMPONENTS Interpreter Development.Module) +find_package(CUDAToolkit REQUIRED) execute_process( COMMAND "${Python_EXECUTABLE}" "-c" @@ -42,4 +43,18 @@ nanobind_add_module(openequivariance_extjax NB_STATIC ${OEQ_JAX_SOURCES} ${OEQ_J target_include_directories(openequivariance_extjax PUBLIC ${XLA_DIR} ${HEADER_DIR}) set_target_properties(openequivariance_extjax PROPERTIES CUDA_STANDARD 17 POSITION_INDEPENDENT_CODE ON) target_compile_options(openequivariance_extjax PRIVATE -Wno-attributes -Wno-return-type) + +get_target_property(CUDA_LIB_DIR CUDA::nvrtc IMPORTED_LOCATION) +get_filename_component(CUDA_LIB_DIR ${CUDA_LIB_DIR} DIRECTORY) + +set_target_properties(openequivariance_extjax PROPERTIES + BUILD_RPATH "${CUDA_LIB_DIR}" + INSTALL_RPATH "${CUDA_LIB_DIR}" +) + +target_link_libraries(openequivariance_extjax PRIVATE + CUDA::cudart + CUDA::cuda_driver + CUDA::nvrtc) + install(TARGETS openequivariance_extjax LIBRARY DESTINATION .) \ No newline at end of file diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp index b48e3e77..ef48aec2 100644 --- a/openequivariance_extjax/src/libjax_tp_jit.cpp +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -68,14 +68,17 @@ JITTPImpl* compile_kernel_with_caching(std::string_view kernel, result = it->second.get(); } else { - auto jit_tp_impl = std::make_unique>( + auto result = parse_ffi_dict(forward_config, launch_config_keys); + + cout << result["smem"] << endl; + /*auto jit_tp_impl = std::make_unique>( std::string(kernel), parse_ffi_dict(forward_config, launch_config_keys), parse_ffi_dict(backward_config, launch_config_keys), parse_ffi_dict(double_backward_config, launch_config_keys), parse_ffi_dict(kernel_prop, kernel_prop_keys)); result = jit_tp_impl.get(); - kernel_cache.insert({hash, std::move(jit_tp_impl)}); + kernel_cache.insert({hash, std::move(jit_tp_impl)});*/ } } return result; From f07592fbaaf7e0f94414c964512408191f583420 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Tue, 25 Nov 2025 12:58:05 -0800 Subject: [PATCH 017/116] Made LoopUnrollTP generic. --- .../openequivariance/core/LoopUnrollTP.py | 186 ++---------------- .../impl_torch/TensorProduct.py | 153 +++++++++++++- 2 files changed, 163 insertions(+), 176 deletions(-) diff --git a/openequivariance/openequivariance/core/LoopUnrollTP.py b/openequivariance/openequivariance/core/LoopUnrollTP.py index 84bebc42..c04c6846 100644 --- a/openequivariance/openequivariance/core/LoopUnrollTP.py +++ b/openequivariance/openequivariance/core/LoopUnrollTP.py @@ -1,27 +1,21 @@ import numpy as np -import openequivariance.impl_torch.extlib as extlib from openequivariance.templates.jinja_utils import get_jinja_environment from openequivariance.core.ComputationSchedule import ComputationSchedule - -from openequivariance.core.dtype_enum import dtype_to_enum from openequivariance.core.TensorProductBase import TensorProductBase -from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.core.dtype_enum import dtype_to_enum + from openequivariance.core.utils import ( filter_and_analyze_problem, count_cg_non_zero, ) -logger = getLogger() - - class LoopUnrollTP(TensorProductBase): - def __init__(self, config, torch_op=True): + def __init__(self, config, dp, postprocess_kernel, torch_op=True): super().__init__(config, torch_op=torch_op) env = get_jinja_environment() template = env.get_template("loop_unroll_batch.cuh") - dp = extlib.DeviceProp(0) analysis = filter_and_analyze_problem(config) self.is_uvw = analysis["is_uvw"] @@ -88,7 +82,7 @@ def generate_double_backward_schedule(warps_per_block): "Tensor product schedule generation failed, shared memory inadequate!" ) - self.jit_kernel = extlib.postprocess_kernel( + self.jit_kernel = postprocess_kernel( template.render( forward_schedule=self.forward_schedule, backward_schedule=self.backward_schedule, @@ -96,163 +90,17 @@ def generate_double_backward_schedule(warps_per_block): ) ) - # with open("scratch.txt", "w") as f: - # f.write(self.jit_kernel) - - internal_cls = None - if self.torch_op and extlib.TORCH_COMPILE: - global torch - import torch - - internal_cls = torch.classes.libtorch_tp_jit.TorchJITProduct - else: - internal_cls = extlib.JITTPImpl - - logger.info("Starting kernel compiler...") - self.internal = internal_cls( - self.jit_kernel, - vars(self.forward_schedule.launch_config), - vars(self.backward_schedule.launch_config), - vars(self.double_backward_schedule.launch_config), - { - "L1_dim": self.L1.dim, - "L2_dim": self.L2.dim, - "L3_dim": self.L3.dim, - "weight_numel": self.config.weight_numel, - "shared_weights": int(self.config.shared_weights), - "opt_level": 3, - "irrep_dtype": dtype_to_enum[self.config.irrep_dtype], - "weight_dtype": dtype_to_enum[self.config.weight_dtype], - }, - ) - logger.info("Kernel compiled!") - logger.info(f"Kernel File Size: {len(self.jit_kernel) // 1024} KB") - - def reorder_weights_from_e3nn(self, weights, has_batch_dim=True): - return self.forward_schedule.reorder_weights_from_e3nn(weights, has_batch_dim) - - def reorder_weights_to_e3nn(self, weights, has_batch_dim=True): - return self.forward_schedule.reorder_weights_to_e3nn(weights, has_batch_dim) - - @classmethod - def register_torch_fakes(cls): - global torch - import torch - - @torch._library.register_fake_class("libtorch_tp_jit::TorchJITProduct") - class TorchJITProduct: - def __init__( - self, - kernel_plaintext: str, - fwd_config: dict[str, int], - bwd_config: dict[str, int], - dbl_bwd_config: dict[str, int], - kernel_dims: dict[str, int], - ) -> None: - ( - self.kernel_plaintext, - self.fwd_config, - self.bwd_config, - self.dbl_bwd_config, - self.kernel_dims, - ) = ( - kernel_plaintext, - fwd_config, - bwd_config, - dbl_bwd_config, - kernel_dims, - ) - - @classmethod - def __obj_unflatten__(cls, flattened_product): - return cls(**dict(flattened_product)) - - def __len__(self): - return 0 - - def __setstate__(self, state): - self.kernel_plaintext = state["kernel_plaintext"] - self.fwd_config = state["fwd_config"] - self.bwd_config = state["bwd_config"] - self.dbl_bwd_config = state["dbl_bwd_config"] - self.kernel_dims = state["kernel_dims"] - - def exec_tensor_product_rawptr(*args, **kwargs): - pass - - def backward_rawptr(*args, **kwargs): - pass - - def L3_dim_getter(self): - return self.kernel_dims["L3_dim"] - - def irrep_dtype_getter(self): - return self.kernel_dims["irrep_dtype"] - - @torch.library.register_fake("libtorch_tp_jit::jit_tp_forward") - def fake_forward(jit, L1_in, L2_in, W): - L3_dim = None - if hasattr(jit, "wrapped_obj"): - L3_dim = jit.wrapped_obj.kernel_dims["L3_dim"] - else: - L3_dim = jit.L3_dim - - return L1_in.new_empty(L1_in.shape[0], L3_dim) - - @torch.library.register_fake("libtorch_tp_jit::jit_tp_backward") - def fake_backward(jit, L1_in, L2_in, W, L3_grad): - return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W) - - @classmethod - def register_autograd(cls): - backward_op = torch.ops.libtorch_tp_jit.jit_tp_backward - - def setup_context(ctx, inputs, output): - ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights = inputs - - def backward(ctx, grad_output): - L1_grad, L2_grad, W_grad = backward_op( - ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, grad_output - ) - return None, L1_grad, L2_grad, W_grad - - torch.library.register_autograd( - "libtorch_tp_jit::jit_tp_forward", backward, setup_context=setup_context - ) - - def setup_context_double_backward(ctx, inputs, output): - ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad = inputs - - def double_backward(ctx, E, F, G): - result = torch.ops.libtorch_tp_jit.jit_tp_double_backward( - ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad, E, F, G - ) - return None, result[0], result[1], result[2], result[3] - - torch.library.register_autograd( - "libtorch_tp_jit::jit_tp_backward", - double_backward, - setup_context=setup_context_double_backward, - ) - - @classmethod - def register_autocast(cls): - global torch - import torch + self.kernelProp = { + "L1_dim": self.L1.dim, + "L2_dim": self.L2.dim, + "L3_dim": self.L3.dim, + "weight_numel": self.config.weight_numel, + "shared_weights": int(self.config.shared_weights), + "opt_level": 3, + "irrep_dtype": dtype_to_enum[self.config.irrep_dtype], + "weight_dtype": dtype_to_enum[self.config.weight_dtype], + } - torch.library.register_autocast( - "libtorch_tp_jit::jit_tp_forward", "cuda", torch.float32 - ) - torch.library.register_autocast( - "libtorch_tp_jit::jit_tp_backward", "cuda", torch.float32 - ) - torch.library.register_autocast( - "libtorch_tp_jit::jit_tp_double_backward", "cuda", torch.float32 - ) - - @staticmethod - def name(): - return "LoopUnrollTP" def calculate_flops_forward(self, batch_size: int) -> dict: if self.is_uvw: @@ -303,9 +151,3 @@ def calculate_flops_backward(self, batch_size: int) -> dict: flop_count["backward"] *= 9 * batch_size flop_count["total"] = sum(flop_count.values()) return flop_count - - -if extlib.TORCH_COMPILE: - LoopUnrollTP.register_torch_fakes() - LoopUnrollTP.register_autograd() - LoopUnrollTP.register_autocast() diff --git a/openequivariance/openequivariance/impl_torch/TensorProduct.py b/openequivariance/openequivariance/impl_torch/TensorProduct.py index 08ce6cc8..ab98dc2b 100644 --- a/openequivariance/openequivariance/impl_torch/TensorProduct.py +++ b/openequivariance/openequivariance/impl_torch/TensorProduct.py @@ -4,7 +4,9 @@ import torch import typing from openequivariance.core.utils import torch_to_oeq_dtype +from openequivariance.benchmark.logging_utils import getLogger +logger = getLogger() class TensorProduct(torch.nn.Module, LoopUnrollTP): r""" @@ -29,9 +31,28 @@ def __init__(self, problem: TPProblem, torch_op=True, use_opaque=False): self._init_class() def _init_class(self): + dp = extlib.DeviceProp(0) LoopUnrollTP.__init__( - self, self.input_args["problem"], self.input_args["torch_op"] + self, self.input_args["problem"], dp, extlib.postprocess_kernel, self.input_args["torch_op"] ) + + internal_cls = None + if extlib.TORCH_COMPILE: + internal_cls = torch.classes.libtorch_tp_jit.TorchJITProduct + else: + internal_cls = extlib.JITTPImpl + + logger.info("Starting kernel compiler...") + self.internal = internal_cls( + self.jit_kernel, + vars(self.forward_schedule.launch_config), + vars(self.backward_schedule.launch_config), + vars(self.double_backward_schedule.launch_config), + self.kernelProp + ) + logger.info("Kernel compiled!") + logger.info(f"Kernel File Size: {len(self.jit_kernel) // 1024} KB") + self.weight_numel = self.input_args["problem"].weight_numel self._setup_notorchbind() if (not extlib.TORCH_COMPILE) or self.input_args["use_opaque"]: @@ -63,9 +84,11 @@ def __setstate__(self, state): self.input_args = state self._init_class() - @staticmethod - def name(): - return LoopUnrollTP.name() + def reorder_weights_from_e3nn(self, weights, has_batch_dim=True): + return self.forward_schedule.reorder_weights_from_e3nn(weights, has_batch_dim) + + def reorder_weights_to_e3nn(self, weights, has_batch_dim=True): + return self.forward_schedule.reorder_weights_to_e3nn(weights, has_batch_dim) def forward( self, x: torch.Tensor, y: torch.Tensor, W: torch.Tensor @@ -198,3 +221,125 @@ def double_backward(ctx, grad_output): backward_helper.register_autograd( double_backward, setup_context=setup_context_double_backward ) + + @classmethod + def register_torch_fakes(cls): + @torch._library.register_fake_class("libtorch_tp_jit::TorchJITProduct") + class TorchJITProduct: + def __init__( + self, + kernel_plaintext: str, + fwd_config: dict[str, int], + bwd_config: dict[str, int], + dbl_bwd_config: dict[str, int], + kernel_dims: dict[str, int], + ) -> None: + ( + self.kernel_plaintext, + self.fwd_config, + self.bwd_config, + self.dbl_bwd_config, + self.kernel_dims, + ) = ( + kernel_plaintext, + fwd_config, + bwd_config, + dbl_bwd_config, + kernel_dims, + ) + + @classmethod + def __obj_unflatten__(cls, flattened_product): + return cls(**dict(flattened_product)) + + def __len__(self): + return 0 + + def __setstate__(self, state): + self.kernel_plaintext = state["kernel_plaintext"] + self.fwd_config = state["fwd_config"] + self.bwd_config = state["bwd_config"] + self.dbl_bwd_config = state["dbl_bwd_config"] + self.kernel_dims = state["kernel_dims"] + + def exec_tensor_product_rawptr(*args, **kwargs): + pass + + def backward_rawptr(*args, **kwargs): + pass + + def L3_dim_getter(self): + return self.kernel_dims["L3_dim"] + + def irrep_dtype_getter(self): + return self.kernel_dims["irrep_dtype"] + + @torch.library.register_fake("libtorch_tp_jit::jit_tp_forward") + def fake_forward(jit, L1_in, L2_in, W): + L3_dim = None + if hasattr(jit, "wrapped_obj"): + L3_dim = jit.wrapped_obj.kernel_dims["L3_dim"] + else: + L3_dim = jit.L3_dim + + return L1_in.new_empty(L1_in.shape[0], L3_dim) + + @torch.library.register_fake("libtorch_tp_jit::jit_tp_backward") + def fake_backward(jit, L1_in, L2_in, W, L3_grad): + return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W) + + @classmethod + def register_autograd(cls): + backward_op = torch.ops.libtorch_tp_jit.jit_tp_backward + + def setup_context(ctx, inputs, output): + ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights = inputs + + def backward(ctx, grad_output): + L1_grad, L2_grad, W_grad = backward_op( + ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, grad_output + ) + return None, L1_grad, L2_grad, W_grad + + torch.library.register_autograd( + "libtorch_tp_jit::jit_tp_forward", backward, setup_context=setup_context + ) + + def setup_context_double_backward(ctx, inputs, output): + ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad = inputs + + def double_backward(ctx, E, F, G): + result = torch.ops.libtorch_tp_jit.jit_tp_double_backward( + ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad, E, F, G + ) + return None, result[0], result[1], result[2], result[3] + + torch.library.register_autograd( + "libtorch_tp_jit::jit_tp_backward", + double_backward, + setup_context=setup_context_double_backward, + ) + + @classmethod + def register_autocast(cls): + global torch + import torch + + torch.library.register_autocast( + "libtorch_tp_jit::jit_tp_forward", "cuda", torch.float32 + ) + torch.library.register_autocast( + "libtorch_tp_jit::jit_tp_backward", "cuda", torch.float32 + ) + torch.library.register_autocast( + "libtorch_tp_jit::jit_tp_double_backward", "cuda", torch.float32 + ) + + @staticmethod + def name(): + return "LoopUnrollTP" + +if extlib.TORCH_COMPILE: + TensorProduct.register_torch_fakes() + TensorProduct.register_autograd() + TensorProduct.register_autocast() From a4506040e923fd28373069f1c306e323da781de4 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Tue, 25 Nov 2025 13:15:57 -0800 Subject: [PATCH 018/116] More things are working. --- openequivariance/README.md | 6 ++++++ openequivariance/pyproject.toml | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) create mode 100644 openequivariance/README.md diff --git a/openequivariance/README.md b/openequivariance/README.md new file mode 100644 index 00000000..45a0ae38 --- /dev/null +++ b/openequivariance/README.md @@ -0,0 +1,6 @@ +# OpenEquivariance + +This package contains the core implementation of OpenEquivariance, which is fully +sufficient to run the package from PyTorch. For JAX support, see instructions +on installing `openequivariance_extjax` along with this package. + diff --git a/openequivariance/pyproject.toml b/openequivariance/pyproject.toml index 5598dde5..db392c7c 100644 --- a/openequivariance/pyproject.toml +++ b/openequivariance/pyproject.toml @@ -19,7 +19,7 @@ dependencies = [ "jinja2", "numpy" ] -readme = "../README.md" +readme = "README.md" license = "BSD-3-Clause" license-files = ["LICENSE"] From 59bc57c1a25fe877a4120dc8219f15ab51944a80 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Tue, 25 Nov 2025 13:53:50 -0800 Subject: [PATCH 019/116] More plumbing. --- .../openequivariance/core/LoopUnrollTP.py | 2 +- .../impl_jax/TensorProduct.py | 38 +++++++++---------- .../impl_jax/extlib/__init__.py | 17 +++++++++ openequivariance_extjax/src/libjax_tp_jit.cpp | 31 ++++++++++++--- 4 files changed, 61 insertions(+), 27 deletions(-) create mode 100644 openequivariance/openequivariance/impl_jax/extlib/__init__.py diff --git a/openequivariance/openequivariance/core/LoopUnrollTP.py b/openequivariance/openequivariance/core/LoopUnrollTP.py index c04c6846..82a4641e 100644 --- a/openequivariance/openequivariance/core/LoopUnrollTP.py +++ b/openequivariance/openequivariance/core/LoopUnrollTP.py @@ -11,7 +11,7 @@ ) class LoopUnrollTP(TensorProductBase): - def __init__(self, config, dp, postprocess_kernel, torch_op=True): + def __init__(self, config, dp, postprocess_kernel, torch_op): super().__init__(config, torch_op=torch_op) env = get_jinja_environment() diff --git a/openequivariance/openequivariance/impl_jax/TensorProduct.py b/openequivariance/openequivariance/impl_jax/TensorProduct.py index 67eb2a1d..5f174efc 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProduct.py +++ b/openequivariance/openequivariance/impl_jax/TensorProduct.py @@ -1,13 +1,10 @@ import numpy as np import jax -import openequivariance_extjax as oeq_extjax +from openequivariance.impl_jax import extlib import hashlib -#from openequivariance.core.e3nn_lite import TPProblem - -for name, target in oeq_extjax.registrations().items(): - print(name, target) - jax.ffi.register_ffi_target(name, target, platform="CUDA") +from openequivariance.core.e3nn_lite import TPProblem, Irreps +from openequivariance.core.LoopUnrollTP import LoopUnrollTP def hash_attributes(attrs): m = hashlib.sha256() @@ -18,21 +15,17 @@ def hash_attributes(attrs): hash = int(m.hexdigest()[:16], 16) >> 1 attrs["hash"] = hash -class TensorProduct: - def __init__(self, problem): - self.problem = problem +class TensorProduct(LoopUnrollTP): + def __init__(self, config): + dp = extlib.DeviceProp(0) + super().__init__(config, dp, extlib.postprocess_kernel, torch_op=False) - self.kernel = "BLAH" - self.forward_config = {"num_blocks": 42, "num_threads": 256, "smem": 8192 } - self.backward_config = {} - self.double_backward_config = {} - self.kernel_prop = {} self.attrs = { - "kernel": self.kernel, - "forward_config": self.forward_config, - "backward_config": self.backward_config, - "double_backward_config": self.double_backward_config, - "kernel_prop": self.kernel_prop + "kernel": self.jit_kernel, + "forward_config": vars(self.forward_schedule.launch_config), + "backward_config": vars(self.backward_schedule.launch_config), + "double_backward_config": vars(self.double_backward_schedule.launch_config), + "kernel_prop": self.kernelProp } hash_attributes(self.attrs) @@ -44,6 +37,9 @@ def __init__(self, problem): if __name__ == "__main__": - tp_problem = None - tensor_product = TensorProduct(tp_problem) + tp_problem = None + X_ir, Y_ir, Z_ir = Irreps("1x2e"), Irreps("1x3e"), Irreps("1x2e") + instructions=[(0, 0, 0, "uvu", True)] + problem = TPProblem(X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False) + tensor_product = TensorProduct(problem) print("COMPLETE!") \ No newline at end of file diff --git a/openequivariance/openequivariance/impl_jax/extlib/__init__.py b/openequivariance/openequivariance/impl_jax/extlib/__init__.py new file mode 100644 index 00000000..94e24454 --- /dev/null +++ b/openequivariance/openequivariance/impl_jax/extlib/__init__.py @@ -0,0 +1,17 @@ +import jax + +def postprocess_kernel(kernel): + return kernel + +import openequivariance_extjax as oeq_extjax +for name, target in oeq_extjax.registrations().items(): + print(name, target) + jax.ffi.register_ffi_target(name, target, platform="CUDA") + +GPUTimer = oeq_extjax.GPUTimer +DeviceProp = oeq_extjax.DeviceProp + +__all__ = [ + "GPUTimer", + "DeviceProp", +] \ No newline at end of file diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp index ef48aec2..e584b45c 100644 --- a/openequivariance_extjax/src/libjax_tp_jit.cpp +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -108,9 +108,30 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled NB_MODULE(openequivariance_extjax, m) { - m.def("registrations", []() { - nb::dict registrations; - registrations["tp_forward"] = nb::capsule(reinterpret_cast(tp_forward)); - return registrations; - }); + m.def("registrations", []() { + nb::dict registrations; + registrations["tp_forward"] = nb::capsule(reinterpret_cast(tp_forward)); + return registrations; + }); + + nb::class_(m, "DeviceProp") + .def(nb::init()) + .def_ro("name", &DeviceProp::name) + .def_ro("warpsize", &DeviceProp::warpsize) + .def_ro("major", &DeviceProp::major) + .def_ro("minor", &DeviceProp::minor) + .def_ro("multiprocessorCount", &DeviceProp::multiprocessorCount) + .def_ro("maxSharedMemPerBlock", &DeviceProp::maxSharedMemPerBlock); + + nb::class_(m, "GPUTimer") + .def(nb::init<>()) + .def("start", &GPUTimer::start) + .def("stop_clock_get_elapsed", &GPUTimer::stop_clock_get_elapsed) + .def("clear_L2_cache", &GPUTimer::clear_L2_cache); + + /*nb::class_>(m, "DeviceBuffer") + .def(nb::init()) + .def(nb::init()) + .def("copy_to_host", &PyDeviceBuffer::copy_to_host) + .def("data_ptr", &PyDeviceBuffer::data_ptr);*/ } From 8f7c675b22e79e6b27673aea8aa89b9e63cd22e3 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 26 Nov 2025 19:24:32 -0800 Subject: [PATCH 020/116] More progress. --- .../impl_jax/TensorProduct.py | 6 +- .../impl_jax/extlib/__init__.py | 1 - openequivariance_extjax/src/libjax_tp_jit.cpp | 150 +++++++++++++++--- 3 files changed, 128 insertions(+), 29 deletions(-) diff --git a/openequivariance/openequivariance/impl_jax/TensorProduct.py b/openequivariance/openequivariance/impl_jax/TensorProduct.py index 5f174efc..65bc5919 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProduct.py +++ b/openequivariance/openequivariance/impl_jax/TensorProduct.py @@ -32,8 +32,9 @@ def __init__(self, config): self.forward_call = jax.ffi.ffi_call( "tp_forward", jax.ShapeDtypeStruct((), jax.numpy.int32)) - - self.forward_call(**self.attrs) + + def forward(self, X, Y, W): + self.forward_call(X, Y, W, **self.attrs) if __name__ == "__main__": @@ -42,4 +43,5 @@ def __init__(self, config): instructions=[(0, 0, 0, "uvu", True)] problem = TPProblem(X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False) tensor_product = TensorProduct(problem) + print("COMPLETE!") \ No newline at end of file diff --git a/openequivariance/openequivariance/impl_jax/extlib/__init__.py b/openequivariance/openequivariance/impl_jax/extlib/__init__.py index 94e24454..9dff4696 100644 --- a/openequivariance/openequivariance/impl_jax/extlib/__init__.py +++ b/openequivariance/openequivariance/impl_jax/extlib/__init__.py @@ -5,7 +5,6 @@ def postprocess_kernel(kernel): import openequivariance_extjax as oeq_extjax for name, target in oeq_extjax.registrations().items(): - print(name, target) jax.ffi.register_ffi_target(name, target, platform="CUDA") GPUTimer = oeq_extjax.GPUTimer diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp index e584b45c..b1daa11b 100644 --- a/openequivariance_extjax/src/libjax_tp_jit.cpp +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -28,7 +28,57 @@ namespace nb = nanobind; namespace ffi = xla::ffi; -std::unordered_map>> kernel_cache; +xla::ffi::DataType enum_to_xla_dtype(int64_t i){ + switch(i) { + case 1: + return xla::ffi::DataType::F32; + case 2: + return xla::ffi::DataType::F64; + case 3: + return xla::ffi::DataType::S32; + case 4: + return xla::ffi::DataType::S64; + case 5: + return xla::ffi::DataType::U8; + } + throw logic_error("Unsupported tensor datatype!"); +} + +struct KernelProp { + int64_t L1_dim, L2_dim, L3_dim, weight_numel; + bool shared_weights; + xla::ffi::DataType irrep_dtype; + xla::ffi::DataType weight_dtype; + + int64_t workspace_size; // Convolution only + bool deterministic; + xla::ffi::DataType idx_dtype; + xla::ffi::DataType workspace_dtype; + + KernelProp() {} + + KernelProp(Map_t &kernel_dims, bool is_convolution): + L1_dim(kernel_dims.at("L1_dim")), + L2_dim(kernel_dims.at("L2_dim")), + L3_dim(kernel_dims.at("L3_dim")), + weight_numel(kernel_dims.at("weight_numel")), + shared_weights(kernel_dims.at("shared_weights")), + irrep_dtype(enum_to_xla_dtype(kernel_dims.at("irrep_dtype"))), + weight_dtype(enum_to_xla_dtype(kernel_dims.at("weight_dtype"))), + workspace_dtype(xla::ffi::DataType::U8) { + if(is_convolution) { + workspace_size = kernel_dims.at("workspace_size"); + deterministic = kernel_dims.at("deterministic"); + idx_dtype = enum_to_xla_dtype(kernel_dims.at("idx_dtype")); + } + } +}; + +std::unordered_map>, + KernelProp + > kernel_cache; std::mutex mut; std::vector launch_config_keys = { @@ -53,59 +103,107 @@ std::unordered_map parse_ffi_dict(ffi::Dictionary &dict, const return result; } -JITTPImpl* compile_kernel_with_caching(std::string_view kernel, +std::pair*, KernelProp> + compile_kernel_with_caching(std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, - int64_t hash) { + int64_t hash, + bool is_convolution) { - JITTPImpl* result = nullptr; { const std::lock_guard lock(mut); auto it = kernel_cache.find(hash); - if (it != kernel_cache.end()) { - result = it->second.get(); - } - else { - auto result = parse_ffi_dict(forward_config, launch_config_keys); - - cout << result["smem"] << endl; - /*auto jit_tp_impl = std::make_unique>( + if (it == kernel_cache.end()) { + auto kernel_prop = parse_ffi_dict(kernel_prop, kernel_prop_keys); + auto jit_tp_impl = std::make_unique>( std::string(kernel), parse_ffi_dict(forward_config, launch_config_keys), parse_ffi_dict(backward_config, launch_config_keys), parse_ffi_dict(double_backward_config, launch_config_keys), - parse_ffi_dict(kernel_prop, kernel_prop_keys)); - result = jit_tp_impl.get(); - kernel_cache.insert({hash, std::move(jit_tp_impl)});*/ + kernel_prop); + kernel_cache.insert({hash, + std::make_pair(std::move(jit_tp_impl), + KernelProp(kernel_prop, is_convolution))}); + it = kernel_cache.find(hash); } } - return result; + return {it->second.first.get(), it->second.second}; +} + + +inline void check_tensor(const ffi::AnyBuffer &buffer, + std::initializer_list expected_shape, + xla::ffi::DataType expected_dtype, + std::string tensor_name) { + const ffi::AnyBuffer::Dimensions dims = buffer.dimensions(); + if (dims.size() != expected_shape.size()) { + throw std::logic_error("Rank mismatch for tensor '" + + tensor_name + + "'. Expected rank " + + std::to_string(expected_shape.size()) + + ", got rank " + + std::to_string(dims.size())); + } + + for (size_t i = 0; i < dims.size(); i++) { + if (dims[i] != expected_shape[i]) { + throw std::logic_error("Shape mismatch for tensor '" + + tensor_name + + "'. Expected dimension " + + std::to_string(expected_shape[i]) + + " at index " + + std::to_string(i) + + ", got " + + std::to_string(dims[i])); + } + } + + if (buffer.element_type() != expected_dtype) { + throw std::logic_error("Datatype mismatch."); + } } ffi::Error tp_forward_impl( + ffi::AnyBuffer L1_in, + ffi::AnyBuffer L2_in, + ffi::AnyBuffer W, + ffi::Result L3_out, cudaStream_t stream, std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, - int64_t hash, ffi::ResultBufferR0 out) { - - auto jit_kernel = compile_kernel_with_caching( - kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash); + int64_t hash) { + + auto [jit_kernel, k] = compile_kernel_with_caching( + kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false); + const int64_t num_batch = L1_in.dimensions[0]; + + check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); + check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); + + if (k.shared_weights) + check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); + else + check_tensor(W, {num_batch, k.weight_numel}, k.weight_dtype, "W"); - std::cout << "SUCCESSFULLY COMPILED KERNEL!" << std::endl; // TODO: Launch the forward kernel here + return ffi::Error::Success(); } XLA_FFI_DEFINE_HANDLER_SYMBOL( tp_forward, tp_forward_impl, ffi::Ffi::Bind() - .Ctx>() - .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") - .Attr("hash") - .Ret>(), - {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled + .Arg() + .Arg() + .Arg() + .Arg>() + .Ctx>() + .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("hash") + .Ret>(), + {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled NB_MODULE(openequivariance_extjax, m) { m.def("registrations", []() { From 41e7cd7f2853a0a4d522cde375de9b64f10bb22c Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 26 Nov 2025 21:04:38 -0800 Subject: [PATCH 021/116] Dispatch complete. --- .../impl_jax/TensorProduct.py | 30 ++++++++++++++----- openequivariance_extjax/src/libjax_tp_jit.cpp | 25 ++++++++-------- 2 files changed, 35 insertions(+), 20 deletions(-) diff --git a/openequivariance/openequivariance/impl_jax/TensorProduct.py b/openequivariance/openequivariance/impl_jax/TensorProduct.py index 65bc5919..7d3aa5b5 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProduct.py +++ b/openequivariance/openequivariance/impl_jax/TensorProduct.py @@ -28,20 +28,36 @@ def __init__(self, config): "kernel_prop": self.kernelProp } hash_attributes(self.attrs) - - self.forward_call = jax.ffi.ffi_call( - "tp_forward", - jax.ShapeDtypeStruct((), jax.numpy.int32)) + + self.weight_numel = config.weight_numel + self.L3_dim = self.config.irreps_out.dim def forward(self, X, Y, W): - self.forward_call(X, Y, W, **self.attrs) + forward_call = jax.ffi.ffi_call("tp_forward", + jax.ShapeDtypeStruct((X.shape[0], self.L3_dim), self.config.irrep_dtype)) + return forward_call(X, Y, W, **self.attrs) if __name__ == "__main__": tp_problem = None X_ir, Y_ir, Z_ir = Irreps("1x2e"), Irreps("1x3e"), Irreps("1x2e") instructions=[(0, 0, 0, "uvu", True)] - problem = TPProblem(X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False) + problem = TPProblem(X_ir, Y_ir, Z_ir, + instructions, + shared_weights=False, + internal_weights=False) tensor_product = TensorProduct(problem) - print("COMPLETE!") \ No newline at end of file + batch_size = 1000 + #X = torch.rand(batch_size, X_ir.dim, device='cuda', generator=gen) + #Y = torch.rand(batch_size, Y_ir.dim, device='cuda', generator=gen) + #W = torch.rand(batch_size, tp_e3nn.weight_numel, device='cuda', generator=gen) + + # Convert the above to JAX Arrays + X = jax.random.uniform(jax.random.PRNGKey(0), (batch_size, X_ir.dim), dtype=jax.numpy.float32) + Y = jax.random.uniform(jax.random.PRNGKey(1), (batch_size, Y_ir.dim), dtype=jax.numpy.float32) + W = jax.random.uniform(jax.random.PRNGKey(2), (batch_size, tensor_product.weight_numel), dtype=jax.numpy.float32) + + Z = tensor_product.forward(X, Y, W) + print("COMPLETE!") + print(Z) \ No newline at end of file diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp index b1daa11b..9b775ada 100644 --- a/openequivariance_extjax/src/libjax_tp_jit.cpp +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -57,7 +57,8 @@ struct KernelProp { KernelProp() {} - KernelProp(Map_t &kernel_dims, bool is_convolution): + KernelProp( + std::unordered_map &kernel_dims, bool is_convolution): L1_dim(kernel_dims.at("L1_dim")), L2_dim(kernel_dims.at("L2_dim")), L3_dim(kernel_dims.at("L3_dim")), @@ -78,7 +79,7 @@ std::unordered_map>, KernelProp - > kernel_cache; + >> kernel_cache; std::mutex mut; std::vector launch_config_keys = { @@ -116,20 +117,20 @@ std::pair*, KernelProp> const std::lock_guard lock(mut); auto it = kernel_cache.find(hash); if (it == kernel_cache.end()) { - auto kernel_prop = parse_ffi_dict(kernel_prop, kernel_prop_keys); + auto kernel_prop_map = parse_ffi_dict(kernel_prop, kernel_prop_keys); auto jit_tp_impl = std::make_unique>( std::string(kernel), parse_ffi_dict(forward_config, launch_config_keys), parse_ffi_dict(backward_config, launch_config_keys), parse_ffi_dict(double_backward_config, launch_config_keys), - kernel_prop); + kernel_prop_map); kernel_cache.insert({hash, std::make_pair(std::move(jit_tp_impl), - KernelProp(kernel_prop, is_convolution))}); + KernelProp(kernel_prop_map, is_convolution))}); it = kernel_cache.find(hash); } + return {it->second.first.get(), it->second.second}; } - return {it->second.first.get(), it->second.second}; } @@ -148,11 +149,11 @@ inline void check_tensor(const ffi::AnyBuffer &buffer, } for (size_t i = 0; i < dims.size(); i++) { - if (dims[i] != expected_shape[i]) { + if (dims[i] != expected_shape.begin()[i]) { throw std::logic_error("Shape mismatch for tensor '" + tensor_name + "'. Expected dimension " - + std::to_string(expected_shape[i]) + + std::to_string(expected_shape.begin()[i]) + " at index " + std::to_string(i) + ", got " @@ -176,7 +177,7 @@ ffi::Error tp_forward_impl( auto [jit_kernel, k] = compile_kernel_with_caching( kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false); - const int64_t num_batch = L1_in.dimensions[0]; + const int64_t num_batch = L1_in.dimensions()[0]; check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); @@ -188,7 +189,6 @@ ffi::Error tp_forward_impl( // TODO: Launch the forward kernel here - return ffi::Error::Success(); } @@ -198,11 +198,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Arg() .Arg() .Arg() - .Arg>() + .Ret() .Ctx>() .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") - .Attr("hash") - .Ret>(), + .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled NB_MODULE(openequivariance_extjax, m) { From 5c7a828ac58221261fd7b21c09d3b4005c2982f8 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 26 Nov 2025 22:17:17 -0800 Subject: [PATCH 022/116] Forward call is working. --- openequivariance_extjax/src/libjax_tp_jit.cpp | 27 +++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp index 9b775ada..5a667876 100644 --- a/openequivariance_extjax/src/libjax_tp_jit.cpp +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -42,7 +42,24 @@ xla::ffi::DataType enum_to_xla_dtype(int64_t i){ return xla::ffi::DataType::U8; } throw logic_error("Unsupported tensor datatype!"); -} +} + +inline void* data_ptr(ffi::AnyBuffer &buffer) { + if(buffer.element_type() == xla::ffi::DataType::F32) + return reinterpret_cast(buffer.typed_data()); + else if(buffer.element_type() == xla::ffi::DataType::F64) + return reinterpret_cast(buffer.typed_data()); + else if(buffer.element_type() == xla::ffi::DataType::S64) + return reinterpret_cast(buffer.typed_data()); + else if(buffer.element_type() == xla::ffi::DataType::U8) + return reinterpret_cast(buffer.typed_data()); + else + throw logic_error("Unsupported tensor datatype!"); +} + +inline void* data_ptr(ffi::Result &buffer) { + return data_ptr(*buffer); +} struct KernelProp { int64_t L1_dim, L2_dim, L3_dim, weight_numel; @@ -187,7 +204,13 @@ ffi::Error tp_forward_impl( else check_tensor(W, {num_batch, k.weight_numel}, k.weight_dtype, "W"); - // TODO: Launch the forward kernel here + jit_kernel->exec_tensor_product( + num_batch, + data_ptr(L1_in), + data_ptr(L2_in), + data_ptr(L3_out), + data_ptr(W), + stream); return ffi::Error::Success(); } From 6790bd0047d4f9ab42d09464caa49a5f056d09e9 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 26 Nov 2025 22:52:04 -0800 Subject: [PATCH 023/116] Added the backward pass. --- openequivariance_extjax/src/libjax_tp_jit.cpp | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp index 5a667876..a47f4e5c 100644 --- a/openequivariance_extjax/src/libjax_tp_jit.cpp +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -215,6 +215,51 @@ ffi::Error tp_forward_impl( return ffi::Error::Success(); } +ffi::Error tp_backward_impl( + ffi::AnyBuffer L1_in, + ffi::AnyBuffer L2_in, + ffi::AnyBuffer W, + ffi::AnyBuffer L3_grad, + ffi::Result L1_grad, + ffi::Result L2_grad, + ffi::Result W_grad, + cudaStream_t stream, + std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + int64_t hash) { + + auto [jit_kernel, k] = compile_kernel_with_caching( + kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false); + const int64_t num_batch = L1_in.dimensions()[0]; + check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); + check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); + check_tensor(L3_grad, {num_batch, k.L3_dim}, k.irrep_dtype, "L3_grad"); + + if (k.shared_weights) { + check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); + check_tensor(*W_grad, {k.weight_numel}, k.weight_dtype, "W_grad"); + } + else { + check_tensor(W, {num_batch, k.weight_numel}, k.weight_dtype, "W"); + check_tensor(*W_grad, {num_batch, k.weight_numel}, k.weight_dtype, "W_grad"); + } + + if (k.shared_weights) { + // Need to zero out W_grad + } + + jit_kernel->exec_tensor_product_backward( + num_batch, + data_ptr(L1_in), + data_ptr(L1_grad), + data_ptr(L2_in), + data_ptr(L2_grad), + data_ptr(W), + data_ptr(W_grad), + data_ptr(L3_grad), + stream); + return ffi::Error::Success(); +} + XLA_FFI_DEFINE_HANDLER_SYMBOL( tp_forward, tp_forward_impl, ffi::Ffi::Bind() @@ -227,10 +272,26 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled +XLA_FFI_DEFINE_HANDLER_SYMBOL( + tp_backward, tp_backward_impl, + ffi::Ffi::Bind() + .Arg() + .Arg() + .Arg() + .Arg() + .Ret() + .Ret() + .Ret() + .Ctx>() + .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("hash"), + {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled + NB_MODULE(openequivariance_extjax, m) { m.def("registrations", []() { nb::dict registrations; registrations["tp_forward"] = nb::capsule(reinterpret_cast(tp_forward)); + registrations["tp_backward"] = nb::capsule(reinterpret_cast(tp_backward)); return registrations; }); From c15f4f70574d115a9993fc23660bf203c28d1281 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 26 Nov 2025 23:12:28 -0800 Subject: [PATCH 024/116] Encapsulated the forward call. --- .../openequivariance/impl_jax/TensorProduct.py | 17 +++++++++-------- openequivariance_extjax/src/libjax_tp_jit.cpp | 2 +- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/openequivariance/openequivariance/impl_jax/TensorProduct.py b/openequivariance/openequivariance/impl_jax/TensorProduct.py index 7d3aa5b5..ba9247f5 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProduct.py +++ b/openequivariance/openequivariance/impl_jax/TensorProduct.py @@ -15,6 +15,14 @@ def hash_attributes(attrs): hash = int(m.hexdigest()[:16], 16) >> 1 attrs["hash"] = hash + +def forward(X, Y, W, L3_dim, irrep_dtype, attrs): + forward_call = jax.ffi.ffi_call("tp_forward", + jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype)) + return forward_call(X, Y, W, **attrs) + +#def backward() + class TensorProduct(LoopUnrollTP): def __init__(self, config): dp = extlib.DeviceProp(0) @@ -33,10 +41,7 @@ def __init__(self, config): self.L3_dim = self.config.irreps_out.dim def forward(self, X, Y, W): - forward_call = jax.ffi.ffi_call("tp_forward", - jax.ShapeDtypeStruct((X.shape[0], self.L3_dim), self.config.irrep_dtype)) - return forward_call(X, Y, W, **self.attrs) - + return forward(X, Y, W, self.L3_dim, self.config.irrep_dtype, self.attrs) if __name__ == "__main__": tp_problem = None @@ -47,11 +52,7 @@ def forward(self, X, Y, W): shared_weights=False, internal_weights=False) tensor_product = TensorProduct(problem) - batch_size = 1000 - #X = torch.rand(batch_size, X_ir.dim, device='cuda', generator=gen) - #Y = torch.rand(batch_size, Y_ir.dim, device='cuda', generator=gen) - #W = torch.rand(batch_size, tp_e3nn.weight_numel, device='cuda', generator=gen) # Convert the above to JAX Arrays X = jax.random.uniform(jax.random.PRNGKey(0), (batch_size, X_ir.dim), dtype=jax.numpy.float32) diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp index a47f4e5c..25e4974e 100644 --- a/openequivariance_extjax/src/libjax_tp_jit.cpp +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -247,7 +247,7 @@ ffi::Error tp_backward_impl( // Need to zero out W_grad } - jit_kernel->exec_tensor_product_backward( + jit_kernel->backward( num_batch, data_ptr(L1_in), data_ptr(L1_grad), From bfa52a58b24417904fe03b79ff2a9d4a424afd49 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Thu, 27 Nov 2025 13:55:15 -0800 Subject: [PATCH 025/116] Skeleton of rule implemented. --- .../openequivariance/impl_jax/TensorProduct.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/openequivariance/openequivariance/impl_jax/TensorProduct.py b/openequivariance/openequivariance/impl_jax/TensorProduct.py index ba9247f5..daf3d115 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProduct.py +++ b/openequivariance/openequivariance/impl_jax/TensorProduct.py @@ -1,6 +1,8 @@ import numpy as np import jax + +from functools import partial from openequivariance.impl_jax import extlib import hashlib from openequivariance.core.e3nn_lite import TPProblem, Irreps @@ -15,13 +17,25 @@ def hash_attributes(attrs): hash = int(m.hexdigest()[:16], 16) >> 1 attrs["hash"] = hash - +@partial(jax.custom_vjp, nondiff_argnums=(3,4,5)) def forward(X, Y, W, L3_dim, irrep_dtype, attrs): forward_call = jax.ffi.ffi_call("tp_forward", jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype)) return forward_call(X, Y, W, **attrs) -#def backward() +def forward_with_inputs(X, Y, W, L3_dim, irrep_dtype, attrs): + return forward(X, Y, W, L3_dim, irrep_dtype, attrs), (X, Y, W) + +def backward(attrs, irrep_dtype, L3_dim, inputs, dZ): + backward_call = jax.ffi.ffi_call("tp_backward", + ( + jax.ShapeDtypeStruct(inputs[0].shape, irrep_dtype), + jax.ShapeDtypeStruct(inputs[1].shape, irrep_dtype), + jax.ShapeDtypeStruct(inputs[2].shape, irrep_dtype), + )) + return backward_call(*inputs, dZ, **attrs) + +forward.defvjp(forward_with_inputs, backward) class TensorProduct(LoopUnrollTP): def __init__(self, config): From d1131fa22f12729fe30612aae6f319b11dcb7ced Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 30 Nov 2025 20:26:22 -0800 Subject: [PATCH 026/116] Backward call is working. --- .../openequivariance/impl_jax/TensorProduct.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/openequivariance/openequivariance/impl_jax/TensorProduct.py b/openequivariance/openequivariance/impl_jax/TensorProduct.py index daf3d115..54ce517e 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProduct.py +++ b/openequivariance/openequivariance/impl_jax/TensorProduct.py @@ -7,6 +7,7 @@ import hashlib from openequivariance.core.e3nn_lite import TPProblem, Irreps from openequivariance.core.LoopUnrollTP import LoopUnrollTP +import jax.numpy as jnp def hash_attributes(attrs): m = hashlib.sha256() @@ -26,13 +27,14 @@ def forward(X, Y, W, L3_dim, irrep_dtype, attrs): def forward_with_inputs(X, Y, W, L3_dim, irrep_dtype, attrs): return forward(X, Y, W, L3_dim, irrep_dtype, attrs), (X, Y, W) -def backward(attrs, irrep_dtype, L3_dim, inputs, dZ): +def backward(L3_dim, irrep_dtype, attrs, inputs, dZ): backward_call = jax.ffi.ffi_call("tp_backward", ( jax.ShapeDtypeStruct(inputs[0].shape, irrep_dtype), jax.ShapeDtypeStruct(inputs[1].shape, irrep_dtype), jax.ShapeDtypeStruct(inputs[2].shape, irrep_dtype), )) + return backward_call(*inputs, dZ, **attrs) forward.defvjp(forward_with_inputs, backward) @@ -66,7 +68,7 @@ def forward(self, X, Y, W): shared_weights=False, internal_weights=False) tensor_product = TensorProduct(problem) - batch_size = 1000 + batch_size = 1 # Convert the above to JAX Arrays X = jax.random.uniform(jax.random.PRNGKey(0), (batch_size, X_ir.dim), dtype=jax.numpy.float32) @@ -74,5 +76,11 @@ def forward(self, X, Y, W): W = jax.random.uniform(jax.random.PRNGKey(2), (batch_size, tensor_product.weight_numel), dtype=jax.numpy.float32) Z = tensor_product.forward(X, Y, W) + + # Test via jax vjp + + ctZ = jnp.ones_like(Z) + result = jax.vjp(lambda x, y, w: tensor_product.forward(x, y, w), X, Y, W)[1](ctZ) + + print(result) print("COMPLETE!") - print(Z) \ No newline at end of file From 136c9f6bc81bfb76c7414d4e562732b3cec5718b Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 30 Nov 2025 20:44:29 -0800 Subject: [PATCH 027/116] Zero'd buffer. --- openequivariance_extjax/src/libjax_tp_jit.cpp | 56 ++++++++++++++----- 1 file changed, 41 insertions(+), 15 deletions(-) diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp index 25e4974e..2199c366 100644 --- a/openequivariance_extjax/src/libjax_tp_jit.cpp +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -11,6 +11,9 @@ #include "nanobind/nanobind.h" #include "xla/ffi/api/ffi.h" +namespace nb = nanobind; +namespace ffi = xla::ffi; + #define CUDA_BACKEND // Stick to CUDA for now #ifdef CUDA_BACKEND @@ -20,14 +23,11 @@ using GPU_Allocator = CUDA_Allocator; template - using GroupMM = GroupMMCUDA; + using GroupMM = GroupMMCUDA; #endif #include "tensorproducts.hpp" -namespace nb = nanobind; -namespace ffi = xla::ffi; - xla::ffi::DataType enum_to_xla_dtype(int64_t i){ switch(i) { case 1: @@ -45,22 +45,48 @@ xla::ffi::DataType enum_to_xla_dtype(int64_t i){ } inline void* data_ptr(ffi::AnyBuffer &buffer) { - if(buffer.element_type() == xla::ffi::DataType::F32) - return reinterpret_cast(buffer.typed_data()); - else if(buffer.element_type() == xla::ffi::DataType::F64) - return reinterpret_cast(buffer.typed_data()); - else if(buffer.element_type() == xla::ffi::DataType::S64) - return reinterpret_cast(buffer.typed_data()); - else if(buffer.element_type() == xla::ffi::DataType::U8) - return reinterpret_cast(buffer.typed_data()); - else - throw logic_error("Unsupported tensor datatype!"); + switch (buffer.element_type()) { + case xla::ffi::DataType::F32: + return reinterpret_cast(buffer.typed_data()); + case xla::ffi::DataType::F64: + return reinterpret_cast(buffer.typed_data()); + case xla::ffi::DataType::S64: + return reinterpret_cast(buffer.typed_data()); + case xla::ffi::DataType::U8: + return reinterpret_cast(buffer.typed_data()); + default: + throw logic_error("Unsupported tensor datatype!"); + } +} + +inline int byte_count(ffi::AnyBuffer &buffer) { + switch (buffer.element_type()) { + case xla::ffi::DataType::F32: + return 4; + case xla::ffi::DataType::F64: + return 8; + case xla::ffi::DataType::S64: + return 8; + case xla::ffi::DataType::U8: + return 1; + default: + throw logic_error("Unsupported tensor datatype!"); + } } inline void* data_ptr(ffi::Result &buffer) { return data_ptr(*buffer); } +#ifdef CUDA_BACKEND +void zero_buffer(ffi::AnyBuffer &buffer) { + cudaMemset( + data_ptr(buffer), + 0, + buffer.element_count() * byte_count(buffer)); +} +#endif + struct KernelProp { int64_t L1_dim, L2_dim, L3_dim, weight_numel; bool shared_weights; @@ -244,7 +270,7 @@ ffi::Error tp_backward_impl( } if (k.shared_weights) { - // Need to zero out W_grad + zero_buffer(*W_grad); } jit_kernel->backward( From 2dadb0fa65c3f90a3553a12ff12aa90f525159a1 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 30 Nov 2025 20:54:02 -0800 Subject: [PATCH 028/116] Wrapped the double-backward pass. --- openequivariance_extjax/src/libjax_tp_jit.cpp | 77 ++++++++++++++++++- 1 file changed, 76 insertions(+), 1 deletion(-) diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp index 2199c366..c24fc4de 100644 --- a/openequivariance_extjax/src/libjax_tp_jit.cpp +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -286,6 +286,61 @@ ffi::Error tp_backward_impl( return ffi::Error::Success(); } + +ffi::Error tp_double_backward_impl( + ffi::AnyBuffer L1_in, + ffi::AnyBuffer L2_in, + ffi::AnyBuffer W, + ffi::AnyBuffer L3_grad, + ffi::AnyBuffer L1_dgrad, + ffi::AnyBuffer L2_dgrad, + ffi::AnyBuffer W_dgrad, + ffi::Result L1_grad, + ffi::Result L2_grad, + ffi::Result W_grad, + ffi::Result L3_dgrad, + cudaStream_t stream, + std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + int64_t hash) { + + auto [jit_kernel, k] = compile_kernel_with_caching( + kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false); + const int64_t num_batch = L1_in.dimensions()[0]; + check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); + check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); + check_tensor(L3_grad, {num_batch, k.L3_dim}, k.irrep_dtype, "L3_grad"); + check_tensor(L1_dgrad, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_dgrad"); + check_tensor(L2_dgrad, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_dgrad"); + + if (k.shared_weights){ + check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); + check_tensor(W_dgrad, {k.weight_numel}, k.weight_dtype, "W_dgrad"); + } else { + check_tensor(W, {num_batch, k.weight_numel}, k.weight_dtype, "W"); + check_tensor(W_dgrad, {num_batch, k.weight_numel}, k.weight_dtype, "W_dgrad"); + } + + if (k.shared_weights) { + zero_buffer(*W_grad); + } + + jit_kernel->double_backward( + num_batch, + data_ptr(L1_in), + data_ptr(L2_in), + data_ptr(W), + data_ptr(L3_grad), + data_ptr(L1_dgrad), + data_ptr(L2_dgrad), + data_ptr(W_dgrad), + data_ptr(L1_grad), + data_ptr(L2_grad), + data_ptr(W_grad), + data_ptr(L3_dgrad), + stream); + return ffi::Error::Success(); +} + XLA_FFI_DEFINE_HANDLER_SYMBOL( tp_forward, tp_forward_impl, ffi::Ffi::Bind() @@ -311,13 +366,33 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Ctx>() .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") .Attr("hash"), - {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled + {xla::ffi::Traits::kCmdBufferCompatible}); + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + tp_double_backward, tp_double_backward_impl, + ffi::Ffi::Bind() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Ret() + .Ret() + .Ret() + .Ret() + .Ctx>() + .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("hash"), + {xla::ffi::Traits::kCmdBufferCompatible}); NB_MODULE(openequivariance_extjax, m) { m.def("registrations", []() { nb::dict registrations; registrations["tp_forward"] = nb::capsule(reinterpret_cast(tp_forward)); registrations["tp_backward"] = nb::capsule(reinterpret_cast(tp_backward)); + registrations["tp_double_backward"] = nb::capsule(reinterpret_cast(tp_double_backward)); return registrations; }); From e78f70552b3f0a44713e21ffc221b7ddced87c19 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 30 Nov 2025 21:20:17 -0800 Subject: [PATCH 029/116] Added the forward convolution implementation. --- .../extension/convolution.hpp | 2 - openequivariance_extjax/src/libjax_tp_jit.cpp | 120 ++++++++++++++++-- 2 files changed, 111 insertions(+), 11 deletions(-) diff --git a/openequivariance/openequivariance/extension/convolution.hpp b/openequivariance/openequivariance/extension/convolution.hpp index 75b4f879..3b2ce1e6 100644 --- a/openequivariance/openequivariance/extension/convolution.hpp +++ b/openequivariance/openequivariance/extension/convolution.hpp @@ -2,8 +2,6 @@ #include #include -#include -#include #include struct ConvData { diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp index c24fc4de..856ea5cd 100644 --- a/openequivariance_extjax/src/libjax_tp_jit.cpp +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -27,6 +27,7 @@ namespace ffi = xla::ffi; #endif #include "tensorproducts.hpp" +#include "convolution.hpp" xla::ffi::DataType enum_to_xla_dtype(int64_t i){ switch(i) { @@ -122,7 +123,13 @@ std::unordered_map>, KernelProp - >> kernel_cache; + >> tp_cache; + +std::unordered_map>, + KernelProp + >> conv_cache; std::mutex mut; std::vector launch_config_keys = { @@ -148,7 +155,7 @@ std::unordered_map parse_ffi_dict(ffi::Dictionary &dict, const } std::pair*, KernelProp> - compile_kernel_with_caching(std::string_view kernel, + compile_tp_with_caching(std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, @@ -158,8 +165,8 @@ std::pair*, KernelProp> { const std::lock_guard lock(mut); - auto it = kernel_cache.find(hash); - if (it == kernel_cache.end()) { + auto it = tp_cache.find(hash); + if (it == tp_cache.end()) { auto kernel_prop_map = parse_ffi_dict(kernel_prop, kernel_prop_keys); auto jit_tp_impl = std::make_unique>( std::string(kernel), @@ -167,15 +174,43 @@ std::pair*, KernelProp> parse_ffi_dict(backward_config, launch_config_keys), parse_ffi_dict(double_backward_config, launch_config_keys), kernel_prop_map); - kernel_cache.insert({hash, + tp_cache.insert({hash, std::make_pair(std::move(jit_tp_impl), KernelProp(kernel_prop_map, is_convolution))}); - it = kernel_cache.find(hash); + it = tp_cache.find(hash); } return {it->second.first.get(), it->second.second}; } } +std::pair*, KernelProp> + compile_conv_with_caching(std::string_view kernel, + ffi::Dictionary forward_config, + ffi::Dictionary backward_config, + ffi::Dictionary double_backward_config, + ffi::Dictionary kernel_prop, + int64_t hash, + bool is_convolution) { + + { + const std::lock_guard lock(mut); + auto it = conv_cache.find(hash); + if (it == conv_cache.end()) { + auto kernel_prop_map = parse_ffi_dict(kernel_prop, kernel_prop_keys); + auto jit_conv_impl = std::make_unique>( + std::string(kernel), + parse_ffi_dict(forward_config, launch_config_keys), + parse_ffi_dict(backward_config, launch_config_keys), + parse_ffi_dict(double_backward_config, launch_config_keys), + kernel_prop_map); + conv_cache.insert({hash, + std::make_pair(std::move(jit_conv_impl), + KernelProp(kernel_prop_map, is_convolution))}); + it = conv_cache.find(hash); + } + return {it->second.first.get(), it->second.second}; + } +} inline void check_tensor(const ffi::AnyBuffer &buffer, std::initializer_list expected_shape, @@ -209,6 +244,7 @@ inline void check_tensor(const ffi::AnyBuffer &buffer, } } +// --------------------- Tensor Products -------------------------- ffi::Error tp_forward_impl( ffi::AnyBuffer L1_in, ffi::AnyBuffer L2_in, @@ -218,7 +254,7 @@ ffi::Error tp_forward_impl( std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, int64_t hash) { - auto [jit_kernel, k] = compile_kernel_with_caching( + auto [jit_kernel, k] = compile_tp_with_caching( kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false); const int64_t num_batch = L1_in.dimensions()[0]; @@ -253,7 +289,7 @@ ffi::Error tp_backward_impl( std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, int64_t hash) { - auto [jit_kernel, k] = compile_kernel_with_caching( + auto [jit_kernel, k] = compile_tp_with_caching( kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false); const int64_t num_batch = L1_in.dimensions()[0]; check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); @@ -303,7 +339,7 @@ ffi::Error tp_double_backward_impl( std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, int64_t hash) { - auto [jit_kernel, k] = compile_kernel_with_caching( + auto [jit_kernel, k] = compile_tp_with_caching( kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, false); const int64_t num_batch = L1_in.dimensions()[0]; check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); @@ -387,12 +423,78 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); +// --------------------- Convolution -------------------------- +ffi::Error conv_forward_impl( + ffi::AnyBuffer L1_in, + ffi::AnyBuffer L2_in, + ffi::AnyBuffer W, + ffi::AnyBuffer rows, + ffi::AnyBuffer cols, + ffi::AnyBuffer workspace, + ffi::AnyBuffer transpose_perm, + ffi::Result L3_out, + cudaStream_t stream, + std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + int64_t hash) { + + auto [jit_kernel, k] = compile_conv_with_caching( + kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true); + const int64_t nnz = rows.dimensions()[0]; + const int64_t node_count = L1_in.dimensions()[0]; + + check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in"); + check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in"); + check_tensor(workspace, {k.workspace_size}, k.workspace_dtype, "workspace"); + check_tensor(rows, {nnz}, k.idx_dtype, "rows"); + check_tensor(cols, {nnz}, k.idx_dtype, "cols"); + + if (k.deterministic){ + check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm"); + } + + if (k.shared_weights) + check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); + else + check_tensor(W, {nnz, k.weight_numel}, k.weight_dtype, "W"); + + jit_kernel->exec_conv( + data_ptr(L1_in), + data_ptr(L2_in), + data_ptr(W), + data_ptr(L3_out), + data_ptr(rows), + data_ptr(cols), + nnz, node_count, + data_ptr(workspace), + stream); + + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + conv_forward, conv_forward_impl, + ffi::Ffi::Bind() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Ret() + .Ctx>() + .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("hash"), + {xla::ffi::Traits::kCmdBufferCompatible}); + NB_MODULE(openequivariance_extjax, m) { m.def("registrations", []() { nb::dict registrations; registrations["tp_forward"] = nb::capsule(reinterpret_cast(tp_forward)); registrations["tp_backward"] = nb::capsule(reinterpret_cast(tp_backward)); registrations["tp_double_backward"] = nb::capsule(reinterpret_cast(tp_double_backward)); + + registrations["conv_forward"] = nb::capsule(reinterpret_cast(conv_forward)); return registrations; }); From 8784dd4c8ff5e8d4a87e4b5fd39dc6fa48ae388f Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 30 Nov 2025 21:34:01 -0800 Subject: [PATCH 030/116] Backward convolution implemented. --- openequivariance_extjax/src/libjax_tp_jit.cpp | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp index 856ea5cd..2e6f03d8 100644 --- a/openequivariance_extjax/src/libjax_tp_jit.cpp +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -471,6 +471,64 @@ ffi::Error conv_forward_impl( return ffi::Error::Success(); } +ffi::Error conv_backward_impl( + ffi::AnyBuffer L1_in, + ffi::AnyBuffer L2_in, + ffi::AnyBuffer W, + ffi::AnyBuffer L3_grad, + ffi::Result L1_grad, + ffi::Result L2_grad, + ffi::Result W_grad, + ffi::AnyBuffer rows, + ffi::AnyBuffer cols, + ffi::AnyBuffer workspace, + ffi::AnyBuffer transpose_perm, + cudaStream_t stream, + std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + int64_t hash) { + + auto [jit_kernel, k] = compile_conv_with_caching( + kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true); + const int64_t nnz = rows.dimensions()[0]; + const int64_t node_count = L1_in.dimensions()[0]; + check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in"); + check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in"); + check_tensor(L3_grad, {node_count, k.L3_dim}, k.irrep_dtype, "L3_grad"); + check_tensor(workspace, {k.workspace_size}, k.workspace_dtype, "workspace"); + check_tensor(rows, {nnz}, k.idx_dtype, "rows"); + check_tensor(cols, {nnz}, k.idx_dtype, "cols"); + + if (k.deterministic) + check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm"); + + if (k.shared_weights) { + check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); + check_tensor(*W_grad, {k.weight_numel}, k.weight_dtype, "W_grad"); + } + else { + check_tensor(W, {nnz, k.weight_numel}, k.weight_dtype, "W"); + check_tensor(*W_grad, {nnz, k.weight_numel}, k.weight_dtype, "W_grad"); + } + if(k.shared_weights) + zero_buffer(*W_grad); + + jit_kernel->backward( + data_ptr(L1_in), + data_ptr(L1_grad), + data_ptr(L2_in), + data_ptr(L2_grad), + data_ptr(W), + data_ptr(W_grad), + data_ptr(L3_grad), + data_ptr(rows), + data_ptr(cols), + nnz, node_count, + data_ptr(workspace), + data_ptr(transpose_perm), + stream); + return ffi::Error::Success(); +} + XLA_FFI_DEFINE_HANDLER_SYMBOL( conv_forward, conv_forward_impl, ffi::Ffi::Bind() @@ -487,6 +545,25 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); +XLA_FFI_DEFINE_HANDLER_SYMBOL( + conv_backward, conv_backward_impl, + ffi::Ffi::Bind() + .Arg() + .Arg() + .Arg() + .Arg() + .Ret() + .Ret() + .Ret() + .Arg() + .Arg() + .Arg() + .Arg() + .Ctx>() + .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("hash"), + {xla::ffi::Traits::kCmdBufferCompatible}); + NB_MODULE(openequivariance_extjax, m) { m.def("registrations", []() { nb::dict registrations; From ac9b3db6b822ac936c87d4ad1b8f63c2dc7533d7 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 30 Nov 2025 21:46:56 -0800 Subject: [PATCH 031/116] Convolution double backward registered. --- openequivariance_extjax/src/libjax_tp_jit.cpp | 92 +++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp index 2e6f03d8..46c717fc 100644 --- a/openequivariance_extjax/src/libjax_tp_jit.cpp +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -529,6 +529,72 @@ ffi::Error conv_backward_impl( return ffi::Error::Success(); } +ffi::Error conv_double_backward_impl( + ffi::AnyBuffer L1_in, + ffi::AnyBuffer L2_in, + ffi::AnyBuffer W, + ffi::AnyBuffer L3_grad, + ffi::AnyBuffer L1_dgrad, + ffi::AnyBuffer L2_dgrad, + ffi::AnyBuffer W_dgrad, + ffi::Result L1_grad, + ffi::Result L2_grad, + ffi::Result W_grad, + ffi::Result L3_dgrad, + ffi::AnyBuffer rows, + ffi::AnyBuffer cols, + ffi::AnyBuffer workspace, + ffi::AnyBuffer transpose_perm, + cudaStream_t stream, + std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, + int64_t hash) { + + auto [jit_kernel, k] = compile_conv_with_caching( + kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true); + const int64_t nnz = rows.dimensions()[0]; + const int64_t node_count = L1_in.dimensions()[0]; + check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in"); + check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in"); + check_tensor(L3_grad, {node_count, k.L3_dim}, k.irrep_dtype, "L3_grad"); + check_tensor(L1_dgrad, {node_count, k.L1_dim}, k.irrep_dtype, "L1_dgrad"); + check_tensor(L2_dgrad, {nnz, k.L2_dim}, k.irrep_dtype, "L2_dgrad"); + check_tensor(workspace, {k.workspace_size}, k.workspace_dtype, "workspace"); + check_tensor(rows, {nnz}, k.idx_dtype, "rows"); + check_tensor(cols, {nnz}, k.idx_dtype, "cols"); + + if (k.deterministic) + check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm"); + + if (k.shared_weights) { + check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); + check_tensor(W_dgrad, {k.weight_numel}, k.weight_dtype, "W_dgrad"); + } else { + check_tensor(W, {nnz, k.weight_numel}, k.weight_dtype, "W"); + check_tensor(W_dgrad, {nnz, k.weight_numel}, k.weight_dtype, "W_dgrad"); + } + if(k.shared_weights) + zero_buffer(*W_grad); + jit_kernel->double_backward( + data_ptr(L1_in), + data_ptr(L2_in), + data_ptr(W), + data_ptr(L3_grad), + data_ptr(L1_dgrad), + data_ptr(L2_dgrad), + data_ptr(W_dgrad), + data_ptr(L1_grad), + data_ptr(L2_grad), + data_ptr(W_grad), + data_ptr(L3_dgrad), + data_ptr(rows), + data_ptr(cols), + nnz, node_count, + data_ptr(workspace), + data_ptr(transpose_perm), + stream); + return ffi::Error::Success(); +} + XLA_FFI_DEFINE_HANDLER_SYMBOL( conv_forward, conv_forward_impl, ffi::Ffi::Bind() @@ -564,6 +630,30 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); +XLA_FFI_DEFINE_HANDLER_SYMBOL( + conv_double_backward, conv_double_backward_impl, + ffi::Ffi::Bind() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Arg() + .Ret() + .Ret() + .Ret() + .Ret() + .Arg() + .Arg() + .Arg() + .Arg() + .Ctx>() + .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") + .Attr("hash"), + {xla::ffi::Traits::kCmdBufferCompatible}); + +// --------------------- NB Module -------------------------- NB_MODULE(openequivariance_extjax, m) { m.def("registrations", []() { nb::dict registrations; @@ -572,6 +662,8 @@ NB_MODULE(openequivariance_extjax, m) { registrations["tp_double_backward"] = nb::capsule(reinterpret_cast(tp_double_backward)); registrations["conv_forward"] = nb::capsule(reinterpret_cast(conv_forward)); + registrations["conv_backward"] = nb::capsule(reinterpret_cast(conv_backward)); + registrations["conv_double_backward"] = nb::capsule(reinterpret_cast(conv_double_backward)); return registrations; }); From 63ed1c049eaca6493ee6d1bc7fd67073edb9db6d Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 30 Nov 2025 22:19:13 -0800 Subject: [PATCH 032/116] Finished the double backward VJP registration. --- .../impl_jax/TensorProduct.py | 27 ++++++++++++++++--- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/openequivariance/openequivariance/impl_jax/TensorProduct.py b/openequivariance/openequivariance/impl_jax/TensorProduct.py index 54ce517e..e4d8f679 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProduct.py +++ b/openequivariance/openequivariance/impl_jax/TensorProduct.py @@ -25,19 +25,38 @@ def forward(X, Y, W, L3_dim, irrep_dtype, attrs): return forward_call(X, Y, W, **attrs) def forward_with_inputs(X, Y, W, L3_dim, irrep_dtype, attrs): - return forward(X, Y, W, L3_dim, irrep_dtype, attrs), (X, Y, W) + return forward(X, Y, W, L3_dim, irrep_dtype, attrs), (X, Y, W) -def backward(L3_dim, irrep_dtype, attrs, inputs, dZ): +@partial(jax.custom_vjp, nondiff_argnums=(4,5)) +def backward(X, Y, W, dZ, irrep_dtype, attrs): backward_call = jax.ffi.ffi_call("tp_backward", + ( + jax.ShapeDtypeStruct(X.shape, irrep_dtype), + jax.ShapeDtypeStruct(Y.shape, irrep_dtype), + jax.ShapeDtypeStruct(W.shape, irrep_dtype), + )) + + return backward_call(X, Y, W, dZ, **attrs) + +def backward_with_inputs(X, Y, W, dZ, irrep_dtype, attrs): + return backward(X, Y, W, dZ, irrep_dtype, attrs), (X, Y, W, dZ) + +def double_backward(irrep_dtype, attrs, inputs, ddX, ddY, ddW): + double_backward_call = jax.ffi.ffi_call("tp_double_backward", ( jax.ShapeDtypeStruct(inputs[0].shape, irrep_dtype), jax.ShapeDtypeStruct(inputs[1].shape, irrep_dtype), jax.ShapeDtypeStruct(inputs[2].shape, irrep_dtype), + jax.ShapeDtypeStruct(inputs[3].shape, irrep_dtype), )) - return backward_call(*inputs, dZ, **attrs) + return double_backward_call(*inputs, ddX, ddY, ddW, **attrs) + +def backward_autograd(L3_dim, irrep_dtype, attrs, inputs, dZ): + return backward(inputs[0], inputs[1], inputs[2], dZ, irrep_dtype, attrs) -forward.defvjp(forward_with_inputs, backward) +forward.defvjp(forward_with_inputs, backward_autograd) +backward.defvjp(backward_with_inputs, backward_autograd) class TensorProduct(LoopUnrollTP): def __init__(self, config): From 673b5ee8909888a481454fa305048aa21fcd9401 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 30 Nov 2025 23:22:17 -0800 Subject: [PATCH 033/116] Double backward pass seems to work. --- .../impl_jax/TensorProduct.py | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/openequivariance/openequivariance/impl_jax/TensorProduct.py b/openequivariance/openequivariance/impl_jax/TensorProduct.py index e4d8f679..9c9ec1e8 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProduct.py +++ b/openequivariance/openequivariance/impl_jax/TensorProduct.py @@ -41,7 +41,7 @@ def backward(X, Y, W, dZ, irrep_dtype, attrs): def backward_with_inputs(X, Y, W, dZ, irrep_dtype, attrs): return backward(X, Y, W, dZ, irrep_dtype, attrs), (X, Y, W, dZ) -def double_backward(irrep_dtype, attrs, inputs, ddX, ddY, ddW): +def double_backward(irrep_dtype, attrs, inputs, derivatives): double_backward_call = jax.ffi.ffi_call("tp_double_backward", ( jax.ShapeDtypeStruct(inputs[0].shape, irrep_dtype), @@ -49,14 +49,13 @@ def double_backward(irrep_dtype, attrs, inputs, ddX, ddY, ddW): jax.ShapeDtypeStruct(inputs[2].shape, irrep_dtype), jax.ShapeDtypeStruct(inputs[3].shape, irrep_dtype), )) - - return double_backward_call(*inputs, ddX, ddY, ddW, **attrs) + return double_backward_call(*inputs, *derivatives, **attrs) def backward_autograd(L3_dim, irrep_dtype, attrs, inputs, dZ): return backward(inputs[0], inputs[1], inputs[2], dZ, irrep_dtype, attrs) forward.defvjp(forward_with_inputs, backward_autograd) -backward.defvjp(backward_with_inputs, backward_autograd) +backward.defvjp(backward_with_inputs, double_backward) class TensorProduct(LoopUnrollTP): def __init__(self, config): @@ -89,17 +88,26 @@ def forward(self, X, Y, W): tensor_product = TensorProduct(problem) batch_size = 1 - # Convert the above to JAX Arrays X = jax.random.uniform(jax.random.PRNGKey(0), (batch_size, X_ir.dim), dtype=jax.numpy.float32) Y = jax.random.uniform(jax.random.PRNGKey(1), (batch_size, Y_ir.dim), dtype=jax.numpy.float32) W = jax.random.uniform(jax.random.PRNGKey(2), (batch_size, tensor_product.weight_numel), dtype=jax.numpy.float32) - Z = tensor_product.forward(X, Y, W) - # Test via jax vjp - + # Test forward jax vjp ctZ = jnp.ones_like(Z) result = jax.vjp(lambda x, y, w: tensor_product.forward(x, y, w), X, Y, W)[1](ctZ) print(result) - print("COMPLETE!") + print("COMPLETED FORWARD PASS!") + + # Test the double backward pass + ddX = jnp.ones_like(X) + ddY = jnp.ones_like(Y) + ddW = jnp.ones_like(W) + result_double_backward = jax.vjp( + lambda x, y, w: jax.vjp(lambda a, b, c: tensor_product.forward(a, b, c), x, y, w)[1](ctZ), + X, Y, W + )[1]((ddX, ddY, ddW)) + + print(result_double_backward) + print("COMPLETED DOUBLE BACKWARD PASS!") \ No newline at end of file From 7b1ce90bf012e791516fc5780ebee43f9d1c00a9 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 1 Dec 2025 00:05:40 -0800 Subject: [PATCH 034/116] Did some extra testing. --- .../impl_jax/TensorProduct.py | 39 ++++++++++++++----- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/openequivariance/openequivariance/impl_jax/TensorProduct.py b/openequivariance/openequivariance/impl_jax/TensorProduct.py index 9c9ec1e8..1265fc8a 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProduct.py +++ b/openequivariance/openequivariance/impl_jax/TensorProduct.py @@ -77,6 +77,12 @@ def __init__(self, config): def forward(self, X, Y, W): return forward(X, Y, W, self.L3_dim, self.config.irrep_dtype, self.attrs) + +def jax_to_torch(x): + import numpy as np + import torch + return torch.tensor(np.asarray(x), requires_grad=True) + if __name__ == "__main__": tp_problem = None X_ir, Y_ir, Z_ir = Irreps("1x2e"), Irreps("1x3e"), Irreps("1x2e") @@ -86,7 +92,7 @@ def forward(self, X, Y, W): shared_weights=False, internal_weights=False) tensor_product = TensorProduct(problem) - batch_size = 1 + batch_size = 100 X = jax.random.uniform(jax.random.PRNGKey(0), (batch_size, X_ir.dim), dtype=jax.numpy.float32) Y = jax.random.uniform(jax.random.PRNGKey(1), (batch_size, Y_ir.dim), dtype=jax.numpy.float32) @@ -94,20 +100,35 @@ def forward(self, X, Y, W): Z = tensor_product.forward(X, Y, W) # Test forward jax vjp - ctZ = jnp.ones_like(Z) + ctZ = jax.random.uniform(jax.random.PRNGKey(3), Z.shape, dtype=jax.numpy.float32) result = jax.vjp(lambda x, y, w: tensor_product.forward(x, y, w), X, Y, W)[1](ctZ) - print(result) print("COMPLETED FORWARD PASS!") - # Test the double backward pass - ddX = jnp.ones_like(X) - ddY = jnp.ones_like(Y) - ddW = jnp.ones_like(W) + ddX = jax.random.uniform(jax.random.PRNGKey(4), X.shape, dtype=jax.numpy.float32) + ddY = jax.random.uniform(jax.random.PRNGKey(5), Y.shape, dtype=jax.numpy.float32) + ddW = jax.random.uniform(jax.random.PRNGKey(6), W.shape, dtype=jax.numpy.float32) + result_double_backward = jax.vjp( lambda x, y, w: jax.vjp(lambda a, b, c: tensor_product.forward(a, b, c), x, y, w)[1](ctZ), X, Y, W )[1]((ddX, ddY, ddW)) - print(result_double_backward) - print("COMPLETED DOUBLE BACKWARD PASS!") \ No newline at end of file + print("COMPLETED DOUBLE BACKWARD PASS!") + + from e3nn import o3 + e3nn_tp = o3.TensorProduct(X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False) + print(jax_to_torch(W).shape) + + X_t = jax_to_torch(X) + Y_t = jax_to_torch(Y) + W_t = jax_to_torch(W) + Z_t = jax_to_torch(Z) + Z_e3nn = e3nn_tp(X_t, Y_t, W_t) + print("E3NN RESULT:", (Z_e3nn - Z_t).norm()) + + Z_e3nn.backward(jax_to_torch(ctZ)) + #^^^ Print the norms of the differences in gradients instead + print("E3NN GRADS NORM:", (jax_to_torch(result[0]) - X_t.grad).norm(), + (jax_to_torch(result[1]) - Y_t.grad).norm(), + (jax_to_torch(result[2]) - W_t.grad).norm()) \ No newline at end of file From 38017b03cc42dfa36a17b06f15055cea05180cc6 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 1 Dec 2025 00:30:58 -0800 Subject: [PATCH 035/116] Reorg of LoopUnrollConv.py --- .../openequivariance/core/ConvolutionBase.py | 12 - .../openequivariance/core/LoopUnrollConv.py | 282 ++---------------- .../impl_torch/TensorProductConv.py | 258 +++++++++++++++- 3 files changed, 271 insertions(+), 281 deletions(-) diff --git a/openequivariance/openequivariance/core/ConvolutionBase.py b/openequivariance/openequivariance/core/ConvolutionBase.py index 35c82307..bde3a91e 100644 --- a/openequivariance/openequivariance/core/ConvolutionBase.py +++ b/openequivariance/openequivariance/core/ConvolutionBase.py @@ -1,6 +1,5 @@ import copy import numpy as np -from openequivariance.impl_torch.extlib import DeviceBuffer from openequivariance.benchmark.random_buffer_utils import ( get_random_buffers_forward_conv, get_random_buffers_backward_conv, @@ -130,17 +129,6 @@ def reorder_weights_to_e3nn(self, weights, has_batch_dim=True): """ return weights - def allocate_workspace(self, size_bytes): - self.workspace_size = size_bytes - if self.torch_op: - self.workspace_buffer = torch.zeros( - size_bytes, dtype=torch.uint8, device="cuda" - ) - else: - self.workspace_buffer = DeviceBuffer(size_bytes) - self.workspace_ptr = self.workspace_buffer.data_ptr() - logger.info(f"Convolution requires {size_bytes // 1000000}MB of workspace.") - @staticmethod def name(): raise NotImplementedError() diff --git a/openequivariance/openequivariance/core/LoopUnrollConv.py b/openequivariance/openequivariance/core/LoopUnrollConv.py index 104230fc..f5aedeb2 100644 --- a/openequivariance/openequivariance/core/LoopUnrollConv.py +++ b/openequivariance/openequivariance/core/LoopUnrollConv.py @@ -6,24 +6,15 @@ SMEMCapacityException, ) -from openequivariance.core.dtype_enum import ( - dtype_to_enum, - enum_to_torch_dtype, -) +from openequivariance.core.dtype_enum import dtype_to_enum from openequivariance.templates.jinja_utils import get_jinja_environment -import openequivariance.impl_torch.extlib as extlib -from openequivariance.impl_torch.extlib import JITConvImpl, postprocess_kernel, DeviceProp - from openequivariance.core.utils import filter_and_analyze_problem -from openequivariance.benchmark.logging_utils import getLogger - -logger = getLogger() - class LoopUnrollConv(ConvolutionBase): def __init__( self, config, + dp, postprocess_kernel, *, idx_dtype: type[np.generic] = np.int64, torch_op: bool = False, @@ -39,7 +30,6 @@ def __init__( env = get_jinja_environment() template = env.get_template("loop_unroll_conv_atomic.cuh") - dp = DeviceProp(0) analysis = filter_and_analyze_problem(config) self.is_uvw = analysis["is_uvw"] @@ -141,10 +131,10 @@ def generate_double_backward_schedule(warps_per_block): self.backward_workspace_offset = None self.double_backwardB_offset = None - workspace_size = 1 + self.workspace_size = 1 if deterministic: destination_index_bytes = 32 # Add extra to account for padding - workspace_size = max( + self.workspace_size = max( ( self.forward_schedule.L3.dim * np.dtype(config.irrep_dtype).itemsize + destination_index_bytes @@ -186,7 +176,19 @@ def generate_double_backward_schedule(warps_per_block): ) self.double_backwardB_offset = (self.double_backwardB_offset + 7) // 8 * 8 - self.allocate_workspace(workspace_size) + self.kernel_prop = { + "L1_dim": self.L1.dim, + "L2_dim": self.L2.dim, + "L3_dim": self.L3.dim, + "weight_numel": self.config.weight_numel, + "workspace_size": self.workspace_size, + "opt_level": 3, + "shared_weights": int(config.shared_weights), + "deterministic": int(self.deterministic), + "irrep_dtype": dtype_to_enum[self.config.irrep_dtype], + "weight_dtype": dtype_to_enum[self.config.weight_dtype], + "idx_dtype": dtype_to_enum[self.idx_dtype], + } self.jit_kernel = template.render( forward_schedule=self.forward_schedule, @@ -199,255 +201,5 @@ def generate_double_backward_schedule(warps_per_block): ) self.jit_kernel = postprocess_kernel(self.jit_kernel) - if self.torch_op and extlib.TORCH_COMPILE: - global torch - import torch - - internal_cls = torch.classes.libtorch_tp_jit.TorchJITConv - else: - internal_cls = JITConvImpl - - logger.info("Starting kernel compiler...") - self.internal = internal_cls( - self.jit_kernel, - vars(self.forward_schedule.launch_config), - vars(self.backward_schedule.launch_config), - vars(self.double_backward_schedule.launch_config), - { - "L1_dim": self.L1.dim, - "L2_dim": self.L2.dim, - "L3_dim": self.L3.dim, - "weight_numel": self.config.weight_numel, - "workspace_size": self.workspace_size, - "opt_level": 3, - "shared_weights": int(config.shared_weights), - "deterministic": int(self.deterministic), - "irrep_dtype": dtype_to_enum[self.config.irrep_dtype], - "weight_dtype": dtype_to_enum[self.config.weight_dtype], - "idx_dtype": dtype_to_enum[self.idx_dtype], - }, - ) - logger.info("Kernel compiled!") - # with open("scratch.txt", "w") as f: # f.write(self.jit_kernel) - - def reorder_weights_from_e3nn(self, weights, has_batch_dim=True): - return self.forward_schedule.reorder_weights_from_e3nn(weights, has_batch_dim) - - def reorder_weights_to_e3nn(self, weights, has_batch_dim=True): - return self.forward_schedule.reorder_weights_to_e3nn(weights, has_batch_dim) - - @staticmethod - def name(): - return "LoopUnrollConv" - - @classmethod - def register_torch_fakes(cls): - global torch - import torch - - @torch._library.register_fake_class("libtorch_tp_jit::TorchJITConv") - class TorchJITConv: - def __init__( - self, - kernel_plaintext: str, - fwd_config: dict[str, int], - bwd_config: dict[str, int], - dbl_bwd_config: dict[str, int], - kernel_dims: dict[str, int], - ) -> None: - ( - self.kernel_plaintext, - self.fwd_config, - self.bwd_config, - self.dbl_bwd_config, - self.kernel_dims, - ) = ( - kernel_plaintext, - fwd_config, - bwd_config, - dbl_bwd_config, - kernel_dims, - ) - - @classmethod - def __obj_unflatten__(cls, flattened_product): - return cls(**dict(flattened_product)) - - def __len__(self): - return 0 - - def __setstate__(self, state): - ( - self.kernel_plaintext, - self.fwd_config, - self.bwd_config, - self.dbl_bwd_config, - self.kernel_dims, - ) = state - - def exec_conv_rawptrs(*args, **kwargs): - pass - - def backward_rawptrs(*args, **kwargs): - pass - - def double_backward_rawptrs(*args, **kwargs): - pass - - def L3_dim_getter(self): - return self.kernel_dims["L3_dim"] - - def irrep_dtype_getter(self): - return self.kernel_dims["irrep_dtype"] - - @torch.library.register_fake("libtorch_tp_jit::jit_conv_forward") - def fake_forward( - jit, L1_in, L2_in, W, rows, cols, workspace_buffer, sender_perm - ): - L3_dim, irrep_dtype = None, None - if hasattr(jit, "wrapped_obj"): - L3_dim = jit.wrapped_obj.kernel_dims["L3_dim"] - irrep_dtype = jit.wrapped_obj.kernel_dims["irrep_dtype"] - else: - L3_dim = jit.L3_dim - irrep_dtype = jit.irrep_dtype - - return torch.empty( - L1_in.shape[0], - L3_dim, - device="cuda", - dtype=enum_to_torch_dtype[irrep_dtype], - ) - - @torch.library.register_fake("libtorch_tp_jit::jit_conv_backward") - def fake_backward( - jit, L1_in, L2_in, W, L3_grad, rows, cols, workspace_buffer, sender_perm - ): - return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W) - - @torch.library.register_fake("libtorch_tp_jit::jit_conv_double_backward") - def fake_double_backward( - jit, - L1_in, - L2_in, - W, - L3_grad, - L1_dgrad, - L2_dgrad, - w_dgrad, - rows, - cols, - workspace_buffer, - transpose_perm=None, - ): - return [ - L1_in.new_empty(*L1_in.shape), - L2_in.new_empty(*L2_in.shape), - W.new_empty(*W.shape), - L3_grad.new_empty(*L3_grad.shape), - ] - - @classmethod - def register_autograd(cls): - backward_op = torch.ops.libtorch_tp_jit.jit_conv_backward - double_backward_op = torch.ops.libtorch_tp_jit.jit_conv_double_backward - - def setup_context(ctx, inputs, output): - ( - ctx.jit, - ctx.L1_in, - ctx.L2_in, - ctx.W, - ctx.rows, - ctx.cols, - ctx.workspace_buffer, - ctx.sender_perm, - ) = inputs - - def backward(ctx, grad_output): - L1_grad, L2_grad, W_grad = backward_op( - ctx.jit, - ctx.L1_in, - ctx.L2_in, - ctx.W, - grad_output, - ctx.rows, - ctx.cols, - ctx.workspace_buffer, - ctx.sender_perm, - ) - return None, L1_grad, L2_grad, W_grad, None, None, None, None - - torch.library.register_autograd( - "libtorch_tp_jit::jit_conv_forward", backward, setup_context=setup_context - ) - - def setup_context_double_backward(ctx, inputs, output): - ( - ctx.jit, - ctx.L1_in, - ctx.L2_in, - ctx.W, - ctx.grad_output, - ctx.rows, - ctx.cols, - ctx.workspace_buffer, - ctx.sender_perm, - ) = inputs - ctx.inputs = inputs - - def double_backward(ctx, E, F, G): - result = double_backward_op( - ctx.jit, - ctx.L1_in, - ctx.L2_in, - ctx.W, - ctx.grad_output, - E, - F, - G, - ctx.rows, - ctx.cols, - ctx.workspace_buffer, - ctx.sender_perm, - ) - return ( - None, - result[0], - result[1], - result[2], - result[3], - None, - None, - None, - None, - ) - - torch.library.register_autograd( - "libtorch_tp_jit::jit_conv_backward", - double_backward, - setup_context=setup_context_double_backward, - ) - - @classmethod - def register_autocast(cls): - global torch - import torch - - torch.library.register_autocast( - "libtorch_tp_jit::jit_conv_forward", "cuda", torch.float32 - ) - torch.library.register_autocast( - "libtorch_tp_jit::jit_conv_backward", "cuda", torch.float32 - ) - torch.library.register_autocast( - "libtorch_tp_jit::jit_conv_double_backward", "cuda", torch.float32 - ) - - -if extlib.TORCH_COMPILE: - LoopUnrollConv.register_torch_fakes() - LoopUnrollConv.register_autograd() - LoopUnrollConv.register_autocast() diff --git a/openequivariance/openequivariance/impl_torch/TensorProductConv.py b/openequivariance/openequivariance/impl_torch/TensorProductConv.py index 3ee71108..a10f8178 100644 --- a/openequivariance/openequivariance/impl_torch/TensorProductConv.py +++ b/openequivariance/openequivariance/impl_torch/TensorProductConv.py @@ -3,7 +3,9 @@ import numpy as np import torch -from openequivariance.impl_torch import extlib +import openequivariance.impl_torch.extlib as extlib +from openequivariance.impl_torch.extlib import JITConvImpl, postprocess_kernel, DeviceProp + from openequivariance.core.ConvolutionBase import ( ConvolutionBase, scatter_add_wrapper, @@ -13,6 +15,8 @@ from openequivariance import TPProblem from openequivariance.core.utils import torch_to_oeq_dtype +from openequivariance.benchmark.logging_utils import getLogger +logger = getLogger() class TensorProductConv(torch.nn.Module, LoopUnrollConv): r""" @@ -58,14 +62,33 @@ def __init__( self._init_class() def _init_class(self): + dp = extlib.DeviceProp(0) LoopUnrollConv.__init__( self, self.input_args["problem"], + dp, postprocess_kernel, idx_dtype=np.int64, torch_op=self.input_args["torch_op"], deterministic=self.input_args["deterministic"], kahan=self.input_args["kahan"], ) + + self.allocate_workspace(self.workspace_size) + if extlib.TORCH_COMPILE: + internal_cls = torch.classes.libtorch_tp_jit.TorchJITConv + else: + internal_cls = JITConvImpl + + logger.info("Starting kernel compiler...") + self.internal = internal_cls( + self.jit_kernel, + vars(self.forward_schedule.launch_config), + vars(self.backward_schedule.launch_config), + vars(self.double_backward_schedule.launch_config), + self.kernel_prop + ) + logger.info("Kernel compiled!") + self.dummy_transpose_perm = torch.zeros(1, dtype=torch.int64, device="cuda") self.weight_numel = self.config.weight_numel self._setup_notorchbind() @@ -151,9 +174,16 @@ def forward( sender_perm, ) - @staticmethod - def name(): - return LoopUnrollConv.name() + def allocate_workspace(self, size_bytes): + self.workspace_size = size_bytes + if self.torch_op: + self.workspace_buffer = torch.zeros( + size_bytes, dtype=torch.uint8, device="cuda" + ) + else: + self.workspace_buffer = extlib.DeviceBuffer(size_bytes) + self.workspace_ptr = self.workspace_buffer.data_ptr() + logger.info(f"Convolution requires {size_bytes // 1000000}MB of workspace.") def _setup_notorchbind(self): @torch.library.custom_op( @@ -379,6 +409,226 @@ def double_backward(ctx, grad_output): double_backward, setup_context=setup_context_double_backward ) + def reorder_weights_from_e3nn(self, weights, has_batch_dim=True): + return self.forward_schedule.reorder_weights_from_e3nn(weights, has_batch_dim) + + def reorder_weights_to_e3nn(self, weights, has_batch_dim=True): + return self.forward_schedule.reorder_weights_to_e3nn(weights, has_batch_dim) + + @staticmethod + def name(): + return "LoopUnrollConv" + + @classmethod + def register_torch_fakes(cls): + global torch + import torch + + @torch._library.register_fake_class("libtorch_tp_jit::TorchJITConv") + class TorchJITConv: + def __init__( + self, + kernel_plaintext: str, + fwd_config: dict[str, int], + bwd_config: dict[str, int], + dbl_bwd_config: dict[str, int], + kernel_dims: dict[str, int], + ) -> None: + ( + self.kernel_plaintext, + self.fwd_config, + self.bwd_config, + self.dbl_bwd_config, + self.kernel_dims, + ) = ( + kernel_plaintext, + fwd_config, + bwd_config, + dbl_bwd_config, + kernel_dims, + ) + + @classmethod + def __obj_unflatten__(cls, flattened_product): + return cls(**dict(flattened_product)) + + def __len__(self): + return 0 + + def __setstate__(self, state): + ( + self.kernel_plaintext, + self.fwd_config, + self.bwd_config, + self.dbl_bwd_config, + self.kernel_dims, + ) = state + + def exec_conv_rawptrs(*args, **kwargs): + pass + + def backward_rawptrs(*args, **kwargs): + pass + + def double_backward_rawptrs(*args, **kwargs): + pass + + def L3_dim_getter(self): + return self.kernel_dims["L3_dim"] + + def irrep_dtype_getter(self): + return self.kernel_dims["irrep_dtype"] + + @torch.library.register_fake("libtorch_tp_jit::jit_conv_forward") + def fake_forward( + jit, L1_in, L2_in, W, rows, cols, workspace_buffer, sender_perm + ): + L3_dim, irrep_dtype = None, None + if hasattr(jit, "wrapped_obj"): + L3_dim = jit.wrapped_obj.kernel_dims["L3_dim"] + irrep_dtype = jit.wrapped_obj.kernel_dims["irrep_dtype"] + else: + L3_dim = jit.L3_dim + irrep_dtype = jit.irrep_dtype + + return torch.empty( + L1_in.shape[0], + L3_dim, + device="cuda", + dtype=enum_to_torch_dtype[irrep_dtype], + ) + + @torch.library.register_fake("libtorch_tp_jit::jit_conv_backward") + def fake_backward( + jit, L1_in, L2_in, W, L3_grad, rows, cols, workspace_buffer, sender_perm + ): + return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W) + + @torch.library.register_fake("libtorch_tp_jit::jit_conv_double_backward") + def fake_double_backward( + jit, + L1_in, + L2_in, + W, + L3_grad, + L1_dgrad, + L2_dgrad, + w_dgrad, + rows, + cols, + workspace_buffer, + transpose_perm=None, + ): + return [ + L1_in.new_empty(*L1_in.shape), + L2_in.new_empty(*L2_in.shape), + W.new_empty(*W.shape), + L3_grad.new_empty(*L3_grad.shape), + ] + + @classmethod + def register_autograd(cls): + backward_op = torch.ops.libtorch_tp_jit.jit_conv_backward + double_backward_op = torch.ops.libtorch_tp_jit.jit_conv_double_backward + + def setup_context(ctx, inputs, output): + ( + ctx.jit, + ctx.L1_in, + ctx.L2_in, + ctx.W, + ctx.rows, + ctx.cols, + ctx.workspace_buffer, + ctx.sender_perm, + ) = inputs + + def backward(ctx, grad_output): + L1_grad, L2_grad, W_grad = backward_op( + ctx.jit, + ctx.L1_in, + ctx.L2_in, + ctx.W, + grad_output, + ctx.rows, + ctx.cols, + ctx.workspace_buffer, + ctx.sender_perm, + ) + return None, L1_grad, L2_grad, W_grad, None, None, None, None + + torch.library.register_autograd( + "libtorch_tp_jit::jit_conv_forward", backward, setup_context=setup_context + ) + + def setup_context_double_backward(ctx, inputs, output): + ( + ctx.jit, + ctx.L1_in, + ctx.L2_in, + ctx.W, + ctx.grad_output, + ctx.rows, + ctx.cols, + ctx.workspace_buffer, + ctx.sender_perm, + ) = inputs + ctx.inputs = inputs + + def double_backward(ctx, E, F, G): + result = double_backward_op( + ctx.jit, + ctx.L1_in, + ctx.L2_in, + ctx.W, + ctx.grad_output, + E, + F, + G, + ctx.rows, + ctx.cols, + ctx.workspace_buffer, + ctx.sender_perm, + ) + return ( + None, + result[0], + result[1], + result[2], + result[3], + None, + None, + None, + None, + ) + + torch.library.register_autograd( + "libtorch_tp_jit::jit_conv_backward", + double_backward, + setup_context=setup_context_double_backward, + ) + + @classmethod + def register_autocast(cls): + global torch + import torch + + torch.library.register_autocast( + "libtorch_tp_jit::jit_conv_forward", "cuda", torch.float32 + ) + torch.library.register_autocast( + "libtorch_tp_jit::jit_conv_backward", "cuda", torch.float32 + ) + torch.library.register_autocast( + "libtorch_tp_jit::jit_conv_double_backward", "cuda", torch.float32 + ) + + +if extlib.TORCH_COMPILE: + TensorProductConv.register_torch_fakes() + TensorProductConv.register_autograd() + TensorProductConv.register_autocast() + # ================================================================== # Reference implementations for benchmarking From 865ca133aefe294694474d7726b4f905417f0138 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 1 Dec 2025 00:36:11 -0800 Subject: [PATCH 036/116] Convolution changed. --- openequivariance/openequivariance/core/ConvolutionBase.py | 1 + 1 file changed, 1 insertion(+) diff --git a/openequivariance/openequivariance/core/ConvolutionBase.py b/openequivariance/openequivariance/core/ConvolutionBase.py index bde3a91e..5f51469a 100644 --- a/openequivariance/openequivariance/core/ConvolutionBase.py +++ b/openequivariance/openequivariance/core/ConvolutionBase.py @@ -9,6 +9,7 @@ from openequivariance.benchmark.correctness_utils import check_similiarity from openequivariance.core.e3nn_lite import wigner_3j from openequivariance.core.utils import benchmark +from openequivariance.impl_torch.extlib import DeviceBuffer logger = getLogger() From b9c9135a7dbfaf43ae201d184a142a632ca17463 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 1 Dec 2025 20:51:44 -0800 Subject: [PATCH 037/116] Finished prototype of TensorProductConv. --- .../openequivariance/core/utils.py | 12 ++++- .../impl_jax/TensorProduct.py | 11 +--- .../impl_jax/TensorProductConv.py | 54 +++++++++++++++++++ .../impl_jax/extlib/__init__.py | 4 ++ 4 files changed, 70 insertions(+), 11 deletions(-) create mode 100644 openequivariance/openequivariance/impl_jax/TensorProductConv.py diff --git a/openequivariance/openequivariance/core/utils.py b/openequivariance/openequivariance/core/utils.py index cf9d038e..677f86d3 100644 --- a/openequivariance/openequivariance/core/utils.py +++ b/openequivariance/openequivariance/core/utils.py @@ -7,9 +7,9 @@ import json import tempfile +import hashlib from openequivariance.impl_torch.extlib import GPUTimer - def sparse_outer_product_work(cg: np.ndarray) -> int: return np.sum(np.max(cg != 0, axis=2)) @@ -170,3 +170,13 @@ def benchmark(func, num_warmup, num_iter, mode="gpu_time", kernel_names=[]): time_millis[i] = kernel_time return time_millis + + +def hash_attributes(attrs): + m = hashlib.sha256() + + for key in sorted(attrs.keys()): + m.update(attrs[key].__repr__().encode("utf-8")) + + hash = int(m.hexdigest()[:16], 16) >> 1 + attrs["hash"] = hash diff --git a/openequivariance/openequivariance/impl_jax/TensorProduct.py b/openequivariance/openequivariance/impl_jax/TensorProduct.py index 1265fc8a..2333ec75 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProduct.py +++ b/openequivariance/openequivariance/impl_jax/TensorProduct.py @@ -7,17 +7,9 @@ import hashlib from openequivariance.core.e3nn_lite import TPProblem, Irreps from openequivariance.core.LoopUnrollTP import LoopUnrollTP +from openequivariance.core.utils import hash_attributes import jax.numpy as jnp -def hash_attributes(attrs): - m = hashlib.sha256() - - for key in sorted(attrs.keys()): - m.update(attrs[key].__repr__().encode("utf-8")) - - hash = int(m.hexdigest()[:16], 16) >> 1 - attrs["hash"] = hash - @partial(jax.custom_vjp, nondiff_argnums=(3,4,5)) def forward(X, Y, W, L3_dim, irrep_dtype, attrs): forward_call = jax.ffi.ffi_call("tp_forward", @@ -84,7 +76,6 @@ def jax_to_torch(x): return torch.tensor(np.asarray(x), requires_grad=True) if __name__ == "__main__": - tp_problem = None X_ir, Y_ir, Z_ir = Irreps("1x2e"), Irreps("1x3e"), Irreps("1x2e") instructions=[(0, 0, 0, "uvu", True)] problem = TPProblem(X_ir, Y_ir, Z_ir, diff --git a/openequivariance/openequivariance/impl_jax/TensorProductConv.py b/openequivariance/openequivariance/impl_jax/TensorProductConv.py new file mode 100644 index 00000000..b64e05b2 --- /dev/null +++ b/openequivariance/openequivariance/impl_jax/TensorProductConv.py @@ -0,0 +1,54 @@ +import numpy as np +from functools import partial +from openequivariance.impl_jax import extlib + +from openequivariance.core.e3nn_lite import TPProblem, Irreps +from openequivariance.core.LoopUnrollConv import LoopUnrollConv +from openequivariance.core.utils import hash_attributes + +import jax +import jax.numpy as jnp + +from openequivariance.benchmark.logging_utils import getLogger +logger = getLogger() + +class TensorProductConv(LoopUnrollConv): + def __init__(self, config, deterministic=False, kahan=False): + dp = extlib.DeviceProp(0) + super().__init__( + self, + config, + dp, extlib.postprocess_kernel, + idx_dtype=np.int64, + torch_op=False, + deterministic=deterministic, + kahan=kahan + ) + + self.attrs = { + "kernel": self.jit_kernel, + "forward_config": vars(self.forward_schedule.launch_config), + "backward_config": vars(self.backward_schedule.launch_config), + "double_backward_config": vars(self.double_backward_schedule.launch_config), + "kernel_prop": self.kernelProp + } + hash_attributes(self.attrs) + + self.weight_numel = config.weight_numel + self.L3_dim = self.config.irreps_out.dim + + self.workspace = jnp.zeros((self.workspace_size,), dtype=jnp.uint8) + logger.info(f"Convolution requires {self.workspace_size // (2 ** 20)}MB of workspace.") + self.dummy_transpose_perm = jnp.zeros((1,), dtype=jnp.int64) + + +if __name__=="__main__": + X_ir, Y_ir, Z_ir = Irreps("1x2e"), Irreps("1x3e"), Irreps("1x2e") + instructions=[(0, 0, 0, "uvu", True)] + problem = TPProblem(X_ir, Y_ir, Z_ir, + instructions, + shared_weights=False, + internal_weights=False) + + conv = TensorProductConv(problem, deterministic=False, kahan=False) + print("COMPLETE!") \ No newline at end of file diff --git a/openequivariance/openequivariance/impl_jax/extlib/__init__.py b/openequivariance/openequivariance/impl_jax/extlib/__init__.py index 9dff4696..29c9283f 100644 --- a/openequivariance/openequivariance/impl_jax/extlib/__init__.py +++ b/openequivariance/openequivariance/impl_jax/extlib/__init__.py @@ -1,6 +1,10 @@ import jax +import hashlib def postprocess_kernel(kernel): + ''' + Only CUDA for now, so no postprocessing. + ''' return kernel import openequivariance_extjax as oeq_extjax From 4ea49dc4b3f45b5b85d96f7187ed49c973be6e5a Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 1 Dec 2025 21:16:17 -0800 Subject: [PATCH 038/116] Added some type annotations. --- .../impl_jax/TensorProduct.py | 4 ++-- .../impl_jax/TensorProductConv.py | 22 ++++++++++++++----- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/openequivariance/openequivariance/impl_jax/TensorProduct.py b/openequivariance/openequivariance/impl_jax/TensorProduct.py index 2333ec75..d1e7759f 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProduct.py +++ b/openequivariance/openequivariance/impl_jax/TensorProduct.py @@ -50,7 +50,7 @@ def backward_autograd(L3_dim, irrep_dtype, attrs, inputs, dZ): backward.defvjp(backward_with_inputs, double_backward) class TensorProduct(LoopUnrollTP): - def __init__(self, config): + def __init__(self, config: TPProblem): dp = extlib.DeviceProp(0) super().__init__(config, dp, extlib.postprocess_kernel, torch_op=False) @@ -66,7 +66,7 @@ def __init__(self, config): self.weight_numel = config.weight_numel self.L3_dim = self.config.irreps_out.dim - def forward(self, X, Y, W): + def forward(self, X: jax.ndarray, Y: jax.ndarray, W: jax.ndarray) -> jax.ndarray: return forward(X, Y, W, self.L3_dim, self.config.irrep_dtype, self.attrs) diff --git a/openequivariance/openequivariance/impl_jax/TensorProductConv.py b/openequivariance/openequivariance/impl_jax/TensorProductConv.py index b64e05b2..1156aedc 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProductConv.py +++ b/openequivariance/openequivariance/impl_jax/TensorProductConv.py @@ -1,5 +1,6 @@ import numpy as np from functools import partial +from typing import Optional from openequivariance.impl_jax import extlib from openequivariance.core.e3nn_lite import TPProblem, Irreps @@ -13,13 +14,12 @@ logger = getLogger() class TensorProductConv(LoopUnrollConv): - def __init__(self, config, deterministic=False, kahan=False): + def __init__(self, config: TPProblem, deterministic: bool = False, kahan: bool = False): dp = extlib.DeviceProp(0) super().__init__( - self, - config, + config, dp, extlib.postprocess_kernel, - idx_dtype=np.int64, + idx_dtype=np.int32, # Note: this is distinct from PyTorch torch_op=False, deterministic=deterministic, kahan=kahan @@ -30,7 +30,7 @@ def __init__(self, config, deterministic=False, kahan=False): "forward_config": vars(self.forward_schedule.launch_config), "backward_config": vars(self.backward_schedule.launch_config), "double_backward_config": vars(self.double_backward_schedule.launch_config), - "kernel_prop": self.kernelProp + "kernel_prop": self.kernel_prop } hash_attributes(self.attrs) @@ -39,8 +39,18 @@ def __init__(self, config, deterministic=False, kahan=False): self.workspace = jnp.zeros((self.workspace_size,), dtype=jnp.uint8) logger.info(f"Convolution requires {self.workspace_size // (2 ** 20)}MB of workspace.") - self.dummy_transpose_perm = jnp.zeros((1,), dtype=jnp.int64) + self.dummy_transpose_perm = jnp.zeros((1,), dtype=jnp.int32) + + def forward( + self, + X: jax.ndarray, + Y: jax.ndarray, + W: jax.ndarray, + rows: jax.ndarray, + cols: jax.ndarray, + sender_perm: Optional[jax.ndarray] = None) -> jax.ndarray: + pass if __name__=="__main__": X_ir, Y_ir, Z_ir = Irreps("1x2e"), Irreps("1x3e"), Irreps("1x2e") From 745e4e0dc824f7c6cde914d16c82ca3a7158ae09 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 1 Dec 2025 21:36:37 -0800 Subject: [PATCH 039/116] Finished the forward call. --- .../impl_jax/TensorProductConv.py | 23 ++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/openequivariance/openequivariance/impl_jax/TensorProductConv.py b/openequivariance/openequivariance/impl_jax/TensorProductConv.py index 1156aedc..7551ba1d 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProductConv.py +++ b/openequivariance/openequivariance/impl_jax/TensorProductConv.py @@ -13,6 +13,15 @@ from openequivariance.benchmark.logging_utils import getLogger logger = getLogger() +@partial(jax.custom_vjp, nondiff_argnums=(3,4,5,6,7,8,9)) +def forward(X, Y, W, rows, cols, sender_perm, workspace, L3_dim, irrep_dtype, attrs): + forward_call = jax.ffi.ffi_call("conv_forward", + jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype)) + return forward_call(X, Y, W, rows, cols, sender_perm, workspace, **attrs) + +def forward_with_inputs(X, Y, W, rows, cols, sender_perm, workspace, L3_dim, irrep_dtype, attrs): + return forward(X, Y, W, rows, cols, sender_perm, workspace, L3_dim, irrep_dtype, attrs), (X, Y, W, rows, cols, sender_perm, workspace) + class TensorProductConv(LoopUnrollConv): def __init__(self, config: TPProblem, deterministic: bool = False, kahan: bool = False): dp = extlib.DeviceProp(0) @@ -50,7 +59,19 @@ def forward( rows: jax.ndarray, cols: jax.ndarray, sender_perm: Optional[jax.ndarray] = None) -> jax.ndarray: - pass + + if self.deterministic: + sender_perm = self.dummy_transpose_perm + else: + assert sender_perm is not None, "Must provide sender_perm for non-deterministic convolutions." + + return forward( + X, Y, W, + rows, cols, sender_perm, + self.workspace, + self.L3_dim, + self.config.irrep_dtype, + self.attrs) if __name__=="__main__": X_ir, Y_ir, Z_ir = Irreps("1x2e"), Irreps("1x3e"), Irreps("1x2e") From 19f284b69887b51dbb00979edf59cd2f912bb450 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 1 Dec 2025 21:43:35 -0800 Subject: [PATCH 040/116] Ready to start JAX support. --- .../openequivariance/impl_jax/TensorProductConv.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/openequivariance/openequivariance/impl_jax/TensorProductConv.py b/openequivariance/openequivariance/impl_jax/TensorProductConv.py index 7551ba1d..6637106e 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProductConv.py +++ b/openequivariance/openequivariance/impl_jax/TensorProductConv.py @@ -82,4 +82,14 @@ def forward( internal_weights=False) conv = TensorProductConv(problem, deterministic=False, kahan=False) + + node_ct, nonzero_ct = 3, 4 + X = jax.random.uniform(jax.random.PRNGKey(0), (node_ct, X_ir.dim), dtype=jax.numpy.float32) + Y = jax.random.uniform(jax.random.PRNGKey(1), (nonzero_ct, Y_ir.dim), dtype=jax.numpy.float32) + W = jax.random.uniform(jax.random.PRNGKey(2), (nonzero_ct, conv.weight_numel), dtype=jax.numpy.float32) + rows = jnp.array([0, 1, 1, 2], dtype=jnp.int32) + cols = jnp.array([1, 0, 2, 1], dtype=jnp.int32) + Z = conv.forward(X, Y, W, rows, cols) + print("Z:", Z) + print("COMPLETE!") \ No newline at end of file From 0d07cd90d5b6687ec48169c87d36fb1e90d390a1 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 1 Dec 2025 22:20:34 -0800 Subject: [PATCH 041/116] More plumbing. --- .../openequivariance/core/LoopUnrollTP.py | 5 +++++ .../impl_jax/TensorProductConv.py | 18 +++++++++--------- openequivariance_extjax/src/libjax_tp_jit.cpp | 9 +++++++-- 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/openequivariance/openequivariance/core/LoopUnrollTP.py b/openequivariance/openequivariance/core/LoopUnrollTP.py index 82a4641e..607c246a 100644 --- a/openequivariance/openequivariance/core/LoopUnrollTP.py +++ b/openequivariance/openequivariance/core/LoopUnrollTP.py @@ -99,6 +99,11 @@ def generate_double_backward_schedule(warps_per_block): "opt_level": 3, "irrep_dtype": dtype_to_enum[self.config.irrep_dtype], "weight_dtype": dtype_to_enum[self.config.weight_dtype], + + # Not relevant, included for compatibility with convolution + "workspace_size": 0, + "deterministic": 1, + "idx_dtype": 0 } diff --git a/openequivariance/openequivariance/impl_jax/TensorProductConv.py b/openequivariance/openequivariance/impl_jax/TensorProductConv.py index 6637106e..ec1f04a8 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProductConv.py +++ b/openequivariance/openequivariance/impl_jax/TensorProductConv.py @@ -13,7 +13,7 @@ from openequivariance.benchmark.logging_utils import getLogger logger = getLogger() -@partial(jax.custom_vjp, nondiff_argnums=(3,4,5,6,7,8,9)) +#@partial(jax.custom_vjp, nondiff_argnums=(3,4,5,6,7,8,9)) def forward(X, Y, W, rows, cols, sender_perm, workspace, L3_dim, irrep_dtype, attrs): forward_call = jax.ffi.ffi_call("conv_forward", jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype)) @@ -53,17 +53,17 @@ def __init__(self, config: TPProblem, deterministic: bool = False, kahan: bool = def forward( self, - X: jax.ndarray, - Y: jax.ndarray, - W: jax.ndarray, - rows: jax.ndarray, - cols: jax.ndarray, - sender_perm: Optional[jax.ndarray] = None) -> jax.ndarray: + X: jax.numpy.ndarray, + Y: jax.numpy.ndarray, + W: jax.numpy.ndarray, + rows: jax.numpy.ndarray, + cols: jax.numpy.ndarray, + sender_perm: Optional[jax.numpy.ndarray] = None) -> jax.numpy.ndarray: - if self.deterministic: + if not self.deterministic: sender_perm = self.dummy_transpose_perm else: - assert sender_perm is not None, "Must provide sender_perm for non-deterministic convolutions." + assert sender_perm is not None, "Must provide sender_perm for deterministic convolutions." return forward( X, Y, W, diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp index 46c717fc..df494ccf 100644 --- a/openequivariance_extjax/src/libjax_tp_jit.cpp +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -144,7 +144,12 @@ std::vector kernel_prop_keys = { "shared_weights", "opt_level", "irrep_dtype", - "weight_dtype"}; + "weight_dtype", + + // Convolution only + "workspace_size", + "deterministic", + "idx_dtype"}; std::unordered_map parse_ffi_dict(ffi::Dictionary &dict, const std::vector &keys) { std::unordered_map result; @@ -240,7 +245,7 @@ inline void check_tensor(const ffi::AnyBuffer &buffer, } if (buffer.element_type() != expected_dtype) { - throw std::logic_error("Datatype mismatch."); + throw std::logic_error("Datatype mismatch for tensor " + tensor_name); } } From ce68f69a33cd2b994b5a69975765b143b87c935d Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 1 Dec 2025 22:38:23 -0800 Subject: [PATCH 042/116] Forward call is working. --- .../impl_jax/TensorProductConv.py | 11 ++++---- openequivariance_extjax/src/libjax_tp_jit.cpp | 28 +++++++------------ 2 files changed, 16 insertions(+), 23 deletions(-) diff --git a/openequivariance/openequivariance/impl_jax/TensorProductConv.py b/openequivariance/openequivariance/impl_jax/TensorProductConv.py index ec1f04a8..9d845839 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProductConv.py +++ b/openequivariance/openequivariance/impl_jax/TensorProductConv.py @@ -14,13 +14,13 @@ logger = getLogger() #@partial(jax.custom_vjp, nondiff_argnums=(3,4,5,6,7,8,9)) -def forward(X, Y, W, rows, cols, sender_perm, workspace, L3_dim, irrep_dtype, attrs): +def forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs): forward_call = jax.ffi.ffi_call("conv_forward", jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype)) - return forward_call(X, Y, W, rows, cols, sender_perm, workspace, **attrs) + return forward_call(X, Y, W, rows, cols, workspace, sender_perm, **attrs) -def forward_with_inputs(X, Y, W, rows, cols, sender_perm, workspace, L3_dim, irrep_dtype, attrs): - return forward(X, Y, W, rows, cols, sender_perm, workspace, L3_dim, irrep_dtype, attrs), (X, Y, W, rows, cols, sender_perm, workspace) +def forward_with_inputs(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs): + return forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs), (X, Y, W, rows, cols, sender_perm, workspace) class TensorProductConv(LoopUnrollConv): def __init__(self, config: TPProblem, deterministic: bool = False, kahan: bool = False): @@ -67,8 +67,9 @@ def forward( return forward( X, Y, W, - rows, cols, sender_perm, + rows, cols, self.workspace, + sender_perm, self.L3_dim, self.config.irrep_dtype, self.attrs) diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp index df494ccf..d05ccff8 100644 --- a/openequivariance_extjax/src/libjax_tp_jit.cpp +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -46,26 +46,20 @@ xla::ffi::DataType enum_to_xla_dtype(int64_t i){ } inline void* data_ptr(ffi::AnyBuffer &buffer) { - switch (buffer.element_type()) { - case xla::ffi::DataType::F32: - return reinterpret_cast(buffer.typed_data()); - case xla::ffi::DataType::F64: - return reinterpret_cast(buffer.typed_data()); - case xla::ffi::DataType::S64: - return reinterpret_cast(buffer.typed_data()); - case xla::ffi::DataType::U8: - return reinterpret_cast(buffer.typed_data()); - default: - throw logic_error("Unsupported tensor datatype!"); - } + return buffer.untyped_data(); +} + +inline void* data_ptr(ffi::Result &buffer) { + return data_ptr(*buffer); } inline int byte_count(ffi::AnyBuffer &buffer) { switch (buffer.element_type()) { + case xla::ffi::DataType::U32: + case xla::ffi::DataType::S32: case xla::ffi::DataType::F32: return 4; case xla::ffi::DataType::F64: - return 8; case xla::ffi::DataType::S64: return 8; case xla::ffi::DataType::U8: @@ -75,10 +69,6 @@ inline int byte_count(ffi::AnyBuffer &buffer) { } } -inline void* data_ptr(ffi::Result &buffer) { - return data_ptr(*buffer); -} - #ifdef CUDA_BACKEND void zero_buffer(ffi::AnyBuffer &buffer) { cudaMemset( @@ -245,7 +235,9 @@ inline void check_tensor(const ffi::AnyBuffer &buffer, } if (buffer.element_type() != expected_dtype) { - throw std::logic_error("Datatype mismatch for tensor " + tensor_name); + throw std::logic_error("Datatype mismatch for tensor " + tensor_name + + ". Expected datatype " + std::to_string(static_cast(expected_dtype)) + + ", got " + std::to_string(static_cast(buffer.element_type()))); } } From d94db28cdda457b8448b1e4eff7afb3c9a0f5e73 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 1 Dec 2025 22:53:27 -0800 Subject: [PATCH 043/116] Registered the VJP rules for backward and double-backward. --- .../impl_jax/TensorProductConv.py | 30 +++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/openequivariance/openequivariance/impl_jax/TensorProductConv.py b/openequivariance/openequivariance/impl_jax/TensorProductConv.py index 9d845839..ae6e754d 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProductConv.py +++ b/openequivariance/openequivariance/impl_jax/TensorProductConv.py @@ -13,7 +13,7 @@ from openequivariance.benchmark.logging_utils import getLogger logger = getLogger() -#@partial(jax.custom_vjp, nondiff_argnums=(3,4,5,6,7,8,9)) +@partial(jax.custom_vjp, nondiff_argnums=(3,4,5,6,7,8,9)) def forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs): forward_call = jax.ffi.ffi_call("conv_forward", jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype)) @@ -22,6 +22,33 @@ def forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, at def forward_with_inputs(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs): return forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs), (X, Y, W, rows, cols, sender_perm, workspace) +@partial(jax.custom_vjp, nondiff_argnums=(4,5,6,7,8,9)) +def backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs): + backward_call = jax.ffi.ffi_call("conv_backward", + (jax.ShapeDtypeStruct(X.shape, irrep_dtype), + jax.ShapeDtypeStruct(Y.shape, irrep_dtype), + jax.ShapeDtypeStruct(W.shape, irrep_dtype))) + return backward_call(X, Y, W, dZ, rows, cols, workspace, sender_perm, **attrs) + +def backward_with_inputs(X, Y, W, dZ, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs): + return backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs), (X, Y, W, dZ, rows, cols, sender_perm, workspace) + +def double_backward(rows, cols, workspace, sender_perm, irrep_dtype, attrs, inputs, derivatives): + double_backward_call = jax.ffi.ffi_call("conv_double_backward", + ( + jax.ShapeDtypeStruct(inputs[0].shape, irrep_dtype), + jax.ShapeDtypeStruct(inputs[1].shape, irrep_dtype), + jax.ShapeDtypeStruct(inputs[2].shape, irrep_dtype), + jax.ShapeDtypeStruct(inputs[3].shape, irrep_dtype), + )) + return double_backward_call(*inputs, *derivatives, rows, cols, workspace, sender_perm, **attrs) + +def backward_autograd(rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs, inputs, dZ): + return backward(inputs[0], inputs[1], inputs[2], dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs) + +forward.defvjp(forward_with_inputs, backward_autograd) +backward.defvjp(backward_with_inputs, double_backward) + class TensorProductConv(LoopUnrollConv): def __init__(self, config: TPProblem, deterministic: bool = False, kahan: bool = False): dp = extlib.DeviceProp(0) @@ -50,7 +77,6 @@ def __init__(self, config: TPProblem, deterministic: bool = False, kahan: bool = logger.info(f"Convolution requires {self.workspace_size // (2 ** 20)}MB of workspace.") self.dummy_transpose_perm = jnp.zeros((1,), dtype=jnp.int32) - def forward( self, X: jax.numpy.ndarray, From 2524f2a7a4a6a8130a8e8d7c51a605f116c0d060 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 1 Dec 2025 23:07:55 -0800 Subject: [PATCH 044/116] Added __call__ functions. --- .../openequivariance/impl_jax/TensorProduct.py | 6 ++++++ .../openequivariance/impl_jax/TensorProductConv.py | 11 ++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/openequivariance/openequivariance/impl_jax/TensorProduct.py b/openequivariance/openequivariance/impl_jax/TensorProduct.py index d1e7759f..72b9bee3 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProduct.py +++ b/openequivariance/openequivariance/impl_jax/TensorProduct.py @@ -69,6 +69,12 @@ def __init__(self, config: TPProblem): def forward(self, X: jax.ndarray, Y: jax.ndarray, W: jax.ndarray) -> jax.ndarray: return forward(X, Y, W, self.L3_dim, self.config.irrep_dtype, self.attrs) + def __call__(self, + X: jax.numpy.ndarray, + Y: jax.numpy.ndarray, + W: jax.numpy.ndarray) -> jax.numpy.ndarray: + return self.forward(X, Y, W) + def jax_to_torch(x): import numpy as np diff --git a/openequivariance/openequivariance/impl_jax/TensorProductConv.py b/openequivariance/openequivariance/impl_jax/TensorProductConv.py index ae6e754d..a6e8489f 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProductConv.py +++ b/openequivariance/openequivariance/impl_jax/TensorProductConv.py @@ -55,7 +55,7 @@ def __init__(self, config: TPProblem, deterministic: bool = False, kahan: bool = super().__init__( config, dp, extlib.postprocess_kernel, - idx_dtype=np.int32, # Note: this is distinct from PyTorch + idx_dtype=np.int32, # N.B. this is distinct from the PyTorch version torch_op=False, deterministic=deterministic, kahan=kahan @@ -99,6 +99,15 @@ def forward( self.L3_dim, self.config.irrep_dtype, self.attrs) + + def __call__(self, + X: jax.numpy.ndarray, + Y: jax.numpy.ndarray, + W: jax.numpy.ndarray, + rows: jax.numpy.ndarray, + cols: jax.numpy.ndarray, + sender_perm: Optional[jax.numpy.ndarray] = None) -> jax.numpy.ndarray: + return self.forward(X, Y, W, rows, cols, sender_perm) if __name__=="__main__": X_ir, Y_ir, Z_ir = Irreps("1x2e"), Irreps("1x3e"), Irreps("1x2e") From a1b6248046e4a0c25ab11225da769b5e55acec2c Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 1 Dec 2025 23:40:54 -0800 Subject: [PATCH 045/116] Prepping to add tests. --- .../impl_jax/TensorProduct.py | 57 +------------------ .../impl_jax/TensorProductConv.py | 24 +------- 2 files changed, 3 insertions(+), 78 deletions(-) diff --git a/openequivariance/openequivariance/impl_jax/TensorProduct.py b/openequivariance/openequivariance/impl_jax/TensorProduct.py index 72b9bee3..e48461ea 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProduct.py +++ b/openequivariance/openequivariance/impl_jax/TensorProduct.py @@ -73,59 +73,4 @@ def __call__(self, X: jax.numpy.ndarray, Y: jax.numpy.ndarray, W: jax.numpy.ndarray) -> jax.numpy.ndarray: - return self.forward(X, Y, W) - - -def jax_to_torch(x): - import numpy as np - import torch - return torch.tensor(np.asarray(x), requires_grad=True) - -if __name__ == "__main__": - X_ir, Y_ir, Z_ir = Irreps("1x2e"), Irreps("1x3e"), Irreps("1x2e") - instructions=[(0, 0, 0, "uvu", True)] - problem = TPProblem(X_ir, Y_ir, Z_ir, - instructions, - shared_weights=False, - internal_weights=False) - tensor_product = TensorProduct(problem) - batch_size = 100 - - X = jax.random.uniform(jax.random.PRNGKey(0), (batch_size, X_ir.dim), dtype=jax.numpy.float32) - Y = jax.random.uniform(jax.random.PRNGKey(1), (batch_size, Y_ir.dim), dtype=jax.numpy.float32) - W = jax.random.uniform(jax.random.PRNGKey(2), (batch_size, tensor_product.weight_numel), dtype=jax.numpy.float32) - Z = tensor_product.forward(X, Y, W) - - # Test forward jax vjp - ctZ = jax.random.uniform(jax.random.PRNGKey(3), Z.shape, dtype=jax.numpy.float32) - result = jax.vjp(lambda x, y, w: tensor_product.forward(x, y, w), X, Y, W)[1](ctZ) - - print("COMPLETED FORWARD PASS!") - - ddX = jax.random.uniform(jax.random.PRNGKey(4), X.shape, dtype=jax.numpy.float32) - ddY = jax.random.uniform(jax.random.PRNGKey(5), Y.shape, dtype=jax.numpy.float32) - ddW = jax.random.uniform(jax.random.PRNGKey(6), W.shape, dtype=jax.numpy.float32) - - result_double_backward = jax.vjp( - lambda x, y, w: jax.vjp(lambda a, b, c: tensor_product.forward(a, b, c), x, y, w)[1](ctZ), - X, Y, W - )[1]((ddX, ddY, ddW)) - - print("COMPLETED DOUBLE BACKWARD PASS!") - - from e3nn import o3 - e3nn_tp = o3.TensorProduct(X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False) - print(jax_to_torch(W).shape) - - X_t = jax_to_torch(X) - Y_t = jax_to_torch(Y) - W_t = jax_to_torch(W) - Z_t = jax_to_torch(Z) - Z_e3nn = e3nn_tp(X_t, Y_t, W_t) - print("E3NN RESULT:", (Z_e3nn - Z_t).norm()) - - Z_e3nn.backward(jax_to_torch(ctZ)) - #^^^ Print the norms of the differences in gradients instead - print("E3NN GRADS NORM:", (jax_to_torch(result[0]) - X_t.grad).norm(), - (jax_to_torch(result[1]) - Y_t.grad).norm(), - (jax_to_torch(result[2]) - W_t.grad).norm()) \ No newline at end of file + return self.forward(X, Y, W) \ No newline at end of file diff --git a/openequivariance/openequivariance/impl_jax/TensorProductConv.py b/openequivariance/openequivariance/impl_jax/TensorProductConv.py index a6e8489f..7745e1bf 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProductConv.py +++ b/openequivariance/openequivariance/impl_jax/TensorProductConv.py @@ -30,8 +30,8 @@ def backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs jax.ShapeDtypeStruct(W.shape, irrep_dtype))) return backward_call(X, Y, W, dZ, rows, cols, workspace, sender_perm, **attrs) -def backward_with_inputs(X, Y, W, dZ, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs): - return backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs), (X, Y, W, dZ, rows, cols, sender_perm, workspace) +def backward_with_inputs(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs): + return backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs), (X, Y, W, dZ) #rows, cols, sender_perm, workspace) def double_backward(rows, cols, workspace, sender_perm, irrep_dtype, attrs, inputs, derivatives): double_backward_call = jax.ffi.ffi_call("conv_double_backward", @@ -109,23 +109,3 @@ def __call__(self, sender_perm: Optional[jax.numpy.ndarray] = None) -> jax.numpy.ndarray: return self.forward(X, Y, W, rows, cols, sender_perm) -if __name__=="__main__": - X_ir, Y_ir, Z_ir = Irreps("1x2e"), Irreps("1x3e"), Irreps("1x2e") - instructions=[(0, 0, 0, "uvu", True)] - problem = TPProblem(X_ir, Y_ir, Z_ir, - instructions, - shared_weights=False, - internal_weights=False) - - conv = TensorProductConv(problem, deterministic=False, kahan=False) - - node_ct, nonzero_ct = 3, 4 - X = jax.random.uniform(jax.random.PRNGKey(0), (node_ct, X_ir.dim), dtype=jax.numpy.float32) - Y = jax.random.uniform(jax.random.PRNGKey(1), (nonzero_ct, Y_ir.dim), dtype=jax.numpy.float32) - W = jax.random.uniform(jax.random.PRNGKey(2), (nonzero_ct, conv.weight_numel), dtype=jax.numpy.float32) - rows = jnp.array([0, 1, 1, 2], dtype=jnp.int32) - cols = jnp.array([1, 0, 2, 1], dtype=jnp.int32) - Z = conv.forward(X, Y, W, rows, cols) - print("Z:", Z) - - print("COMPLETE!") \ No newline at end of file From 427fdcbeff16351e670e879077e0d9b3c90945ef Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 1 Dec 2025 23:51:53 -0800 Subject: [PATCH 046/116] Ran ruff. --- docs/conf.py | 7 +- openequivariance/openequivariance/__init__.py | 10 +- .../openequivariance/core/ConvolutionBase.py | 1 + .../openequivariance/core/LoopUnrollConv.py | 4 +- .../openequivariance/core/LoopUnrollTP.py | 7 +- .../openequivariance/core/utils.py | 1 + .../impl_jax/TensorProduct.py | 51 +++--- .../impl_jax/TensorProductConv.py | 153 ++++++++++++------ .../impl_jax/extlib/__init__.py | 11 +- .../openequivariance/impl_torch/CUEConv.py | 1 + .../impl_torch/TensorProduct.py | 10 +- .../impl_torch/TensorProductConv.py | 16 +- .../impl_torch/extlib/__init__.py | 9 ++ 13 files changed, 184 insertions(+), 97 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index a360fb6d..57a3f2c4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -34,5 +34,10 @@ sys.path.insert(0, str(Path("..").resolve())) -autodoc_mock_imports = ["torch", "openequivariance.impl_torch.extlib", "jinja2", "numpy"] +autodoc_mock_imports = [ + "torch", + "openequivariance.impl_torch.extlib", + "jinja2", + "numpy", +] autodoc_typehints = "description" diff --git a/openequivariance/openequivariance/__init__.py b/openequivariance/openequivariance/__init__.py index 2097a33b..f7422572 100644 --- a/openequivariance/openequivariance/__init__.py +++ b/openequivariance/openequivariance/__init__.py @@ -17,6 +17,7 @@ from openequivariance.impl_torch.TensorProductConv import ( TensorProductConv, ) +from openequivariance.impl_torch.extlib import torch_ext_so_path from openequivariance.core.utils import torch_to_oeq_dtype __version__ = None @@ -37,20 +38,13 @@ def _check_package_editable(): _editable_install_output_path = Path(__file__).parent.parent.parent / "outputs" -def torch_ext_so_path(): - """ - :returns: Path to a ``.so`` file that must be linked to use OpenEquivariance - from the PyTorch C++ Interface. - """ - return openequivariance.impl_torch.extlib.torch_module.__file__ - - def extension_source_path(): """ :returns: Path to the source code of the C++ extension. """ return str(Path(__file__).parent / "extension") + torch.serialization.add_safe_globals( [ TensorProduct, diff --git a/openequivariance/openequivariance/core/ConvolutionBase.py b/openequivariance/openequivariance/core/ConvolutionBase.py index 5f51469a..f9b33eeb 100644 --- a/openequivariance/openequivariance/core/ConvolutionBase.py +++ b/openequivariance/openequivariance/core/ConvolutionBase.py @@ -562,6 +562,7 @@ def test_correctness_double_backward( if reference_implementation is None: from openequivariance.impl_torch.E3NNConv import E3NNConv + reference_implementation = E3NNConv reference_problem = self.config diff --git a/openequivariance/openequivariance/core/LoopUnrollConv.py b/openequivariance/openequivariance/core/LoopUnrollConv.py index f5aedeb2..0763d69c 100644 --- a/openequivariance/openequivariance/core/LoopUnrollConv.py +++ b/openequivariance/openequivariance/core/LoopUnrollConv.py @@ -10,11 +10,13 @@ from openequivariance.templates.jinja_utils import get_jinja_environment from openequivariance.core.utils import filter_and_analyze_problem + class LoopUnrollConv(ConvolutionBase): def __init__( self, config, - dp, postprocess_kernel, + dp, + postprocess_kernel, *, idx_dtype: type[np.generic] = np.int64, torch_op: bool = False, diff --git a/openequivariance/openequivariance/core/LoopUnrollTP.py b/openequivariance/openequivariance/core/LoopUnrollTP.py index 607c246a..1705a8dd 100644 --- a/openequivariance/openequivariance/core/LoopUnrollTP.py +++ b/openequivariance/openequivariance/core/LoopUnrollTP.py @@ -10,6 +10,7 @@ count_cg_non_zero, ) + class LoopUnrollTP(TensorProductBase): def __init__(self, config, dp, postprocess_kernel, torch_op): super().__init__(config, torch_op=torch_op) @@ -99,14 +100,12 @@ def generate_double_backward_schedule(warps_per_block): "opt_level": 3, "irrep_dtype": dtype_to_enum[self.config.irrep_dtype], "weight_dtype": dtype_to_enum[self.config.weight_dtype], - - # Not relevant, included for compatibility with convolution + # Not relevant, included for compatibility with convolution "workspace_size": 0, "deterministic": 1, - "idx_dtype": 0 + "idx_dtype": 0, } - def calculate_flops_forward(self, batch_size: int) -> dict: if self.is_uvw: return super().calculate_flops_forward(batch_size) diff --git a/openequivariance/openequivariance/core/utils.py b/openequivariance/openequivariance/core/utils.py index 677f86d3..442ef6c7 100644 --- a/openequivariance/openequivariance/core/utils.py +++ b/openequivariance/openequivariance/core/utils.py @@ -10,6 +10,7 @@ import hashlib from openequivariance.impl_torch.extlib import GPUTimer + def sparse_outer_product_work(cg: np.ndarray) -> int: return np.sum(np.max(cg != 0, axis=2)) diff --git a/openequivariance/openequivariance/impl_jax/TensorProduct.py b/openequivariance/openequivariance/impl_jax/TensorProduct.py index e48461ea..f2dd6c38 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProduct.py +++ b/openequivariance/openequivariance/impl_jax/TensorProduct.py @@ -1,54 +1,62 @@ -import numpy as np - import jax - from functools import partial from openequivariance.impl_jax import extlib -import hashlib -from openequivariance.core.e3nn_lite import TPProblem, Irreps +from openequivariance.core.e3nn_lite import TPProblem from openequivariance.core.LoopUnrollTP import LoopUnrollTP from openequivariance.core.utils import hash_attributes -import jax.numpy as jnp -@partial(jax.custom_vjp, nondiff_argnums=(3,4,5)) + +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5)) def forward(X, Y, W, L3_dim, irrep_dtype, attrs): - forward_call = jax.ffi.ffi_call("tp_forward", - jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype)) + forward_call = jax.ffi.ffi_call( + "tp_forward", jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype) + ) return forward_call(X, Y, W, **attrs) + def forward_with_inputs(X, Y, W, L3_dim, irrep_dtype, attrs): return forward(X, Y, W, L3_dim, irrep_dtype, attrs), (X, Y, W) -@partial(jax.custom_vjp, nondiff_argnums=(4,5)) + +@partial(jax.custom_vjp, nondiff_argnums=(4, 5)) def backward(X, Y, W, dZ, irrep_dtype, attrs): - backward_call = jax.ffi.ffi_call("tp_backward", + backward_call = jax.ffi.ffi_call( + "tp_backward", ( jax.ShapeDtypeStruct(X.shape, irrep_dtype), jax.ShapeDtypeStruct(Y.shape, irrep_dtype), jax.ShapeDtypeStruct(W.shape, irrep_dtype), - )) + ), + ) return backward_call(X, Y, W, dZ, **attrs) + def backward_with_inputs(X, Y, W, dZ, irrep_dtype, attrs): return backward(X, Y, W, dZ, irrep_dtype, attrs), (X, Y, W, dZ) + def double_backward(irrep_dtype, attrs, inputs, derivatives): - double_backward_call = jax.ffi.ffi_call("tp_double_backward", + double_backward_call = jax.ffi.ffi_call( + "tp_double_backward", ( jax.ShapeDtypeStruct(inputs[0].shape, irrep_dtype), jax.ShapeDtypeStruct(inputs[1].shape, irrep_dtype), jax.ShapeDtypeStruct(inputs[2].shape, irrep_dtype), jax.ShapeDtypeStruct(inputs[3].shape, irrep_dtype), - )) + ), + ) return double_backward_call(*inputs, *derivatives, **attrs) + def backward_autograd(L3_dim, irrep_dtype, attrs, inputs, dZ): - return backward(inputs[0], inputs[1], inputs[2], dZ, irrep_dtype, attrs) + return backward(inputs[0], inputs[1], inputs[2], dZ, irrep_dtype, attrs) + forward.defvjp(forward_with_inputs, backward_autograd) backward.defvjp(backward_with_inputs, double_backward) + class TensorProduct(LoopUnrollTP): def __init__(self, config: TPProblem): dp = extlib.DeviceProp(0) @@ -59,18 +67,17 @@ def __init__(self, config: TPProblem): "forward_config": vars(self.forward_schedule.launch_config), "backward_config": vars(self.backward_schedule.launch_config), "double_backward_config": vars(self.double_backward_schedule.launch_config), - "kernel_prop": self.kernelProp + "kernel_prop": self.kernelProp, } hash_attributes(self.attrs) - + self.weight_numel = config.weight_numel self.L3_dim = self.config.irreps_out.dim def forward(self, X: jax.ndarray, Y: jax.ndarray, W: jax.ndarray) -> jax.ndarray: return forward(X, Y, W, self.L3_dim, self.config.irrep_dtype, self.attrs) - def __call__(self, - X: jax.numpy.ndarray, - Y: jax.numpy.ndarray, - W: jax.numpy.ndarray) -> jax.numpy.ndarray: - return self.forward(X, Y, W) \ No newline at end of file + def __call__( + self, X: jax.numpy.ndarray, Y: jax.numpy.ndarray, W: jax.numpy.ndarray + ) -> jax.numpy.ndarray: + return self.forward(X, Y, W) diff --git a/openequivariance/openequivariance/impl_jax/TensorProductConv.py b/openequivariance/openequivariance/impl_jax/TensorProductConv.py index 7745e1bf..419c5a1b 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProductConv.py +++ b/openequivariance/openequivariance/impl_jax/TensorProductConv.py @@ -3,62 +3,106 @@ from typing import Optional from openequivariance.impl_jax import extlib -from openequivariance.core.e3nn_lite import TPProblem, Irreps -from openequivariance.core.LoopUnrollConv import LoopUnrollConv +from openequivariance.core.e3nn_lite import TPProblem +from openequivariance.core.LoopUnrollConv import LoopUnrollConv from openequivariance.core.utils import hash_attributes import jax import jax.numpy as jnp from openequivariance.benchmark.logging_utils import getLogger + logger = getLogger() -@partial(jax.custom_vjp, nondiff_argnums=(3,4,5,6,7,8,9)) + +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8, 9)) def forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs): - forward_call = jax.ffi.ffi_call("conv_forward", - jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype)) + forward_call = jax.ffi.ffi_call( + "conv_forward", jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype) + ) return forward_call(X, Y, W, rows, cols, workspace, sender_perm, **attrs) -def forward_with_inputs(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs): - return forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs), (X, Y, W, rows, cols, sender_perm, workspace) -@partial(jax.custom_vjp, nondiff_argnums=(4,5,6,7,8,9)) +def forward_with_inputs( + X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs +): + return forward( + X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs + ), (X, Y, W, rows, cols, sender_perm, workspace) + + +@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9)) def backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs): - backward_call = jax.ffi.ffi_call("conv_backward", - (jax.ShapeDtypeStruct(X.shape, irrep_dtype), - jax.ShapeDtypeStruct(Y.shape, irrep_dtype), - jax.ShapeDtypeStruct(W.shape, irrep_dtype))) + backward_call = jax.ffi.ffi_call( + "conv_backward", + ( + jax.ShapeDtypeStruct(X.shape, irrep_dtype), + jax.ShapeDtypeStruct(Y.shape, irrep_dtype), + jax.ShapeDtypeStruct(W.shape, irrep_dtype), + ), + ) return backward_call(X, Y, W, dZ, rows, cols, workspace, sender_perm, **attrs) -def backward_with_inputs(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs): - return backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs), (X, Y, W, dZ) #rows, cols, sender_perm, workspace) -def double_backward(rows, cols, workspace, sender_perm, irrep_dtype, attrs, inputs, derivatives): - double_backward_call = jax.ffi.ffi_call("conv_double_backward", +def backward_with_inputs( + X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs +): + return backward( + X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs + ), (X, Y, W, dZ) # rows, cols, sender_perm, workspace) + + +def double_backward( + rows, cols, workspace, sender_perm, irrep_dtype, attrs, inputs, derivatives +): + double_backward_call = jax.ffi.ffi_call( + "conv_double_backward", ( jax.ShapeDtypeStruct(inputs[0].shape, irrep_dtype), jax.ShapeDtypeStruct(inputs[1].shape, irrep_dtype), jax.ShapeDtypeStruct(inputs[2].shape, irrep_dtype), jax.ShapeDtypeStruct(inputs[3].shape, irrep_dtype), - )) - return double_backward_call(*inputs, *derivatives, rows, cols, workspace, sender_perm, **attrs) + ), + ) + return double_backward_call( + *inputs, *derivatives, rows, cols, workspace, sender_perm, **attrs + ) + + +def backward_autograd( + rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs, inputs, dZ +): + return backward( + inputs[0], + inputs[1], + inputs[2], + dZ, + rows, + cols, + workspace, + sender_perm, + irrep_dtype, + attrs, + ) -def backward_autograd(rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs, inputs, dZ): - return backward(inputs[0], inputs[1], inputs[2], dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs) forward.defvjp(forward_with_inputs, backward_autograd) backward.defvjp(backward_with_inputs, double_backward) + class TensorProductConv(LoopUnrollConv): - def __init__(self, config: TPProblem, deterministic: bool = False, kahan: bool = False): + def __init__( + self, config: TPProblem, deterministic: bool = False, kahan: bool = False + ): dp = extlib.DeviceProp(0) super().__init__( config, - dp, extlib.postprocess_kernel, - idx_dtype=np.int32, # N.B. this is distinct from the PyTorch version + dp, + extlib.postprocess_kernel, + idx_dtype=np.int32, # N.B. this is distinct from the PyTorch version torch_op=False, deterministic=deterministic, - kahan=kahan + kahan=kahan, ) self.attrs = { @@ -66,46 +110,55 @@ def __init__(self, config: TPProblem, deterministic: bool = False, kahan: bool = "forward_config": vars(self.forward_schedule.launch_config), "backward_config": vars(self.backward_schedule.launch_config), "double_backward_config": vars(self.double_backward_schedule.launch_config), - "kernel_prop": self.kernel_prop + "kernel_prop": self.kernel_prop, } hash_attributes(self.attrs) - + self.weight_numel = config.weight_numel self.L3_dim = self.config.irreps_out.dim self.workspace = jnp.zeros((self.workspace_size,), dtype=jnp.uint8) - logger.info(f"Convolution requires {self.workspace_size // (2 ** 20)}MB of workspace.") + logger.info( + f"Convolution requires {self.workspace_size // (2**20)}MB of workspace." + ) self.dummy_transpose_perm = jnp.zeros((1,), dtype=jnp.int32) def forward( - self, - X: jax.numpy.ndarray, - Y: jax.numpy.ndarray, - W: jax.numpy.ndarray, - rows: jax.numpy.ndarray, - cols: jax.numpy.ndarray, - sender_perm: Optional[jax.numpy.ndarray] = None) -> jax.numpy.ndarray: - + self, + X: jax.numpy.ndarray, + Y: jax.numpy.ndarray, + W: jax.numpy.ndarray, + rows: jax.numpy.ndarray, + cols: jax.numpy.ndarray, + sender_perm: Optional[jax.numpy.ndarray] = None, + ) -> jax.numpy.ndarray: if not self.deterministic: sender_perm = self.dummy_transpose_perm else: - assert sender_perm is not None, "Must provide sender_perm for deterministic convolutions." + assert sender_perm is not None, ( + "Must provide sender_perm for deterministic convolutions." + ) return forward( - X, Y, W, - rows, cols, + X, + Y, + W, + rows, + cols, self.workspace, sender_perm, - self.L3_dim, - self.config.irrep_dtype, - self.attrs) - - def __call__(self, - X: jax.numpy.ndarray, - Y: jax.numpy.ndarray, - W: jax.numpy.ndarray, - rows: jax.numpy.ndarray, - cols: jax.numpy.ndarray, - sender_perm: Optional[jax.numpy.ndarray] = None) -> jax.numpy.ndarray: - return self.forward(X, Y, W, rows, cols, sender_perm) + self.L3_dim, + self.config.irrep_dtype, + self.attrs, + ) + def __call__( + self, + X: jax.numpy.ndarray, + Y: jax.numpy.ndarray, + W: jax.numpy.ndarray, + rows: jax.numpy.ndarray, + cols: jax.numpy.ndarray, + sender_perm: Optional[jax.numpy.ndarray] = None, + ) -> jax.numpy.ndarray: + return self.forward(X, Y, W, rows, cols, sender_perm) diff --git a/openequivariance/openequivariance/impl_jax/extlib/__init__.py b/openequivariance/openequivariance/impl_jax/extlib/__init__.py index 29c9283f..8719e848 100644 --- a/openequivariance/openequivariance/impl_jax/extlib/__init__.py +++ b/openequivariance/openequivariance/impl_jax/extlib/__init__.py @@ -1,13 +1,14 @@ import jax -import hashlib +import openequivariance_extjax as oeq_extjax + def postprocess_kernel(kernel): - ''' + """ Only CUDA for now, so no postprocessing. - ''' + """ return kernel -import openequivariance_extjax as oeq_extjax + for name, target in oeq_extjax.registrations().items(): jax.ffi.register_ffi_target(name, target, platform="CUDA") @@ -17,4 +18,4 @@ def postprocess_kernel(kernel): __all__ = [ "GPUTimer", "DeviceProp", -] \ No newline at end of file +] diff --git a/openequivariance/openequivariance/impl_torch/CUEConv.py b/openequivariance/openequivariance/impl_torch/CUEConv.py index a8877ab8..00e345f2 100644 --- a/openequivariance/openequivariance/impl_torch/CUEConv.py +++ b/openequivariance/openequivariance/impl_torch/CUEConv.py @@ -8,6 +8,7 @@ scatter_add_wrapper, ) + class CUEConv(ConvolutionBase): def __init__(self, config, *, idx_dtype=np.int64, torch_op=True): super().__init__(config, idx_dtype=idx_dtype, torch_op=torch_op) diff --git a/openequivariance/openequivariance/impl_torch/TensorProduct.py b/openequivariance/openequivariance/impl_torch/TensorProduct.py index ab98dc2b..e4858947 100644 --- a/openequivariance/openequivariance/impl_torch/TensorProduct.py +++ b/openequivariance/openequivariance/impl_torch/TensorProduct.py @@ -8,6 +8,7 @@ logger = getLogger() + class TensorProduct(torch.nn.Module, LoopUnrollTP): r""" Drop-in replacement for ``o3.TensorProduct`` from e3nn. Supports forward, @@ -33,7 +34,11 @@ def __init__(self, problem: TPProblem, torch_op=True, use_opaque=False): def _init_class(self): dp = extlib.DeviceProp(0) LoopUnrollTP.__init__( - self, self.input_args["problem"], dp, extlib.postprocess_kernel, self.input_args["torch_op"] + self, + self.input_args["problem"], + dp, + extlib.postprocess_kernel, + self.input_args["torch_op"], ) internal_cls = None @@ -48,7 +53,7 @@ def _init_class(self): vars(self.forward_schedule.launch_config), vars(self.backward_schedule.launch_config), vars(self.double_backward_schedule.launch_config), - self.kernelProp + self.kernelProp, ) logger.info("Kernel compiled!") logger.info(f"Kernel File Size: {len(self.jit_kernel) // 1024} KB") @@ -339,6 +344,7 @@ def register_autocast(cls): def name(): return "LoopUnrollTP" + if extlib.TORCH_COMPILE: TensorProduct.register_torch_fakes() TensorProduct.register_autograd() diff --git a/openequivariance/openequivariance/impl_torch/TensorProductConv.py b/openequivariance/openequivariance/impl_torch/TensorProductConv.py index a10f8178..c7997b9a 100644 --- a/openequivariance/openequivariance/impl_torch/TensorProductConv.py +++ b/openequivariance/openequivariance/impl_torch/TensorProductConv.py @@ -4,7 +4,11 @@ import torch import openequivariance.impl_torch.extlib as extlib -from openequivariance.impl_torch.extlib import JITConvImpl, postprocess_kernel, DeviceProp +from openequivariance.impl_torch.extlib import ( + JITConvImpl, + postprocess_kernel, + DeviceProp, +) from openequivariance.core.ConvolutionBase import ( ConvolutionBase, @@ -14,10 +18,13 @@ from openequivariance.impl_torch.TensorProduct import TensorProduct from openequivariance import TPProblem from openequivariance.core.utils import torch_to_oeq_dtype +from openequivariance.core.dtype_enum import enum_to_torch_dtype from openequivariance.benchmark.logging_utils import getLogger + logger = getLogger() + class TensorProductConv(torch.nn.Module, LoopUnrollConv): r""" Given a **symmetric, directed** graph :math:`G = (V, E)`, inputs :math:`x_1...x_{|V|}`, @@ -62,11 +69,12 @@ def __init__( self._init_class() def _init_class(self): - dp = extlib.DeviceProp(0) + dp = DeviceProp(0) LoopUnrollConv.__init__( self, self.input_args["problem"], - dp, postprocess_kernel, + dp, + postprocess_kernel, idx_dtype=np.int64, torch_op=self.input_args["torch_op"], deterministic=self.input_args["deterministic"], @@ -85,7 +93,7 @@ def _init_class(self): vars(self.forward_schedule.launch_config), vars(self.backward_schedule.launch_config), vars(self.double_backward_schedule.launch_config), - self.kernel_prop + self.kernel_prop, ) logger.info("Kernel compiled!") diff --git a/openequivariance/openequivariance/impl_torch/extlib/__init__.py b/openequivariance/openequivariance/impl_torch/extlib/__init__.py index 3d8fd085..a96b5b1a 100644 --- a/openequivariance/openequivariance/impl_torch/extlib/__init__.py +++ b/openequivariance/openequivariance/impl_torch/extlib/__init__.py @@ -40,6 +40,7 @@ generic_module = None if not build_ext: import openequivariance.impl_torch.extlib.generic_module + generic_module = openequivariance.impl_torch.extlib.generic_module elif TORCH_VERSION_CUDA_OR_HIP: @@ -140,6 +141,14 @@ def _raise_import_error_helper(import_target: str): ) +def torch_ext_so_path(): + """ + :returns: Path to a ``.so`` file that must be linked to use OpenEquivariance + from the PyTorch C++ Interface. + """ + return torch_module.__file__ + + if TORCH_VERSION_CUDA_OR_HIP: from generic_module import ( JITTPImpl, From f16c62237ec045401b36e0d2df2df0f4c2a5a91e Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Tue, 2 Dec 2025 18:55:56 -0800 Subject: [PATCH 047/116] Moved tests back. --- {openequivariance/tests => tests}/batch_test.py | 0 {openequivariance/tests => tests}/benchmark.py | 0 {openequivariance/tests => tests}/conv_test.py | 0 {openequivariance/tests => tests}/examples_test.py | 0 {openequivariance/tests => tests}/export_test.py | 0 {openequivariance/tests => tests}/import_test.py | 0 {openequivariance/tests => tests}/input_validation_test.py | 0 {openequivariance/tests => tests}/mace_driver.py | 0 {openequivariance/tests => tests}/multidevice_test.py | 0 {openequivariance/tests => tests}/stream_test.py | 0 {openequivariance/tests => tests}/torch_determinism_test.py | 0 11 files changed, 0 insertions(+), 0 deletions(-) rename {openequivariance/tests => tests}/batch_test.py (100%) rename {openequivariance/tests => tests}/benchmark.py (100%) rename {openequivariance/tests => tests}/conv_test.py (100%) rename {openequivariance/tests => tests}/examples_test.py (100%) rename {openequivariance/tests => tests}/export_test.py (100%) rename {openequivariance/tests => tests}/import_test.py (100%) rename {openequivariance/tests => tests}/input_validation_test.py (100%) rename {openequivariance/tests => tests}/mace_driver.py (100%) rename {openequivariance/tests => tests}/multidevice_test.py (100%) rename {openequivariance/tests => tests}/stream_test.py (100%) rename {openequivariance/tests => tests}/torch_determinism_test.py (100%) diff --git a/openequivariance/tests/batch_test.py b/tests/batch_test.py similarity index 100% rename from openequivariance/tests/batch_test.py rename to tests/batch_test.py diff --git a/openequivariance/tests/benchmark.py b/tests/benchmark.py similarity index 100% rename from openequivariance/tests/benchmark.py rename to tests/benchmark.py diff --git a/openequivariance/tests/conv_test.py b/tests/conv_test.py similarity index 100% rename from openequivariance/tests/conv_test.py rename to tests/conv_test.py diff --git a/openequivariance/tests/examples_test.py b/tests/examples_test.py similarity index 100% rename from openequivariance/tests/examples_test.py rename to tests/examples_test.py diff --git a/openequivariance/tests/export_test.py b/tests/export_test.py similarity index 100% rename from openequivariance/tests/export_test.py rename to tests/export_test.py diff --git a/openequivariance/tests/import_test.py b/tests/import_test.py similarity index 100% rename from openequivariance/tests/import_test.py rename to tests/import_test.py diff --git a/openequivariance/tests/input_validation_test.py b/tests/input_validation_test.py similarity index 100% rename from openequivariance/tests/input_validation_test.py rename to tests/input_validation_test.py diff --git a/openequivariance/tests/mace_driver.py b/tests/mace_driver.py similarity index 100% rename from openequivariance/tests/mace_driver.py rename to tests/mace_driver.py diff --git a/openequivariance/tests/multidevice_test.py b/tests/multidevice_test.py similarity index 100% rename from openequivariance/tests/multidevice_test.py rename to tests/multidevice_test.py diff --git a/openequivariance/tests/stream_test.py b/tests/stream_test.py similarity index 100% rename from openequivariance/tests/stream_test.py rename to tests/stream_test.py diff --git a/openequivariance/tests/torch_determinism_test.py b/tests/torch_determinism_test.py similarity index 100% rename from openequivariance/tests/torch_determinism_test.py rename to tests/torch_determinism_test.py From fa4265488bf96bfd5774fcd67d0d6ecfa0d8caa9 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Tue, 2 Dec 2025 21:23:39 -0800 Subject: [PATCH 048/116] 1/3 tests is passing. --- .../openequivariance/impl_jax/TensorProduct.py | 17 ++++++++++++++++- .../openequivariance/impl_jax/__init__.py | 4 ++++ tests/batch_test.py | 15 +++++++++++---- tests/conftest.py | 8 ++++++++ 4 files changed, 39 insertions(+), 5 deletions(-) create mode 100644 openequivariance/openequivariance/impl_jax/__init__.py create mode 100644 tests/conftest.py diff --git a/openequivariance/openequivariance/impl_jax/TensorProduct.py b/openequivariance/openequivariance/impl_jax/TensorProduct.py index f2dd6c38..1a6e601f 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProduct.py +++ b/openequivariance/openequivariance/impl_jax/TensorProduct.py @@ -1,4 +1,5 @@ import jax +import numpy as np from functools import partial from openequivariance.impl_jax import extlib from openequivariance.core.e3nn_lite import TPProblem @@ -74,10 +75,24 @@ def __init__(self, config: TPProblem): self.weight_numel = config.weight_numel self.L3_dim = self.config.irreps_out.dim - def forward(self, X: jax.ndarray, Y: jax.ndarray, W: jax.ndarray) -> jax.ndarray: + def forward(self, X: jax.numpy.ndarray, Y: jax.numpy.ndarray, W: jax.numpy.ndarray) -> jax.numpy.ndarray: return forward(X, Y, W, self.L3_dim, self.config.irrep_dtype, self.attrs) def __call__( self, X: jax.numpy.ndarray, Y: jax.numpy.ndarray, W: jax.numpy.ndarray ) -> jax.numpy.ndarray: return self.forward(X, Y, W) + + def forward_cpu( + self, + L1_in: np.ndarray, + L2_in: np.ndarray, + L3_out: np.ndarray, + weights: np.ndarray, + ) -> None: + result = self.forward( + jax.numpy.asarray(L1_in), + jax.numpy.asarray(L2_in), + jax.numpy.asarray(weights), + ) + L3_out[:] = np.asarray(result) \ No newline at end of file diff --git a/openequivariance/openequivariance/impl_jax/__init__.py b/openequivariance/openequivariance/impl_jax/__init__.py new file mode 100644 index 00000000..b2ec0994 --- /dev/null +++ b/openequivariance/openequivariance/impl_jax/__init__.py @@ -0,0 +1,4 @@ +from openequivariance.impl_jax.TensorProduct import TensorProduct as TensorProduct +from openequivariance.impl_jax.TensorProductConv import TensorProductConv as TensorProductConv + +__all__ = ["TensorProduct", "TensorProductConv"] \ No newline at end of file diff --git a/tests/batch_test.py b/tests/batch_test.py index 9cd032a3..73531e45 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -2,8 +2,8 @@ from pytest_check import check import numpy as np +import openequivariance import openequivariance as oeq -from openequivariance.impl_torch.TensorProduct import TensorProduct from openequivariance.benchmark.correctness_utils import ( correctness_forward, correctness_backward, @@ -19,7 +19,6 @@ from itertools import product import torch - class TPCorrectness: def thresh(self, direction): return {"fwd": 1e-5, "bwd": 3e-4, "double_bwd": 3e-4}[direction] @@ -41,8 +40,16 @@ def extra_tp_constructor_args(self): return {} @pytest.fixture(scope="class") - def tp_and_problem(self, problem, extra_tp_constructor_args): - tp = TensorProduct(problem, **extra_tp_constructor_args) + def test_jax(self, request): + return request.config.getoption("--jax") + + @pytest.fixture(scope="class") + def tp_and_problem(self, problem, extra_tp_constructor_args, test_jax): + cls = oeq.TensorProduct + if test_jax: + import openequivariance.impl_jax.TensorProduct as jax_tp + cls = jax_tp + tp = cls(problem, **extra_tp_constructor_args) return tp, problem def test_tp_fwd(self, tp_and_problem): diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..ba3285fb --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,8 @@ +import pytest +import os + +os.environ["JAX_ENABLE_X64"] = "True" +def pytest_addoption(parser): + parser.addoption( + "--jax", action="store", default=False, help="Test the JAX frontend instead of PyTorch" + ) \ No newline at end of file From 7f4ac06d22a1379528517e04fa044dff4cbdd7b5 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Tue, 2 Dec 2025 21:41:56 -0800 Subject: [PATCH 049/116] Backward test is passing. --- .../impl_jax/TensorProduct.py | 29 ++++++++++++++----- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/openequivariance/openequivariance/impl_jax/TensorProduct.py b/openequivariance/openequivariance/impl_jax/TensorProduct.py index 1a6e601f..e49a9793 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProduct.py +++ b/openequivariance/openequivariance/impl_jax/TensorProduct.py @@ -83,16 +83,29 @@ def __call__( ) -> jax.numpy.ndarray: return self.forward(X, Y, W) - def forward_cpu( - self, - L1_in: np.ndarray, - L2_in: np.ndarray, - L3_out: np.ndarray, - weights: np.ndarray, - ) -> None: + def forward_cpu(self, L1_in, L2_in, L3_out, weights) -> None: result = self.forward( jax.numpy.asarray(L1_in), jax.numpy.asarray(L2_in), jax.numpy.asarray(weights), ) - L3_out[:] = np.asarray(result) \ No newline at end of file + L3_out[:] = np.asarray(result) + + def backward_cpu( + self, L1_in, L1_grad, L2_in, L2_grad, L3_grad, weights, weights_grad + ) -> None: + backward_fn = jax.vjp( + lambda X, Y, W: self.forward(X, Y, W), + jax.numpy.asarray(L1_in), + jax.numpy.asarray(L2_in), + jax.numpy.asarray(weights), + )[1] + L1_grad_jax, L2_grad_jax, weights_grad_jax = backward_fn( + jax.numpy.asarray(L3_grad) + ) + L1_grad[:] = np.asarray(L1_grad_jax) + L2_grad[:] = np.asarray(L2_grad_jax) + weights_grad[:] = np.asarray(weights_grad_jax) + + + From 617d99666fee38aeba7cdde3b9900fa536007036 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Tue, 2 Dec 2025 22:45:07 -0800 Subject: [PATCH 050/116] Backward convolution is failing, need to figure out why. --- .../impl_jax/TensorProduct.py | 2 - .../impl_jax/TensorProductConv.py | 49 +++++++++++++++++++ tests/conftest.py | 3 +- tests/conv_test.py | 18 +++++-- 4 files changed, 65 insertions(+), 7 deletions(-) diff --git a/openequivariance/openequivariance/impl_jax/TensorProduct.py b/openequivariance/openequivariance/impl_jax/TensorProduct.py index e49a9793..35e82796 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProduct.py +++ b/openequivariance/openequivariance/impl_jax/TensorProduct.py @@ -107,5 +107,3 @@ def backward_cpu( L2_grad[:] = np.asarray(L2_grad_jax) weights_grad[:] = np.asarray(weights_grad_jax) - - diff --git a/openequivariance/openequivariance/impl_jax/TensorProductConv.py b/openequivariance/openequivariance/impl_jax/TensorProductConv.py index 419c5a1b..855c57b0 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProductConv.py +++ b/openequivariance/openequivariance/impl_jax/TensorProductConv.py @@ -162,3 +162,52 @@ def __call__( sender_perm: Optional[jax.numpy.ndarray] = None, ) -> jax.numpy.ndarray: return self.forward(X, Y, W, rows, cols, sender_perm) + + def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph): + rows = graph.rows.astype(np.int32) + cols = graph.cols.astype(np.int32) + sender_perm = graph.transpose_perm.astype(np.int32) + result = self.forward( + jax.numpy.asarray(L1_in), + jax.numpy.asarray(L2_in), + jax.numpy.asarray(weights), + jax.numpy.asarray(rows), + jax.numpy.asarray(cols), + jax.numpy.asarray(sender_perm), + ) + L3_out[:] = np.asarray(result) + + def backward_cpu( + self, + L1_in, + L1_grad, + L2_in, + L2_grad, + L3_grad, + weights, + weights_grad, + graph, + ): + rows = graph.rows.astype(np.int32) + cols = graph.cols.astype(np.int32) + sender_perm = graph.transpose_perm.astype(np.int32) + + backward_fn = jax.vjp( + lambda X, Y, W: self.forward( + X, + Y, + W, + jax.numpy.asarray(rows), + jax.numpy.asarray(cols), + jax.numpy.asarray(sender_perm), + ), + jax.numpy.asarray(L1_in), + jax.numpy.asarray(L2_in), + jax.numpy.asarray(weights), + )[1] + L1_grad_jax, L2_grad_jax, weights_grad_jax = backward_fn( + jax.numpy.asarray(L3_grad) + ) + L1_grad[:] = np.asarray(L1_grad_jax) + L2_grad[:] = np.asarray(L2_grad_jax) + weights_grad[:] = np.asarray(weights_grad_jax) \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index ba3285fb..d5e9f008 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,7 +2,8 @@ import os os.environ["JAX_ENABLE_X64"] = "True" +os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False" def pytest_addoption(parser): parser.addoption( - "--jax", action="store", default=False, help="Test the JAX frontend instead of PyTorch" + "--jax", action="store_true", default=False, help="Test the JAX frontend instead of PyTorch" ) \ No newline at end of file diff --git a/tests/conv_test.py b/tests/conv_test.py index 14c8c3d2..556609f4 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -4,6 +4,7 @@ from pytest_check import check import numpy as np +import openequivariance import openequivariance as oeq from openequivariance.benchmark.ConvBenchmarkSuite import load_graph from itertools import product @@ -51,15 +52,24 @@ def graph(self, request): def extra_conv_constructor_args(self): return {} + @pytest.fixture(scope="class") + def test_jax(self, request): + return request.config.getoption("--jax") + @pytest.fixture(params=["atomic", "deterministic", "kahan"], scope="class") - def conv_object(self, request, problem, extra_conv_constructor_args): + def conv_object(self, request, problem, extra_conv_constructor_args, test_jax): + cls = oeq.TensorProductConv + if test_jax: + from openequivariance.impl_jax import TensorProductConv as jax_conv + cls = jax_conv + if request.param == "atomic": - return oeq.TensorProductConv( + return cls( problem, deterministic=False, **extra_conv_constructor_args ) elif request.param == "deterministic": if not problem.shared_weights: - return oeq.TensorProductConv( + return cls( problem, deterministic=True, **extra_conv_constructor_args ) else: @@ -67,7 +77,7 @@ def conv_object(self, request, problem, extra_conv_constructor_args): elif request.param == "kahan": if problem.irrep_dtype == np.float32: if not problem.shared_weights: - return oeq.TensorProductConv( + return cls( problem, deterministic=True, kahan=True, From 1b0deb062ceadcca127ba5747a8b24008d0d1106 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 3 Dec 2025 23:08:43 -0800 Subject: [PATCH 051/116] Zerod gradient buffer. --- openequivariance_extjax/src/libjax_tp_jit.cpp | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp index d05ccff8..5c484ade 100644 --- a/openequivariance_extjax/src/libjax_tp_jit.cpp +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -448,6 +448,9 @@ ffi::Error conv_forward_impl( if (k.deterministic){ check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm"); } + else { + zero_buffer(*L3_out); + } if (k.shared_weights) check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); @@ -495,9 +498,13 @@ ffi::Error conv_backward_impl( check_tensor(rows, {nnz}, k.idx_dtype, "rows"); check_tensor(cols, {nnz}, k.idx_dtype, "cols"); - if (k.deterministic) + if (k.deterministic) { check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm"); - + } + else { + zero_buffer(*L1_grad); + } + if (k.shared_weights) { check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); check_tensor(*W_grad, {k.weight_numel}, k.weight_dtype, "W_grad"); @@ -559,8 +566,13 @@ ffi::Error conv_double_backward_impl( check_tensor(rows, {nnz}, k.idx_dtype, "rows"); check_tensor(cols, {nnz}, k.idx_dtype, "cols"); - if (k.deterministic) + if (k.deterministic) { check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm"); + } + else { + zero_buffer(*L1_grad); + zero_buffer(*L3_dgrad); + } if (k.shared_weights) { check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); @@ -571,6 +583,7 @@ ffi::Error conv_double_backward_impl( } if(k.shared_weights) zero_buffer(*W_grad); + jit_kernel->double_backward( data_ptr(L1_in), data_ptr(L2_in), From d924503ce5c3318ae9b59de872e353df99927628 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 7 Dec 2025 12:41:59 -0800 Subject: [PATCH 052/116] Abstracted away reordering. --- .../core/ComputationSchedule.py | 92 +++++++------------ .../impl_torch/TensorProduct.py | 6 +- .../impl_torch/TensorProductConv.py | 5 +- .../openequivariance/impl_torch/utils.py | 52 +++++++++++ 4 files changed, 90 insertions(+), 65 deletions(-) create mode 100644 openequivariance/openequivariance/impl_torch/utils.py diff --git a/openequivariance/openequivariance/core/ComputationSchedule.py b/openequivariance/openequivariance/core/ComputationSchedule.py index e52cb7d2..135a0f25 100644 --- a/openequivariance/openequivariance/core/ComputationSchedule.py +++ b/openequivariance/openequivariance/core/ComputationSchedule.py @@ -619,40 +619,33 @@ def calculate_backward_smem( smem=self.memory_per_warp * warps_per_block, ) - def reorder_weights(self, weights_in, direction, has_batch_dim): + def weight_reordering_info(self, weights_in, has_batch_dim): """ - Reorders weights from the canonical e3nn form to the - form that LoopUnrollTP can ingest. Can also reorder the parameters - of a dense neural network layer that produces the weight matrix. - - If has_batch_dim is true, the first dimension of the input weight matrix - is treated as the batch dimension. + Calculates all shapes, slices, and permutation info to reorder + weights. """ - import torch # TODO-someday: no need to specialize this to PyTorch + batch_dim = weights_in.shape[0] + reorder_specs = [] - weights_out = torch.zeros_like(weights_in) - assert direction in ["forward", "backward"] for i, child_inst in enumerate(self.problem_splitter.new_instructions): parent_start, parent_end = ( child_inst.parent_weights_start, child_inst.parent_weights_end, ) parent_shape = list(child_inst.parent_weights_shape) + parent_range = [slice(parent_start, parent_end)] child_start, child_end, child_shape = ( self.updated_config.weight_range_and_shape_for_instruction(i) ) - - parent_range, child_range = ( - [slice(parent_start, parent_end)], - [slice(child_start, child_end)], - ) + child_range = [slice(child_start, child_end)] + weights_subrange = child_inst.weights_subrange - batch_dim = weights_in.shape[0] + reshape_size = [-1] transpose_perm = None - connection_mode = self.updated_config.instructions[i].connection_mode + if connection_mode == "uvu": transpose_perm = [1, 0] elif connection_mode == "uvw": @@ -662,50 +655,27 @@ def reorder_weights(self, weights_in, direction, has_batch_dim): child_range = [slice(0, batch_dim)] + child_range parent_range = [slice(0, batch_dim)] + parent_range parent_shape = [batch_dim] + parent_shape + child_shape = [batch_dim] + list(child_shape) weights_subrange = [slice(0, batch_dim)] + child_inst.weights_subrange reshape_size = [batch_dim] + reshape_size - transpose_perm = [0] + [i + 1 for i in transpose_perm] - - if direction == "forward": - sliced_weights = weights_in[tuple(parent_range)].reshape(parent_shape)[ - tuple(weights_subrange) - ] - weights_out[tuple(child_range)] = sliced_weights.permute( - transpose_perm - ).reshape(reshape_size) - elif direction == "backward": - transpose_child_shape = [child_shape[i] for i in transpose_perm] - sliced_weights = ( - weights_in[tuple(child_range)] - .reshape(transpose_child_shape) - .permute(transpose_perm) - ) - weights_out[tuple(parent_range)].reshape(parent_shape)[ - tuple(weights_subrange) - ] = sliced_weights.flatten().reshape(child_shape) - - return weights_out - - def reorder_weights_numpy(self, weights_in, direction, has_batch_dim): - import torch - - weights_in = torch.from_numpy(weights_in.copy()) - result = self.reorder_weights(weights_in, direction, has_batch_dim) - return result.detach().cpu().numpy().copy() - - def reorder_weights_from_e3nn(self, weights_in, has_batch_dim): - import torch - - if isinstance(weights_in, np.ndarray): - return self.reorder_weights_numpy(weights_in, "forward", has_batch_dim) - elif isinstance(weights_in, torch.Tensor): - return self.reorder_weights(weights_in, "forward", has_batch_dim) - - def reorder_weights_to_e3nn(self, weights_in, has_batch_dim): - import torch - - if isinstance(weights_in, np.ndarray): - return self.reorder_weights_numpy(weights_in, "backward", has_batch_dim) - elif isinstance(weights_in, torch.Tensor): - return self.reorder_weights(weights_in, "backward", has_batch_dim) + + if transpose_perm is not None: + transpose_perm = [0] + [k + 1 for k in transpose_perm] + + transpose_child_shape = None + if transpose_perm is not None: + transpose_child_shape = [child_shape[k] for k in transpose_perm] + + reorder_specs.append({ + "parent_range": tuple(parent_range), + "parent_shape": parent_shape, + "weights_subrange": tuple(weights_subrange), + "child_range": tuple(child_range), + "child_shape": child_shape, + "transpose_perm": transpose_perm, + "reshape_size": reshape_size, + "transpose_child_shape": transpose_child_shape, + }) + + return reorder_specs diff --git a/openequivariance/openequivariance/impl_torch/TensorProduct.py b/openequivariance/openequivariance/impl_torch/TensorProduct.py index e4858947..c0f3a19b 100644 --- a/openequivariance/openequivariance/impl_torch/TensorProduct.py +++ b/openequivariance/openequivariance/impl_torch/TensorProduct.py @@ -5,6 +5,8 @@ import typing from openequivariance.core.utils import torch_to_oeq_dtype from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.impl_torch.utils import reorder_torch + logger = getLogger() @@ -90,10 +92,10 @@ def __setstate__(self, state): self._init_class() def reorder_weights_from_e3nn(self, weights, has_batch_dim=True): - return self.forward_schedule.reorder_weights_from_e3nn(weights, has_batch_dim) + return reorder_torch(self.forward_schedule, weights, "forward", not self.config.shared_weights) def reorder_weights_to_e3nn(self, weights, has_batch_dim=True): - return self.forward_schedule.reorder_weights_to_e3nn(weights, has_batch_dim) + return reorder_torch(self.forward_schedule, weights, "backward", not self.config.shared_weights) def forward( self, x: torch.Tensor, y: torch.Tensor, W: torch.Tensor diff --git a/openequivariance/openequivariance/impl_torch/TensorProductConv.py b/openequivariance/openequivariance/impl_torch/TensorProductConv.py index c7997b9a..44e1775f 100644 --- a/openequivariance/openequivariance/impl_torch/TensorProductConv.py +++ b/openequivariance/openequivariance/impl_torch/TensorProductConv.py @@ -19,6 +19,7 @@ from openequivariance import TPProblem from openequivariance.core.utils import torch_to_oeq_dtype from openequivariance.core.dtype_enum import enum_to_torch_dtype +from openequivariance.impl_torch.utils import reorder_torch from openequivariance.benchmark.logging_utils import getLogger @@ -418,10 +419,10 @@ def double_backward(ctx, grad_output): ) def reorder_weights_from_e3nn(self, weights, has_batch_dim=True): - return self.forward_schedule.reorder_weights_from_e3nn(weights, has_batch_dim) + return reorder_torch(self.forward_schedule, weights, "forward", not self.config.shared_weights) def reorder_weights_to_e3nn(self, weights, has_batch_dim=True): - return self.forward_schedule.reorder_weights_to_e3nn(weights, has_batch_dim) + return reorder_torch(self.forward_schedule, weights, "backward", not self.config.shared_weights) @staticmethod def name(): diff --git a/openequivariance/openequivariance/impl_torch/utils.py b/openequivariance/openequivariance/impl_torch/utils.py new file mode 100644 index 00000000..8911b27c --- /dev/null +++ b/openequivariance/openequivariance/impl_torch/utils.py @@ -0,0 +1,52 @@ +import torch + +def reorder_helper(schedule, weights_in, direction, has_batch_dim): + assert direction in ["forward", "backward"] + + specs = schedule.weight_reordering_info(weights_in, has_batch_dim) + weights_out = torch.zeros_like(weights_in) + + for spec in specs: + parent_range = spec["parent_range"] + parent_shape = spec["parent_shape"] + weights_subrange = spec["weights_subrange"] + child_range = spec["child_range"] + transpose_perm = spec["transpose_perm"] + + if direction == "forward": + reshape_size = spec["reshape_size"] + + sliced_weights = weights_in[parent_range].reshape(parent_shape)[ + weights_subrange + ] + + weights_out[child_range] = sliced_weights.permute( + transpose_perm + ).reshape(reshape_size) + + elif direction == "backward": + transpose_child_shape = spec["transpose_child_shape"] + child_shape = spec["child_shape"] + + sliced_weights = ( + weights_in[child_range] + .reshape(transpose_child_shape) + .permute(transpose_perm) + ) + + weights_out[parent_range].reshape(parent_shape)[ + weights_subrange + ] = sliced_weights.flatten().reshape(child_shape) + + return weights_out + +def reorder_numpy_helper(schedule, weights_in, direction, has_batch_dim): + weights_in = torch.from_numpy(weights_in.copy()) + result = reorder_helper(schedule, weights_in, direction, has_batch_dim) + return result.detach().cpu().numpy().copy() + +def reorder_torch(schedule, weights_in, direction, has_batch_dim): + if isinstance(weights_in, torch.Tensor): + return reorder_helper(schedule, weights_in, direction, has_batch_dim) + else: + return reorder_numpy_helper(schedule, weights_in, direction, has_batch_dim) From 6452140f05198a018302524aef130c71a1525774 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 7 Dec 2025 13:05:38 -0800 Subject: [PATCH 053/116] Added JAX reordering function. --- .../impl_jax/TensorProduct.py | 13 +++-- .../impl_jax/TensorProductConv.py | 7 +++ .../openequivariance/impl_jax/utils.py | 56 +++++++++++++++++++ 3 files changed, 72 insertions(+), 4 deletions(-) create mode 100644 openequivariance/openequivariance/impl_jax/utils.py diff --git a/openequivariance/openequivariance/impl_jax/TensorProduct.py b/openequivariance/openequivariance/impl_jax/TensorProduct.py index 35e82796..92d4fc07 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProduct.py +++ b/openequivariance/openequivariance/impl_jax/TensorProduct.py @@ -5,7 +5,7 @@ from openequivariance.core.e3nn_lite import TPProblem from openequivariance.core.LoopUnrollTP import LoopUnrollTP from openequivariance.core.utils import hash_attributes - +from openequivariance.impl_jax.utils import reorder_jax @partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5)) def forward(X, Y, W, L3_dim, irrep_dtype, attrs): @@ -82,7 +82,13 @@ def __call__( self, X: jax.numpy.ndarray, Y: jax.numpy.ndarray, W: jax.numpy.ndarray ) -> jax.numpy.ndarray: return self.forward(X, Y, W) - + + def reorder_weights_from_e3nn(self, weights, has_batch_dim=True): + return reorder_jax(self.forward_schedule, weights, "forward", not self.config.shared_weights) + + def reorder_weights_to_e3nn(self, weights, has_batch_dim=True): + return reorder_jax(self.forward_schedule, weights, "backward", not self.config.shared_weights) + def forward_cpu(self, L1_in, L2_in, L3_out, weights) -> None: result = self.forward( jax.numpy.asarray(L1_in), @@ -105,5 +111,4 @@ def backward_cpu( ) L1_grad[:] = np.asarray(L1_grad_jax) L2_grad[:] = np.asarray(L2_grad_jax) - weights_grad[:] = np.asarray(weights_grad_jax) - + weights_grad[:] = np.asarray(weights_grad_jax) \ No newline at end of file diff --git a/openequivariance/openequivariance/impl_jax/TensorProductConv.py b/openequivariance/openequivariance/impl_jax/TensorProductConv.py index 855c57b0..62e3385a 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProductConv.py +++ b/openequivariance/openequivariance/impl_jax/TensorProductConv.py @@ -6,6 +6,7 @@ from openequivariance.core.e3nn_lite import TPProblem from openequivariance.core.LoopUnrollConv import LoopUnrollConv from openequivariance.core.utils import hash_attributes +from openequivariance.impl_jax.utils import reorder_jax import jax import jax.numpy as jnp @@ -163,6 +164,12 @@ def __call__( ) -> jax.numpy.ndarray: return self.forward(X, Y, W, rows, cols, sender_perm) + def reorder_weights_from_e3nn(self, weights, has_batch_dim=True): + return reorder_jax(self.forward_schedule, weights, "forward", not self.config.shared_weights) + + def reorder_weights_to_e3nn(self, weights, has_batch_dim=True): + return reorder_jax(self.forward_schedule, weights, "backward", not self.config.shared_weights) + def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph): rows = graph.rows.astype(np.int32) cols = graph.cols.astype(np.int32) diff --git a/openequivariance/openequivariance/impl_jax/utils.py b/openequivariance/openequivariance/impl_jax/utils.py new file mode 100644 index 00000000..14cc8394 --- /dev/null +++ b/openequivariance/openequivariance/impl_jax/utils.py @@ -0,0 +1,56 @@ +import jax +import jax.numpy as jnp +import numpy as np + +def reorder_jax_helper(schedule, weights_in, direction, has_batch_dim): + assert direction in ["forward", "backward"] + + specs = schedule.weight_reordering_info(weights_in, has_batch_dim) + weights_out = jnp.zeros_like(weights_in) + + for spec in specs: + parent_range = spec["parent_range"] + parent_shape = spec["parent_shape"] + weights_subrange = spec["weights_subrange"] + child_range = spec["child_range"] + transpose_perm = spec["transpose_perm"] + + if direction == "forward": + reshape_size = spec["reshape_size"] + + sliced_weights = weights_in[parent_range].reshape(parent_shape)[ + weights_subrange + ] + + value_to_assign = sliced_weights.transpose(transpose_perm).reshape(reshape_size) + weights_out = weights_out.at[child_range].set(value_to_assign) + + elif direction == "backward": + transpose_child_shape = spec["transpose_child_shape"] + child_shape = spec["child_shape"] + + sliced_weights = ( + weights_in[child_range] + .reshape(transpose_child_shape) + .transpose(transpose_perm) + ) + + value_to_insert = sliced_weights.flatten().reshape(child_shape) + + slab = weights_out[parent_range] + slab_reshaped = slab.reshape(parent_shape) + slab_reshaped = slab_reshaped.at[weights_subrange].set(value_to_insert) + weights_out = weights_out.at[parent_range].set(slab_reshaped.reshape(slab.shape)) + + return weights_out + +def reorder_numpy_jax_helper(schedule, weights_in, direction, has_batch_dim): + weights_in_jax = jnp.array(weights_in) + result = reorder_jax_helper(schedule, weights_in_jax, direction, has_batch_dim) + return np.array(result) + +def reorder_jax(schedule, weights_in, direction, has_batch_dim): + if isinstance(weights_in, (jnp.ndarray, jax.Array)): + return reorder_jax_helper(schedule, weights_in, direction, has_batch_dim) + else: + return reorder_numpy_jax_helper(schedule, weights_in, direction, has_batch_dim) \ No newline at end of file From 64c5c563ceeac3cb1ef5265ac2fc7dbfc979f39f Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 7 Dec 2025 13:13:20 -0800 Subject: [PATCH 054/116] Reordering starting to work... --- .../openequivariance/impl_jax/TensorProductConv.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/openequivariance/openequivariance/impl_jax/TensorProductConv.py b/openequivariance/openequivariance/impl_jax/TensorProductConv.py index 62e3385a..c1640548 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProductConv.py +++ b/openequivariance/openequivariance/impl_jax/TensorProductConv.py @@ -174,6 +174,7 @@ def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph): rows = graph.rows.astype(np.int32) cols = graph.cols.astype(np.int32) sender_perm = graph.transpose_perm.astype(np.int32) + weights = self.reorder_weights_from_e3nn(weights, has_batch_dim=not self.config.shared_weights) result = self.forward( jax.numpy.asarray(L1_in), jax.numpy.asarray(L2_in), @@ -198,6 +199,7 @@ def backward_cpu( rows = graph.rows.astype(np.int32) cols = graph.cols.astype(np.int32) sender_perm = graph.transpose_perm.astype(np.int32) + weights = self.reorder_weights_from_e3nn(weights, has_batch_dim=not self.config.shared_weights) backward_fn = jax.vjp( lambda X, Y, W: self.forward( @@ -217,4 +219,5 @@ def backward_cpu( ) L1_grad[:] = np.asarray(L1_grad_jax) L2_grad[:] = np.asarray(L2_grad_jax) - weights_grad[:] = np.asarray(weights_grad_jax) \ No newline at end of file + weights_grad[:] = np.asarray(weights_grad_jax) + weights_grad[:] = self.reorder_weights_to_e3nn(weights_grad, has_batch_dim=not self.config.shared_weights) \ No newline at end of file From d815424cdfc3f8360cb2e78394b208927adcd68f Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 7 Dec 2025 13:59:38 -0800 Subject: [PATCH 055/116] Forward and backward are working. --- tests/conv_test.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/conv_test.py b/tests/conv_test.py index 556609f4..a325c8d2 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -255,7 +255,9 @@ def conv_object(self, request, problem): class TestTorchbindDisable(TestProductionModels): @pytest.fixture(scope="class") - def extra_conv_constructor_args(self): + def extra_conv_constructor_args(self, test_jax): + if test_jax: + pytest.skip("N/A for JAX") return {"use_opaque": True} @@ -263,7 +265,10 @@ class TestTorchTo(ConvCorrectness): problems = [mace_problems()[0]] @pytest.fixture(params=problems, ids=lambda x: x.label, scope="class") - def problem(self, request, dtype): + def problem(self, request, dtype, test_jax): + if test_jax: + pytest.skip("N/A for JAX") + problem = request.param problem.irrep_dtype, problem.weight_dtype = dtype, dtype return problem From c3f83ea2c9ca842a7fb720360468524b385a6a40 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 7 Dec 2025 14:31:09 -0800 Subject: [PATCH 056/116] Batch test is working. --- .../impl_jax/TensorProduct.py | 5 +++- tests/batch_test.py | 23 +++++++++++-------- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/openequivariance/openequivariance/impl_jax/TensorProduct.py b/openequivariance/openequivariance/impl_jax/TensorProduct.py index 92d4fc07..e3a03505 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProduct.py +++ b/openequivariance/openequivariance/impl_jax/TensorProduct.py @@ -90,6 +90,7 @@ def reorder_weights_to_e3nn(self, weights, has_batch_dim=True): return reorder_jax(self.forward_schedule, weights, "backward", not self.config.shared_weights) def forward_cpu(self, L1_in, L2_in, L3_out, weights) -> None: + weights = self.reorder_weights_from_e3nn(weights, has_batch_dim=not self.config.shared_weights) result = self.forward( jax.numpy.asarray(L1_in), jax.numpy.asarray(L2_in), @@ -100,6 +101,7 @@ def forward_cpu(self, L1_in, L2_in, L3_out, weights) -> None: def backward_cpu( self, L1_in, L1_grad, L2_in, L2_grad, L3_grad, weights, weights_grad ) -> None: + weights = self.reorder_weights_from_e3nn(weights, has_batch_dim=not self.config.shared_weights) backward_fn = jax.vjp( lambda X, Y, W: self.forward(X, Y, W), jax.numpy.asarray(L1_in), @@ -111,4 +113,5 @@ def backward_cpu( ) L1_grad[:] = np.asarray(L1_grad_jax) L2_grad[:] = np.asarray(L2_grad_jax) - weights_grad[:] = np.asarray(weights_grad_jax) \ No newline at end of file + weights_grad[:] = np.asarray(weights_grad_jax) + weights_grad[:] = self.reorder_weights_to_e3nn(weights_grad, has_batch_dim=not self.config.shared_weights) \ No newline at end of file diff --git a/tests/batch_test.py b/tests/batch_test.py index 73531e45..6f61fe6f 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -254,7 +254,9 @@ def problem(self, request, dtype): class TestTorchbindDisable(TestProductionModels): @pytest.fixture(scope="class") - def extra_tp_constructor_args(self): + def extra_tp_constructor_args(self, test_jax): + if test_jax: + pytest.skip("N/A for JAX") return {"use_opaque": True} @@ -268,11 +270,14 @@ def problem(self, request, dtype): return problem @pytest.fixture(scope="class") - def tp_and_problem(self, problem, extra_tp_constructor_args): - tp = TensorProduct(problem, **extra_tp_constructor_args) - switch_map = { - np.float32: torch.float64, - np.float64: torch.float32, - } - tp.to(switch_map[problem.irrep_dtype]) - return tp, tp.config + def tp_and_problem(self, problem, extra_tp_constructor_args, test_jax): + if test_jax: + pytest.skip("N/A for JAX") + else: + tp = oeq.TensorProduct(problem, **extra_tp_constructor_args) + switch_map = { + np.float32: torch.float64, + np.float64: torch.float32, + } + tp.to(switch_map[problem.irrep_dtype]) + return tp, tp.config From 58b7957ec1e5ce3cc0dfba51508e51f94c915fb9 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 7 Dec 2025 18:54:15 -0800 Subject: [PATCH 057/116] Ready to modify the double backward correctness function. --- .../benchmark/random_buffer_utils.py | 13 ++++--- .../impl_torch/NPDoubleBackwardMixin.py | 36 +++++++++++++++++++ .../impl_torch/TensorProduct.py | 7 ++-- 3 files changed, 49 insertions(+), 7 deletions(-) create mode 100644 openequivariance/openequivariance/impl_torch/NPDoubleBackwardMixin.py diff --git a/openequivariance/openequivariance/benchmark/random_buffer_utils.py b/openequivariance/openequivariance/benchmark/random_buffer_utils.py index 20e9ac72..bdd85750 100644 --- a/openequivariance/openequivariance/benchmark/random_buffer_utils.py +++ b/openequivariance/openequivariance/benchmark/random_buffer_utils.py @@ -104,10 +104,13 @@ def get_random_buffers_double_backward( ) weights = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype) - weights_grad = np.zeros_like(weights) - in1_grad = np.zeros_like(in1) - in2_grad = np.zeros_like(in2) - out_double_grad = np.zeros_like(out_grad) + weights_grad = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype) + in1_grad = np.array( + rng.uniform(size=(batch_size, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype) + in2_grad = np.array( + rng.uniform(size=(batch_size, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype) + out_double_grad = np.array( + rng.uniform(size=(batch_size, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype) return ( in1, @@ -176,3 +179,5 @@ def get_random_buffers_backward_conv( in2_grad = np.zeros_like(in2) return in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad + + diff --git a/openequivariance/openequivariance/impl_torch/NPDoubleBackwardMixin.py b/openequivariance/openequivariance/impl_torch/NPDoubleBackwardMixin.py new file mode 100644 index 00000000..53a12dae --- /dev/null +++ b/openequivariance/openequivariance/impl_torch/NPDoubleBackwardMixin.py @@ -0,0 +1,36 @@ +import torch + +class NumpyDoubleBackwardMixin: + ''' + Adds a Numpy double backward method to any TensorProduct + with the forward pass defined in PyTorch and the relevant + derivatives registered. + ''' + def double_backward_cpu(self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad): + assert self.torch_op + + in1_torch = torch.tensor(in1).to('cuda').requires_grad_(True) + in2_torch = torch.tensor(in2).to('cuda').requires_grad_(True) + weights_torch = torch.tensor(weights).to('cuda').requires_grad_(True) + out_grad_torch = torch.tensor(out_grad).to('cuda').requires_grad_(True) + in1_dgrad_torch = torch.tensor(in1_dgrad).to('cuda') + in2_dgrad_torch = torch.tensor(in2_dgrad).to('cuda') + weights_dgrad_torch = torch.tensor(weights_dgrad).to('cuda') + out_torch = self.forward(in1_torch, in2_torch, weights_torch) + + in1_grad, in2_grad, weights_grad = torch.autograd.grad( + outputs=out_torch, + inputs=[in1_torch, in2_torch, weights_torch], + grad_outputs=out_grad_torch, + create_graph=True, + retain_graph=True + ) + + a, b, c, d = torch.autograd.grad( + outputs=[in1_grad, in2_grad, weights_grad], + inputs=[in1_torch, in2_torch, weights_torch, out_grad_torch], + grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch] + ) + + return a.detach().cpu().numpy(), b.detach().cpu().numpy(), c.detach().cpu().numpy(), d.detach().cpu().numpy() + diff --git a/openequivariance/openequivariance/impl_torch/TensorProduct.py b/openequivariance/openequivariance/impl_torch/TensorProduct.py index c0f3a19b..173d984e 100644 --- a/openequivariance/openequivariance/impl_torch/TensorProduct.py +++ b/openequivariance/openequivariance/impl_torch/TensorProduct.py @@ -6,12 +6,13 @@ from openequivariance.core.utils import torch_to_oeq_dtype from openequivariance.benchmark.logging_utils import getLogger from openequivariance.impl_torch.utils import reorder_torch - +from openequivariance.impl_torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin +from openequivariance.core.e3nn_lite import Irreps logger = getLogger() -class TensorProduct(torch.nn.Module, LoopUnrollTP): +class TensorProduct(torch.nn.Module, LoopUnrollTP, NumpyDoubleBackwardMixin): r""" Drop-in replacement for ``o3.TensorProduct`` from e3nn. Supports forward, backward, and double-backward passes using JIT-compiled kernels. Initialization @@ -347,7 +348,7 @@ def name(): return "LoopUnrollTP" -if extlib.TORCH_COMPILE: +if extlib.TORCH_COMPILE and __name__ != "__main__": TensorProduct.register_torch_fakes() TensorProduct.register_autograd() TensorProduct.register_autocast() From 4dc31dc853e217ccefc8be67dad9a19ebb0b271b Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 7 Dec 2025 20:07:50 -0800 Subject: [PATCH 058/116] Correctness double backward works for existing code, need to extend to JAX. --- .../benchmark/correctness_utils.py | 60 ++++--------------- .../impl_torch/E3NNTensorProduct.py | 3 +- 2 files changed, 13 insertions(+), 50 deletions(-) diff --git a/openequivariance/openequivariance/benchmark/correctness_utils.py b/openequivariance/openequivariance/benchmark/correctness_utils.py index 5a3ad87c..daa99c92 100644 --- a/openequivariance/openequivariance/benchmark/correctness_utils.py +++ b/openequivariance/openequivariance/benchmark/correctness_utils.py @@ -6,7 +6,8 @@ from openequivariance.benchmark.random_buffer_utils import ( get_random_buffers_forward, get_random_buffers_backward, -) + get_random_buffers_double_backward) + from openequivariance.benchmark.logging_utils import getLogger, bcolors import numpy as np import numpy.linalg as la @@ -194,68 +195,29 @@ def correctness_double_backward( global torch import torch - in1, in2, out_grad, weights, _, _, _ = get_random_buffers_backward( - problem, batch_size, prng_seed - ) - rng = np.random.default_rng(seed=prng_seed * 2) - dummy_grad = rng.standard_normal(1)[0] + in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, _ = \ + get_random_buffers_double_backward(problem, batch_size=batch_size, prng_seed=prng_seed) if reference_implementation is None: from openequivariance.impl_torch.E3NNTensorProduct import E3NNTensorProduct - reference_implementation = E3NNTensorProduct result = {"thresh": correctness_threshold, "batch_size": batch_size} tensors = [] - for i, impl in enumerate([test_implementation, reference_implementation]): + for _, impl in enumerate([test_implementation, reference_implementation]): tp = instantiate_implementation(impl, problem) if impl == CUETensorProduct and problem.shared_weights: weights = weights[np.newaxis, :] - weights_reordered = tp.reorder_weights_from_e3nn( - weights, not tp.config.shared_weights - ) - - in1_torch = torch.tensor(in1, device="cuda", requires_grad=True) - in2_torch = torch.tensor(in2, device="cuda", requires_grad=True) - weights_torch = torch.tensor( - weights_reordered, device="cuda", requires_grad=True - ) - - out_torch = tp.forward(in1_torch, in2_torch, weights_torch) - out_grad = out_torch.clone().detach().to(device="cuda").requires_grad_(True) - - in1_grad, in2_grad, w_grad = torch.autograd.grad( - outputs=[out_torch], - inputs=[in1_torch, in2_torch, weights_torch], - grad_outputs=[out_grad], - create_graph=True, - ) - - dummy = torch.norm(in1_grad) + torch.norm(in2_grad) + torch.norm(w_grad) - dummy_grad = torch.tensor(float(dummy_grad), device="cuda", requires_grad=True) - - dummy.backward( - dummy_grad, - retain_graph=True, - inputs=[out_grad, in1_torch, in2_torch, weights_torch], - ) - - weights_grad = weights_torch.grad.detach().cpu().numpy() - weights_grad = tp.reorder_weights_to_e3nn( - weights_grad, not tp.config.shared_weights - ) - + in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu(in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad) tensors.append( - ( - out_grad.grad.detach().cpu().numpy(), - in1_torch.grad.detach().cpu().numpy(), - in2_torch.grad.detach().cpu().numpy(), - weights_grad, - ) - ) + ( out_dgrad, + in1_grad, + in2_grad, + weights_grad + )) for name, to_check, ground_truth in [ ("output_double_grad", tensors[0][0], tensors[1][0]), diff --git a/openequivariance/openequivariance/impl_torch/E3NNTensorProduct.py b/openequivariance/openequivariance/impl_torch/E3NNTensorProduct.py index 32196235..c0416d67 100644 --- a/openequivariance/openequivariance/impl_torch/E3NNTensorProduct.py +++ b/openequivariance/openequivariance/impl_torch/E3NNTensorProduct.py @@ -12,13 +12,14 @@ from openequivariance.core.TensorProductBase import TensorProductBase from openequivariance.core.e3nn_lite import TPProblem from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.impl_torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin TORCH_COMPILE_AUTOTUNING_DIR = pathlib.Path("triton_autotuning") logger = getLogger() -class E3NNTensorProduct(TensorProductBase): +class E3NNTensorProduct(TensorProductBase, NumpyDoubleBackwardMixin): def __init__(self, config: TPProblem, torch_op=True): super().__init__(config, torch_op=torch_op) assert self.torch_op From 79522a1230e4067c67317b68ac936ffa0441fde8 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 7 Dec 2025 20:32:41 -0800 Subject: [PATCH 059/116] Wrote double backward function for JAX. --- .../benchmark/correctness_utils.py | 7 ++++--- .../impl_jax/TensorProduct.py | 19 ++++++++++++++++++- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/openequivariance/openequivariance/benchmark/correctness_utils.py b/openequivariance/openequivariance/benchmark/correctness_utils.py index daa99c92..ce1bbf43 100644 --- a/openequivariance/openequivariance/benchmark/correctness_utils.py +++ b/openequivariance/openequivariance/benchmark/correctness_utils.py @@ -207,16 +207,17 @@ def correctness_double_backward( tensors = [] for _, impl in enumerate([test_implementation, reference_implementation]): tp = instantiate_implementation(impl, problem) + weights_reordered = tp.reorder_weights_from_e3nn(weights, has_batch_dim=not problem.shared_weights) if impl == CUETensorProduct and problem.shared_weights: - weights = weights[np.newaxis, :] + weights_reordered = weights_reordered[np.newaxis, :] - in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu(in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad) + in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu(in1, in2, out_grad, weights_reordered, weights_dgrad, in1_dgrad, in2_dgrad) tensors.append( ( out_dgrad, in1_grad, in2_grad, - weights_grad + tp.reorder_weights_to_e3nn(weights_grad, has_batch_dim=not problem.shared_weights) )) for name, to_check, ground_truth in [ diff --git a/openequivariance/openequivariance/impl_jax/TensorProduct.py b/openequivariance/openequivariance/impl_jax/TensorProduct.py index e3a03505..20f6a518 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProduct.py +++ b/openequivariance/openequivariance/impl_jax/TensorProduct.py @@ -114,4 +114,21 @@ def backward_cpu( L1_grad[:] = np.asarray(L1_grad_jax) L2_grad[:] = np.asarray(L2_grad_jax) weights_grad[:] = np.asarray(weights_grad_jax) - weights_grad[:] = self.reorder_weights_to_e3nn(weights_grad, has_batch_dim=not self.config.shared_weights) \ No newline at end of file + weights_grad[:] = self.reorder_weights_to_e3nn(weights_grad, has_batch_dim=not self.config.shared_weights) + + + def double_backward_cpu(self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad): + in1_jax = jax.numpy.asarray(in1) + in2_jax = jax.numpy.asarray(in2) + weights_jax = jax.numpy.asarray(weights) + out_grad_jax = jax.numpy.asarray(out_grad) + in1_dgrad_jax = jax.numpy.asarray(in1_dgrad) + in2_dgrad_jax = jax.numpy.asarray(in2_dgrad) + weights_dgrad_jax = jax.numpy.asarray(weights_dgrad) + + in1_grad, in2_grad, weights_grad, out_dgrad = jax.vjp( + lambda x, y, w: jax.vjp(lambda a, b, c: self.forward(a, b, c), x, y, w)[1](out_grad_jax), + in1_jax, in2_jax, weights_jax + )[1]((in1_dgrad_jax, in2_dgrad_jax, weights_dgrad_jax)) + + return in1_grad, in2_grad, weights_grad, out_dgrad \ No newline at end of file From 71ca862e4093e32e8d85f5699e5f2a76fdf51c51 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 7 Dec 2025 21:59:51 -0800 Subject: [PATCH 060/116] All double backward tests passing. --- .../openequivariance/benchmark/correctness_utils.py | 3 ++- openequivariance/openequivariance/impl_jax/TensorProduct.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/openequivariance/openequivariance/benchmark/correctness_utils.py b/openequivariance/openequivariance/benchmark/correctness_utils.py index ce1bbf43..01931c99 100644 --- a/openequivariance/openequivariance/benchmark/correctness_utils.py +++ b/openequivariance/openequivariance/benchmark/correctness_utils.py @@ -208,11 +208,12 @@ def correctness_double_backward( for _, impl in enumerate([test_implementation, reference_implementation]): tp = instantiate_implementation(impl, problem) weights_reordered = tp.reorder_weights_from_e3nn(weights, has_batch_dim=not problem.shared_weights) + weights_dgrad_reordered = tp.reorder_weights_from_e3nn(weights_dgrad, has_batch_dim=not problem.shared_weights) if impl == CUETensorProduct and problem.shared_weights: weights_reordered = weights_reordered[np.newaxis, :] - in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu(in1, in2, out_grad, weights_reordered, weights_dgrad, in1_dgrad, in2_dgrad) + in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu(in1, in2, out_grad, weights_reordered, weights_dgrad_reordered, in1_dgrad, in2_dgrad) tensors.append( ( out_dgrad, in1_grad, diff --git a/openequivariance/openequivariance/impl_jax/TensorProduct.py b/openequivariance/openequivariance/impl_jax/TensorProduct.py index 20f6a518..5d3df8e6 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProduct.py +++ b/openequivariance/openequivariance/impl_jax/TensorProduct.py @@ -127,8 +127,8 @@ def double_backward_cpu(self, in1, in2, out_grad, weights, weights_dgrad, in1_dg weights_dgrad_jax = jax.numpy.asarray(weights_dgrad) in1_grad, in2_grad, weights_grad, out_dgrad = jax.vjp( - lambda x, y, w: jax.vjp(lambda a, b, c: self.forward(a, b, c), x, y, w)[1](out_grad_jax), - in1_jax, in2_jax, weights_jax + lambda x, y, w, o: jax.vjp(lambda a, b, c: self.forward(a, b, c), x, y, w)[1](o), + in1_jax, in2_jax, weights_jax, out_grad_jax )[1]((in1_dgrad_jax, in2_dgrad_jax, weights_dgrad_jax)) return in1_grad, in2_grad, weights_grad, out_dgrad \ No newline at end of file From e140c072f51bfc1d07abcf550fdb6fcc60c1ae7e Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 7 Dec 2025 23:15:30 -0800 Subject: [PATCH 061/116] Added the mixins. --- .../openequivariance/impl_torch/E3NNConv.py | 4 +- .../impl_torch/NPDoubleBackwardMixin.py | 39 +++++++++++++++++++ .../impl_torch/TensorProduct.py | 3 +- .../impl_torch/TensorProductConv.py | 5 ++- 4 files changed, 45 insertions(+), 6 deletions(-) diff --git a/openequivariance/openequivariance/impl_torch/E3NNConv.py b/openequivariance/openequivariance/impl_torch/E3NNConv.py index 29137819..b4975ace 100644 --- a/openequivariance/openequivariance/impl_torch/E3NNConv.py +++ b/openequivariance/openequivariance/impl_torch/E3NNConv.py @@ -5,9 +5,9 @@ scatter_add_wrapper, ) from openequivariance.impl_torch.E3NNTensorProduct import E3NNTensorProduct +from openequivariance.impl_torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixinConv - -class E3NNConv(ConvolutionBase): +class E3NNConv(ConvolutionBase, NumpyDoubleBackwardMixinConv): def __init__(self, config, *, idx_dtype=np.int64, torch_op=True): assert torch_op super().__init__(config, idx_dtype=idx_dtype, torch_op=torch_op) diff --git a/openequivariance/openequivariance/impl_torch/NPDoubleBackwardMixin.py b/openequivariance/openequivariance/impl_torch/NPDoubleBackwardMixin.py index 53a12dae..7e623429 100644 --- a/openequivariance/openequivariance/impl_torch/NPDoubleBackwardMixin.py +++ b/openequivariance/openequivariance/impl_torch/NPDoubleBackwardMixin.py @@ -34,3 +34,42 @@ def double_backward_cpu(self, in1, in2, out_grad, weights, weights_dgrad, in1_dg return a.detach().cpu().numpy(), b.detach().cpu().numpy(), c.detach().cpu().numpy(), d.detach().cpu().numpy() + +class NumpyDoubleBackwardMixinConv: + ''' + Similar, but for fused graph convolution. + ''' + def double_backward_cpu(self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, graph): + assert self.torch_op + + in1_torch = torch.tensor(in1).to('cuda').requires_grad_(True) + in2_torch = torch.tensor(in2).to('cuda').requires_grad_(True) + weights_torch = torch.tensor(weights).to('cuda').requires_grad_(True) + out_grad_torch = torch.tensor(out_grad).to('cuda').requires_grad_(True) + in1_dgrad_torch = torch.tensor(in1_dgrad).to('cuda') + in2_dgrad_torch = torch.tensor(in2_dgrad).to('cuda') + weights_dgrad_torch = torch.tensor(weights_dgrad).to('cuda') + + torch_rows = torch.tensor(graph.rows, device="cuda") + torch_cols = torch.tensor(graph.cols, device="cuda") + torch_transpose_perm = torch.tensor(graph.transpose_perm, device="cuda") + + out_torch = self.forward(in1_torch, in2_torch, weights_torch, torch_rows, torch_cols, torch_transpose_perm) + + in1_grad, in2_grad, weights_grad = torch.autograd.grad( + outputs=out_torch, + inputs=[in1_torch, in2_torch, weights_torch], + grad_outputs=out_grad_torch, + create_graph=True, + retain_graph=True + ) + + a, b, c, d = torch.autograd.grad( + outputs=[in1_grad, in2_grad, weights_grad], + inputs=[in1_torch, in2_torch, weights_torch, out_grad_torch], + grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch] + ) + + return a.detach().cpu().numpy(), b.detach().cpu().numpy(), c.detach().cpu().numpy(), d.detach().cpu().numpy() + + diff --git a/openequivariance/openequivariance/impl_torch/TensorProduct.py b/openequivariance/openequivariance/impl_torch/TensorProduct.py index 173d984e..f7bb8ff6 100644 --- a/openequivariance/openequivariance/impl_torch/TensorProduct.py +++ b/openequivariance/openequivariance/impl_torch/TensorProduct.py @@ -7,7 +7,6 @@ from openequivariance.benchmark.logging_utils import getLogger from openequivariance.impl_torch.utils import reorder_torch from openequivariance.impl_torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin -from openequivariance.core.e3nn_lite import Irreps logger = getLogger() @@ -348,7 +347,7 @@ def name(): return "LoopUnrollTP" -if extlib.TORCH_COMPILE and __name__ != "__main__": +if extlib.TORCH_COMPILE: TensorProduct.register_torch_fakes() TensorProduct.register_autograd() TensorProduct.register_autocast() diff --git a/openequivariance/openequivariance/impl_torch/TensorProductConv.py b/openequivariance/openequivariance/impl_torch/TensorProductConv.py index 44e1775f..771c72e2 100644 --- a/openequivariance/openequivariance/impl_torch/TensorProductConv.py +++ b/openequivariance/openequivariance/impl_torch/TensorProductConv.py @@ -22,11 +22,12 @@ from openequivariance.impl_torch.utils import reorder_torch from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.impl_torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixinConv -logger = getLogger() +logger = getLogger() -class TensorProductConv(torch.nn.Module, LoopUnrollConv): +class TensorProductConv(torch.nn.Module, LoopUnrollConv, NumpyDoubleBackwardMixinConv): r""" Given a **symmetric, directed** graph :math:`G = (V, E)`, inputs :math:`x_1...x_{|V|}`, :math:`y_1...y_{|E|}`, and weights :math:`W_1...W_{|E|}`, computes From 8a0094a0c7e422fe81dc46e600fd5732f372457b Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 7 Dec 2025 23:40:36 -0800 Subject: [PATCH 062/116] Added double backward CPU function to jax TP conv. --- .../benchmark/random_buffer_utils.py | 41 +++++++++++ .../openequivariance/core/ConvolutionBase.py | 71 ++++--------------- .../impl_jax/TensorProductConv.py | 22 +++++- 3 files changed, 77 insertions(+), 57 deletions(-) diff --git a/openequivariance/openequivariance/benchmark/random_buffer_utils.py b/openequivariance/openequivariance/benchmark/random_buffer_utils.py index bdd85750..a403962d 100644 --- a/openequivariance/openequivariance/benchmark/random_buffer_utils.py +++ b/openequivariance/openequivariance/benchmark/random_buffer_utils.py @@ -181,3 +181,44 @@ def get_random_buffers_backward_conv( return in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad +def get_random_buffers_double_backward_conv( + tpp: TPProblem, node_count: int, edge_count: int, prng_seed: int +): + rng = np.random.default_rng(prng_seed) + in1 = np.array( + rng.uniform(size=(node_count, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype + ) + in2 = np.array( + rng.uniform(size=(edge_count, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype + ) + out_grad = np.array( + rng.uniform(size=(node_count, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype + ) + + weights_size = ( + tuple([tpp.weight_numel]) + if tpp.shared_weights + else tuple([edge_count, tpp.weight_numel]) + ) + + weights = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype) + weights_grad = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype) + in1_grad = np.array( + rng.uniform(size=(node_count, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype + ) + in2_grad = np.array( + rng.uniform(size=(edge_count, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype + ) + out_double_grad = np.array( + rng.uniform(size=(node_count, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype + ) + return ( + in1, + in2, + out_grad, + weights, + weights_grad, + in1_grad, + in2_grad, + out_double_grad, + ) \ No newline at end of file diff --git a/openequivariance/openequivariance/core/ConvolutionBase.py b/openequivariance/openequivariance/core/ConvolutionBase.py index f9b33eeb..1301482c 100644 --- a/openequivariance/openequivariance/core/ConvolutionBase.py +++ b/openequivariance/openequivariance/core/ConvolutionBase.py @@ -3,6 +3,7 @@ from openequivariance.benchmark.random_buffer_utils import ( get_random_buffers_forward_conv, get_random_buffers_backward_conv, + get_random_buffers_double_backward_conv, ) from openequivariance.benchmark.logging_utils import getLogger, bcolors @@ -13,7 +14,6 @@ logger = getLogger() - def flops_data_per_tp(config, direction): """ Assumes all interactions are "uvu" for now @@ -549,20 +549,12 @@ def test_correctness_double_backward( reference_implementation=None, high_precision_ref=False, ): - global torch - import torch - - assert self.torch_op - buffers = get_random_buffers_backward_conv( - self.config, graph.node_count, graph.nnz, prng_seed - ) - - rng = np.random.default_rng(seed=prng_seed * 2) - dummy_grad_value = rng.standard_normal(1)[0] + buffers = get_random_buffers_double_backward_conv( + self.config, graph.node_count, graph.nnz, prng_seed + ) if reference_implementation is None: from openequivariance.impl_torch.E3NNConv import E3NNConv - reference_implementation = E3NNConv reference_problem = self.config @@ -576,63 +568,30 @@ def test_correctness_double_backward( result = {"thresh": thresh} tensors = [] for i, tp in enumerate([self, reference_tp]): - in1, in2, out_grad, weights, _, _, _ = [buf.copy() for buf in buffers] + buffers_copy = [buf.copy() for buf in buffers] if i == 1 and high_precision_ref: - in1, in2, out_grad, weights, _, _, _ = [ + buffers_copy = [ np.array(el, dtype=np.float64) for el in buffers ] - in1_torch = torch.tensor(in1, device="cuda", requires_grad=True) - in2_torch = torch.tensor(in2, device="cuda", requires_grad=True) + in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, _ = buffers_copy weights_reordered = tp.reorder_weights_from_e3nn( weights, not self.config.shared_weights ) - - weights_torch = torch.tensor( - weights_reordered, device="cuda", requires_grad=True + weights_dgrad_reordered = tp.reorder_weights_from_e3nn( + weights_dgrad, not self.config.shared_weights ) - torch_rows = torch.tensor(graph.rows, device="cuda") - torch_cols = torch.tensor(graph.cols, device="cuda") - torch_transpose_perm = torch.tensor(graph.transpose_perm, device="cuda") - - fwd_args = [in1_torch, in2_torch, weights_torch, torch_rows, torch_cols] - if tp.deterministic: - fwd_args.append(torch_transpose_perm) - - out_torch = tp.forward(*fwd_args) - out_grad_torch = torch.tensor(out_grad, device="cuda", requires_grad=True) - - in1_grad, in2_grad, w_grad = torch.autograd.grad( - outputs=[out_torch], - inputs=[in1_torch, in2_torch, weights_torch], - grad_outputs=[out_grad_torch], - create_graph=True, - ) - - dummy = torch.norm(in1_grad) + torch.norm(in2_grad) + torch.norm(w_grad) - dummy_grad = torch.tensor( - float(dummy_grad_value), device="cuda", requires_grad=True - ) - dummy.backward( - dummy_grad, inputs=[out_grad_torch, in1_torch, in2_torch, weights_torch] - ) - - weights_grad = weights_torch.grad.detach().cpu().numpy() - weights_grad = tp.reorder_weights_to_e3nn( - weights_grad, not self.config.shared_weights - ) + in1_grad, in2_grad, weights_grad, out_dgrad = self.double_backward_cpu(in1, in2, out_grad, weights_reordered, weights_dgrad_reordered, in1_dgrad, in2_dgrad, graph) tensors.append( - ( - out_grad_torch.grad.detach().cpu().numpy().copy(), - in1_torch.grad.detach().cpu().numpy().copy(), - in2_torch.grad.detach().cpu().numpy().copy(), - weights_grad.copy(), - ) - ) + ( out_dgrad, + in1_grad, + in2_grad, + self.reorder_weights_to_e3nn(weights_grad, has_batch_dim=not self.config.shared_weights) + )) for name, to_check, ground_truth in [ ("output_grad", tensors[0][0], tensors[1][0]), diff --git a/openequivariance/openequivariance/impl_jax/TensorProductConv.py b/openequivariance/openequivariance/impl_jax/TensorProductConv.py index c1640548..53568739 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProductConv.py +++ b/openequivariance/openequivariance/impl_jax/TensorProductConv.py @@ -220,4 +220,24 @@ def backward_cpu( L1_grad[:] = np.asarray(L1_grad_jax) L2_grad[:] = np.asarray(L2_grad_jax) weights_grad[:] = np.asarray(weights_grad_jax) - weights_grad[:] = self.reorder_weights_to_e3nn(weights_grad, has_batch_dim=not self.config.shared_weights) \ No newline at end of file + weights_grad[:] = self.reorder_weights_to_e3nn(weights_grad, has_batch_dim=not self.config.shared_weights) + + def double_backward_cpu(self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, graph): + in1_jax = jax.numpy.asarray(in1) + in2_jax = jax.numpy.asarray(in2) + weights_jax = jax.numpy.asarray(weights) + out_grad_jax = jax.numpy.asarray(out_grad) + in1_dgrad_jax = jax.numpy.asarray(in1_dgrad) + in2_dgrad_jax = jax.numpy.asarray(in2_dgrad) + weights_dgrad_jax = jax.numpy.asarray(weights_dgrad) + + rows_jax = jax.numpy.asarray(graph.rows.astype(self.idx_dtype)) + cols_jax = jax.numpy.asarray(graph.cols.astype(self.idx_dtype)) + sender_perm_jax = jax.numpy.asarray(graph.transpose_perm.astype(self.idx_dtype)) + + in1_grad, in2_grad, weights_grad, out_dgrad = jax.vjp( + lambda x, y, w, o: jax.vjp(lambda a, b, c: self.forward(a, b, c, rows_jax, cols_jax, sender_perm_jax), x, y, w)[1](o), + in1_jax, in2_jax, weights_jax, out_grad_jax + )[1]((in1_dgrad_jax, in2_dgrad_jax, weights_dgrad_jax)) + + return in1_grad, in2_grad, weights_grad, out_dgrad From 61e0566d92138b5404ba5046a987b36a3374901e Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 8 Dec 2025 00:24:30 -0800 Subject: [PATCH 063/116] Almost there, need to get TensorProductConv working. --- openequivariance/openequivariance/core/ConvolutionBase.py | 8 ++++---- .../openequivariance/impl_jax/TensorProductConv.py | 6 +++--- openequivariance/openequivariance/impl_torch/E3NNConv.py | 2 +- openequivariance/pyproject.toml | 3 +-- 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/openequivariance/openequivariance/core/ConvolutionBase.py b/openequivariance/openequivariance/core/ConvolutionBase.py index 1301482c..f81b0f7d 100644 --- a/openequivariance/openequivariance/core/ConvolutionBase.py +++ b/openequivariance/openequivariance/core/ConvolutionBase.py @@ -578,19 +578,19 @@ def test_correctness_double_backward( in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, _ = buffers_copy weights_reordered = tp.reorder_weights_from_e3nn( - weights, not self.config.shared_weights + weights, not tp.config.shared_weights ) weights_dgrad_reordered = tp.reorder_weights_from_e3nn( - weights_dgrad, not self.config.shared_weights + weights_dgrad, not tp.config.shared_weights ) - in1_grad, in2_grad, weights_grad, out_dgrad = self.double_backward_cpu(in1, in2, out_grad, weights_reordered, weights_dgrad_reordered, in1_dgrad, in2_dgrad, graph) + in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu(in1, in2, out_grad, weights_reordered, weights_dgrad_reordered, in1_dgrad, in2_dgrad, graph) tensors.append( ( out_dgrad, in1_grad, in2_grad, - self.reorder_weights_to_e3nn(weights_grad, has_batch_dim=not self.config.shared_weights) + tp.reorder_weights_to_e3nn(weights_grad, has_batch_dim=not self.config.shared_weights) )) for name, to_check, ground_truth in [ diff --git a/openequivariance/openequivariance/impl_jax/TensorProductConv.py b/openequivariance/openequivariance/impl_jax/TensorProductConv.py index 53568739..fd200a15 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProductConv.py +++ b/openequivariance/openequivariance/impl_jax/TensorProductConv.py @@ -165,10 +165,10 @@ def __call__( return self.forward(X, Y, W, rows, cols, sender_perm) def reorder_weights_from_e3nn(self, weights, has_batch_dim=True): - return reorder_jax(self.forward_schedule, weights, "forward", not self.config.shared_weights) + return reorder_jax(self.forward_schedule, weights, "forward", has_batch_dim) def reorder_weights_to_e3nn(self, weights, has_batch_dim=True): - return reorder_jax(self.forward_schedule, weights, "backward", not self.config.shared_weights) + return reorder_jax(self.forward_schedule, weights, "backward", has_batch_dim) def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph): rows = graph.rows.astype(np.int32) @@ -240,4 +240,4 @@ def double_backward_cpu(self, in1, in2, out_grad, weights, weights_dgrad, in1_dg in1_jax, in2_jax, weights_jax, out_grad_jax )[1]((in1_dgrad_jax, in2_dgrad_jax, weights_dgrad_jax)) - return in1_grad, in2_grad, weights_grad, out_dgrad + return np.asarray(in1_grad), np.asarray(in2_grad), np.asarray(weights_grad), np.asarray(out_dgrad) diff --git a/openequivariance/openequivariance/impl_torch/E3NNConv.py b/openequivariance/openequivariance/impl_torch/E3NNConv.py index b4975ace..18b1329b 100644 --- a/openequivariance/openequivariance/impl_torch/E3NNConv.py +++ b/openequivariance/openequivariance/impl_torch/E3NNConv.py @@ -37,7 +37,7 @@ def __init__(self, config, *, idx_dtype=np.int64, torch_op=True): if config.irrep_dtype == np.float64: torch.set_default_dtype(torch.float32) # Reset to default - def forward(self, L1_in, L2_in, weights, rows, cols): + def forward(self, L1_in, L2_in, weights, rows, cols, transpose_perm=None): messages = self.reference_tp(L1_in[cols], L2_in, weights) return scatter_add_wrapper(messages, rows, L1_in.size(0)) diff --git a/openequivariance/pyproject.toml b/openequivariance/pyproject.toml index db392c7c..aaee76da 100644 --- a/openequivariance/pyproject.toml +++ b/openequivariance/pyproject.toml @@ -48,8 +48,7 @@ bench = [ "cuequivariance-ops-torch-cu12", ] -jax = [ - "jax[cuda12]", +jax = [ "nanobind", "scikit-build-core" ] From ab83aef4e11389538e3047bc2bbbb5375822b358 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 24 Dec 2025 17:18:11 -0800 Subject: [PATCH 064/116] Double backward tests are passing. --- .github/workflows/verify_extension_build.yml | 10 ++++- README.md | 10 +++++ openequivariance/README.md | 4 +- openequivariance_extjax/README.md | 3 ++ openequivariance_extjax/pyproject.toml | 2 +- openequivariance_extjax/src/libjax_tp_jit.cpp | 40 ++++++++++++------- tests/conv_test.py | 2 +- 7 files changed, 50 insertions(+), 21 deletions(-) create mode 100644 openequivariance_extjax/README.md diff --git a/.github/workflows/verify_extension_build.yml b/.github/workflows/verify_extension_build.yml index db48af7b..296db3f1 100644 --- a/.github/workflows/verify_extension_build.yml +++ b/.github/workflows/verify_extension_build.yml @@ -29,8 +29,14 @@ jobs: sudo apt-get update sudo apt install nvidia-cuda-toolkit pip install -r .github/workflows/requirements_cuda_ci.txt - pip install -e . + pip install -e ./openequivariance - name: Test extension build via import run: | - pytest tests/import_test.py -k test_import \ No newline at end of file + pytest tests/import_test.py -k test_import + + - name: Install dependencies to test JAX extension build + run: | + pip install "jax[cuda12]" + pip install -e ./openequivariance[jax] + pip install -e ./openequivariance_extjax --no-build-isolation \ No newline at end of file diff --git a/README.md b/README.md index 288e1daf..25f6b83b 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,16 @@ computation and memory consumption significantly. For detailed instructions on tests, benchmarks, MACE / Nequip, and our API, check out the [documentation](https://passionlab.github.io/OpenEquivariance). +⭐️ **JAX Support**: Our latest update brings +support for JAX. You need to execute the following +commands in order: + +``` +pip install openequivariance[jax] +pip install openequivariance_extjax --no-build-isolation +``` + + 📣 📣 OpenEquivariance was accepted to the 2025 SIAM Conference on Applied and Computational Discrete Algorithms (Proceedings Track)! Catch the talk in Montréal and check out the [camera-ready copy on Arxiv](https://arxiv.org/abs/2501.13986) (available May 12, 2025). diff --git a/openequivariance/README.md b/openequivariance/README.md index 45a0ae38..976ab6c6 100644 --- a/openequivariance/README.md +++ b/openequivariance/README.md @@ -1,6 +1,6 @@ # OpenEquivariance -This package contains the core implementation of OpenEquivariance, which is fully -sufficient to run the package from PyTorch. For JAX support, see instructions +The core implementation of OpenEquivariance with +PyTorch support. For JAX, see instructions on installing `openequivariance_extjax` along with this package. diff --git a/openequivariance_extjax/README.md b/openequivariance_extjax/README.md new file mode 100644 index 00000000..64958567 --- /dev/null +++ b/openequivariance_extjax/README.md @@ -0,0 +1,3 @@ +# OpenEquivariance JAX Extension + +The JAX extension module for OpenEquivariance. diff --git a/openequivariance_extjax/pyproject.toml b/openequivariance_extjax/pyproject.toml index 17165af7..cc6a4738 100644 --- a/openequivariance_extjax/pyproject.toml +++ b/openequivariance_extjax/pyproject.toml @@ -19,7 +19,7 @@ description = "JAX C++ Extension for OpenEquivariance" requires-python = ">=3.10" dependencies = [] -readme = "../README.md" +readme = "README.md" #license = "BSD-3-Clause" #license-files = ["../LICENSE"] diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp index 5c484ade..257a737b 100644 --- a/openequivariance_extjax/src/libjax_tp_jit.cpp +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -70,11 +70,12 @@ inline int byte_count(ffi::AnyBuffer &buffer) { } #ifdef CUDA_BACKEND -void zero_buffer(ffi::AnyBuffer &buffer) { - cudaMemset( +void zero_buffer(ffi::AnyBuffer &buffer, cudaStream_t stream) { + cudaMemsetAsync( data_ptr(buffer), 0, - buffer.element_count() * byte_count(buffer)); + buffer.element_count() * byte_count(buffer), + stream); } #endif @@ -303,7 +304,7 @@ ffi::Error tp_backward_impl( } if (k.shared_weights) { - zero_buffer(*W_grad); + zero_buffer(*W_grad, stream); } jit_kernel->backward( @@ -354,7 +355,7 @@ ffi::Error tp_double_backward_impl( } if (k.shared_weights) { - zero_buffer(*W_grad); + zero_buffer(*W_grad, stream); } jit_kernel->double_backward( @@ -438,6 +439,7 @@ ffi::Error conv_forward_impl( kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true); const int64_t nnz = rows.dimensions()[0]; const int64_t node_count = L1_in.dimensions()[0]; + void* workspace_ptr = data_ptr(workspace); check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in"); check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in"); @@ -449,8 +451,9 @@ ffi::Error conv_forward_impl( check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm"); } else { - zero_buffer(*L3_out); + workspace_ptr = nullptr; } + zero_buffer(*L3_out, stream); if (k.shared_weights) check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); @@ -465,7 +468,7 @@ ffi::Error conv_forward_impl( data_ptr(rows), data_ptr(cols), nnz, node_count, - data_ptr(workspace), + workspace_ptr, stream); return ffi::Error::Success(); @@ -491,6 +494,8 @@ ffi::Error conv_backward_impl( kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true); const int64_t nnz = rows.dimensions()[0]; const int64_t node_count = L1_in.dimensions()[0]; + void* workspace_ptr = data_ptr(workspace); + check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in"); check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in"); check_tensor(L3_grad, {node_count, k.L3_dim}, k.irrep_dtype, "L3_grad"); @@ -502,8 +507,9 @@ ffi::Error conv_backward_impl( check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm"); } else { - zero_buffer(*L1_grad); - } + workspace_ptr = nullptr; + } + zero_buffer(*L1_grad, stream); if (k.shared_weights) { check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); @@ -514,7 +520,7 @@ ffi::Error conv_backward_impl( check_tensor(*W_grad, {nnz, k.weight_numel}, k.weight_dtype, "W_grad"); } if(k.shared_weights) - zero_buffer(*W_grad); + zero_buffer(*W_grad, stream); jit_kernel->backward( data_ptr(L1_in), @@ -527,7 +533,7 @@ ffi::Error conv_backward_impl( data_ptr(rows), data_ptr(cols), nnz, node_count, - data_ptr(workspace), + workspace_ptr, data_ptr(transpose_perm), stream); return ffi::Error::Success(); @@ -557,6 +563,8 @@ ffi::Error conv_double_backward_impl( kernel, forward_config, backward_config, double_backward_config, kernel_prop, hash, true); const int64_t nnz = rows.dimensions()[0]; const int64_t node_count = L1_in.dimensions()[0]; + void* workspace_ptr = data_ptr(workspace); + check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in"); check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in"); check_tensor(L3_grad, {node_count, k.L3_dim}, k.irrep_dtype, "L3_grad"); @@ -570,9 +578,11 @@ ffi::Error conv_double_backward_impl( check_tensor(transpose_perm, {nnz}, k.idx_dtype, "transpose perm"); } else { - zero_buffer(*L1_grad); - zero_buffer(*L3_dgrad); + workspace_ptr = nullptr; } + zero_buffer(*L1_grad, stream); + zero_buffer(*L3_dgrad, stream); + if (k.shared_weights) { check_tensor(W, {k.weight_numel}, k.weight_dtype, "W"); @@ -582,7 +592,7 @@ ffi::Error conv_double_backward_impl( check_tensor(W_dgrad, {nnz, k.weight_numel}, k.weight_dtype, "W_dgrad"); } if(k.shared_weights) - zero_buffer(*W_grad); + zero_buffer(*W_grad, stream); jit_kernel->double_backward( data_ptr(L1_in), @@ -599,7 +609,7 @@ ffi::Error conv_double_backward_impl( data_ptr(rows), data_ptr(cols), nnz, node_count, - data_ptr(workspace), + workspace_ptr, data_ptr(transpose_perm), stream); return ffi::Error::Success(); diff --git a/tests/conv_test.py b/tests/conv_test.py index a325c8d2..dd2cc9b9 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -238,7 +238,7 @@ def thresh(self, direction): return { "fwd": 1e-5, "bwd": 7.5e-2, # Expect higher errors for shared weights - "double_bwd": 5e-2, + "double_bwd": 5e-1, }[direction] @pytest.fixture(params=problems, ids=lambda x: x.label, scope="class") From 50e0fcc71f9c8d5796a903da0b35129c3f9ac61b Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 24 Dec 2025 17:46:14 -0800 Subject: [PATCH 065/116] Updated documentation. --- README.md | 3 +- .../impl_jax/TensorProduct.py | 6 ++++ .../impl_jax/TensorProductConv.py | 34 ++++++++++++++++++- openequivariance_extjax/README.md | 2 +- 4 files changed, 41 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 25f6b83b..27610c7e 100644 --- a/README.md +++ b/README.md @@ -30,8 +30,7 @@ For detailed instructions on tests, benchmarks, MACE / Nequip, and our API, check out the [documentation](https://passionlab.github.io/OpenEquivariance). ⭐️ **JAX Support**: Our latest update brings -support for JAX. You need to execute the following -commands in order: +support for JAX. To install, execute the following commands in order: ``` pip install openequivariance[jax] diff --git a/openequivariance/openequivariance/impl_jax/TensorProduct.py b/openequivariance/openequivariance/impl_jax/TensorProduct.py index 5d3df8e6..7aa4c7db 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProduct.py +++ b/openequivariance/openequivariance/impl_jax/TensorProduct.py @@ -59,6 +59,12 @@ def backward_autograd(L3_dim, irrep_dtype, attrs, inputs, dZ): class TensorProduct(LoopUnrollTP): + r""" + Identical to ``oeq.torch.TensorProduct`` with functionality in JAX. + + :param problem: Specification of the tensor product. + """ + def __init__(self, config: TPProblem): dp = extlib.DeviceProp(0) super().__init__(config, dp, extlib.postprocess_kernel, torch_op=False) diff --git a/openequivariance/openequivariance/impl_jax/TensorProductConv.py b/openequivariance/openequivariance/impl_jax/TensorProductConv.py index fd200a15..f6c1b66f 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProductConv.py +++ b/openequivariance/openequivariance/impl_jax/TensorProductConv.py @@ -92,6 +92,18 @@ def backward_autograd( class TensorProductConv(LoopUnrollConv): + r""" + Identical to ``oeq.torch.TensorProductConv`` with functionality in JAX, with one + key difference: integer arrays passed to this function must have dtype + ``np.int32`` (as opposed to ``np.int64`` in the PyTorch version). + + :param problem: Specification of the tensor product. + :param deterministic: if ``False``, uses atomics for the convolution. If ``True``, uses a deterministic + fixup-based algorithm. `Default`: ``False``. + :param kahan: If ``True``, uses Kahan summation to improve accuracy during aggregation. To use this option, + the input tensors must be in float32 precision AND you must set ``deterministic=True``. *Default*: ``False``. + """ + def __init__( self, config: TPProblem, deterministic: bool = False, kahan: bool = False ): @@ -132,7 +144,27 @@ def forward( rows: jax.numpy.ndarray, cols: jax.numpy.ndarray, sender_perm: Optional[jax.numpy.ndarray] = None, - ) -> jax.numpy.ndarray: + ) -> jax.numpy.ndarray: + r""" + Computes the fused CG tensor product + convolution. + + :param X: Tensor of shape ``[|V|, problem.irreps_in1.dim()]``, datatype ``problem.irrep_dtype``. + :param Y: Tensor of shape ``[|E|, problem.irreps_in1.dim()]``, datatype ``problem.irrep_dtype``. + :param W: Tensor of datatype ``problem.weight_dtype`` and shape + + * ``[|E|, problem.weight_numel]`` if ``problem.shared_weights=False`` + * ``[problem.weight_numel]`` if ``problem.shared_weights=True`` + + :param rows: Tensor of shape ``[|E|]`` with row indices for each nonzero in the adjacency matrix, + datatype ``np.int32``. Must be row-major sorted along with ``cols`` when ``deterministic=True``. + :param cols: Tensor of shape ``[|E|]`` with column indices for each nonzero in the adjacency matrix, + datatype ``np.int32``. + :param sender_perm: Tensor of shape ``[|E|]`` and ``np.int32`` datatype containing a + permutation that transposes the adjacency matrix nonzeros from row-major to column-major order. + Must be provided when ``deterministic=True``. + + :return: Tensor of shape ``[|V|, problem.irreps_out.dim()]``, datatype ``problem.irrep_dtype``. + """ if not self.deterministic: sender_perm = self.dummy_transpose_perm else: diff --git a/openequivariance_extjax/README.md b/openequivariance_extjax/README.md index 64958567..ad7455ef 100644 --- a/openequivariance_extjax/README.md +++ b/openequivariance_extjax/README.md @@ -1,3 +1,3 @@ # OpenEquivariance JAX Extension -The JAX extension module for OpenEquivariance. +The JAX extension module for OpenEquivariance. \ No newline at end of file From 44af17b3fd555e69c0903c42d13857e965358bd0 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 24 Dec 2025 18:49:26 -0800 Subject: [PATCH 066/116] Modified documentation. --- docs/api.rst | 21 +++++++- docs/conf.py | 5 +- openequivariance/openequivariance/__init__.py | 51 +++++++++++-------- .../core/TensorProductBase.py | 6 +-- .../impl_jax/TensorProduct.py | 6 +-- 5 files changed, 60 insertions(+), 29 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index b8345747..c03204f7 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -31,9 +31,26 @@ trying our code. OpenEquivariance cannot accelerate all tensor products; see :members: :undoc-members: -.. autofunction:: openequivariance.impl_torch_to_oeq_dtype +.. autofunction:: openequivariance.torch_to_oeq_dtype + +.. autofunction:: openequivariance.torch_ext_so_path + +OpenEquivariance JAX API +------------------------ +The JAX API consists of ``TensorProduct`` and ``TensorProductConv`` +classes that behave identically to their PyTorch counterparts. These classes +do not conform exactly to the e3nn-jax API, but perform the same computation. + +.. autoclass:: openequivariance.jax.TensorProduct + :members: forward, reorder_weights_from_e3nn, reorder_weights_to_e3nn + :undoc-members: + :exclude-members: + +.. autoclass:: openequivariance.jax.TensorProductConv + :members: forward, reorder_weights_from_e3nn, reorder_weights_to_e3nn + :undoc-members: + :exclude-members: -.. autofunction:: openequivariance.impl_torch_ext_so_path API Identical to e3nn --------------------- diff --git a/docs/conf.py b/docs/conf.py index 57a3f2c4..d87eb496 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -32,11 +32,14 @@ "sphinx.ext.autodoc", ] -sys.path.insert(0, str(Path("..").resolve())) +sys.path.insert(0, str(Path("../openequivariance").resolve())) autodoc_mock_imports = [ "torch", + "jax", "openequivariance.impl_torch.extlib", + "openequivariance.impl_jax.extlib", + "openequivariance_extjax", "jinja2", "numpy", ] diff --git a/openequivariance/openequivariance/__init__.py b/openequivariance/openequivariance/__init__.py index f7422572..13522cf0 100644 --- a/openequivariance/openequivariance/__init__.py +++ b/openequivariance/openequivariance/__init__.py @@ -1,6 +1,6 @@ # ruff: noqa: F401 import sys -import torch +import os import numpy as np from pathlib import Path @@ -13,12 +13,6 @@ _MulIr, Instruction, ) -from openequivariance.impl_torch.TensorProduct import TensorProduct -from openequivariance.impl_torch.TensorProductConv import ( - TensorProductConv, -) -from openequivariance.impl_torch.extlib import torch_ext_so_path -from openequivariance.core.utils import torch_to_oeq_dtype __version__ = None try: @@ -44,20 +38,36 @@ def extension_source_path(): """ return str(Path(__file__).parent / "extension") +TensorProduct, TensorProductConv, torch_ext_so_path, torch_to_oeq_dtype = None, None, None, None -torch.serialization.add_safe_globals( - [ - TensorProduct, - TensorProductConv, - TPProblem, - Irrep, - Irreps, - _MulIr, - Instruction, - np.float32, - np.float64, - ] -) +if "OEQ_NOTORCH" not in os.environ or os.environ["OEQ_NOTORCH"] != "1": + import torch + from openequivariance.impl_torch.TensorProduct import TensorProduct + from openequivariance.impl_torch.TensorProductConv import TensorProductConv + + from openequivariance.impl_torch.extlib import torch_ext_so_path + from openequivariance.core.utils import torch_to_oeq_dtype + + torch.serialization.add_safe_globals( + [ + TensorProduct, + TensorProductConv, + TPProblem, + Irrep, + Irreps, + _MulIr, + Instruction, + np.float32, + np.float64, + ] + ) + +jax = None +try: + import openequivariance_extjax + import openequivariance.impl_jax as jax +except ImportError: + pass __all__ = [ "TPProblem", @@ -67,4 +77,5 @@ def extension_source_path(): "torch_to_oeq_dtype", "_check_package_editable", "torch_ext_so_path", + "jax" ] diff --git a/openequivariance/openequivariance/core/TensorProductBase.py b/openequivariance/openequivariance/core/TensorProductBase.py index f00dcc7c..f360cce4 100644 --- a/openequivariance/openequivariance/core/TensorProductBase.py +++ b/openequivariance/openequivariance/core/TensorProductBase.py @@ -44,7 +44,7 @@ def reorder_weights_from_e3nn(self, weights, has_batch_dim: bool = True): Reorders weights from ``e3nn`` canonical order to the order used by ``oeq``. :param weights: Weights in ``e3nn`` canonical order, either an - np.ndarray or a torch.Tensor. Tensor of dimensions ``[B, problem.weight_numel]`` + np.ndarray, torch.Tensor or JAX array. Tensor of dimensions ``[B, problem.weight_numel]`` when ``has_batch_dim=True``, otherwise of dimensions ``[problem.weight_numel]``. :param has_batch_dim: If ``True``, treats the first dimension of weights as a batch dimension. Default: ``True``. @@ -57,8 +57,8 @@ def reorder_weights_to_e3nn(self, weights, has_batch_dim: bool = True): r""" Reorders weights from ``oeq`` canonical order to the order used by ``e3nn``. - :param weights: Weights in ``oeq`` canonical order, either an - np.ndarray or a torch.Tensor. Tensor of dimensions ``[B, problem.weight_numel]`` + :param weights: Weights in ``oeq`` canonical order, either a + np.ndarray, torch.Tensor or JAX array. Tensor of dimensions ``[B, problem.weight_numel]`` when ``has_batch_dim=True``, otherwise of dimensions ``[problem.weight_numel]``. :param has_batch_dim: If ``True``, treats the first dimension of wieghts as a batch dimension. Default: ``True``. diff --git a/openequivariance/openequivariance/impl_jax/TensorProduct.py b/openequivariance/openequivariance/impl_jax/TensorProduct.py index 7aa4c7db..dc9d8bdc 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProduct.py +++ b/openequivariance/openequivariance/impl_jax/TensorProduct.py @@ -65,9 +65,9 @@ class TensorProduct(LoopUnrollTP): :param problem: Specification of the tensor product. """ - def __init__(self, config: TPProblem): + def __init__(self, problem: TPProblem): dp = extlib.DeviceProp(0) - super().__init__(config, dp, extlib.postprocess_kernel, torch_op=False) + super().__init__(problem, dp, extlib.postprocess_kernel, torch_op=False) self.attrs = { "kernel": self.jit_kernel, @@ -78,7 +78,7 @@ def __init__(self, config: TPProblem): } hash_attributes(self.attrs) - self.weight_numel = config.weight_numel + self.weight_numel = problem.weight_numel self.L3_dim = self.config.irreps_out.dim def forward(self, X: jax.numpy.ndarray, Y: jax.numpy.ndarray, W: jax.numpy.ndarray) -> jax.numpy.ndarray: From 8caa93ea256f89eae52547aa7fc18e8c5e7227c9 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 24 Dec 2025 19:16:41 -0800 Subject: [PATCH 067/116] Updated documentation. --- docs/api.rst | 25 ++++++++---- docs/installation.rst | 39 +++++++++++++++---- openequivariance/openequivariance/__init__.py | 15 +++++-- .../impl_torch/extlib/__init__.py | 4 -- 4 files changed, 61 insertions(+), 22 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index c03204f7..c21b918f 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -8,7 +8,7 @@ OpenEquivariance API OpenEquivariance exposes two key classes: :py:class:`openequivariance.TensorProduct`, which replaces ``o3.TensorProduct`` from e3nn, and :py:class:`openequivariance.TensorProductConv`, which fuses the CG tensor product with a subsequent graph convolution. Initializing either class triggers -JIT compilation of a custom kernel, which can take a few seconds. +JIT compilation of a custom kernel, which can take a few seconds. Both classes require a configuration object specified by :py:class:`openequivariance.TPProblem`, which has a constructor @@ -17,6 +17,9 @@ We recommend reading the `e3nn documentation ` trying our code. OpenEquivariance cannot accelerate all tensor products; see :doc:`this page ` for a list of supported configurations. +PyTorch API +------------------------ + .. autoclass:: openequivariance.TensorProduct :members: forward, reorder_weights_from_e3nn, reorder_weights_to_e3nn, to :undoc-members: @@ -27,19 +30,21 @@ trying our code. OpenEquivariance cannot accelerate all tensor products; see :undoc-members: :exclude-members: name -.. autoclass:: openequivariance.TPProblem - :members: - :undoc-members: - .. autofunction:: openequivariance.torch_to_oeq_dtype .. autofunction:: openequivariance.torch_ext_so_path -OpenEquivariance JAX API +JAX API ------------------------ The JAX API consists of ``TensorProduct`` and ``TensorProductConv`` classes that behave identically to their PyTorch counterparts. These classes -do not conform exactly to the e3nn-jax API, but perform the same computation. +do not conform exactly to the e3nn-jax API, but perform the same computation. + +If you plan to use ``oeq.jax`` without PyTorch installed, +you need to set ``OEQ_NOTORCH=1`` in your local environment (within Python, +``os.environ["OEQ_NOTORCH"] = 1``). For the moment, we require this to avoid +breaking the PyTorch version of OpenEquivariance. + .. autoclass:: openequivariance.jax.TensorProduct :members: forward, reorder_weights_from_e3nn, reorder_weights_to_e3nn @@ -51,6 +56,12 @@ do not conform exactly to the e3nn-jax API, but perform the same computation. :undoc-members: :exclude-members: +Common API +--------------------- + +.. autoclass:: openequivariance.TPProblem + :members: + :undoc-members: API Identical to e3nn --------------------- diff --git a/docs/installation.rst b/docs/installation.rst index 9c3588cb..3949457b 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -1,4 +1,4 @@ -Installation +Installation (Torch and JAX) ============================== .. toctree:: @@ -8,11 +8,15 @@ Installation You need the following to install OpenEquivariance: - A Linux system equipped with an NVIDIA / AMD graphics card. -- PyTorch >= 2.4 (>= 2.8 for AOTI and export). +- Either PyTorch >= 2.4 (>= 2.8 for AOTI and export), or JAX with CUDA 12 support + or higher. - GCC 9+ and the CUDA / HIP toolkit. The command ``c++ --version`` should return >= 9.0; see below for details on setting an alternate compiler. +PyTorch +------------------------------------------ + Installation is one easy command, followed by import verification: .. code-block:: bash @@ -28,11 +32,8 @@ To get the nightly build, run .. code-block:: bash - pip install git+https://github.com/PASSIONLab/OpenEquivariance - + pip install git+https://github.com/PASSIONLab/OpenEquivariance#subdirectory=openequivariance -Compiling the Integrated PyTorch Extension ------------------------------------------- To support ``torch.compile``, ``torch.export``, and JITScript, OpenEquivariance needs to compile a C++ extension tightly integrated with PyTorch. If you see a warning that @@ -48,13 +49,37 @@ environment variable and retry the import: .. code-block:: bash - export CCC=/path/to/your/gcc + export CC=/path/to/your/gcc export CXX=/path/to/your/g++ python -c "import openequivariance" These configuration steps are required only ONCE after installation (or upgrade) with pip. +JAX +------------------------------------------ +JAX support is currently limited to NVIDIA GPUs. You need to execute +the following two commands strictly in order: + +.. code-block:: bash + + pip install openequivariance[jax] + pip install openequivariance_extjax --no-build-isolation + +From there, set ``OEQ_NOTORCH=1`` to avoid a PyTorch import and test the package: + +.. code-block:: bash + + OEQ_NOTORCH=1 + python -c "import openequivariance.jax" + +You can get the nightly build as follows: + +.. code-block:: bash + + pip install git+https://github.com/PASSIONLab/OpenEquivariance#subdirectory=openequivariance[jax] + pip install git+https://github.com/PASSIONLab/OpenEquivariance#subdirectory=openequivariance_extjax + Configurations on Major Platforms --------------------------------- OpenEquivariance has been tested on both supercomputers and lab clusters. diff --git a/openequivariance/openequivariance/__init__.py b/openequivariance/openequivariance/__init__.py index 13522cf0..f53a9b9c 100644 --- a/openequivariance/openequivariance/__init__.py +++ b/openequivariance/openequivariance/__init__.py @@ -31,21 +31,18 @@ def _check_package_editable(): _editable_install_output_path = Path(__file__).parent.parent.parent / "outputs" - def extension_source_path(): """ :returns: Path to the source code of the C++ extension. """ return str(Path(__file__).parent / "extension") -TensorProduct, TensorProductConv, torch_ext_so_path, torch_to_oeq_dtype = None, None, None, None - if "OEQ_NOTORCH" not in os.environ or os.environ["OEQ_NOTORCH"] != "1": import torch from openequivariance.impl_torch.TensorProduct import TensorProduct from openequivariance.impl_torch.TensorProductConv import TensorProductConv - from openequivariance.impl_torch.extlib import torch_ext_so_path + from openequivariance.impl_torch.extlib import torch_ext_so_path as torch_ext_so_path_internal from openequivariance.core.utils import torch_to_oeq_dtype torch.serialization.add_safe_globals( @@ -62,6 +59,16 @@ def extension_source_path(): ] ) +def torch_ext_so_path(): + """ + :returns: Path to a ``.so`` file that must be linked to use OpenEquivariance + from the PyTorch C++ Interface. + """ + try: + return torch_ext_so_path_internal() + except NameError: + return None + jax = None try: import openequivariance_extjax diff --git a/openequivariance/openequivariance/impl_torch/extlib/__init__.py b/openequivariance/openequivariance/impl_torch/extlib/__init__.py index a96b5b1a..8f812409 100644 --- a/openequivariance/openequivariance/impl_torch/extlib/__init__.py +++ b/openequivariance/openequivariance/impl_torch/extlib/__init__.py @@ -142,10 +142,6 @@ def _raise_import_error_helper(import_target: str): def torch_ext_so_path(): - """ - :returns: Path to a ``.so`` file that must be linked to use OpenEquivariance - from the PyTorch C++ Interface. - """ return torch_module.__file__ From 9d6e30e9e0f4fd133f8e9e4c09d7124aecf7f3da Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 24 Dec 2025 19:23:27 -0800 Subject: [PATCH 068/116] More documentation progress. --- docs/tests_and_benchmarks.rst | 8 ++++++++ openequivariance/pyproject.toml | 2 ++ 2 files changed, 10 insertions(+) diff --git a/docs/tests_and_benchmarks.rst b/docs/tests_and_benchmarks.rst index 7bc11b26..fa24e765 100644 --- a/docs/tests_and_benchmarks.rst +++ b/docs/tests_and_benchmarks.rst @@ -32,6 +32,14 @@ To set up the editable install and run the entire testsuite, use: Browse the ``tests`` directory to run specific components. +To test the JAX wrappers, follow the same steps above and make sure that +``openequivariance_extjax`` is installed without build isolation. Then run + +.. code-block:: bash + + pytest --jax tests/batch_test.py + pytest --jax tests/conv_test.py + Replicating our Benchmarks ------------------------------ diff --git a/openequivariance/pyproject.toml b/openequivariance/pyproject.toml index aaee76da..24a196b2 100644 --- a/openequivariance/pyproject.toml +++ b/openequivariance/pyproject.toml @@ -65,6 +65,8 @@ dev = [ "furo", "sphinx", "sphinx-autobuild" + "nanobind", + "scikit-build-core" ] [tool.setuptools.packages.find] From b7af425de832c85180e81af869bab6984b154a63 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 24 Dec 2025 19:41:25 -0800 Subject: [PATCH 069/116] Renamed. --- docs/conf.py | 2 +- openequivariance/openequivariance/__init__.py | 2 +- openequivariance/openequivariance/impl_jax/__init__.py | 4 ---- .../openequivariance/{impl_jax => jax}/TensorProduct.py | 4 ++-- .../openequivariance/{impl_jax => jax}/TensorProductConv.py | 4 ++-- openequivariance/openequivariance/jax/__init__.py | 4 ++++ .../openequivariance/{impl_jax => jax}/extlib/__init__.py | 0 openequivariance/openequivariance/{impl_jax => jax}/utils.py | 0 tests/batch_test.py | 2 +- tests/conv_test.py | 3 +-- 10 files changed, 12 insertions(+), 13 deletions(-) delete mode 100644 openequivariance/openequivariance/impl_jax/__init__.py rename openequivariance/openequivariance/{impl_jax => jax}/TensorProduct.py (98%) rename openequivariance/openequivariance/{impl_jax => jax}/TensorProductConv.py (99%) create mode 100644 openequivariance/openequivariance/jax/__init__.py rename openequivariance/openequivariance/{impl_jax => jax}/extlib/__init__.py (100%) rename openequivariance/openequivariance/{impl_jax => jax}/utils.py (100%) diff --git a/docs/conf.py b/docs/conf.py index d87eb496..cc398261 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -38,7 +38,7 @@ "torch", "jax", "openequivariance.impl_torch.extlib", - "openequivariance.impl_jax.extlib", + "openequivariance.jax.extlib", "openequivariance_extjax", "jinja2", "numpy", diff --git a/openequivariance/openequivariance/__init__.py b/openequivariance/openequivariance/__init__.py index f53a9b9c..e7de3728 100644 --- a/openequivariance/openequivariance/__init__.py +++ b/openequivariance/openequivariance/__init__.py @@ -72,7 +72,7 @@ def torch_ext_so_path(): jax = None try: import openequivariance_extjax - import openequivariance.impl_jax as jax + import openequivariance.jax as jax except ImportError: pass diff --git a/openequivariance/openequivariance/impl_jax/__init__.py b/openequivariance/openequivariance/impl_jax/__init__.py deleted file mode 100644 index b2ec0994..00000000 --- a/openequivariance/openequivariance/impl_jax/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from openequivariance.impl_jax.TensorProduct import TensorProduct as TensorProduct -from openequivariance.impl_jax.TensorProductConv import TensorProductConv as TensorProductConv - -__all__ = ["TensorProduct", "TensorProductConv"] \ No newline at end of file diff --git a/openequivariance/openequivariance/impl_jax/TensorProduct.py b/openequivariance/openequivariance/jax/TensorProduct.py similarity index 98% rename from openequivariance/openequivariance/impl_jax/TensorProduct.py rename to openequivariance/openequivariance/jax/TensorProduct.py index dc9d8bdc..15275e39 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProduct.py +++ b/openequivariance/openequivariance/jax/TensorProduct.py @@ -1,11 +1,11 @@ import jax import numpy as np from functools import partial -from openequivariance.impl_jax import extlib +from openequivariance.jax import extlib from openequivariance.core.e3nn_lite import TPProblem from openequivariance.core.LoopUnrollTP import LoopUnrollTP from openequivariance.core.utils import hash_attributes -from openequivariance.impl_jax.utils import reorder_jax +from openequivariance.jax.utils import reorder_jax @partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5)) def forward(X, Y, W, L3_dim, irrep_dtype, attrs): diff --git a/openequivariance/openequivariance/impl_jax/TensorProductConv.py b/openequivariance/openequivariance/jax/TensorProductConv.py similarity index 99% rename from openequivariance/openequivariance/impl_jax/TensorProductConv.py rename to openequivariance/openequivariance/jax/TensorProductConv.py index f6c1b66f..c24a1add 100644 --- a/openequivariance/openequivariance/impl_jax/TensorProductConv.py +++ b/openequivariance/openequivariance/jax/TensorProductConv.py @@ -1,12 +1,12 @@ import numpy as np from functools import partial from typing import Optional -from openequivariance.impl_jax import extlib +from openequivariance.jax import extlib from openequivariance.core.e3nn_lite import TPProblem from openequivariance.core.LoopUnrollConv import LoopUnrollConv from openequivariance.core.utils import hash_attributes -from openequivariance.impl_jax.utils import reorder_jax +from openequivariance.jax.utils import reorder_jax import jax import jax.numpy as jnp diff --git a/openequivariance/openequivariance/jax/__init__.py b/openequivariance/openequivariance/jax/__init__.py new file mode 100644 index 00000000..5313c325 --- /dev/null +++ b/openequivariance/openequivariance/jax/__init__.py @@ -0,0 +1,4 @@ +from openequivariance.jax.TensorProduct import TensorProduct as TensorProduct +from openequivariance.jax.TensorProductConv import TensorProductConv as TensorProductConv + +__all__ = ["TensorProduct", "TensorProductConv"] \ No newline at end of file diff --git a/openequivariance/openequivariance/impl_jax/extlib/__init__.py b/openequivariance/openequivariance/jax/extlib/__init__.py similarity index 100% rename from openequivariance/openequivariance/impl_jax/extlib/__init__.py rename to openequivariance/openequivariance/jax/extlib/__init__.py diff --git a/openequivariance/openequivariance/impl_jax/utils.py b/openequivariance/openequivariance/jax/utils.py similarity index 100% rename from openequivariance/openequivariance/impl_jax/utils.py rename to openequivariance/openequivariance/jax/utils.py diff --git a/tests/batch_test.py b/tests/batch_test.py index 6f61fe6f..4c0e6334 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -47,7 +47,7 @@ def test_jax(self, request): def tp_and_problem(self, problem, extra_tp_constructor_args, test_jax): cls = oeq.TensorProduct if test_jax: - import openequivariance.impl_jax.TensorProduct as jax_tp + import openequivariance.jax.TensorProduct as jax_tp cls = jax_tp tp = cls(problem, **extra_tp_constructor_args) return tp, problem diff --git a/tests/conv_test.py b/tests/conv_test.py index dd2cc9b9..7e5f78a0 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -4,7 +4,6 @@ from pytest_check import check import numpy as np -import openequivariance import openequivariance as oeq from openequivariance.benchmark.ConvBenchmarkSuite import load_graph from itertools import product @@ -60,7 +59,7 @@ def test_jax(self, request): def conv_object(self, request, problem, extra_conv_constructor_args, test_jax): cls = oeq.TensorProductConv if test_jax: - from openequivariance.impl_jax import TensorProductConv as jax_conv + from openequivariance.jax import TensorProductConv as jax_conv cls = jax_conv if request.param == "atomic": From 1bcea33712109206305d79e53b9f3e1f69300a8c Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 24 Dec 2025 20:44:48 -0800 Subject: [PATCH 070/116] Renaming + added JAX example. --- docs/api.rst | 4 +- docs/conf.py | 2 +- docs/supported_ops.rst | 2 +- openequivariance/openequivariance/__init__.py | 7 +- .../{impl_torch => _torch}/CUEConv.py | 2 +- .../CUETensorProduct.py | 0 .../{impl_torch => _torch}/E3NNConv.py | 4 +- .../E3NNTensorProduct.py | 2 +- .../{impl_torch => _torch}/FlashTPConv.py | 0 .../NPDoubleBackwardMixin.py | 0 .../{impl_torch => _torch}/TensorProduct.py | 6 +- .../TensorProductConv.py | 10 +-- .../{impl_torch => _torch}/extlib/.empty | 0 .../{impl_torch => _torch}/extlib/__init__.py | 4 +- .../_torch/symmetric_contraction/__init__.py | 5 ++ .../symmetric_contraction.py | 2 +- .../{impl_torch => _torch}/utils.py | 0 .../benchmark/TestBenchmarkSuite.py | 2 +- .../benchmark/benchmark_utils.py | 2 +- .../benchmark/correctness_utils.py | 8 +-- .../openequivariance/core/ConvolutionBase.py | 8 +-- .../core/TensorProductBase.py | 2 +- .../openequivariance/core/utils.py | 2 +- .../symmetric_contraction/__init__.py | 5 -- openequivariance/pyproject.toml | 2 +- tests/batch_test.py | 14 ++-- tests/benchmark.py | 14 ++-- tests/conv_test.py | 15 ++-- tests/examples_test.py | 68 +++++++++++++++++-- tests/export_test.py | 2 +- 30 files changed, 127 insertions(+), 67 deletions(-) rename openequivariance/openequivariance/{impl_torch => _torch}/CUEConv.py (97%) rename openequivariance/openequivariance/{impl_torch => _torch}/CUETensorProduct.py (100%) rename openequivariance/openequivariance/{impl_torch => _torch}/E3NNConv.py (93%) rename openequivariance/openequivariance/{impl_torch => _torch}/E3NNTensorProduct.py (98%) rename openequivariance/openequivariance/{impl_torch => _torch}/FlashTPConv.py (100%) rename openequivariance/openequivariance/{impl_torch => _torch}/NPDoubleBackwardMixin.py (100%) rename openequivariance/openequivariance/{impl_torch => _torch}/TensorProduct.py (98%) rename openequivariance/openequivariance/{impl_torch => _torch}/TensorProductConv.py (98%) rename openequivariance/openequivariance/{impl_torch => _torch}/extlib/.empty (100%) rename openequivariance/openequivariance/{impl_torch => _torch}/extlib/__init__.py (97%) create mode 100644 openequivariance/openequivariance/_torch/symmetric_contraction/__init__.py rename openequivariance/openequivariance/{impl_torch => _torch}/symmetric_contraction/symmetric_contraction.py (99%) rename openequivariance/openequivariance/{impl_torch => _torch}/utils.py (100%) delete mode 100644 openequivariance/openequivariance/impl_torch/symmetric_contraction/__init__.py diff --git a/docs/api.rst b/docs/api.rst index c21b918f..e268b160 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -30,9 +30,9 @@ PyTorch API :undoc-members: :exclude-members: name -.. autofunction:: openequivariance.torch_to_oeq_dtype +.. autofunction:: openequivariance._torch_to_oeq_dtype -.. autofunction:: openequivariance.torch_ext_so_path +.. autofunction:: openequivariance._torch_ext_so_path JAX API ------------------------ diff --git a/docs/conf.py b/docs/conf.py index cc398261..df3b7636 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -37,7 +37,7 @@ autodoc_mock_imports = [ "torch", "jax", - "openequivariance.impl_torch.extlib", + "openequivariance._torch.extlib", "openequivariance.jax.extlib", "openequivariance_extjax", "jinja2", diff --git a/docs/supported_ops.rst b/docs/supported_ops.rst index 02a98282..bcc11955 100644 --- a/docs/supported_ops.rst +++ b/docs/supported_ops.rst @@ -117,7 +117,7 @@ toplevel. You can use our implementation by running .. code-block:: - from openequivariance.impl_torch.symmetric_contraction import SymmetricContraction as OEQSymmetricContraction + from openequivariance._torch.symmetric_contraction import SymmetricContraction as OEQSymmetricContraction Some Github users report weak performance for the symmetric contraction backward pass; your mileage may vary. diff --git a/openequivariance/openequivariance/__init__.py b/openequivariance/openequivariance/__init__.py index e7de3728..de5a8af3 100644 --- a/openequivariance/openequivariance/__init__.py +++ b/openequivariance/openequivariance/__init__.py @@ -39,10 +39,11 @@ def extension_source_path(): if "OEQ_NOTORCH" not in os.environ or os.environ["OEQ_NOTORCH"] != "1": import torch - from openequivariance.impl_torch.TensorProduct import TensorProduct - from openequivariance.impl_torch.TensorProductConv import TensorProductConv - from openequivariance.impl_torch.extlib import torch_ext_so_path as torch_ext_so_path_internal + from openequivariance._torch.TensorProduct import TensorProduct + from openequivariance._torch.TensorProductConv import TensorProductConv + + from openequivariance._torch.extlib import torch_ext_so_path as torch_ext_so_path_internal from openequivariance.core.utils import torch_to_oeq_dtype torch.serialization.add_safe_globals( diff --git a/openequivariance/openequivariance/impl_torch/CUEConv.py b/openequivariance/openequivariance/_torch/CUEConv.py similarity index 97% rename from openequivariance/openequivariance/impl_torch/CUEConv.py rename to openequivariance/openequivariance/_torch/CUEConv.py index 00e345f2..8500e39c 100644 --- a/openequivariance/openequivariance/impl_torch/CUEConv.py +++ b/openequivariance/openequivariance/_torch/CUEConv.py @@ -2,7 +2,7 @@ import itertools from typing import Iterator -from openequivariance.impl_torch.CUETensorProduct import CUETensorProduct +from openequivariance._torch.CUETensorProduct import CUETensorProduct from openequivariance.core.ConvolutionBase import ( ConvolutionBase, scatter_add_wrapper, diff --git a/openequivariance/openequivariance/impl_torch/CUETensorProduct.py b/openequivariance/openequivariance/_torch/CUETensorProduct.py similarity index 100% rename from openequivariance/openequivariance/impl_torch/CUETensorProduct.py rename to openequivariance/openequivariance/_torch/CUETensorProduct.py diff --git a/openequivariance/openequivariance/impl_torch/E3NNConv.py b/openequivariance/openequivariance/_torch/E3NNConv.py similarity index 93% rename from openequivariance/openequivariance/impl_torch/E3NNConv.py rename to openequivariance/openequivariance/_torch/E3NNConv.py index 18b1329b..811509dc 100644 --- a/openequivariance/openequivariance/impl_torch/E3NNConv.py +++ b/openequivariance/openequivariance/_torch/E3NNConv.py @@ -4,8 +4,8 @@ ConvolutionBase, scatter_add_wrapper, ) -from openequivariance.impl_torch.E3NNTensorProduct import E3NNTensorProduct -from openequivariance.impl_torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixinConv +from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct +from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixinConv class E3NNConv(ConvolutionBase, NumpyDoubleBackwardMixinConv): def __init__(self, config, *, idx_dtype=np.int64, torch_op=True): diff --git a/openequivariance/openequivariance/impl_torch/E3NNTensorProduct.py b/openequivariance/openequivariance/_torch/E3NNTensorProduct.py similarity index 98% rename from openequivariance/openequivariance/impl_torch/E3NNTensorProduct.py rename to openequivariance/openequivariance/_torch/E3NNTensorProduct.py index c0416d67..067a7e6b 100644 --- a/openequivariance/openequivariance/impl_torch/E3NNTensorProduct.py +++ b/openequivariance/openequivariance/_torch/E3NNTensorProduct.py @@ -12,7 +12,7 @@ from openequivariance.core.TensorProductBase import TensorProductBase from openequivariance.core.e3nn_lite import TPProblem from openequivariance.benchmark.logging_utils import getLogger -from openequivariance.impl_torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin +from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin TORCH_COMPILE_AUTOTUNING_DIR = pathlib.Path("triton_autotuning") diff --git a/openequivariance/openequivariance/impl_torch/FlashTPConv.py b/openequivariance/openequivariance/_torch/FlashTPConv.py similarity index 100% rename from openequivariance/openequivariance/impl_torch/FlashTPConv.py rename to openequivariance/openequivariance/_torch/FlashTPConv.py diff --git a/openequivariance/openequivariance/impl_torch/NPDoubleBackwardMixin.py b/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py similarity index 100% rename from openequivariance/openequivariance/impl_torch/NPDoubleBackwardMixin.py rename to openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py diff --git a/openequivariance/openequivariance/impl_torch/TensorProduct.py b/openequivariance/openequivariance/_torch/TensorProduct.py similarity index 98% rename from openequivariance/openequivariance/impl_torch/TensorProduct.py rename to openequivariance/openequivariance/_torch/TensorProduct.py index f7bb8ff6..e4197583 100644 --- a/openequivariance/openequivariance/impl_torch/TensorProduct.py +++ b/openequivariance/openequivariance/_torch/TensorProduct.py @@ -1,12 +1,12 @@ from openequivariance.core.LoopUnrollTP import LoopUnrollTP from openequivariance import TPProblem -from openequivariance.impl_torch import extlib +from openequivariance._torch import extlib import torch import typing from openequivariance.core.utils import torch_to_oeq_dtype from openequivariance.benchmark.logging_utils import getLogger -from openequivariance.impl_torch.utils import reorder_torch -from openequivariance.impl_torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin +from openequivariance._torch.utils import reorder_torch +from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin logger = getLogger() diff --git a/openequivariance/openequivariance/impl_torch/TensorProductConv.py b/openequivariance/openequivariance/_torch/TensorProductConv.py similarity index 98% rename from openequivariance/openequivariance/impl_torch/TensorProductConv.py rename to openequivariance/openequivariance/_torch/TensorProductConv.py index 771c72e2..13b2c757 100644 --- a/openequivariance/openequivariance/impl_torch/TensorProductConv.py +++ b/openequivariance/openequivariance/_torch/TensorProductConv.py @@ -3,8 +3,8 @@ import numpy as np import torch -import openequivariance.impl_torch.extlib as extlib -from openequivariance.impl_torch.extlib import ( +import openequivariance._torch.extlib as extlib +from openequivariance._torch.extlib import ( JITConvImpl, postprocess_kernel, DeviceProp, @@ -15,14 +15,14 @@ scatter_add_wrapper, ) from openequivariance.core.LoopUnrollConv import LoopUnrollConv -from openequivariance.impl_torch.TensorProduct import TensorProduct +from openequivariance._torch.TensorProduct import TensorProduct from openequivariance import TPProblem from openequivariance.core.utils import torch_to_oeq_dtype from openequivariance.core.dtype_enum import enum_to_torch_dtype -from openequivariance.impl_torch.utils import reorder_torch +from openequivariance._torch.utils import reorder_torch from openequivariance.benchmark.logging_utils import getLogger -from openequivariance.impl_torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixinConv +from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixinConv logger = getLogger() diff --git a/openequivariance/openequivariance/impl_torch/extlib/.empty b/openequivariance/openequivariance/_torch/extlib/.empty similarity index 100% rename from openequivariance/openequivariance/impl_torch/extlib/.empty rename to openequivariance/openequivariance/_torch/extlib/.empty diff --git a/openequivariance/openequivariance/impl_torch/extlib/__init__.py b/openequivariance/openequivariance/_torch/extlib/__init__.py similarity index 97% rename from openequivariance/openequivariance/impl_torch/extlib/__init__.py rename to openequivariance/openequivariance/_torch/extlib/__init__.py index 8f812409..9995a076 100644 --- a/openequivariance/openequivariance/impl_torch/extlib/__init__.py +++ b/openequivariance/openequivariance/_torch/extlib/__init__.py @@ -39,9 +39,9 @@ generic_module = None if not build_ext: - import openequivariance.impl_torch.extlib.generic_module + import openequivariance._torch.extlib.generic_module - generic_module = openequivariance.impl_torch.extlib.generic_module + generic_module = openequivariance._torch.extlib.generic_module elif TORCH_VERSION_CUDA_OR_HIP: from torch.utils.cpp_extension import library_paths, include_paths diff --git a/openequivariance/openequivariance/_torch/symmetric_contraction/__init__.py b/openequivariance/openequivariance/_torch/symmetric_contraction/__init__.py new file mode 100644 index 00000000..00edefcb --- /dev/null +++ b/openequivariance/openequivariance/_torch/symmetric_contraction/__init__.py @@ -0,0 +1,5 @@ +from openequivariance._torch.symmetric_contraction.symmetric_contraction import ( + SymmetricContraction, +) + +__all__ = ["SymmetricContraction"] diff --git a/openequivariance/openequivariance/impl_torch/symmetric_contraction/symmetric_contraction.py b/openequivariance/openequivariance/_torch/symmetric_contraction/symmetric_contraction.py similarity index 99% rename from openequivariance/openequivariance/impl_torch/symmetric_contraction/symmetric_contraction.py rename to openequivariance/openequivariance/_torch/symmetric_contraction/symmetric_contraction.py index d1f409d0..504e788e 100644 --- a/openequivariance/openequivariance/impl_torch/symmetric_contraction/symmetric_contraction.py +++ b/openequivariance/openequivariance/_torch/symmetric_contraction/symmetric_contraction.py @@ -1,7 +1,7 @@ # ruff: noqa : E402 import torch -from openequivariance.impl_torch.extlib import GroupMM_F32, GroupMM_F64 +from openequivariance._torch.extlib import GroupMM_F32, GroupMM_F64 class GroupMM: diff --git a/openequivariance/openequivariance/impl_torch/utils.py b/openequivariance/openequivariance/_torch/utils.py similarity index 100% rename from openequivariance/openequivariance/impl_torch/utils.py rename to openequivariance/openequivariance/_torch/utils.py diff --git a/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py b/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py index 12cbee5e..119c866c 100644 --- a/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py +++ b/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py @@ -7,7 +7,7 @@ from dataclasses import dataclass import openequivariance as oeq -from openequivariance.impl_torch.extlib import DeviceProp +from openequivariance._torch.extlib import DeviceProp from openequivariance.core.TensorProductBase import TensorProductBase from openequivariance.benchmark.logging_utils import getLogger, bcolors diff --git a/openequivariance/openequivariance/benchmark/benchmark_utils.py b/openequivariance/openequivariance/benchmark/benchmark_utils.py index b7abaf77..377df3d6 100644 --- a/openequivariance/openequivariance/benchmark/benchmark_utils.py +++ b/openequivariance/openequivariance/benchmark/benchmark_utils.py @@ -13,7 +13,7 @@ from openequivariance.core.utils import calculate_total_nnz from openequivariance.core.TensorProductBase import TensorProductBase from openequivariance.core.e3nn_lite import TPProblem -from openequivariance.impl_torch.CUETensorProduct import CUETensorProduct +from openequivariance._torch.CUETensorProduct import CUETensorProduct from openequivariance.benchmark.logging_utils import getLogger, bcolors logger = getLogger() diff --git a/openequivariance/openequivariance/benchmark/correctness_utils.py b/openequivariance/openequivariance/benchmark/correctness_utils.py index 01931c99..3f743332 100644 --- a/openequivariance/openequivariance/benchmark/correctness_utils.py +++ b/openequivariance/openequivariance/benchmark/correctness_utils.py @@ -2,7 +2,7 @@ from openequivariance.core.TensorProductBase import TensorProductBase from openequivariance.core.e3nn_lite import TPProblem -from openequivariance.impl_torch.CUETensorProduct import CUETensorProduct +from openequivariance._torch.CUETensorProduct import CUETensorProduct from openequivariance.benchmark.random_buffer_utils import ( get_random_buffers_forward, get_random_buffers_backward, @@ -72,7 +72,7 @@ def correctness_forward( prng_seed: int, ) -> dict: if reference_implementation is None: - from openequivariance.impl_torch.E3NNTensorProduct import E3NNTensorProduct + from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct reference_implementation = E3NNTensorProduct @@ -116,7 +116,7 @@ def correctness_backward( prng_seed: int, ) -> dict: if reference_implementation is None: - from openequivariance.impl_torch.E3NNTensorProduct import E3NNTensorProduct + from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct reference_implementation = E3NNTensorProduct @@ -199,7 +199,7 @@ def correctness_double_backward( get_random_buffers_double_backward(problem, batch_size=batch_size, prng_seed=prng_seed) if reference_implementation is None: - from openequivariance.impl_torch.E3NNTensorProduct import E3NNTensorProduct + from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct reference_implementation = E3NNTensorProduct result = {"thresh": correctness_threshold, "batch_size": batch_size} diff --git a/openequivariance/openequivariance/core/ConvolutionBase.py b/openequivariance/openequivariance/core/ConvolutionBase.py index f81b0f7d..a450d03d 100644 --- a/openequivariance/openequivariance/core/ConvolutionBase.py +++ b/openequivariance/openequivariance/core/ConvolutionBase.py @@ -10,7 +10,7 @@ from openequivariance.benchmark.correctness_utils import check_similiarity from openequivariance.core.e3nn_lite import wigner_3j from openequivariance.core.utils import benchmark -from openequivariance.impl_torch.extlib import DeviceBuffer +from openequivariance._torch.extlib import DeviceBuffer logger = getLogger() @@ -229,7 +229,7 @@ def test_correctness_forward( high_precision_ref=False, ): if reference_implementation is None: - from openequivariance.impl_torch.E3NNConv import E3NNConv + from openequivariance._torch.E3NNConv import E3NNConv reference_implementation = E3NNConv @@ -473,7 +473,7 @@ def test_correctness_backward( high_precision_ref=False, ): if reference_implementation is None: - from openequivariance.impl_torch.E3NNConv import E3NNConv + from openequivariance._torch.E3NNConv import E3NNConv reference_implementation = E3NNConv @@ -554,7 +554,7 @@ def test_correctness_double_backward( ) if reference_implementation is None: - from openequivariance.impl_torch.E3NNConv import E3NNConv + from openequivariance._torch.E3NNConv import E3NNConv reference_implementation = E3NNConv reference_problem = self.config diff --git a/openequivariance/openequivariance/core/TensorProductBase.py b/openequivariance/openequivariance/core/TensorProductBase.py index f360cce4..5af538bb 100644 --- a/openequivariance/openequivariance/core/TensorProductBase.py +++ b/openequivariance/openequivariance/core/TensorProductBase.py @@ -3,7 +3,7 @@ from openequivariance.core.e3nn_lite import TPProblem from openequivariance.benchmark.logging_utils import getLogger from openequivariance.core.utils import benchmark -from openequivariance.impl_torch.extlib import DeviceBuffer +from openequivariance._torch.extlib import DeviceBuffer logger = getLogger() diff --git a/openequivariance/openequivariance/core/utils.py b/openequivariance/openequivariance/core/utils.py index 442ef6c7..472c00f2 100644 --- a/openequivariance/openequivariance/core/utils.py +++ b/openequivariance/openequivariance/core/utils.py @@ -8,7 +8,7 @@ import json import tempfile import hashlib -from openequivariance.impl_torch.extlib import GPUTimer +from openequivariance._torch.extlib import GPUTimer def sparse_outer_product_work(cg: np.ndarray) -> int: diff --git a/openequivariance/openequivariance/impl_torch/symmetric_contraction/__init__.py b/openequivariance/openequivariance/impl_torch/symmetric_contraction/__init__.py deleted file mode 100644 index 23d4b030..00000000 --- a/openequivariance/openequivariance/impl_torch/symmetric_contraction/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from openequivariance.impl_torch.symmetric_contraction.symmetric_contraction import ( - SymmetricContraction, -) - -__all__ = ["SymmetricContraction"] diff --git a/openequivariance/pyproject.toml b/openequivariance/pyproject.toml index 24a196b2..30df0e60 100644 --- a/openequivariance/pyproject.toml +++ b/openequivariance/pyproject.toml @@ -64,7 +64,7 @@ dev = [ "cmake", "furo", "sphinx", - "sphinx-autobuild" + "sphinx-autobuild", "nanobind", "scikit-build-core" ] diff --git a/tests/batch_test.py b/tests/batch_test.py index 4c0e6334..5a65611d 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -40,13 +40,13 @@ def extra_tp_constructor_args(self): return {} @pytest.fixture(scope="class") - def test_jax(self, request): + def with_jax(self, request): return request.config.getoption("--jax") @pytest.fixture(scope="class") - def tp_and_problem(self, problem, extra_tp_constructor_args, test_jax): + def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax): cls = oeq.TensorProduct - if test_jax: + if with_jax: import openequivariance.jax.TensorProduct as jax_tp cls = jax_tp tp = cls(problem, **extra_tp_constructor_args) @@ -254,8 +254,8 @@ def problem(self, request, dtype): class TestTorchbindDisable(TestProductionModels): @pytest.fixture(scope="class") - def extra_tp_constructor_args(self, test_jax): - if test_jax: + def extra_tp_constructor_args(self, with_jax): + if with_jax: pytest.skip("N/A for JAX") return {"use_opaque": True} @@ -270,8 +270,8 @@ def problem(self, request, dtype): return problem @pytest.fixture(scope="class") - def tp_and_problem(self, problem, extra_tp_constructor_args, test_jax): - if test_jax: + def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax): + if with_jax: pytest.skip("N/A for JAX") else: tp = oeq.TensorProduct(problem, **extra_tp_constructor_args) diff --git a/tests/benchmark.py b/tests/benchmark.py index 7ef63b9c..829cc46c 100644 --- a/tests/benchmark.py +++ b/tests/benchmark.py @@ -11,14 +11,14 @@ import numpy as np from openequivariance.benchmark.logging_utils import getLogger -from openequivariance.impl_torch.extlib import DeviceProp -from openequivariance.impl_torch.E3NNTensorProduct import ( +from openequivariance._torch.extlib import DeviceProp +from openequivariance._torch.E3NNTensorProduct import ( E3NNTensorProduct, E3NNTensorProductCompiledCUDAGraphs, E3NNTensorProductCompiledMaxAutotuneCUDAGraphs, ) -from openequivariance.impl_torch.TensorProduct import TensorProduct -from openequivariance.impl_torch.CUETensorProduct import CUETensorProduct +from openequivariance._torch.TensorProduct import TensorProduct +from openequivariance._torch.CUETensorProduct import CUETensorProduct from openequivariance.benchmark.TestBenchmarkSuite import ( TestBenchmarkSuite, TestDefinition, @@ -30,15 +30,15 @@ SingleInstruction, ) -from openequivariance.impl_torch.TensorProductConv import ( +from openequivariance._torch.TensorProductConv import ( TensorProductConvAtomic, TensorProductConvDeterministic, TensorProductConvKahan, TensorProductConvScatterSum, ) -from openequivariance.impl_torch.CUEConv import CUEConv, CUEConvFused -from openequivariance.impl_torch.FlashTPConv import FlashTPConv +from openequivariance._torch.CUEConv import CUEConv, CUEConvFused +from openequivariance._torch.FlashTPConv import FlashTPConv from openequivariance.benchmark.ConvBenchmarkSuite import ConvBenchmarkSuite, load_graph from openequivariance.benchmark.problems import ( diff --git a/tests/conv_test.py b/tests/conv_test.py index 7e5f78a0..e12503e8 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -15,7 +15,6 @@ e3tools_problems, ) - class ConvCorrectness: def thresh(self, direction): return {"fwd": 3e-4, "bwd": 3e-4, "double_bwd": 3e-4}[direction] @@ -52,13 +51,13 @@ def extra_conv_constructor_args(self): return {} @pytest.fixture(scope="class") - def test_jax(self, request): + def with_jax(self, request): return request.config.getoption("--jax") @pytest.fixture(params=["atomic", "deterministic", "kahan"], scope="class") - def conv_object(self, request, problem, extra_conv_constructor_args, test_jax): + def conv_object(self, request, problem, extra_conv_constructor_args, with_jax): cls = oeq.TensorProductConv - if test_jax: + if with_jax: from openequivariance.jax import TensorProductConv as jax_conv cls = jax_conv @@ -254,8 +253,8 @@ def conv_object(self, request, problem): class TestTorchbindDisable(TestProductionModels): @pytest.fixture(scope="class") - def extra_conv_constructor_args(self, test_jax): - if test_jax: + def extra_conv_constructor_args(self, with_jax): + if with_jax: pytest.skip("N/A for JAX") return {"use_opaque": True} @@ -264,8 +263,8 @@ class TestTorchTo(ConvCorrectness): problems = [mace_problems()[0]] @pytest.fixture(params=problems, ids=lambda x: x.label, scope="class") - def problem(self, request, dtype, test_jax): - if test_jax: + def problem(self, request, dtype, with_jax): + if with_jax: pytest.skip("N/A for JAX") problem = request.param diff --git a/tests/examples_test.py b/tests/examples_test.py index 3beaabb1..7d2832cf 100644 --- a/tests/examples_test.py +++ b/tests/examples_test.py @@ -1,4 +1,14 @@ -def test_tutorial(): +import pytest +import os + +@pytest.fixture +def with_jax(request): + return request.config.getoption("--jax") + +def test_tutorial_torch(with_jax): + if with_jax: + pytest.skip("Skipping PyTorch tutorial when testing JAX") + import torch import e3nn.o3 as o3 @@ -26,7 +36,7 @@ def test_tutorial(): problem = oeq.TPProblem( X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False ) - tp_fast = oeq.TensorProduct(problem, torch_op=True) + tp_fast = oeq.TensorProduct(problem) Z = tp_fast(X, Y, W) # Reuse X, Y, W from earlier print(torch.norm(Z)) @@ -53,7 +63,7 @@ def test_tutorial(): W = torch.rand(nonzero_ct, problem.weight_numel, device="cuda", generator=gen) tp_conv = oeq.TensorProductConv( - problem, torch_op=True, deterministic=False + problem, deterministic=False ) # Reuse problem from earlier Z = tp_conv.forward( X, Y, W, edge_index[0], edge_index[1] @@ -66,10 +76,60 @@ def test_tutorial(): edge_index, receiver_perm = edge_index.sort_by("row") # Sort by receiver index # Now we can use the faster deterministic algorithm - tp_conv = oeq.TensorProductConv(problem, torch_op=True, deterministic=True) + tp_conv = oeq.TensorProductConv(problem, deterministic=True) Z = tp_conv.forward( X, Y[receiver_perm], W[receiver_perm], edge_index[0], edge_index[1], sender_perm ) print(torch.norm(Z)) # =============================== assert True + + +def test_tutorial_jax(with_jax): + if not with_jax: + pytest.skip("Skipping JAX tutorial when testing PyTorch") + + os.environ.OEQ_NOTORCH = "1" + import openequivariance as oeq + import jax + + seed = 42 + key = jax.random.PRNGKey(seed) + + batch_size = 1000 + X_ir, Y_ir, Z_ir = oeq.Irreps("1x2e"), oeq.Irreps("1x3e"), oeq.Irreps("1x2e") + instructions = [(0, 0, 0, "uvu", True)] + + problem = oeq.TPProblem( + X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False + ) + tp_fast = oeq.jax.TensorProduct(problem) + + X = jax.random.uniform(key, shape=(batch_size, X_ir.dim), minval=0.0, maxval=1.0, dtype=jax.numpy.float32) + Y = jax.random.uniform(key, shape=(batch_size, Y_ir.dim), minval=0.0, maxval=1.0, dtype=jax.numpy.float32) + W = jax.random.uniform(key, shape=(batch_size, tp_fast.weight_numel), minval=0.0, maxval=1.0, dtype=jax.numpy.float32) + + Z = tp_fast(X, Y, W) + print(jax.numpy.linalg.norm(Z)) + + edge_index = jax.numpy.array( + [ + [0, 1, 1, 2], + [1, 0, 2, 1], + ], + dtype=jax.numpy.int32, # NOTE: This int32, not int64 + ) + + node_ct, nonzero_ct = 3, 4 + X = jax.random.uniform(key, shape=(node_ct, X_ir.dim), minval=0.0, maxval=1.0, dtype=jax.numpy.float32) + Y = jax.random.uniform(key, shape=(nonzero_ct, Y_ir.dim), + minval=0.0, maxval=1.0, dtype=jax.numpy.float32) + W = jax.random.uniform(key, shape=(nonzero_ct, problem.weight_numel), + minval=0.0, maxval=1.0, dtype=jax.numpy.float32) + tp_conv = oeq.jax.TensorProductConv(problem, deterministic=False) + Z = tp_conv.forward( + X, Y, W, edge_index[0], edge_index[1] + ) + print(jax.numpy.linalg.norm(Z)) + + diff --git a/tests/export_test.py b/tests/export_test.py index 9b64e2fe..0fd23b2b 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -11,7 +11,7 @@ from torch_geometric import EdgeIndex import importlib.resources -from openequivariance.impl_torch.E3NNTensorProduct import E3NNTensorProduct +from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct @pytest.fixture(scope="session") From 259ea20a8cb87d38efdc88de7ca7f922d2fab97e Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 24 Dec 2025 21:02:51 -0800 Subject: [PATCH 071/116] JAX example. --- openequivariance_extjax/src/libjax_tp_jit.cpp | 42 ++++++++++++++++++- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp index 257a737b..7342eb01 100644 --- a/openequivariance_extjax/src/libjax_tp_jit.cpp +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -45,6 +45,44 @@ xla::ffi::DataType enum_to_xla_dtype(int64_t i){ throw logic_error("Unsupported tensor datatype!"); } +std::string xla_dtype_to_string(xla::ffi::DataType dtype) { + const std::unordered_map map = { + {xla::ffi::DataType::INVALID, "INVALID"}, + {xla::ffi::DataType::PRED, "PRED"}, + {xla::ffi::DataType::S1, "S1"}, + {xla::ffi::DataType::S2, "S2"}, + {xla::ffi::DataType::S4, "S4"}, + {xla::ffi::DataType::S8, "S8"}, + {xla::ffi::DataType::S16, "S16"}, + {xla::ffi::DataType::S32, "S32"}, + {xla::ffi::DataType::S64, "S64"}, + {xla::ffi::DataType::U1, "U1"}, + {xla::ffi::DataType::U2, "U2"}, + {xla::ffi::DataType::U4, "U4"}, + {xla::ffi::DataType::U8, "U8"}, + {xla::ffi::DataType::U16, "U16"}, + {xla::ffi::DataType::U32, "U32"}, + {xla::ffi::DataType::U64, "U64"}, + {xla::ffi::DataType::F16, "F16"}, + {xla::ffi::DataType::F32, "F32"}, + {xla::ffi::DataType::F64, "F64"}, + {xla::ffi::DataType::BF16, "BF16"}, + {xla::ffi::DataType::C64, "C64"}, + {xla::ffi::DataType::C128, "C128"}, + {xla::ffi::DataType::TOKEN, "TOKEN"}, + {xla::ffi::DataType::F8E5M2, "F8E5M2"}, + {xla::ffi::DataType::F8E4M3, "F8E4M3"}, + {xla::ffi::DataType::F8E4M3FN, "F8E4M3FN"}, + {xla::ffi::DataType::F8E4M3B11FNUZ, "F8E4M3B11FNUZ"}, + {xla::ffi::DataType::F8E5M2FNUZ, "F8E5M2FNUZ"}, + {xla::ffi::DataType::F8E4M3FNUZ, "F8E4M3FNUZ"}, + {xla::ffi::DataType::F8E3M4, "F8E3M4"}, + {xla::ffi::DataType::F4E2M1FN, "F4E2M1FN"}, + {xla::ffi::DataType::F8E8M0FNU, "F8E8M0FNU"}, + }; + return map.at(dtype); +} + inline void* data_ptr(ffi::AnyBuffer &buffer) { return buffer.untyped_data(); } @@ -237,8 +275,8 @@ inline void check_tensor(const ffi::AnyBuffer &buffer, if (buffer.element_type() != expected_dtype) { throw std::logic_error("Datatype mismatch for tensor " + tensor_name + - ". Expected datatype " + std::to_string(static_cast(expected_dtype)) + - ", got " + std::to_string(static_cast(buffer.element_type()))); + ". Expected datatype " + xla_dtype_to_string(expected_dtype) + + ", got " + xla_dtype_to_string(buffer.element_type())); } } From ab871851735c56395dc36f1cc257ada1d8ceca6d Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 24 Dec 2025 21:07:58 -0800 Subject: [PATCH 072/116] Added examples. --- README.md | 30 ++++++++++++++++++++++++++++++ tests/examples_test.py | 2 -- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 27610c7e..2691e570 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,36 @@ pip install openequivariance[jax] pip install openequivariance_extjax --no-build-isolation ``` +```python +os.environ["OEQ_NOTORCH"] = "1" +import openequivariance as oeq +import jax + +seed = 42 +key = jax.random.PRNGKey(seed) + +node_ct, nonzero_ct = 3, 4 +edge_index = jax.numpy.array( + [ + [0, 1, 1, 2], + [1, 0, 2, 1], + ], + dtype=jax.numpy.int32, # NOTE: This int32, not int64 +) + +X = jax.random.uniform(key, shape=(node_ct, X_ir.dim), minval=0.0, maxval=1.0, dtype=jax.numpy.float32) +Y = jax.random.uniform(key, shape=(nonzero_ct, Y_ir.dim), + minval=0.0, maxval=1.0, dtype=jax.numpy.float32) +W = jax.random.uniform(key, shape=(nonzero_ct, problem.weight_numel), + minval=0.0, maxval=1.0, dtype=jax.numpy.float32) + +# Reuse problem from earlier +tp_conv = oeq.jax.TensorProductConv(problem, deterministic=False) +Z = tp_conv.forward( + X, Y, W, edge_index[0], edge_index[1] +) +print(jax.numpy.linalg.norm(Z)) +``` 📣 📣 OpenEquivariance was accepted to the 2025 SIAM Conference on Applied and Computational Discrete Algorithms (Proceedings Track)! Catch the talk in diff --git a/tests/examples_test.py b/tests/examples_test.py index 7d2832cf..61d42416 100644 --- a/tests/examples_test.py +++ b/tests/examples_test.py @@ -131,5 +131,3 @@ def test_tutorial_jax(with_jax): X, Y, W, edge_index[0], edge_index[1] ) print(jax.numpy.linalg.norm(Z)) - - From 9d8f5d882f0945599bacc6daae8d212d3ec682c0 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 24 Dec 2025 21:21:06 -0800 Subject: [PATCH 073/116] Updated README. --- README.md | 81 ++++++++++++++++++++++++++++++------------------------- 1 file changed, 44 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index 2691e570..176ba5cc 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,7 @@ [![License](https://img.shields.io/badge/License-BSD_3--Clause-blue.svg)](https://opensource.org/licenses/BSD-3-Clause) [[Examples]](#show-me-some-examples) +[[JAX Examples]](#jax-examples) [[Citation and Acknowledgements]](#citation-and-acknowledgements) OpenEquivariance is a CUDA and HIP kernel generator for the Clebsch-Gordon tensor product, @@ -29,48 +30,15 @@ computation and memory consumption significantly. For detailed instructions on tests, benchmarks, MACE / Nequip, and our API, check out the [documentation](https://passionlab.github.io/OpenEquivariance). -⭐️ **JAX Support**: Our latest update brings -support for JAX. To install, execute the following commands in order: +⭐️ **JAX**: Our latest update brings +support for JAX. To install, execute the following +commands in order: ``` pip install openequivariance[jax] pip install openequivariance_extjax --no-build-isolation ``` - -```python -os.environ["OEQ_NOTORCH"] = "1" -import openequivariance as oeq -import jax - -seed = 42 -key = jax.random.PRNGKey(seed) - -node_ct, nonzero_ct = 3, 4 -edge_index = jax.numpy.array( - [ - [0, 1, 1, 2], - [1, 0, 2, 1], - ], - dtype=jax.numpy.int32, # NOTE: This int32, not int64 -) - -X = jax.random.uniform(key, shape=(node_ct, X_ir.dim), minval=0.0, maxval=1.0, dtype=jax.numpy.float32) -Y = jax.random.uniform(key, shape=(nonzero_ct, Y_ir.dim), - minval=0.0, maxval=1.0, dtype=jax.numpy.float32) -W = jax.random.uniform(key, shape=(nonzero_ct, problem.weight_numel), - minval=0.0, maxval=1.0, dtype=jax.numpy.float32) - -# Reuse problem from earlier -tp_conv = oeq.jax.TensorProductConv(problem, deterministic=False) -Z = tp_conv.forward( - X, Y, W, edge_index[0], edge_index[1] -) -print(jax.numpy.linalg.norm(Z)) -``` - -📣 📣 OpenEquivariance was accepted to the 2025 SIAM Conference on Applied and -Computational Discrete Algorithms (Proceedings Track)! Catch the talk in -Montréal and check out the [camera-ready copy on Arxiv](https://arxiv.org/abs/2501.13986) (available May 12, 2025). +See below for example usage. ## Show me some examples Here's a CG tensor product implemented by e3nn: @@ -166,6 +134,45 @@ print(torch.norm(Z)) `deterministic=False`, the `sender` and `receiver` indices can have arbitrary order. +## JAX Examples +After installation, use the library +as follows. Set `OEQ_NOTORCH=1` +in your environment to avoid the PyTorch import in +the regular `openequivariance` package. +```python +import jax +import os + +os.environ["OEQ_NOTORCH"] = "1" +import openequivariance as oeq + +seed = 42 +key = jax.random.PRNGKey(seed) + +node_ct, nonzero_ct = 3, 4 +edge_index = jax.numpy.array( + [ + [0, 1, 1, 2], + [1, 0, 2, 1], + ], + dtype=jax.numpy.int32, # NOTE: This int32, not int64 +) + +X = jax.random.uniform(key, shape=(node_ct, X_ir.dim), minval=0.0, maxval=1.0, dtype=jax.numpy.float32) +Y = jax.random.uniform(key, shape=(nonzero_ct, Y_ir.dim), + minval=0.0, maxval=1.0, dtype=jax.numpy.float32) +W = jax.random.uniform(key, shape=(nonzero_ct, problem.weight_numel), + minval=0.0, maxval=1.0, dtype=jax.numpy.float32) + +# Reuse problem from earlier +# ... +tp_conv = oeq.jax.TensorProductConv(problem, deterministic=False) +Z = tp_conv.forward( + X, Y, W, edge_index[0], edge_index[1] +) +print(jax.numpy.linalg.norm(Z)) +``` + ## Citation and Acknowledgements If you find this code useful, please cite our paper: From 7acad2e55e937fdadcdf3deae061208479f9c033 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 24 Dec 2025 21:24:57 -0800 Subject: [PATCH 074/116] Updated release file. --- .github/workflows/release.yaml | 54 ++++++++++++++++++++++++++++++---- 1 file changed, 49 insertions(+), 5 deletions(-) diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 588eebce..d1407fc0 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -6,7 +6,7 @@ on: # ref: https://packaging.python.org/en/latest/guides/publishing-package-distribution-releases-using-github-actions-ci-cd-workflows/ jobs: - build: + build-oeq: name: Build distribution runs-on: ubuntu-latest steps: @@ -16,20 +16,21 @@ jobs: with: python-version: '3.10' - name: install dependencies, then build source tarball - run: | + run: | + cd openequivariance python3 -m pip install build --user python3 -m build --sdist - name: store the distribution packages uses: actions/upload-artifact@v4 with: name: python-package-distributions - path: dist/ + path: openequivariance/dist/ pypi-publish: name: Upload release to PyPI runs-on: ubuntu-latest # build task to be completed first - needs: build + needs: build-oeq # Specifying a GitHub environment is optional, but strongly encouraged environment: name: pypi @@ -42,6 +43,49 @@ jobs: uses: actions/download-artifact@v4 with: name: python-package-distributions - path: dist/ + path: openequivariance/dist/ + - name: publish package distributions to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + + # ------------------------------------ + + build-oeq-extjax: + name: Build distribution + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + - name: install dependencies, then build source tarball + run: | + cd openequivariance_extjax + python3 -m pip install build --user + python3 -m build --sdist + - name: store the distribution packages + uses: actions/upload-artifact@v4 + with: + name: python-package-distributions + path: openequivariance_extjax/dist/ + + pypi-publish: + name: Upload release to PyPI + runs-on: ubuntu-latest + # build task to be completed first + needs: build-oeq-extjax + # Specifying a GitHub environment is optional, but strongly encouraged + environment: + name: pypi + url: https://pypi.org/p/openequivariance_extjax + permissions: + # IMPORTANT: this permission is mandatory for Trusted Publishing + id-token: write + steps: + - name: download the distributions + uses: actions/download-artifact@v4 + with: + name: python-package-distributions + path: openequivariance_extjax/dist/ - name: publish package distributions to PyPI uses: pypa/gh-action-pypi-publish@release/v1 \ No newline at end of file From f65911543c4bf687a1094c4707934669619299aa Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 24 Dec 2025 21:26:03 -0800 Subject: [PATCH 075/116] Linted. --- openequivariance/openequivariance/__init__.py | 12 ++- .../openequivariance/_torch/E3NNConv.py | 1 + .../_torch/NPDoubleBackwardMixin.py | 84 ++++++++++++------- .../openequivariance/_torch/TensorProduct.py | 10 ++- .../_torch/TensorProductConv.py | 9 +- .../openequivariance/_torch/utils.py | 79 ++++++++--------- .../benchmark/correctness_utils.py | 39 +++++++-- .../benchmark/random_buffer_utils.py | 11 ++- .../core/ComputationSchedule.py | 34 ++++---- .../openequivariance/core/ConvolutionBase.py | 35 +++++--- .../openequivariance/jax/TensorProduct.py | 41 ++++++--- .../openequivariance/jax/TensorProductConv.py | 43 +++++++--- .../openequivariance/jax/__init__.py | 6 +- .../openequivariance/jax/utils.py | 21 +++-- tests/batch_test.py | 5 +- tests/conftest.py | 10 ++- tests/conv_test.py | 10 +-- tests/examples_test.py | 62 ++++++++++---- 18 files changed, 339 insertions(+), 173 deletions(-) diff --git a/openequivariance/openequivariance/__init__.py b/openequivariance/openequivariance/__init__.py index de5a8af3..84cf9f92 100644 --- a/openequivariance/openequivariance/__init__.py +++ b/openequivariance/openequivariance/__init__.py @@ -31,19 +31,23 @@ def _check_package_editable(): _editable_install_output_path = Path(__file__).parent.parent.parent / "outputs" + def extension_source_path(): """ :returns: Path to the source code of the C++ extension. """ return str(Path(__file__).parent / "extension") + if "OEQ_NOTORCH" not in os.environ or os.environ["OEQ_NOTORCH"] != "1": import torch - from openequivariance._torch.TensorProduct import TensorProduct + from openequivariance._torch.TensorProduct import TensorProduct from openequivariance._torch.TensorProductConv import TensorProductConv - from openequivariance._torch.extlib import torch_ext_so_path as torch_ext_so_path_internal + from openequivariance._torch.extlib import ( + torch_ext_so_path as torch_ext_so_path_internal, + ) from openequivariance.core.utils import torch_to_oeq_dtype torch.serialization.add_safe_globals( @@ -60,6 +64,7 @@ def extension_source_path(): ] ) + def torch_ext_so_path(): """ :returns: Path to a ``.so`` file that must be linked to use OpenEquivariance @@ -70,6 +75,7 @@ def torch_ext_so_path(): except NameError: return None + jax = None try: import openequivariance_extjax @@ -85,5 +91,5 @@ def torch_ext_so_path(): "torch_to_oeq_dtype", "_check_package_editable", "torch_ext_so_path", - "jax" + "jax", ] diff --git a/openequivariance/openequivariance/_torch/E3NNConv.py b/openequivariance/openequivariance/_torch/E3NNConv.py index 811509dc..4cc20662 100644 --- a/openequivariance/openequivariance/_torch/E3NNConv.py +++ b/openequivariance/openequivariance/_torch/E3NNConv.py @@ -7,6 +7,7 @@ from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixinConv + class E3NNConv(ConvolutionBase, NumpyDoubleBackwardMixinConv): def __init__(self, config, *, idx_dtype=np.int64, torch_op=True): assert torch_op diff --git a/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py b/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py index 7e623429..caf94268 100644 --- a/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py +++ b/openequivariance/openequivariance/_torch/NPDoubleBackwardMixin.py @@ -1,21 +1,25 @@ import torch + class NumpyDoubleBackwardMixin: - ''' - Adds a Numpy double backward method to any TensorProduct + """ + Adds a Numpy double backward method to any TensorProduct with the forward pass defined in PyTorch and the relevant - derivatives registered. - ''' - def double_backward_cpu(self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad): + derivatives registered. + """ + + def double_backward_cpu( + self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad + ): assert self.torch_op - in1_torch = torch.tensor(in1).to('cuda').requires_grad_(True) - in2_torch = torch.tensor(in2).to('cuda').requires_grad_(True) - weights_torch = torch.tensor(weights).to('cuda').requires_grad_(True) - out_grad_torch = torch.tensor(out_grad).to('cuda').requires_grad_(True) - in1_dgrad_torch = torch.tensor(in1_dgrad).to('cuda') - in2_dgrad_torch = torch.tensor(in2_dgrad).to('cuda') - weights_dgrad_torch = torch.tensor(weights_dgrad).to('cuda') + in1_torch = torch.tensor(in1).to("cuda").requires_grad_(True) + in2_torch = torch.tensor(in2).to("cuda").requires_grad_(True) + weights_torch = torch.tensor(weights).to("cuda").requires_grad_(True) + out_grad_torch = torch.tensor(out_grad).to("cuda").requires_grad_(True) + in1_dgrad_torch = torch.tensor(in1_dgrad).to("cuda") + in2_dgrad_torch = torch.tensor(in2_dgrad).to("cuda") + weights_dgrad_torch = torch.tensor(weights_dgrad).to("cuda") out_torch = self.forward(in1_torch, in2_torch, weights_torch) in1_grad, in2_grad, weights_grad = torch.autograd.grad( @@ -23,53 +27,71 @@ def double_backward_cpu(self, in1, in2, out_grad, weights, weights_dgrad, in1_dg inputs=[in1_torch, in2_torch, weights_torch], grad_outputs=out_grad_torch, create_graph=True, - retain_graph=True + retain_graph=True, ) a, b, c, d = torch.autograd.grad( outputs=[in1_grad, in2_grad, weights_grad], inputs=[in1_torch, in2_torch, weights_torch, out_grad_torch], - grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch] + grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch], ) - return a.detach().cpu().numpy(), b.detach().cpu().numpy(), c.detach().cpu().numpy(), d.detach().cpu().numpy() + return ( + a.detach().cpu().numpy(), + b.detach().cpu().numpy(), + c.detach().cpu().numpy(), + d.detach().cpu().numpy(), + ) class NumpyDoubleBackwardMixinConv: - ''' + """ Similar, but for fused graph convolution. - ''' - def double_backward_cpu(self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, graph): + """ + + def double_backward_cpu( + self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, graph + ): assert self.torch_op - in1_torch = torch.tensor(in1).to('cuda').requires_grad_(True) - in2_torch = torch.tensor(in2).to('cuda').requires_grad_(True) - weights_torch = torch.tensor(weights).to('cuda').requires_grad_(True) - out_grad_torch = torch.tensor(out_grad).to('cuda').requires_grad_(True) - in1_dgrad_torch = torch.tensor(in1_dgrad).to('cuda') - in2_dgrad_torch = torch.tensor(in2_dgrad).to('cuda') - weights_dgrad_torch = torch.tensor(weights_dgrad).to('cuda') + in1_torch = torch.tensor(in1).to("cuda").requires_grad_(True) + in2_torch = torch.tensor(in2).to("cuda").requires_grad_(True) + weights_torch = torch.tensor(weights).to("cuda").requires_grad_(True) + out_grad_torch = torch.tensor(out_grad).to("cuda").requires_grad_(True) + in1_dgrad_torch = torch.tensor(in1_dgrad).to("cuda") + in2_dgrad_torch = torch.tensor(in2_dgrad).to("cuda") + weights_dgrad_torch = torch.tensor(weights_dgrad).to("cuda") torch_rows = torch.tensor(graph.rows, device="cuda") torch_cols = torch.tensor(graph.cols, device="cuda") torch_transpose_perm = torch.tensor(graph.transpose_perm, device="cuda") - out_torch = self.forward(in1_torch, in2_torch, weights_torch, torch_rows, torch_cols, torch_transpose_perm) + out_torch = self.forward( + in1_torch, + in2_torch, + weights_torch, + torch_rows, + torch_cols, + torch_transpose_perm, + ) in1_grad, in2_grad, weights_grad = torch.autograd.grad( outputs=out_torch, inputs=[in1_torch, in2_torch, weights_torch], grad_outputs=out_grad_torch, create_graph=True, - retain_graph=True + retain_graph=True, ) a, b, c, d = torch.autograd.grad( outputs=[in1_grad, in2_grad, weights_grad], inputs=[in1_torch, in2_torch, weights_torch, out_grad_torch], - grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch] + grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch], ) - return a.detach().cpu().numpy(), b.detach().cpu().numpy(), c.detach().cpu().numpy(), d.detach().cpu().numpy() - - + return ( + a.detach().cpu().numpy(), + b.detach().cpu().numpy(), + c.detach().cpu().numpy(), + d.detach().cpu().numpy(), + ) diff --git a/openequivariance/openequivariance/_torch/TensorProduct.py b/openequivariance/openequivariance/_torch/TensorProduct.py index e4197583..2590207a 100644 --- a/openequivariance/openequivariance/_torch/TensorProduct.py +++ b/openequivariance/openequivariance/_torch/TensorProduct.py @@ -92,10 +92,14 @@ def __setstate__(self, state): self._init_class() def reorder_weights_from_e3nn(self, weights, has_batch_dim=True): - return reorder_torch(self.forward_schedule, weights, "forward", not self.config.shared_weights) + return reorder_torch( + self.forward_schedule, weights, "forward", not self.config.shared_weights + ) def reorder_weights_to_e3nn(self, weights, has_batch_dim=True): - return reorder_torch(self.forward_schedule, weights, "backward", not self.config.shared_weights) + return reorder_torch( + self.forward_schedule, weights, "backward", not self.config.shared_weights + ) def forward( self, x: torch.Tensor, y: torch.Tensor, W: torch.Tensor @@ -347,7 +351,7 @@ def name(): return "LoopUnrollTP" -if extlib.TORCH_COMPILE: +if extlib.TORCH_COMPILE: TensorProduct.register_torch_fakes() TensorProduct.register_autograd() TensorProduct.register_autocast() diff --git a/openequivariance/openequivariance/_torch/TensorProductConv.py b/openequivariance/openequivariance/_torch/TensorProductConv.py index 13b2c757..d7880ec9 100644 --- a/openequivariance/openequivariance/_torch/TensorProductConv.py +++ b/openequivariance/openequivariance/_torch/TensorProductConv.py @@ -27,6 +27,7 @@ logger = getLogger() + class TensorProductConv(torch.nn.Module, LoopUnrollConv, NumpyDoubleBackwardMixinConv): r""" Given a **symmetric, directed** graph :math:`G = (V, E)`, inputs :math:`x_1...x_{|V|}`, @@ -420,10 +421,14 @@ def double_backward(ctx, grad_output): ) def reorder_weights_from_e3nn(self, weights, has_batch_dim=True): - return reorder_torch(self.forward_schedule, weights, "forward", not self.config.shared_weights) + return reorder_torch( + self.forward_schedule, weights, "forward", not self.config.shared_weights + ) def reorder_weights_to_e3nn(self, weights, has_batch_dim=True): - return reorder_torch(self.forward_schedule, weights, "backward", not self.config.shared_weights) + return reorder_torch( + self.forward_schedule, weights, "backward", not self.config.shared_weights + ) @staticmethod def name(): diff --git a/openequivariance/openequivariance/_torch/utils.py b/openequivariance/openequivariance/_torch/utils.py index 8911b27c..e03bb3ea 100644 --- a/openequivariance/openequivariance/_torch/utils.py +++ b/openequivariance/openequivariance/_torch/utils.py @@ -1,50 +1,53 @@ import torch + def reorder_helper(schedule, weights_in, direction, has_batch_dim): - assert direction in ["forward", "backward"] - - specs = schedule.weight_reordering_info(weights_in, has_batch_dim) - weights_out = torch.zeros_like(weights_in) - - for spec in specs: - parent_range = spec["parent_range"] - parent_shape = spec["parent_shape"] - weights_subrange = spec["weights_subrange"] - child_range = spec["child_range"] - transpose_perm = spec["transpose_perm"] - - if direction == "forward": - reshape_size = spec["reshape_size"] - - sliced_weights = weights_in[parent_range].reshape(parent_shape)[ - weights_subrange - ] - - weights_out[child_range] = sliced_weights.permute( - transpose_perm - ).reshape(reshape_size) - - elif direction == "backward": - transpose_child_shape = spec["transpose_child_shape"] - child_shape = spec["child_shape"] - - sliced_weights = ( - weights_in[child_range] - .reshape(transpose_child_shape) - .permute(transpose_perm) - ) - - weights_out[parent_range].reshape(parent_shape)[ - weights_subrange - ] = sliced_weights.flatten().reshape(child_shape) - - return weights_out + assert direction in ["forward", "backward"] + + specs = schedule.weight_reordering_info(weights_in, has_batch_dim) + weights_out = torch.zeros_like(weights_in) + + for spec in specs: + parent_range = spec["parent_range"] + parent_shape = spec["parent_shape"] + weights_subrange = spec["weights_subrange"] + child_range = spec["child_range"] + transpose_perm = spec["transpose_perm"] + + if direction == "forward": + reshape_size = spec["reshape_size"] + + sliced_weights = weights_in[parent_range].reshape(parent_shape)[ + weights_subrange + ] + + weights_out[child_range] = sliced_weights.permute(transpose_perm).reshape( + reshape_size + ) + + elif direction == "backward": + transpose_child_shape = spec["transpose_child_shape"] + child_shape = spec["child_shape"] + + sliced_weights = ( + weights_in[child_range] + .reshape(transpose_child_shape) + .permute(transpose_perm) + ) + + weights_out[parent_range].reshape(parent_shape)[weights_subrange] = ( + sliced_weights.flatten().reshape(child_shape) + ) + + return weights_out + def reorder_numpy_helper(schedule, weights_in, direction, has_batch_dim): weights_in = torch.from_numpy(weights_in.copy()) result = reorder_helper(schedule, weights_in, direction, has_batch_dim) return result.detach().cpu().numpy().copy() + def reorder_torch(schedule, weights_in, direction, has_batch_dim): if isinstance(weights_in, torch.Tensor): return reorder_helper(schedule, weights_in, direction, has_batch_dim) diff --git a/openequivariance/openequivariance/benchmark/correctness_utils.py b/openequivariance/openequivariance/benchmark/correctness_utils.py index 3f743332..788d209e 100644 --- a/openequivariance/openequivariance/benchmark/correctness_utils.py +++ b/openequivariance/openequivariance/benchmark/correctness_utils.py @@ -6,7 +6,8 @@ from openequivariance.benchmark.random_buffer_utils import ( get_random_buffers_forward, get_random_buffers_backward, - get_random_buffers_double_backward) + get_random_buffers_double_backward, +) from openequivariance.benchmark.logging_utils import getLogger, bcolors import numpy as np @@ -195,11 +196,15 @@ def correctness_double_backward( global torch import torch - in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, _ = \ - get_random_buffers_double_backward(problem, batch_size=batch_size, prng_seed=prng_seed) + in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, _ = ( + get_random_buffers_double_backward( + problem, batch_size=batch_size, prng_seed=prng_seed + ) + ) if reference_implementation is None: from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct + reference_implementation = E3NNTensorProduct result = {"thresh": correctness_threshold, "batch_size": batch_size} @@ -207,19 +212,35 @@ def correctness_double_backward( tensors = [] for _, impl in enumerate([test_implementation, reference_implementation]): tp = instantiate_implementation(impl, problem) - weights_reordered = tp.reorder_weights_from_e3nn(weights, has_batch_dim=not problem.shared_weights) - weights_dgrad_reordered = tp.reorder_weights_from_e3nn(weights_dgrad, has_batch_dim=not problem.shared_weights) + weights_reordered = tp.reorder_weights_from_e3nn( + weights, has_batch_dim=not problem.shared_weights + ) + weights_dgrad_reordered = tp.reorder_weights_from_e3nn( + weights_dgrad, has_batch_dim=not problem.shared_weights + ) if impl == CUETensorProduct and problem.shared_weights: weights_reordered = weights_reordered[np.newaxis, :] - in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu(in1, in2, out_grad, weights_reordered, weights_dgrad_reordered, in1_dgrad, in2_dgrad) + in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu( + in1, + in2, + out_grad, + weights_reordered, + weights_dgrad_reordered, + in1_dgrad, + in2_dgrad, + ) tensors.append( - ( out_dgrad, + ( + out_dgrad, in1_grad, in2_grad, - tp.reorder_weights_to_e3nn(weights_grad, has_batch_dim=not problem.shared_weights) - )) + tp.reorder_weights_to_e3nn( + weights_grad, has_batch_dim=not problem.shared_weights + ), + ) + ) for name, to_check, ground_truth in [ ("output_double_grad", tensors[0][0], tensors[1][0]), diff --git a/openequivariance/openequivariance/benchmark/random_buffer_utils.py b/openequivariance/openequivariance/benchmark/random_buffer_utils.py index a403962d..c657d5bc 100644 --- a/openequivariance/openequivariance/benchmark/random_buffer_utils.py +++ b/openequivariance/openequivariance/benchmark/random_buffer_utils.py @@ -106,11 +106,14 @@ def get_random_buffers_double_backward( weights_grad = np.array(rng.uniform(size=weights_size), dtype=tpp.irrep_dtype) in1_grad = np.array( - rng.uniform(size=(batch_size, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype) + rng.uniform(size=(batch_size, tpp.irreps_in1.dim)), dtype=tpp.irrep_dtype + ) in2_grad = np.array( - rng.uniform(size=(batch_size, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype) + rng.uniform(size=(batch_size, tpp.irreps_in2.dim)), dtype=tpp.irrep_dtype + ) out_double_grad = np.array( - rng.uniform(size=(batch_size, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype) + rng.uniform(size=(batch_size, tpp.irreps_out.dim)), dtype=tpp.irrep_dtype + ) return ( in1, @@ -221,4 +224,4 @@ def get_random_buffers_double_backward_conv( in1_grad, in2_grad, out_double_grad, - ) \ No newline at end of file + ) diff --git a/openequivariance/openequivariance/core/ComputationSchedule.py b/openequivariance/openequivariance/core/ComputationSchedule.py index 135a0f25..9c3884c9 100644 --- a/openequivariance/openequivariance/core/ComputationSchedule.py +++ b/openequivariance/openequivariance/core/ComputationSchedule.py @@ -622,7 +622,7 @@ def calculate_backward_smem( def weight_reordering_info(self, weights_in, has_batch_dim): """ Calculates all shapes, slices, and permutation info to reorder - weights. + weights. """ batch_dim = weights_in.shape[0] reorder_specs = [] @@ -639,13 +639,13 @@ def weight_reordering_info(self, weights_in, has_batch_dim): self.updated_config.weight_range_and_shape_for_instruction(i) ) child_range = [slice(child_start, child_end)] - + weights_subrange = child_inst.weights_subrange - + reshape_size = [-1] transpose_perm = None connection_mode = self.updated_config.instructions[i].connection_mode - + if connection_mode == "uvu": transpose_perm = [1, 0] elif connection_mode == "uvw": @@ -655,11 +655,11 @@ def weight_reordering_info(self, weights_in, has_batch_dim): child_range = [slice(0, batch_dim)] + child_range parent_range = [slice(0, batch_dim)] + parent_range parent_shape = [batch_dim] + parent_shape - + child_shape = [batch_dim] + list(child_shape) weights_subrange = [slice(0, batch_dim)] + child_inst.weights_subrange reshape_size = [batch_dim] + reshape_size - + if transpose_perm is not None: transpose_perm = [0] + [k + 1 for k in transpose_perm] @@ -667,15 +667,17 @@ def weight_reordering_info(self, weights_in, has_batch_dim): if transpose_perm is not None: transpose_child_shape = [child_shape[k] for k in transpose_perm] - reorder_specs.append({ - "parent_range": tuple(parent_range), - "parent_shape": parent_shape, - "weights_subrange": tuple(weights_subrange), - "child_range": tuple(child_range), - "child_shape": child_shape, - "transpose_perm": transpose_perm, - "reshape_size": reshape_size, - "transpose_child_shape": transpose_child_shape, - }) + reorder_specs.append( + { + "parent_range": tuple(parent_range), + "parent_shape": parent_shape, + "weights_subrange": tuple(weights_subrange), + "child_range": tuple(child_range), + "child_shape": child_shape, + "transpose_perm": transpose_perm, + "reshape_size": reshape_size, + "transpose_child_shape": transpose_child_shape, + } + ) return reorder_specs diff --git a/openequivariance/openequivariance/core/ConvolutionBase.py b/openequivariance/openequivariance/core/ConvolutionBase.py index a450d03d..dbfeb5ff 100644 --- a/openequivariance/openequivariance/core/ConvolutionBase.py +++ b/openequivariance/openequivariance/core/ConvolutionBase.py @@ -14,6 +14,7 @@ logger = getLogger() + def flops_data_per_tp(config, direction): """ Assumes all interactions are "uvu" for now @@ -550,11 +551,12 @@ def test_correctness_double_backward( high_precision_ref=False, ): buffers = get_random_buffers_double_backward_conv( - self.config, graph.node_count, graph.nnz, prng_seed - ) + self.config, graph.node_count, graph.nnz, prng_seed + ) if reference_implementation is None: from openequivariance._torch.E3NNConv import E3NNConv + reference_implementation = E3NNConv reference_problem = self.config @@ -571,11 +573,11 @@ def test_correctness_double_backward( buffers_copy = [buf.copy() for buf in buffers] if i == 1 and high_precision_ref: - buffers_copy = [ - np.array(el, dtype=np.float64) for el in buffers - ] + buffers_copy = [np.array(el, dtype=np.float64) for el in buffers] - in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, _ = buffers_copy + in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, _ = ( + buffers_copy + ) weights_reordered = tp.reorder_weights_from_e3nn( weights, not tp.config.shared_weights @@ -584,14 +586,27 @@ def test_correctness_double_backward( weights_dgrad, not tp.config.shared_weights ) - in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu(in1, in2, out_grad, weights_reordered, weights_dgrad_reordered, in1_dgrad, in2_dgrad, graph) + in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu( + in1, + in2, + out_grad, + weights_reordered, + weights_dgrad_reordered, + in1_dgrad, + in2_dgrad, + graph, + ) tensors.append( - ( out_dgrad, + ( + out_dgrad, in1_grad, in2_grad, - tp.reorder_weights_to_e3nn(weights_grad, has_batch_dim=not self.config.shared_weights) - )) + tp.reorder_weights_to_e3nn( + weights_grad, has_batch_dim=not self.config.shared_weights + ), + ) + ) for name, to_check, ground_truth in [ ("output_grad", tensors[0][0], tensors[1][0]), diff --git a/openequivariance/openequivariance/jax/TensorProduct.py b/openequivariance/openequivariance/jax/TensorProduct.py index 15275e39..452e7bb7 100644 --- a/openequivariance/openequivariance/jax/TensorProduct.py +++ b/openequivariance/openequivariance/jax/TensorProduct.py @@ -7,6 +7,7 @@ from openequivariance.core.utils import hash_attributes from openequivariance.jax.utils import reorder_jax + @partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5)) def forward(X, Y, W, L3_dim, irrep_dtype, attrs): forward_call = jax.ffi.ffi_call( @@ -81,7 +82,9 @@ def __init__(self, problem: TPProblem): self.weight_numel = problem.weight_numel self.L3_dim = self.config.irreps_out.dim - def forward(self, X: jax.numpy.ndarray, Y: jax.numpy.ndarray, W: jax.numpy.ndarray) -> jax.numpy.ndarray: + def forward( + self, X: jax.numpy.ndarray, Y: jax.numpy.ndarray, W: jax.numpy.ndarray + ) -> jax.numpy.ndarray: return forward(X, Y, W, self.L3_dim, self.config.irrep_dtype, self.attrs) def __call__( @@ -90,13 +93,19 @@ def __call__( return self.forward(X, Y, W) def reorder_weights_from_e3nn(self, weights, has_batch_dim=True): - return reorder_jax(self.forward_schedule, weights, "forward", not self.config.shared_weights) + return reorder_jax( + self.forward_schedule, weights, "forward", not self.config.shared_weights + ) def reorder_weights_to_e3nn(self, weights, has_batch_dim=True): - return reorder_jax(self.forward_schedule, weights, "backward", not self.config.shared_weights) + return reorder_jax( + self.forward_schedule, weights, "backward", not self.config.shared_weights + ) def forward_cpu(self, L1_in, L2_in, L3_out, weights) -> None: - weights = self.reorder_weights_from_e3nn(weights, has_batch_dim=not self.config.shared_weights) + weights = self.reorder_weights_from_e3nn( + weights, has_batch_dim=not self.config.shared_weights + ) result = self.forward( jax.numpy.asarray(L1_in), jax.numpy.asarray(L2_in), @@ -107,7 +116,9 @@ def forward_cpu(self, L1_in, L2_in, L3_out, weights) -> None: def backward_cpu( self, L1_in, L1_grad, L2_in, L2_grad, L3_grad, weights, weights_grad ) -> None: - weights = self.reorder_weights_from_e3nn(weights, has_batch_dim=not self.config.shared_weights) + weights = self.reorder_weights_from_e3nn( + weights, has_batch_dim=not self.config.shared_weights + ) backward_fn = jax.vjp( lambda X, Y, W: self.forward(X, Y, W), jax.numpy.asarray(L1_in), @@ -120,10 +131,13 @@ def backward_cpu( L1_grad[:] = np.asarray(L1_grad_jax) L2_grad[:] = np.asarray(L2_grad_jax) weights_grad[:] = np.asarray(weights_grad_jax) - weights_grad[:] = self.reorder_weights_to_e3nn(weights_grad, has_batch_dim=not self.config.shared_weights) - + weights_grad[:] = self.reorder_weights_to_e3nn( + weights_grad, has_batch_dim=not self.config.shared_weights + ) - def double_backward_cpu(self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad): + def double_backward_cpu( + self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad + ): in1_jax = jax.numpy.asarray(in1) in2_jax = jax.numpy.asarray(in2) weights_jax = jax.numpy.asarray(weights) @@ -133,8 +147,13 @@ def double_backward_cpu(self, in1, in2, out_grad, weights, weights_dgrad, in1_dg weights_dgrad_jax = jax.numpy.asarray(weights_dgrad) in1_grad, in2_grad, weights_grad, out_dgrad = jax.vjp( - lambda x, y, w, o: jax.vjp(lambda a, b, c: self.forward(a, b, c), x, y, w)[1](o), - in1_jax, in2_jax, weights_jax, out_grad_jax + lambda x, y, w, o: jax.vjp(lambda a, b, c: self.forward(a, b, c), x, y, w)[ + 1 + ](o), + in1_jax, + in2_jax, + weights_jax, + out_grad_jax, )[1]((in1_dgrad_jax, in2_dgrad_jax, weights_dgrad_jax)) - return in1_grad, in2_grad, weights_grad, out_dgrad \ No newline at end of file + return in1_grad, in2_grad, weights_grad, out_dgrad diff --git a/openequivariance/openequivariance/jax/TensorProductConv.py b/openequivariance/openequivariance/jax/TensorProductConv.py index c24a1add..3aaee28a 100644 --- a/openequivariance/openequivariance/jax/TensorProductConv.py +++ b/openequivariance/openequivariance/jax/TensorProductConv.py @@ -95,7 +95,7 @@ class TensorProductConv(LoopUnrollConv): r""" Identical to ``oeq.torch.TensorProductConv`` with functionality in JAX, with one key difference: integer arrays passed to this function must have dtype - ``np.int32`` (as opposed to ``np.int64`` in the PyTorch version). + ``np.int32`` (as opposed to ``np.int64`` in the PyTorch version). :param problem: Specification of the tensor product. :param deterministic: if ``False``, uses atomics for the convolution. If ``True``, uses a deterministic @@ -144,7 +144,7 @@ def forward( rows: jax.numpy.ndarray, cols: jax.numpy.ndarray, sender_perm: Optional[jax.numpy.ndarray] = None, - ) -> jax.numpy.ndarray: + ) -> jax.numpy.ndarray: r""" Computes the fused CG tensor product + convolution. @@ -197,7 +197,7 @@ def __call__( return self.forward(X, Y, W, rows, cols, sender_perm) def reorder_weights_from_e3nn(self, weights, has_batch_dim=True): - return reorder_jax(self.forward_schedule, weights, "forward", has_batch_dim) + return reorder_jax(self.forward_schedule, weights, "forward", has_batch_dim) def reorder_weights_to_e3nn(self, weights, has_batch_dim=True): return reorder_jax(self.forward_schedule, weights, "backward", has_batch_dim) @@ -206,7 +206,9 @@ def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph): rows = graph.rows.astype(np.int32) cols = graph.cols.astype(np.int32) sender_perm = graph.transpose_perm.astype(np.int32) - weights = self.reorder_weights_from_e3nn(weights, has_batch_dim=not self.config.shared_weights) + weights = self.reorder_weights_from_e3nn( + weights, has_batch_dim=not self.config.shared_weights + ) result = self.forward( jax.numpy.asarray(L1_in), jax.numpy.asarray(L2_in), @@ -231,7 +233,9 @@ def backward_cpu( rows = graph.rows.astype(np.int32) cols = graph.cols.astype(np.int32) sender_perm = graph.transpose_perm.astype(np.int32) - weights = self.reorder_weights_from_e3nn(weights, has_batch_dim=not self.config.shared_weights) + weights = self.reorder_weights_from_e3nn( + weights, has_batch_dim=not self.config.shared_weights + ) backward_fn = jax.vjp( lambda X, Y, W: self.forward( @@ -252,9 +256,13 @@ def backward_cpu( L1_grad[:] = np.asarray(L1_grad_jax) L2_grad[:] = np.asarray(L2_grad_jax) weights_grad[:] = np.asarray(weights_grad_jax) - weights_grad[:] = self.reorder_weights_to_e3nn(weights_grad, has_batch_dim=not self.config.shared_weights) + weights_grad[:] = self.reorder_weights_to_e3nn( + weights_grad, has_batch_dim=not self.config.shared_weights + ) - def double_backward_cpu(self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, graph): + def double_backward_cpu( + self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, graph + ): in1_jax = jax.numpy.asarray(in1) in2_jax = jax.numpy.asarray(in2) weights_jax = jax.numpy.asarray(weights) @@ -268,8 +276,23 @@ def double_backward_cpu(self, in1, in2, out_grad, weights, weights_dgrad, in1_dg sender_perm_jax = jax.numpy.asarray(graph.transpose_perm.astype(self.idx_dtype)) in1_grad, in2_grad, weights_grad, out_dgrad = jax.vjp( - lambda x, y, w, o: jax.vjp(lambda a, b, c: self.forward(a, b, c, rows_jax, cols_jax, sender_perm_jax), x, y, w)[1](o), - in1_jax, in2_jax, weights_jax, out_grad_jax + lambda x, y, w, o: jax.vjp( + lambda a, b, c: self.forward( + a, b, c, rows_jax, cols_jax, sender_perm_jax + ), + x, + y, + w, + )[1](o), + in1_jax, + in2_jax, + weights_jax, + out_grad_jax, )[1]((in1_dgrad_jax, in2_dgrad_jax, weights_dgrad_jax)) - return np.asarray(in1_grad), np.asarray(in2_grad), np.asarray(weights_grad), np.asarray(out_dgrad) + return ( + np.asarray(in1_grad), + np.asarray(in2_grad), + np.asarray(weights_grad), + np.asarray(out_dgrad), + ) diff --git a/openequivariance/openequivariance/jax/__init__.py b/openequivariance/openequivariance/jax/__init__.py index 5313c325..410e5dbf 100644 --- a/openequivariance/openequivariance/jax/__init__.py +++ b/openequivariance/openequivariance/jax/__init__.py @@ -1,4 +1,6 @@ from openequivariance.jax.TensorProduct import TensorProduct as TensorProduct -from openequivariance.jax.TensorProductConv import TensorProductConv as TensorProductConv +from openequivariance.jax.TensorProductConv import ( + TensorProductConv as TensorProductConv, +) -__all__ = ["TensorProduct", "TensorProductConv"] \ No newline at end of file +__all__ = ["TensorProduct", "TensorProductConv"] diff --git a/openequivariance/openequivariance/jax/utils.py b/openequivariance/openequivariance/jax/utils.py index 14cc8394..ae15d1a6 100644 --- a/openequivariance/openequivariance/jax/utils.py +++ b/openequivariance/openequivariance/jax/utils.py @@ -2,6 +2,7 @@ import jax.numpy as jnp import numpy as np + def reorder_jax_helper(schedule, weights_in, direction, has_batch_dim): assert direction in ["forward", "backward"] @@ -14,15 +15,17 @@ def reorder_jax_helper(schedule, weights_in, direction, has_batch_dim): weights_subrange = spec["weights_subrange"] child_range = spec["child_range"] transpose_perm = spec["transpose_perm"] - + if direction == "forward": reshape_size = spec["reshape_size"] - + sliced_weights = weights_in[parent_range].reshape(parent_shape)[ weights_subrange ] - - value_to_assign = sliced_weights.transpose(transpose_perm).reshape(reshape_size) + + value_to_assign = sliced_weights.transpose(transpose_perm).reshape( + reshape_size + ) weights_out = weights_out.at[child_range].set(value_to_assign) elif direction == "backward": @@ -34,23 +37,27 @@ def reorder_jax_helper(schedule, weights_in, direction, has_batch_dim): .reshape(transpose_child_shape) .transpose(transpose_perm) ) - + value_to_insert = sliced_weights.flatten().reshape(child_shape) slab = weights_out[parent_range] slab_reshaped = slab.reshape(parent_shape) slab_reshaped = slab_reshaped.at[weights_subrange].set(value_to_insert) - weights_out = weights_out.at[parent_range].set(slab_reshaped.reshape(slab.shape)) + weights_out = weights_out.at[parent_range].set( + slab_reshaped.reshape(slab.shape) + ) return weights_out + def reorder_numpy_jax_helper(schedule, weights_in, direction, has_batch_dim): weights_in_jax = jnp.array(weights_in) result = reorder_jax_helper(schedule, weights_in_jax, direction, has_batch_dim) return np.array(result) + def reorder_jax(schedule, weights_in, direction, has_batch_dim): if isinstance(weights_in, (jnp.ndarray, jax.Array)): return reorder_jax_helper(schedule, weights_in, direction, has_batch_dim) else: - return reorder_numpy_jax_helper(schedule, weights_in, direction, has_batch_dim) \ No newline at end of file + return reorder_numpy_jax_helper(schedule, weights_in, direction, has_batch_dim) diff --git a/tests/batch_test.py b/tests/batch_test.py index 5a65611d..f32f7b51 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -2,7 +2,6 @@ from pytest_check import check import numpy as np -import openequivariance import openequivariance as oeq from openequivariance.benchmark.correctness_utils import ( correctness_forward, @@ -19,6 +18,7 @@ from itertools import product import torch + class TPCorrectness: def thresh(self, direction): return {"fwd": 1e-5, "bwd": 3e-4, "double_bwd": 3e-4}[direction] @@ -45,9 +45,10 @@ def with_jax(self, request): @pytest.fixture(scope="class") def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax): - cls = oeq.TensorProduct + cls = oeq.TensorProduct if with_jax: import openequivariance.jax.TensorProduct as jax_tp + cls = jax_tp tp = cls(problem, **extra_tp_constructor_args) return tp, problem diff --git a/tests/conftest.py b/tests/conftest.py index d5e9f008..0e7098e0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,13 @@ -import pytest import os os.environ["JAX_ENABLE_X64"] = "True" os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False" + + def pytest_addoption(parser): parser.addoption( - "--jax", action="store_true", default=False, help="Test the JAX frontend instead of PyTorch" - ) \ No newline at end of file + "--jax", + action="store_true", + default=False, + help="Test the JAX frontend instead of PyTorch", + ) diff --git a/tests/conv_test.py b/tests/conv_test.py index e12503e8..9c6bb4c8 100644 --- a/tests/conv_test.py +++ b/tests/conv_test.py @@ -15,6 +15,7 @@ e3tools_problems, ) + class ConvCorrectness: def thresh(self, direction): return {"fwd": 3e-4, "bwd": 3e-4, "double_bwd": 3e-4}[direction] @@ -59,17 +60,14 @@ def conv_object(self, request, problem, extra_conv_constructor_args, with_jax): cls = oeq.TensorProductConv if with_jax: from openequivariance.jax import TensorProductConv as jax_conv + cls = jax_conv if request.param == "atomic": - return cls( - problem, deterministic=False, **extra_conv_constructor_args - ) + return cls(problem, deterministic=False, **extra_conv_constructor_args) elif request.param == "deterministic": if not problem.shared_weights: - return cls( - problem, deterministic=True, **extra_conv_constructor_args - ) + return cls(problem, deterministic=True, **extra_conv_constructor_args) else: pytest.skip("Shared weights not supported with deterministic") elif request.param == "kahan": diff --git a/tests/examples_test.py b/tests/examples_test.py index 61d42416..bad7c220 100644 --- a/tests/examples_test.py +++ b/tests/examples_test.py @@ -1,10 +1,12 @@ import pytest import os + @pytest.fixture def with_jax(request): return request.config.getoption("--jax") + def test_tutorial_torch(with_jax): if with_jax: pytest.skip("Skipping PyTorch tutorial when testing JAX") @@ -36,7 +38,7 @@ def test_tutorial_torch(with_jax): problem = oeq.TPProblem( X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False ) - tp_fast = oeq.TensorProduct(problem) + tp_fast = oeq.TensorProduct(problem) Z = tp_fast(X, Y, W) # Reuse X, Y, W from earlier print(torch.norm(Z)) @@ -103,31 +105,59 @@ def test_tutorial_jax(with_jax): problem = oeq.TPProblem( X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False ) - tp_fast = oeq.jax.TensorProduct(problem) - - X = jax.random.uniform(key, shape=(batch_size, X_ir.dim), minval=0.0, maxval=1.0, dtype=jax.numpy.float32) - Y = jax.random.uniform(key, shape=(batch_size, Y_ir.dim), minval=0.0, maxval=1.0, dtype=jax.numpy.float32) - W = jax.random.uniform(key, shape=(batch_size, tp_fast.weight_numel), minval=0.0, maxval=1.0, dtype=jax.numpy.float32) + tp_fast = oeq.jax.TensorProduct(problem) + + X = jax.random.uniform( + key, + shape=(batch_size, X_ir.dim), + minval=0.0, + maxval=1.0, + dtype=jax.numpy.float32, + ) + Y = jax.random.uniform( + key, + shape=(batch_size, Y_ir.dim), + minval=0.0, + maxval=1.0, + dtype=jax.numpy.float32, + ) + W = jax.random.uniform( + key, + shape=(batch_size, tp_fast.weight_numel), + minval=0.0, + maxval=1.0, + dtype=jax.numpy.float32, + ) Z = tp_fast(X, Y, W) print(jax.numpy.linalg.norm(Z)) - + edge_index = jax.numpy.array( [ [0, 1, 1, 2], [1, 0, 2, 1], ], - dtype=jax.numpy.int32, # NOTE: This int32, not int64 + dtype=jax.numpy.int32, # NOTE: This int32, not int64 ) node_ct, nonzero_ct = 3, 4 - X = jax.random.uniform(key, shape=(node_ct, X_ir.dim), minval=0.0, maxval=1.0, dtype=jax.numpy.float32) - Y = jax.random.uniform(key, shape=(nonzero_ct, Y_ir.dim), - minval=0.0, maxval=1.0, dtype=jax.numpy.float32) - W = jax.random.uniform(key, shape=(nonzero_ct, problem.weight_numel), - minval=0.0, maxval=1.0, dtype=jax.numpy.float32) - tp_conv = oeq.jax.TensorProductConv(problem, deterministic=False) - Z = tp_conv.forward( - X, Y, W, edge_index[0], edge_index[1] + X = jax.random.uniform( + key, shape=(node_ct, X_ir.dim), minval=0.0, maxval=1.0, dtype=jax.numpy.float32 ) + Y = jax.random.uniform( + key, + shape=(nonzero_ct, Y_ir.dim), + minval=0.0, + maxval=1.0, + dtype=jax.numpy.float32, + ) + W = jax.random.uniform( + key, + shape=(nonzero_ct, problem.weight_numel), + minval=0.0, + maxval=1.0, + dtype=jax.numpy.float32, + ) + tp_conv = oeq.jax.TensorProductConv(problem, deterministic=False) + Z = tp_conv.forward(X, Y, W, edge_index[0], edge_index[1]) print(jax.numpy.linalg.norm(Z)) From dc45a8f924fe474cebcb034d7738170ed463f715 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 24 Dec 2025 21:29:56 -0800 Subject: [PATCH 076/116] Updated the build verification. --- .github/workflows/requirements_cuda_ci.txt | 4 +++- .github/workflows/verify_extension_build.yml | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/requirements_cuda_ci.txt b/.github/workflows/requirements_cuda_ci.txt index da9fde80..0e04348c 100644 --- a/.github/workflows/requirements_cuda_ci.txt +++ b/.github/workflows/requirements_cuda_ci.txt @@ -1,4 +1,6 @@ numpy==2.2.5 torch==2.7.0 --index-url https://download.pytorch.org/whl/cu128 pytest==8.3.5 -ninja==1.11.1.4 \ No newline at end of file +ninja==1.11.1.4 +nanobind==2.10.2 +scikit-build-core==0.11.6 \ No newline at end of file diff --git a/.github/workflows/verify_extension_build.yml b/.github/workflows/verify_extension_build.yml index 296db3f1..25cf966f 100644 --- a/.github/workflows/verify_extension_build.yml +++ b/.github/workflows/verify_extension_build.yml @@ -35,7 +35,7 @@ jobs: run: | pytest tests/import_test.py -k test_import - - name: Install dependencies to test JAX extension build + - name: Test JAX extension build run: | pip install "jax[cuda12]" pip install -e ./openequivariance[jax] From 1817a2a0dbfb95331a2220c5da2305eb7d452283 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 24 Dec 2025 21:51:11 -0800 Subject: [PATCH 077/116] Merge complete. --- openequivariance/openequivariance/_torch/extlib/__init__.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/openequivariance/openequivariance/_torch/extlib/__init__.py b/openequivariance/openequivariance/_torch/extlib/__init__.py index 0863f6ee..26b5b52f 100644 --- a/openequivariance/openequivariance/_torch/extlib/__init__.py +++ b/openequivariance/openequivariance/_torch/extlib/__init__.py @@ -23,7 +23,6 @@ torch_module, generic_module = None, None postprocess_kernel = lambda kernel: kernel # noqa : E731 - try: python_lib_dir = sysconfig.get_config_var("LIBDIR") major, minor = sys.version_info.major, sys.version_info.minor @@ -44,6 +43,7 @@ if BUILT_EXTENSION: import openequivariance._torch.extlib.generic_module generic_module = openequivariance._torch.extlib.generic_module + elif torch.version.cuda or torch.version.hip: try: from torch.utils.cpp_extension import library_paths, include_paths @@ -141,12 +141,10 @@ def _raise_import_error_helper(import_target: str): if not BUILT_EXTENSION: raise ImportError(f"Could not import {import_target}: {BUILT_EXTENSION_ERROR}") - def torch_ext_so_path(): return torch_module.__file__ - -if TORCH_VERSION_CUDA_OR_HIP: +if BUILT_EXTENSION: from generic_module import ( JITTPImpl, JITConvImpl, From a030fb588b1b79928afa9110d8c4c64865184aab Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 24 Dec 2025 22:40:30 -0800 Subject: [PATCH 078/116] Updated README. --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 176ba5cc..e6188c8b 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,8 @@ check out the [documentation](https://passionlab.github.io/OpenEquivariance). support for JAX. To install, execute the following commands in order: -``` +``` bash +pip install jax[cuda12] # Not needed if you already have JAX pip install openequivariance[jax] pip install openequivariance_extjax --no-build-isolation ``` From 6edad897dc9293a61dde46e8bcf9296e6cdbabcd Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Thu, 25 Dec 2025 11:58:24 -0800 Subject: [PATCH 079/116] Fixed some minor issues. --- .github/workflows/release.yaml | 2 +- .github/workflows/verify_extension_build.yml | 10 +++------- openequivariance/MANIFEST.in | 6 ++++-- openequivariance_extjax/MANIFEST.in | 2 ++ 4 files changed, 10 insertions(+), 10 deletions(-) create mode 100644 openequivariance_extjax/MANIFEST.in diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index d1407fc0..2881158a 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -69,7 +69,7 @@ jobs: name: python-package-distributions path: openequivariance_extjax/dist/ - pypi-publish: + pypi-publish-extjax: name: Upload release to PyPI runs-on: ubuntu-latest # build task to be completed first diff --git a/.github/workflows/verify_extension_build.yml b/.github/workflows/verify_extension_build.yml index 2cb032cb..6c08d673 100644 --- a/.github/workflows/verify_extension_build.yml +++ b/.github/workflows/verify_extension_build.yml @@ -33,16 +33,12 @@ jobs: - name: Test extension build via import run: | -<<<<<<< HEAD - pytest tests/import_test.py -k test_import + pytest \ + tests/import_test.py::test_extension_built \ + tests/import_test.py::test_torch_extension_built - name: Test JAX extension build run: | pip install "jax[cuda12]" pip install -e ./openequivariance[jax] pip install -e ./openequivariance_extjax --no-build-isolation -======= - pytest \ - tests/import_test.py::test_extension_built \ - tests/import_test.py::test_torch_extension_built ->>>>>>> main diff --git a/openequivariance/MANIFEST.in b/openequivariance/MANIFEST.in index ab5b72e7..1d4a8cce 100644 --- a/openequivariance/MANIFEST.in +++ b/openequivariance/MANIFEST.in @@ -1,2 +1,4 @@ -include templates/*.cuh -include openequivariance/templates/*.jinja \ No newline at end of file +include openequivariance/templates/*.cuh +include openequivariance/templates/*.jinja + +include openequivariance/extension/* \ No newline at end of file diff --git a/openequivariance_extjax/MANIFEST.in b/openequivariance_extjax/MANIFEST.in new file mode 100644 index 00000000..fcd76e59 --- /dev/null +++ b/openequivariance_extjax/MANIFEST.in @@ -0,0 +1,2 @@ +include CMakeLists.txt +include src/* \ No newline at end of file From 895ad78e322949071045ee81e7dac99365a12d1b Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Thu, 25 Dec 2025 12:12:10 -0800 Subject: [PATCH 080/116] Added symlinks. --- README.md | 8 ++++---- openequivariance/LICENSE | 1 + openequivariance/README.md | 7 +------ openequivariance_extjax/LICENSE | 1 + openequivariance_extjax/pyproject.toml | 4 ++-- 5 files changed, 9 insertions(+), 12 deletions(-) create mode 120000 openequivariance/LICENSE mode change 100644 => 120000 openequivariance/README.md create mode 120000 openequivariance_extjax/LICENSE diff --git a/README.md b/README.md index e6188c8b..a7184cb8 100644 --- a/README.md +++ b/README.md @@ -13,8 +13,8 @@ that [e3nn](https://e3nn.org/) supports commonly found in graph neural networks (e.g. [Nequip](https://github.com/mir-group/nequip) or [MACE](https://github.com/ACEsuit/mace)). To get -started, ensure that you have GCC 9+ on your system -and install our package via +started with PyTorch, ensure that you have PyTorch +and GCC 9+ available before installing our package via ```bash pip install openequivariance @@ -35,11 +35,11 @@ support for JAX. To install, execute the following commands in order: ``` bash -pip install jax[cuda12] # Not needed if you already have JAX +pip install jax[cuda12] # Skip if JAX installed pip install openequivariance[jax] pip install openequivariance_extjax --no-build-isolation ``` -See below for example usage. +See the section below for example usage. ## Show me some examples Here's a CG tensor product implemented by e3nn: diff --git a/openequivariance/LICENSE b/openequivariance/LICENSE new file mode 120000 index 00000000..ea5b6064 --- /dev/null +++ b/openequivariance/LICENSE @@ -0,0 +1 @@ +../LICENSE \ No newline at end of file diff --git a/openequivariance/README.md b/openequivariance/README.md deleted file mode 100644 index 976ab6c6..00000000 --- a/openequivariance/README.md +++ /dev/null @@ -1,6 +0,0 @@ -# OpenEquivariance - -The core implementation of OpenEquivariance with -PyTorch support. For JAX, see instructions -on installing `openequivariance_extjax` along with this package. - diff --git a/openequivariance/README.md b/openequivariance/README.md new file mode 120000 index 00000000..32d46ee8 --- /dev/null +++ b/openequivariance/README.md @@ -0,0 +1 @@ +../README.md \ No newline at end of file diff --git a/openequivariance_extjax/LICENSE b/openequivariance_extjax/LICENSE new file mode 120000 index 00000000..ea5b6064 --- /dev/null +++ b/openequivariance_extjax/LICENSE @@ -0,0 +1 @@ +../LICENSE \ No newline at end of file diff --git a/openequivariance_extjax/pyproject.toml b/openequivariance_extjax/pyproject.toml index cc6a4738..41137b61 100644 --- a/openequivariance_extjax/pyproject.toml +++ b/openequivariance_extjax/pyproject.toml @@ -21,8 +21,8 @@ requires-python = ">=3.10" dependencies = [] readme = "README.md" -#license = "BSD-3-Clause" -#license-files = ["../LICENSE"] +license = "BSD-3-Clause" +license-files = ["LICENSE"] classifiers = [ "Programming Language :: Python :: 3", From 2b2c156f0c00073d060e40b7c9032594b0d4de0e Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Thu, 25 Dec 2025 14:52:24 -0800 Subject: [PATCH 081/116] Cleaning up the core. --- .../openequivariance/_torch/TensorProduct.py | 91 +++++++++++++++++++ .../_torch/TensorProductConv.py | 89 +++++++++++++++++- .../openequivariance/core/ConvolutionBase.py | 86 ------------------ .../core/TensorProductBase.py | 89 ------------------ openequivariance_extjax/CMakeLists.txt | 1 + tests/examples_test.py | 2 +- 6 files changed, 181 insertions(+), 177 deletions(-) diff --git a/openequivariance/openequivariance/_torch/TensorProduct.py b/openequivariance/openequivariance/_torch/TensorProduct.py index 2590207a..05ea54b5 100644 --- a/openequivariance/openequivariance/_torch/TensorProduct.py +++ b/openequivariance/openequivariance/_torch/TensorProduct.py @@ -8,6 +8,9 @@ from openequivariance._torch.utils import reorder_torch from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin +import numpy as np +from openequivariance._torch.extlib import DeviceBuffer + logger = getLogger() @@ -350,6 +353,94 @@ def register_autocast(cls): def name(): return "LoopUnrollTP" + def forward_raw( + self, + batch: np.uint64, + L1_in: np.uint64, + L2_in: np.uint64, + L3_out: np.uint64, + weights: np.uint64, + ) -> None: + self.internal.exec_tensor_product_rawptr(batch, L1_in, L2_in, L3_out, weights) + + def backward_raw( + self, + batch_size: np.uint64, + L1_in: np.uint64, + L1_grad: np.uint64, + L2_in: np.uint64, + L2_grad: np.uint64, + weights: np.uint64, + weights_grad: np.uint64, + L3_grad: np.uint64, + ): + self.internal.backward_rawptr( + batch_size, L1_in, L1_grad, L2_in, L2_grad, weights, weights_grad, L3_grad + ) + + def forward_cpu( + self, + L1_in: np.ndarray, + L2_in: np.ndarray, + L3_out: np.ndarray, + weights: np.ndarray, + ) -> None: + weights_chunked = self.reorder_weights_from_e3nn( + weights, not self.config.shared_weights + ) + + batch = L1_in.shape[0] + L1_d = DeviceBuffer(L1_in) + L2_d = DeviceBuffer(L2_in) + L3_d = DeviceBuffer(L3_out) + weights_d = DeviceBuffer(weights_chunked) + self.internal.exec_tensor_product_rawptr( + batch, + L1_d.data_ptr(), + L2_d.data_ptr(), + L3_d.data_ptr(), + weights_d.data_ptr(), + ) + L3_d.copy_to_host() + + def backward_cpu( + self, L1_in, L1_grad, L2_in, L2_grad, L3_grad, weights, weights_grad + ) -> None: + weights_chunked = self.reorder_weights_from_e3nn( + weights, not self.config.shared_weights + ) + + batch = L1_in.shape[0] + L1_d, L2_d, L3_d = ( + DeviceBuffer(L1_in), + DeviceBuffer(L2_in), + DeviceBuffer(L3_grad), + ) + L1_grad_d, L2_grad_d = DeviceBuffer(L1_grad), DeviceBuffer(L2_grad) + weights_d, weights_grad_d = ( + DeviceBuffer(weights_chunked), + DeviceBuffer(weights_grad), + ) + + self.internal.backward_rawptr( + batch, + L1_d.data_ptr(), + L1_grad_d.data_ptr(), + L2_d.data_ptr(), + L2_grad_d.data_ptr(), + weights_d.data_ptr(), + weights_grad_d.data_ptr(), + L3_d.data_ptr(), + ) + + L1_grad_d.copy_to_host() + L2_grad_d.copy_to_host() + weights_grad_d.copy_to_host() + + weights_grad[:] = self.reorder_weights_to_e3nn( + weights_grad, not self.config.shared_weights + ) + if extlib.TORCH_COMPILE: TensorProduct.register_torch_fakes() diff --git a/openequivariance/openequivariance/_torch/TensorProductConv.py b/openequivariance/openequivariance/_torch/TensorProductConv.py index d7880ec9..be572732 100644 --- a/openequivariance/openequivariance/_torch/TensorProductConv.py +++ b/openequivariance/openequivariance/_torch/TensorProductConv.py @@ -23,7 +23,7 @@ from openequivariance.benchmark.logging_utils import getLogger from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixinConv - +from openequivariance._torch.extlib import DeviceBuffer logger = getLogger() @@ -639,6 +639,93 @@ def register_autocast(cls): ) + def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph): + assert graph.rows.dtype == self.idx_dtype + assert graph.cols.dtype == self.idx_dtype + + weights_chunked = self.reorder_weights_from_e3nn( + weights, not self.config.shared_weights + ) + + L1_d, L2_d, weights_d = ( + DeviceBuffer(L1_in), + DeviceBuffer(L2_in), + DeviceBuffer(weights_chunked), + ) + L3_d = DeviceBuffer(L3_out) + + rows_d = DeviceBuffer(graph.rows) + cols_d = DeviceBuffer(graph.cols) + + self.internal.exec_conv_rawptrs( + L1_d.data_ptr(), + L2_d.data_ptr(), + weights_d.data_ptr(), + L3_d.data_ptr(), + rows_d.data_ptr(), + cols_d.data_ptr(), + graph.nnz, + graph.node_count, + self.workspace_ptr, + ) + + L3_d.copy_to_host() + + def backward_cpu( + self, L1_in, L1_grad, L2_in, L2_grad, weights, weights_grad, L3_grad, graph + ): + assert graph.rows.dtype == self.idx_dtype + assert graph.cols.dtype == self.idx_dtype + + weights_chunked = self.reorder_weights_from_e3nn( + weights, not self.config.shared_weights + ) + + L1_d = DeviceBuffer(L1_in) + L2_d = DeviceBuffer(L2_in) + weights_d = DeviceBuffer(weights_chunked) + L3_d = DeviceBuffer(L3_grad) + rows_d = DeviceBuffer(graph.rows) + cols_d = DeviceBuffer(graph.cols) + + L1_grad_d = DeviceBuffer(L1_grad) + L2_grad_d = DeviceBuffer(L2_grad) + weights_grad_d = DeviceBuffer(weights_grad) + + transpose_perm_d = None + transpose_perm_ptr = 0 + if self.deterministic: + transpose_perm_d = DeviceBuffer(graph.transpose_perm) + transpose_perm_ptr = transpose_perm_d.data_ptr() + + self.internal.backward_rawptrs( + L1_d.data_ptr(), + L1_grad_d.data_ptr(), + L2_d.data_ptr(), + L2_grad_d.data_ptr(), + weights_d.data_ptr(), + weights_grad_d.data_ptr(), + L3_d.data_ptr(), + rows_d.data_ptr(), + cols_d.data_ptr(), + graph.nnz, + graph.node_count, + self.workspace_ptr, + transpose_perm_ptr, + ) + + L1_grad_d.copy_to_host() + L2_grad_d.copy_to_host() + weights_grad_d.copy_to_host() + + weights_grad[:] = self.reorder_weights_to_e3nn( + weights_grad, not self.config.shared_weights + ) + + return L1_grad, L2_grad, weights_grad + + + if extlib.TORCH_COMPILE: TensorProductConv.register_torch_fakes() TensorProductConv.register_autograd() diff --git a/openequivariance/openequivariance/core/ConvolutionBase.py b/openequivariance/openequivariance/core/ConvolutionBase.py index dbfeb5ff..a06b2c79 100644 --- a/openequivariance/openequivariance/core/ConvolutionBase.py +++ b/openequivariance/openequivariance/core/ConvolutionBase.py @@ -10,7 +10,6 @@ from openequivariance.benchmark.correctness_utils import check_similiarity from openequivariance.core.e3nn_lite import wigner_3j from openequivariance.core.utils import benchmark -from openequivariance._torch.extlib import DeviceBuffer logger = getLogger() @@ -135,91 +134,6 @@ def reorder_weights_to_e3nn(self, weights, has_batch_dim=True): def name(): raise NotImplementedError() - def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph): - assert graph.rows.dtype == self.idx_dtype - assert graph.cols.dtype == self.idx_dtype - - weights_chunked = self.reorder_weights_from_e3nn( - weights, not self.config.shared_weights - ) - - L1_d, L2_d, weights_d = ( - DeviceBuffer(L1_in), - DeviceBuffer(L2_in), - DeviceBuffer(weights_chunked), - ) - L3_d = DeviceBuffer(L3_out) - - rows_d = DeviceBuffer(graph.rows) - cols_d = DeviceBuffer(graph.cols) - - self.internal.exec_conv_rawptrs( - L1_d.data_ptr(), - L2_d.data_ptr(), - weights_d.data_ptr(), - L3_d.data_ptr(), - rows_d.data_ptr(), - cols_d.data_ptr(), - graph.nnz, - graph.node_count, - self.workspace_ptr, - ) - - L3_d.copy_to_host() - - def backward_cpu( - self, L1_in, L1_grad, L2_in, L2_grad, weights, weights_grad, L3_grad, graph - ): - assert graph.rows.dtype == self.idx_dtype - assert graph.cols.dtype == self.idx_dtype - - weights_chunked = self.reorder_weights_from_e3nn( - weights, not self.config.shared_weights - ) - - L1_d = DeviceBuffer(L1_in) - L2_d = DeviceBuffer(L2_in) - weights_d = DeviceBuffer(weights_chunked) - L3_d = DeviceBuffer(L3_grad) - rows_d = DeviceBuffer(graph.rows) - cols_d = DeviceBuffer(graph.cols) - - L1_grad_d = DeviceBuffer(L1_grad) - L2_grad_d = DeviceBuffer(L2_grad) - weights_grad_d = DeviceBuffer(weights_grad) - - transpose_perm_d = None - transpose_perm_ptr = 0 - if self.deterministic: - transpose_perm_d = DeviceBuffer(graph.transpose_perm) - transpose_perm_ptr = transpose_perm_d.data_ptr() - - self.internal.backward_rawptrs( - L1_d.data_ptr(), - L1_grad_d.data_ptr(), - L2_d.data_ptr(), - L2_grad_d.data_ptr(), - weights_d.data_ptr(), - weights_grad_d.data_ptr(), - L3_d.data_ptr(), - rows_d.data_ptr(), - cols_d.data_ptr(), - graph.nnz, - graph.node_count, - self.workspace_ptr, - transpose_perm_ptr, - ) - - L1_grad_d.copy_to_host() - L2_grad_d.copy_to_host() - weights_grad_d.copy_to_host() - - weights_grad[:] = self.reorder_weights_to_e3nn( - weights_grad, not self.config.shared_weights - ) - - return L1_grad, L2_grad, weights_grad - def test_correctness_forward( self, graph, diff --git a/openequivariance/openequivariance/core/TensorProductBase.py b/openequivariance/openequivariance/core/TensorProductBase.py index 5af538bb..b5d3831f 100644 --- a/openequivariance/openequivariance/core/TensorProductBase.py +++ b/openequivariance/openequivariance/core/TensorProductBase.py @@ -3,7 +3,6 @@ from openequivariance.core.e3nn_lite import TPProblem from openequivariance.benchmark.logging_utils import getLogger from openequivariance.core.utils import benchmark -from openequivariance._torch.extlib import DeviceBuffer logger = getLogger() @@ -67,94 +66,6 @@ def reorder_weights_to_e3nn(self, weights, has_batch_dim: bool = True): """ return weights - def forward_raw( - self, - batch: np.uint64, - L1_in: np.uint64, - L2_in: np.uint64, - L3_out: np.uint64, - weights: np.uint64, - ) -> None: - self.internal.exec_tensor_product_rawptr(batch, L1_in, L2_in, L3_out, weights) - - def backward_raw( - self, - batch_size: np.uint64, - L1_in: np.uint64, - L1_grad: np.uint64, - L2_in: np.uint64, - L2_grad: np.uint64, - weights: np.uint64, - weights_grad: np.uint64, - L3_grad: np.uint64, - ): - self.internal.backward_rawptr( - batch_size, L1_in, L1_grad, L2_in, L2_grad, weights, weights_grad, L3_grad - ) - - def forward_cpu( - self, - L1_in: np.ndarray, - L2_in: np.ndarray, - L3_out: np.ndarray, - weights: np.ndarray, - ) -> None: - weights_chunked = self.reorder_weights_from_e3nn( - weights, not self.config.shared_weights - ) - - batch = L1_in.shape[0] - L1_d = DeviceBuffer(L1_in) - L2_d = DeviceBuffer(L2_in) - L3_d = DeviceBuffer(L3_out) - weights_d = DeviceBuffer(weights_chunked) - self.internal.exec_tensor_product_rawptr( - batch, - L1_d.data_ptr(), - L2_d.data_ptr(), - L3_d.data_ptr(), - weights_d.data_ptr(), - ) - L3_d.copy_to_host() - - def backward_cpu( - self, L1_in, L1_grad, L2_in, L2_grad, L3_grad, weights, weights_grad - ) -> None: - weights_chunked = self.reorder_weights_from_e3nn( - weights, not self.config.shared_weights - ) - - batch = L1_in.shape[0] - L1_d, L2_d, L3_d = ( - DeviceBuffer(L1_in), - DeviceBuffer(L2_in), - DeviceBuffer(L3_grad), - ) - L1_grad_d, L2_grad_d = DeviceBuffer(L1_grad), DeviceBuffer(L2_grad) - weights_d, weights_grad_d = ( - DeviceBuffer(weights_chunked), - DeviceBuffer(weights_grad), - ) - - self.internal.backward_rawptr( - batch, - L1_d.data_ptr(), - L1_grad_d.data_ptr(), - L2_d.data_ptr(), - L2_grad_d.data_ptr(), - weights_d.data_ptr(), - weights_grad_d.data_ptr(), - L3_d.data_ptr(), - ) - - L1_grad_d.copy_to_host() - L2_grad_d.copy_to_host() - weights_grad_d.copy_to_host() - - weights_grad[:] = self.reorder_weights_to_e3nn( - weights_grad, not self.config.shared_weights - ) - def benchmark_forward( self, num_warmup: int, diff --git a/openequivariance_extjax/CMakeLists.txt b/openequivariance_extjax/CMakeLists.txt index f40272f2..6dbb5052 100644 --- a/openequivariance_extjax/CMakeLists.txt +++ b/openequivariance_extjax/CMakeLists.txt @@ -17,6 +17,7 @@ execute_process( ) message(STATUS "nanobind cmake directory: ${nanobind_ROOT}") +set(ENV{OEQ_NOTORCH} "1") execute_process( COMMAND "${Python_EXECUTABLE}" "-c" "import openequivariance; print(openequivariance.extension_source_path())" diff --git a/tests/examples_test.py b/tests/examples_test.py index bad7c220..ae19f77e 100644 --- a/tests/examples_test.py +++ b/tests/examples_test.py @@ -91,7 +91,7 @@ def test_tutorial_jax(with_jax): if not with_jax: pytest.skip("Skipping JAX tutorial when testing PyTorch") - os.environ.OEQ_NOTORCH = "1" + os.environ["OEQ_NOTORCH"] = "1" import openequivariance as oeq import jax From c28f2b9d7765321f0812dc6978f32d7b33ee8382 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Thu, 25 Dec 2025 15:54:18 -0800 Subject: [PATCH 082/116] More core cleanup. --- .../openequivariance/_torch/utils.py | 14 +++++- .../openequivariance/core/LoopUnrollConv.py | 2 +- .../openequivariance/core/LoopUnrollTP.py | 2 +- .../openequivariance/core/dtype_enum.py | 47 ------------------- .../openequivariance/core/utils.py | 31 ++++++++++-- 5 files changed, 43 insertions(+), 53 deletions(-) delete mode 100644 openequivariance/openequivariance/core/dtype_enum.py diff --git a/openequivariance/openequivariance/_torch/utils.py b/openequivariance/openequivariance/_torch/utils.py index e03bb3ea..0fe98309 100644 --- a/openequivariance/openequivariance/_torch/utils.py +++ b/openequivariance/openequivariance/_torch/utils.py @@ -1,5 +1,7 @@ import torch - +from types import MappingProxyType +from openequivariance.core.utils import DTypeEnum +from openequivariance.core.utils import dtype_to_enum as dtype_to_enum_core def reorder_helper(schedule, weights_in, direction, has_batch_dim): assert direction in ["forward", "backward"] @@ -53,3 +55,13 @@ def reorder_torch(schedule, weights_in, direction, has_batch_dim): return reorder_helper(schedule, weights_in, direction, has_batch_dim) else: return reorder_numpy_helper(schedule, weights_in, direction, has_batch_dim) + +enum_to_torch_dtype = MappingProxyType( + { + DTypeEnum.FLOAT32: torch.float32, + DTypeEnum.FLOAT64: torch.float64, + DTypeEnum.INT32: torch.int32, + DTypeEnum.INT64: torch.int64, + DTypeEnum.UINT8: torch.uint8, + } +) \ No newline at end of file diff --git a/openequivariance/openequivariance/core/LoopUnrollConv.py b/openequivariance/openequivariance/core/LoopUnrollConv.py index 0763d69c..35a9bc3e 100644 --- a/openequivariance/openequivariance/core/LoopUnrollConv.py +++ b/openequivariance/openequivariance/core/LoopUnrollConv.py @@ -6,7 +6,7 @@ SMEMCapacityException, ) -from openequivariance.core.dtype_enum import dtype_to_enum +from openequivariance.core.utils import dtype_to_enum from openequivariance.templates.jinja_utils import get_jinja_environment from openequivariance.core.utils import filter_and_analyze_problem diff --git a/openequivariance/openequivariance/core/LoopUnrollTP.py b/openequivariance/openequivariance/core/LoopUnrollTP.py index 1705a8dd..12ad4536 100644 --- a/openequivariance/openequivariance/core/LoopUnrollTP.py +++ b/openequivariance/openequivariance/core/LoopUnrollTP.py @@ -3,7 +3,7 @@ from openequivariance.templates.jinja_utils import get_jinja_environment from openequivariance.core.ComputationSchedule import ComputationSchedule from openequivariance.core.TensorProductBase import TensorProductBase -from openequivariance.core.dtype_enum import dtype_to_enum +from openequivariance.core.utils import dtype_to_enum from openequivariance.core.utils import ( filter_and_analyze_problem, diff --git a/openequivariance/openequivariance/core/dtype_enum.py b/openequivariance/openequivariance/core/dtype_enum.py deleted file mode 100644 index 292b7e4f..00000000 --- a/openequivariance/openequivariance/core/dtype_enum.py +++ /dev/null @@ -1,47 +0,0 @@ -from enum import IntEnum -from types import MappingProxyType -import numpy as np -import torch - - -class DTypeEnum(IntEnum): - FLOAT32 = 1 - FLOAT64 = 2 - INT32 = 3 - INT64 = 4 - UINT8 = 5 - - -dtype_to_enum = MappingProxyType( - { - torch.float32: DTypeEnum.FLOAT32, - torch.float64: DTypeEnum.FLOAT64, - torch.int32: DTypeEnum.INT32, - torch.int64: DTypeEnum.INT64, - torch.uint8: DTypeEnum.UINT8, - # torch - np.float32: DTypeEnum.FLOAT32, - np.float64: DTypeEnum.FLOAT64, - np.int32: DTypeEnum.INT32, - np.int64: DTypeEnum.INT64, - np.uint8: DTypeEnum.UINT8, - # numpy generic - np.dtype(np.float32): DTypeEnum.FLOAT32, - np.dtype(np.float64): DTypeEnum.FLOAT64, - np.dtype(np.int32): DTypeEnum.INT32, - np.dtype(np.int64): DTypeEnum.INT64, - np.dtype(np.uint8): DTypeEnum.UINT8, - # numpy dtype - } -) - - -enum_to_torch_dtype = MappingProxyType( - { - DTypeEnum.FLOAT32: torch.float32, - DTypeEnum.FLOAT64: torch.float64, - DTypeEnum.INT32: torch.int32, - DTypeEnum.INT64: torch.int64, - DTypeEnum.UINT8: torch.uint8, - } -) diff --git a/openequivariance/openequivariance/core/utils.py b/openequivariance/openequivariance/core/utils.py index 472c00f2..960b5826 100644 --- a/openequivariance/openequivariance/core/utils.py +++ b/openequivariance/openequivariance/core/utils.py @@ -8,13 +8,38 @@ import json import tempfile import hashlib -from openequivariance._torch.extlib import GPUTimer + +from enum import IntEnum +import numpy as np + +class DTypeEnum(IntEnum): + ''' + The C++ layer storess a copy of this map. + ''' + FLOAT32 = 1 + FLOAT64 = 2 + INT32 = 3 + INT64 = 4 + UINT8 = 5 + +dtype_to_enum = { + np.float32: DTypeEnum.FLOAT32, + np.float64: DTypeEnum.FLOAT64, + np.int32: DTypeEnum.INT32, + np.int64: DTypeEnum.INT64, + np.uint8: DTypeEnum.UINT8, + + np.dtype(np.float32): DTypeEnum.FLOAT32, + np.dtype(np.float64): DTypeEnum.FLOAT64, + np.dtype(np.int32): DTypeEnum.INT32, + np.dtype(np.int64): DTypeEnum.INT64, + np.dtype(np.uint8): DTypeEnum.UINT8, + } def sparse_outer_product_work(cg: np.ndarray) -> int: return np.sum(np.max(cg != 0, axis=2)) - # Nonzeros @functools.lru_cache(typed=True) def count_cg_non_zero(l1, l2, l3) -> int: @@ -86,7 +111,6 @@ def filter_and_analyze_problem(problem): } return result - def torch_to_oeq_dtype(torch_dtype) -> type[np.generic]: """ Convenience function; converts a torch datatype to the corresponding @@ -124,6 +148,7 @@ def benchmark(func, num_warmup, num_iter, mode="gpu_time", kernel_names=[]): mode=gpu_time may include PyTorch overhead mode=kernel_time measures runtime for only the specified kernels """ + from openequivariance._torch.extlib import GPUTimer assert mode in ["gpu_time", "torch_kernel_time"] time_millis = np.zeros(num_iter, dtype=np.float32) timer = GPUTimer() From 52cf2ce70a29651c34cd0ca61df2d25fb973b836 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Thu, 25 Dec 2025 15:57:21 -0800 Subject: [PATCH 083/116] Rename. --- tests/{examples_test.py => example_test.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{examples_test.py => example_test.py} (100%) diff --git a/tests/examples_test.py b/tests/example_test.py similarity index 100% rename from tests/examples_test.py rename to tests/example_test.py From b2145cdb321b8656d00fa66c3f856f86ad4f8b5d Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Thu, 25 Dec 2025 16:12:10 -0800 Subject: [PATCH 084/116] Example test is working. --- .../openequivariance/extension/util/backend_cuda.hpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/openequivariance/openequivariance/extension/util/backend_cuda.hpp b/openequivariance/openequivariance/extension/util/backend_cuda.hpp index 364186fc..4c79faed 100644 --- a/openequivariance/openequivariance/extension/util/backend_cuda.hpp +++ b/openequivariance/openequivariance/extension/util/backend_cuda.hpp @@ -349,7 +349,11 @@ class __attribute__((visibility("default"))) CUJITKernel { ~CUJITKernel() { if(compiled) { - CUDA_SAFE_CALL(cuLibraryUnload(library)); + auto result = cuLibraryUnload(library); + if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) { + std::cout << "Failed to unload CUDA library, error code: " << ((int) result) << std::endl; + } + delete[] code; } NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog)); From 3f5953249eaf84046ffafc775c2f2903f45f67a7 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Thu, 25 Dec 2025 16:21:19 -0800 Subject: [PATCH 085/116] Sanded away some more issues. --- openequivariance/openequivariance/_torch/TensorProductConv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openequivariance/openequivariance/_torch/TensorProductConv.py b/openequivariance/openequivariance/_torch/TensorProductConv.py index be572732..ed55c936 100644 --- a/openequivariance/openequivariance/_torch/TensorProductConv.py +++ b/openequivariance/openequivariance/_torch/TensorProductConv.py @@ -18,7 +18,7 @@ from openequivariance._torch.TensorProduct import TensorProduct from openequivariance import TPProblem from openequivariance.core.utils import torch_to_oeq_dtype -from openequivariance.core.dtype_enum import enum_to_torch_dtype +from openequivariance._torch.utils import enum_to_torch_dtype from openequivariance._torch.utils import reorder_torch from openequivariance.benchmark.logging_utils import getLogger From e49ad884053242d024ce31d20886a5bca2ca774f Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Thu, 25 Dec 2025 18:40:51 -0800 Subject: [PATCH 086/116] Updated changelog. --- CHANGELOG.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a120656b..49aef9d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,16 @@ ## Latest Changes +### v0.5.0 (2025-12-25) +JAX support is now available in +OpenEquivariance for NVIDIA GPUs. See the +[documentation](https://passionlab.github.io/OpenEquivariance/) +and README.md for instructions on installation +and usage. + +Minor changes: +- Defer error reporting when CUDA is not available + to the first library usage in code, not library load. + ### v0.4.1 (2025-09-04) Minor update, fixes a bug loading JIT-compiled modules with PyTorch 2.9. From d5650b47ab592b488051189029ef7840eeda4b85 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Thu, 25 Dec 2025 18:44:53 -0800 Subject: [PATCH 087/116] Pre-commit. --- .../_torch/TensorProductConv.py | 2 -- .../_torch/extlib/__init__.py | 3 ++ .../openequivariance/_torch/utils.py | 5 +-- .../openequivariance/core/utils.py | 34 +++++++++++-------- 4 files changed, 25 insertions(+), 19 deletions(-) diff --git a/openequivariance/openequivariance/_torch/TensorProductConv.py b/openequivariance/openequivariance/_torch/TensorProductConv.py index ed55c936..f30c943c 100644 --- a/openequivariance/openequivariance/_torch/TensorProductConv.py +++ b/openequivariance/openequivariance/_torch/TensorProductConv.py @@ -638,7 +638,6 @@ def register_autocast(cls): "libtorch_tp_jit::jit_conv_double_backward", "cuda", torch.float32 ) - def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph): assert graph.rows.dtype == self.idx_dtype assert graph.cols.dtype == self.idx_dtype @@ -725,7 +724,6 @@ def backward_cpu( return L1_grad, L2_grad, weights_grad - if extlib.TORCH_COMPILE: TensorProductConv.register_torch_fakes() TensorProductConv.register_autograd() diff --git a/openequivariance/openequivariance/_torch/extlib/__init__.py b/openequivariance/openequivariance/_torch/extlib/__init__.py index 26b5b52f..72440872 100644 --- a/openequivariance/openequivariance/_torch/extlib/__init__.py +++ b/openequivariance/openequivariance/_torch/extlib/__init__.py @@ -42,6 +42,7 @@ if BUILT_EXTENSION: import openequivariance._torch.extlib.generic_module + generic_module = openequivariance._torch.extlib.generic_module elif torch.version.cuda or torch.version.hip: @@ -141,9 +142,11 @@ def _raise_import_error_helper(import_target: str): if not BUILT_EXTENSION: raise ImportError(f"Could not import {import_target}: {BUILT_EXTENSION_ERROR}") + def torch_ext_so_path(): return torch_module.__file__ + if BUILT_EXTENSION: from generic_module import ( JITTPImpl, diff --git a/openequivariance/openequivariance/_torch/utils.py b/openequivariance/openequivariance/_torch/utils.py index 0fe98309..7538fb27 100644 --- a/openequivariance/openequivariance/_torch/utils.py +++ b/openequivariance/openequivariance/_torch/utils.py @@ -1,7 +1,7 @@ import torch from types import MappingProxyType from openequivariance.core.utils import DTypeEnum -from openequivariance.core.utils import dtype_to_enum as dtype_to_enum_core + def reorder_helper(schedule, weights_in, direction, has_batch_dim): assert direction in ["forward", "backward"] @@ -56,6 +56,7 @@ def reorder_torch(schedule, weights_in, direction, has_batch_dim): else: return reorder_numpy_helper(schedule, weights_in, direction, has_batch_dim) + enum_to_torch_dtype = MappingProxyType( { DTypeEnum.FLOAT32: torch.float32, @@ -64,4 +65,4 @@ def reorder_torch(schedule, weights_in, direction, has_batch_dim): DTypeEnum.INT64: torch.int64, DTypeEnum.UINT8: torch.uint8, } -) \ No newline at end of file +) diff --git a/openequivariance/openequivariance/core/utils.py b/openequivariance/openequivariance/core/utils.py index 960b5826..5fd8f81d 100644 --- a/openequivariance/openequivariance/core/utils.py +++ b/openequivariance/openequivariance/core/utils.py @@ -10,36 +10,38 @@ import hashlib from enum import IntEnum -import numpy as np + class DTypeEnum(IntEnum): - ''' + """ The C++ layer storess a copy of this map. - ''' + """ + FLOAT32 = 1 FLOAT64 = 2 INT32 = 3 INT64 = 4 UINT8 = 5 + dtype_to_enum = { - np.float32: DTypeEnum.FLOAT32, - np.float64: DTypeEnum.FLOAT64, - np.int32: DTypeEnum.INT32, - np.int64: DTypeEnum.INT64, - np.uint8: DTypeEnum.UINT8, - - np.dtype(np.float32): DTypeEnum.FLOAT32, - np.dtype(np.float64): DTypeEnum.FLOAT64, - np.dtype(np.int32): DTypeEnum.INT32, - np.dtype(np.int64): DTypeEnum.INT64, - np.dtype(np.uint8): DTypeEnum.UINT8, - } + np.float32: DTypeEnum.FLOAT32, + np.float64: DTypeEnum.FLOAT64, + np.int32: DTypeEnum.INT32, + np.int64: DTypeEnum.INT64, + np.uint8: DTypeEnum.UINT8, + np.dtype(np.float32): DTypeEnum.FLOAT32, + np.dtype(np.float64): DTypeEnum.FLOAT64, + np.dtype(np.int32): DTypeEnum.INT32, + np.dtype(np.int64): DTypeEnum.INT64, + np.dtype(np.uint8): DTypeEnum.UINT8, +} def sparse_outer_product_work(cg: np.ndarray) -> int: return np.sum(np.max(cg != 0, axis=2)) + # Nonzeros @functools.lru_cache(typed=True) def count_cg_non_zero(l1, l2, l3) -> int: @@ -111,6 +113,7 @@ def filter_and_analyze_problem(problem): } return result + def torch_to_oeq_dtype(torch_dtype) -> type[np.generic]: """ Convenience function; converts a torch datatype to the corresponding @@ -149,6 +152,7 @@ def benchmark(func, num_warmup, num_iter, mode="gpu_time", kernel_names=[]): mode=kernel_time measures runtime for only the specified kernels """ from openequivariance._torch.extlib import GPUTimer + assert mode in ["gpu_time", "torch_kernel_time"] time_millis = np.zeros(num_iter, dtype=np.float32) timer = GPUTimer() From 14f642d05f3e13f9afc30a22f6c3cba9ba5b756f Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 26 Dec 2025 13:34:30 -0800 Subject: [PATCH 088/116] Download XLA directly. --- openequivariance_extjax/CMakeLists.txt | 30 +++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/openequivariance_extjax/CMakeLists.txt b/openequivariance_extjax/CMakeLists.txt index 6dbb5052..6e1b8ccd 100644 --- a/openequivariance_extjax/CMakeLists.txt +++ b/openequivariance_extjax/CMakeLists.txt @@ -4,11 +4,29 @@ project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX) # TODO: Add HIP support find_package(Python 3.11 REQUIRED COMPONENTS Interpreter Development.Module) find_package(CUDAToolkit REQUIRED) -execute_process( - COMMAND "${Python_EXECUTABLE}" "-c" - "from jax import ffi; print(ffi.include_dir())" - OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR -) +# Includes XLA headers based on a JAX installation +#execute_process( +# COMMAND "${Python_EXECUTABLE}" "-c" +# "from jax import ffi; print(ffi.include_dir())" +# OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR +#) +#message(STATUS "XLA include directory: ${XLA_DIR}") + +include(ExternalProject) +ExternalProject_Add( + xla + PREFIX ${CMAKE_BINARY_DIR}/xla + GIT_REPOSITORY https://github.com/openxla/xla.git + GIT_TAG d56e645fe72988e2b4464119300ca6f894f82598 + GIT_SHALLOW TRUE + GIT_PROGRESS TRUE + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + LOG_DOWNLOAD ON + ) +ExternalProject_Get_Property(xla source_dir) +set(XLA_DIR ${source_dir}) message(STATUS "XLA include directory: ${XLA_DIR}") execute_process( @@ -48,6 +66,8 @@ target_compile_options(openequivariance_extjax PRIVATE -Wno-attributes -Wno-retu get_target_property(CUDA_LIB_DIR CUDA::nvrtc IMPORTED_LOCATION) get_filename_component(CUDA_LIB_DIR ${CUDA_LIB_DIR} DIRECTORY) +add_dependencies(openequivariance_extjax xla) + set_target_properties(openequivariance_extjax PROPERTIES BUILD_RPATH "${CUDA_LIB_DIR}" INSTALL_RPATH "${CUDA_LIB_DIR}" From 0e71c548caf1713d99c9970d3c4cf60b0418714f Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 26 Dec 2025 13:54:17 -0800 Subject: [PATCH 089/116] Removed need for build isolation. --- openequivariance_extjax/CMakeLists.txt | 29 +++++++++++--------------- openequivariance_extjax/src/extension | 1 + 2 files changed, 13 insertions(+), 17 deletions(-) create mode 120000 openequivariance_extjax/src/extension diff --git a/openequivariance_extjax/CMakeLists.txt b/openequivariance_extjax/CMakeLists.txt index 6e1b8ccd..122b4bcb 100644 --- a/openequivariance_extjax/CMakeLists.txt +++ b/openequivariance_extjax/CMakeLists.txt @@ -4,14 +4,6 @@ project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX) # TODO: Add HIP support find_package(Python 3.11 REQUIRED COMPONENTS Interpreter Development.Module) find_package(CUDAToolkit REQUIRED) -# Includes XLA headers based on a JAX installation -#execute_process( -# COMMAND "${Python_EXECUTABLE}" "-c" -# "from jax import ffi; print(ffi.include_dir())" -# OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR -#) -#message(STATUS "XLA include directory: ${XLA_DIR}") - include(ExternalProject) ExternalProject_Add( xla @@ -34,14 +26,7 @@ execute_process( OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE nanobind_ROOT ) message(STATUS "nanobind cmake directory: ${nanobind_ROOT}") - -set(ENV{OEQ_NOTORCH} "1") -execute_process( - COMMAND "${Python_EXECUTABLE}" "-c" - "import openequivariance; print(openequivariance.extension_source_path())" - OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE HEADER_DIR -) -message(STATUS "OpenEquivariance extension source directory: ${HEADER_DIR}") +set(HEADER_DIR "src/extension") find_package(nanobind CONFIG REQUIRED) @@ -78,4 +63,14 @@ target_link_libraries(openequivariance_extjax PRIVATE CUDA::cuda_driver CUDA::nvrtc) -install(TARGETS openequivariance_extjax LIBRARY DESTINATION .) \ No newline at end of file +install(TARGETS openequivariance_extjax LIBRARY DESTINATION .) + +# ------------------------------------------------------------- + +# Uncomment to include XLA from JAX installation +#execute_process( +# COMMAND "${Python_EXECUTABLE}" "-c" +# "from jax import ffi; print(ffi.include_dir())" +# OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR +#) +#message(STATUS "XLA include directory: ${XLA_DIR}") \ No newline at end of file diff --git a/openequivariance_extjax/src/extension b/openequivariance_extjax/src/extension new file mode 120000 index 00000000..8370a418 --- /dev/null +++ b/openequivariance_extjax/src/extension @@ -0,0 +1 @@ +../../openequivariance/openequivariance/extension \ No newline at end of file From 85c2e2fe927d2ba721d4314952267680134f6d22 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 26 Dec 2025 14:05:04 -0800 Subject: [PATCH 090/116] Removed need for build isolation. --- .github/workflows/verify_extension_build.yml | 4 +--- README.md | 12 ++++++------ docs/installation.rst | 11 +++++------ openequivariance/pyproject.toml | 9 +-------- 4 files changed, 13 insertions(+), 23 deletions(-) diff --git a/.github/workflows/verify_extension_build.yml b/.github/workflows/verify_extension_build.yml index 6c08d673..9bbef17b 100644 --- a/.github/workflows/verify_extension_build.yml +++ b/.github/workflows/verify_extension_build.yml @@ -39,6 +39,4 @@ jobs: - name: Test JAX extension build run: | - pip install "jax[cuda12]" - pip install -e ./openequivariance[jax] - pip install -e ./openequivariance_extjax --no-build-isolation + pip install -e ./openequivariance_extjax diff --git a/README.md b/README.md index a7184cb8..d97795d7 100644 --- a/README.md +++ b/README.md @@ -31,15 +31,15 @@ For detailed instructions on tests, benchmarks, MACE / Nequip, and our API, check out the [documentation](https://passionlab.github.io/OpenEquivariance). ⭐️ **JAX**: Our latest update brings -support for JAX. To install, execute the following -commands in order: +support for JAX. Install it with the following: ``` bash -pip install jax[cuda12] # Skip if JAX installed -pip install openequivariance[jax] -pip install openequivariance_extjax --no-build-isolation +pip install openequivariance +pip install openequivariance_extjax ``` -See the section below for example usage. + +See the section below for example usage and +our [API page](https://passionlab.github.io/OpenEquivariance/api/) for more details. ## Show me some examples Here's a CG tensor product implemented by e3nn: diff --git a/docs/installation.rst b/docs/installation.rst index 3949457b..376b4863 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -58,13 +58,13 @@ installation (or upgrade) with pip. JAX ------------------------------------------ -JAX support is currently limited to NVIDIA GPUs. You need to execute -the following two commands strictly in order: +JAX support is currently limited to NVIDIA GPUs. After +installing the main OpenEquivariance package, install +our JAX extension as follows: .. code-block:: bash - pip install openequivariance[jax] - pip install openequivariance_extjax --no-build-isolation + pip install openequivariance_extjax From there, set ``OEQ_NOTORCH=1`` to avoid a PyTorch import and test the package: @@ -73,11 +73,10 @@ From there, set ``OEQ_NOTORCH=1`` to avoid a PyTorch import and test the package OEQ_NOTORCH=1 python -c "import openequivariance.jax" -You can get the nightly build as follows: +You can get the nightly build with teh following command: .. code-block:: bash - pip install git+https://github.com/PASSIONLab/OpenEquivariance#subdirectory=openequivariance[jax] pip install git+https://github.com/PASSIONLab/OpenEquivariance#subdirectory=openequivariance_extjax Configurations on Major Platforms diff --git a/openequivariance/pyproject.toml b/openequivariance/pyproject.toml index 30df0e60..c55cffea 100644 --- a/openequivariance/pyproject.toml +++ b/openequivariance/pyproject.toml @@ -48,11 +48,6 @@ bench = [ "cuequivariance-ops-torch-cu12", ] -jax = [ - "nanobind", - "scikit-build-core" -] - dev = [ "e3nn", "pre-commit", @@ -64,9 +59,7 @@ dev = [ "cmake", "furo", "sphinx", - "sphinx-autobuild", - "nanobind", - "scikit-build-core" + "sphinx-autobuild" ] [tool.setuptools.packages.find] From 0efef5bf79a30e81e9693095fc8ddfb273ae8d29 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 26 Dec 2025 14:06:00 -0800 Subject: [PATCH 091/116] Updated README. --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index d97795d7..166a7e36 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![OEQ CUDA C++ Extension Build Verification](https://github.com/PASSIONLab/OpenEquivariance/actions/workflows/verify_extension_build.yml/badge.svg?event=push)](https://github.com/PASSIONLab/OpenEquivariance/actions/workflows/verify_extension_build.yml) [![License](https://img.shields.io/badge/License-BSD_3--Clause-blue.svg)](https://opensource.org/licenses/BSD-3-Clause) -[[Examples]](#show-me-some-examples) +[[PyTorch Examples]](#pytorch-examples) [[JAX Examples]](#jax-examples) [[Citation and Acknowledgements]](#citation-and-acknowledgements) @@ -41,7 +41,7 @@ pip install openequivariance_extjax See the section below for example usage and our [API page](https://passionlab.github.io/OpenEquivariance/api/) for more details. -## Show me some examples +## PyTorch Examples Here's a CG tensor product implemented by e3nn: ```python From 37f4891a0dcba4a9516427a85582c25306bc2cd5 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 26 Dec 2025 14:13:31 -0800 Subject: [PATCH 092/116] Updated documentation slightly. --- docs/tests_and_benchmarks.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/tests_and_benchmarks.rst b/docs/tests_and_benchmarks.rst index fa24e765..0799a033 100644 --- a/docs/tests_and_benchmarks.rst +++ b/docs/tests_and_benchmarks.rst @@ -21,7 +21,7 @@ download the test folder and install the non-editable package and the dependenci Correctness ------------------------------ -To set up the editable install and run the entire testsuite, use: +To set up the editable install and run the entire PyTorch testsuite, use: .. code-block:: bash @@ -36,7 +36,7 @@ To test the JAX wrappers, follow the same steps above and make sure that ``openequivariance_extjax`` is installed without build isolation. Then run .. code-block:: bash - + pytest --jax tests/example_test.py pytest --jax tests/batch_test.py pytest --jax tests/conv_test.py From 5d8e42fdd0554362002a9fcc524a3dfbe6946130 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Fri, 26 Dec 2025 14:22:07 -0800 Subject: [PATCH 093/116] Don't need extension source path anymore. --- openequivariance/openequivariance/__init__.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/openequivariance/openequivariance/__init__.py b/openequivariance/openequivariance/__init__.py index 84cf9f92..76f9c2bc 100644 --- a/openequivariance/openequivariance/__init__.py +++ b/openequivariance/openequivariance/__init__.py @@ -31,14 +31,6 @@ def _check_package_editable(): _editable_install_output_path = Path(__file__).parent.parent.parent / "outputs" - -def extension_source_path(): - """ - :returns: Path to the source code of the C++ extension. - """ - return str(Path(__file__).parent / "extension") - - if "OEQ_NOTORCH" not in os.environ or os.environ["OEQ_NOTORCH"] != "1": import torch From 01e124d6432ef5f409c45429417b1a5a54bc4638 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sat, 27 Dec 2025 13:05:11 -0800 Subject: [PATCH 094/116] Removed a spurious import. --- openequivariance/__init__.py | 80 ------------------- openequivariance/openequivariance/__init__.py | 2 +- 2 files changed, 1 insertion(+), 81 deletions(-) delete mode 100644 openequivariance/__init__.py diff --git a/openequivariance/__init__.py b/openequivariance/__init__.py deleted file mode 100644 index 5a8ab812..00000000 --- a/openequivariance/__init__.py +++ /dev/null @@ -1,80 +0,0 @@ -# ruff: noqa: F401 -import sys -import torch -import numpy as np -from pathlib import Path -from importlib.metadata import version - -import openequivariance.extlib - -from openequivariance.extlib import ( - LINKED_LIBPYTHON, - LINKED_LIBPYTHON_ERROR, - BUILT_EXTENSION, - BUILT_EXTENSION_ERROR, - TORCH_COMPILE, - TORCH_COMPILE_ERROR, -) - -from openequivariance.implementations.e3nn_lite import ( - TPProblem, - Irrep, - Irreps, - _MulIr, - Instruction, -) -from openequivariance.implementations.TensorProduct import TensorProduct -from openequivariance.implementations.convolution.TensorProductConv import ( - TensorProductConv, -) -from openequivariance.implementations.utils import torch_to_oeq_dtype - -__version__ = None -try: - __version__ = version("openequivariance") -except Exception as e: - print(f"Warning: Could not determine oeq version: {e}", file=sys.stderr) - - -def _check_package_editable(): - import json - from importlib.metadata import Distribution - - direct_url = Distribution.from_name("openequivariance").read_text("direct_url.json") - return json.loads(direct_url).get("dir_info", {}).get("editable", False) - - -_editable_install_output_path = Path(__file__).parent.parent / "outputs" - - -def torch_ext_so_path(): - """ - :returns: Path to a ``.so`` file that must be linked to use OpenEquivariance - from the PyTorch C++ Interface. - """ - return openequivariance.extlib.torch_module.__file__ - - -torch.serialization.add_safe_globals( - [ - TensorProduct, - TensorProductConv, - TPProblem, - Irrep, - Irreps, - _MulIr, - Instruction, - np.float32, - np.float64, - ] -) - -__all__ = [ - "TPProblem", - "Irreps", - "TensorProduct", - "TensorProductConv", - "torch_to_oeq_dtype", - "_check_package_editable", - "torch_ext_so_path", -] diff --git a/openequivariance/openequivariance/__init__.py b/openequivariance/openequivariance/__init__.py index 76f9c2bc..b1a13325 100644 --- a/openequivariance/openequivariance/__init__.py +++ b/openequivariance/openequivariance/__init__.py @@ -72,7 +72,7 @@ def torch_ext_so_path(): try: import openequivariance_extjax import openequivariance.jax as jax -except ImportError: +except Exception: pass __all__ = [ From d21b2487784980175a4de77768efd237599eaa20 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj <67718556+vbharadwaj-bk@users.noreply.github.com> Date: Sun, 4 Jan 2026 02:33:10 -0800 Subject: [PATCH 095/116] Update Python version and XLA Git tag in CMakeLists --- openequivariance_extjax/CMakeLists.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/openequivariance_extjax/CMakeLists.txt b/openequivariance_extjax/CMakeLists.txt index 122b4bcb..e2b527aa 100644 --- a/openequivariance_extjax/CMakeLists.txt +++ b/openequivariance_extjax/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.15...3.30) project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX) # TODO: Add HIP support -find_package(Python 3.11 REQUIRED COMPONENTS Interpreter Development.Module) +find_package(Python 3.10 REQUIRED COMPONENTS Interpreter Development.Module) find_package(CUDAToolkit REQUIRED) include(ExternalProject) @@ -9,7 +9,7 @@ ExternalProject_Add( xla PREFIX ${CMAKE_BINARY_DIR}/xla GIT_REPOSITORY https://github.com/openxla/xla.git - GIT_TAG d56e645fe72988e2b4464119300ca6f894f82598 + GIT_TAG main GIT_SHALLOW TRUE GIT_PROGRESS TRUE CONFIGURE_COMMAND "" @@ -73,4 +73,4 @@ install(TARGETS openequivariance_extjax LIBRARY DESTINATION .) # "from jax import ffi; print(ffi.include_dir())" # OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR #) -#message(STATUS "XLA include directory: ${XLA_DIR}") \ No newline at end of file +#message(STATUS "XLA include directory: ${XLA_DIR}") From c774f0ed3042ff3603f57c51c80bd11faeb5b204 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 4 Jan 2026 12:51:37 -0800 Subject: [PATCH 096/116] Update XLA dir. --- openequivariance_extjax/CMakeLists.txt | 41 +++++++++++++------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/openequivariance_extjax/CMakeLists.txt b/openequivariance_extjax/CMakeLists.txt index 122b4bcb..626943ee 100644 --- a/openequivariance_extjax/CMakeLists.txt +++ b/openequivariance_extjax/CMakeLists.txt @@ -4,21 +4,27 @@ project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX) # TODO: Add HIP support find_package(Python 3.11 REQUIRED COMPONENTS Interpreter Development.Module) find_package(CUDAToolkit REQUIRED) -include(ExternalProject) -ExternalProject_Add( - xla - PREFIX ${CMAKE_BINARY_DIR}/xla - GIT_REPOSITORY https://github.com/openxla/xla.git - GIT_TAG d56e645fe72988e2b4464119300ca6f894f82598 - GIT_SHALLOW TRUE - GIT_PROGRESS TRUE - CONFIGURE_COMMAND "" - BUILD_COMMAND "" - INSTALL_COMMAND "" - LOG_DOWNLOAD ON - ) -ExternalProject_Get_Property(xla source_dir) -set(XLA_DIR ${source_dir}) +#include(ExternalProject) +#ExternalProject_Add( +# xla +# PREFIX ${CMAKE_BINARY_DIR}/xla +# GIT_REPOSITORY https://github.com/openxla/xla.git +# GIT_TAG d56e645fe72988e2b4464119300ca6f894f82598 +# GIT_SHALLOW TRUE +# GIT_PROGRESS TRUE +# CONFIGURE_COMMAND "" +# BUILD_COMMAND "" +# INSTALL_COMMAND "" +# LOG_DOWNLOAD ON +# ) +#ExternalProject_Get_Property(xla source_dir) +#set(XLA_DIR ${source_dir}) + +execute_process( + COMMAND "${Python_EXECUTABLE}" "-c" + "from jax import ffi; print(ffi.include_dir())" + OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR +) message(STATUS "XLA include directory: ${XLA_DIR}") execute_process( @@ -68,9 +74,4 @@ install(TARGETS openequivariance_extjax LIBRARY DESTINATION .) # ------------------------------------------------------------- # Uncomment to include XLA from JAX installation -#execute_process( -# COMMAND "${Python_EXECUTABLE}" "-c" -# "from jax import ffi; print(ffi.include_dir())" -# OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR -#) #message(STATUS "XLA include directory: ${XLA_DIR}") \ No newline at end of file From aa040f1de8b9837319461fcd70cffdb9aea55272 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 4 Jan 2026 12:57:38 -0800 Subject: [PATCH 097/116] Removed dependency. --- openequivariance_extjax/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openequivariance_extjax/CMakeLists.txt b/openequivariance_extjax/CMakeLists.txt index 626943ee..7c7a9ab5 100644 --- a/openequivariance_extjax/CMakeLists.txt +++ b/openequivariance_extjax/CMakeLists.txt @@ -57,7 +57,7 @@ target_compile_options(openequivariance_extjax PRIVATE -Wno-attributes -Wno-retu get_target_property(CUDA_LIB_DIR CUDA::nvrtc IMPORTED_LOCATION) get_filename_component(CUDA_LIB_DIR ${CUDA_LIB_DIR} DIRECTORY) -add_dependencies(openequivariance_extjax xla) +#add_dependencies(openequivariance_extjax xla) set_target_properties(openequivariance_extjax PROPERTIES BUILD_RPATH "${CUDA_LIB_DIR}" From fa5236fa9b4a13608817749ea910312034440609 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 4 Jan 2026 13:21:41 -0800 Subject: [PATCH 098/116] Went back to version that disables build isolation. --- .github/workflows/verify_extension_build.yml | 5 +++-- README.md | 7 ++++--- docs/installation.rst | 11 ++++++----- openequivariance/pyproject.toml | 6 ++++++ openequivariance_extjax/CMakeLists.txt | 11 +++-------- 5 files changed, 22 insertions(+), 18 deletions(-) diff --git a/.github/workflows/verify_extension_build.yml b/.github/workflows/verify_extension_build.yml index 9bbef17b..73bd2eb1 100644 --- a/.github/workflows/verify_extension_build.yml +++ b/.github/workflows/verify_extension_build.yml @@ -29,7 +29,7 @@ jobs: sudo apt-get update sudo apt install nvidia-cuda-toolkit pip install -r .github/workflows/requirements_cuda_ci.txt - pip install -e ./openequivariance + pip install -e "./openequivariance" - name: Test extension build via import run: | @@ -39,4 +39,5 @@ jobs: - name: Test JAX extension build run: | - pip install -e ./openequivariance_extjax + pip install -e "./openequivariance[jax]" + pip install -e ""./openequivariance_extjax" --no-build-isolation diff --git a/README.md b/README.md index 166a7e36..d085bb86 100644 --- a/README.md +++ b/README.md @@ -31,11 +31,12 @@ For detailed instructions on tests, benchmarks, MACE / Nequip, and our API, check out the [documentation](https://passionlab.github.io/OpenEquivariance). ⭐️ **JAX**: Our latest update brings -support for JAX. Install it with the following: +support for JAX. Install it with the following two commands +(strictly in order): ``` bash -pip install openequivariance -pip install openequivariance_extjax +pip install openequivariance[jax] +pip install openequivariance_extjax --no-build-isolation ``` See the section below for example usage and diff --git a/docs/installation.rst b/docs/installation.rst index 376b4863..d65a0995 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -58,13 +58,13 @@ installation (or upgrade) with pip. JAX ------------------------------------------ -JAX support is currently limited to NVIDIA GPUs. After -installing the main OpenEquivariance package, install -our JAX extension as follows: +JAX support is currently limited to NVIDIA GPUs. Install it by +executing the following two commands in order: .. code-block:: bash - pip install openequivariance_extjax + pip install openequivariance[jax] + pip install openequivariance_extjax --no-build-isolation From there, set ``OEQ_NOTORCH=1`` to avoid a PyTorch import and test the package: @@ -77,7 +77,8 @@ You can get the nightly build with teh following command: .. code-block:: bash - pip install git+https://github.com/PASSIONLab/OpenEquivariance#subdirectory=openequivariance_extjax + pip install "git+https://github.com/PASSIONLab/OpenEquivariance#subdirectory=openequivariance[jax]" + pip install "git+https://github.com/PASSIONLab/OpenEquivariance#subdirectory=openequivariance_extjax" Configurations on Major Platforms --------------------------------- diff --git a/openequivariance/pyproject.toml b/openequivariance/pyproject.toml index c55cffea..e2dd21de 100644 --- a/openequivariance/pyproject.toml +++ b/openequivariance/pyproject.toml @@ -62,6 +62,12 @@ dev = [ "sphinx-autobuild" ] +jax = [ + "nanobind", + "scikit-build-core", + "setuptools-scm" +] + [tool.setuptools.packages.find] include = ["openequivariance*"] diff --git a/openequivariance_extjax/CMakeLists.txt b/openequivariance_extjax/CMakeLists.txt index 7c7a9ab5..ec6d6f61 100644 --- a/openequivariance_extjax/CMakeLists.txt +++ b/openequivariance_extjax/CMakeLists.txt @@ -4,6 +4,7 @@ project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX) # TODO: Add HIP support find_package(Python 3.11 REQUIRED COMPONENTS Interpreter Development.Module) find_package(CUDAToolkit REQUIRED) +# Would like this to avoid build isolation, but breaks ABI compat. #include(ExternalProject) #ExternalProject_Add( # xla @@ -19,6 +20,7 @@ find_package(CUDAToolkit REQUIRED) # ) #ExternalProject_Get_Property(xla source_dir) #set(XLA_DIR ${source_dir}) +#add_dependencies(openequivariance_extjax xla) execute_process( COMMAND "${Python_EXECUTABLE}" "-c" @@ -57,8 +59,6 @@ target_compile_options(openequivariance_extjax PRIVATE -Wno-attributes -Wno-retu get_target_property(CUDA_LIB_DIR CUDA::nvrtc IMPORTED_LOCATION) get_filename_component(CUDA_LIB_DIR ${CUDA_LIB_DIR} DIRECTORY) -#add_dependencies(openequivariance_extjax xla) - set_target_properties(openequivariance_extjax PROPERTIES BUILD_RPATH "${CUDA_LIB_DIR}" INSTALL_RPATH "${CUDA_LIB_DIR}" @@ -69,9 +69,4 @@ target_link_libraries(openequivariance_extjax PRIVATE CUDA::cuda_driver CUDA::nvrtc) -install(TARGETS openequivariance_extjax LIBRARY DESTINATION .) - -# ------------------------------------------------------------- - -# Uncomment to include XLA from JAX installation -#message(STATUS "XLA include directory: ${XLA_DIR}") \ No newline at end of file +install(TARGETS openequivariance_extjax LIBRARY DESTINATION .) \ No newline at end of file From 94490dd2c9ced0a89ef258255a2747088ff80ddf Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 4 Jan 2026 13:24:55 -0800 Subject: [PATCH 099/116] Updated README. --- README.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index d085bb86..f701905a 100644 --- a/README.md +++ b/README.md @@ -151,6 +151,11 @@ import openequivariance as oeq seed = 42 key = jax.random.PRNGKey(seed) +batch_size = 1000 +X_ir, Y_ir, Z_ir = oeq.Irreps("1x2e"), oeq.Irreps("1x3e"), oeq.Irreps("1x2e") +problem = oeq.TPProblem(X_ir, Y_ir, Z_ir, [(0, 0, 0, "uvu", True)], shared_weights=False, internal_weights=False) + + node_ct, nonzero_ct = 3, 4 edge_index = jax.numpy.array( [ @@ -166,8 +171,6 @@ Y = jax.random.uniform(key, shape=(nonzero_ct, Y_ir.dim), W = jax.random.uniform(key, shape=(nonzero_ct, problem.weight_numel), minval=0.0, maxval=1.0, dtype=jax.numpy.float32) -# Reuse problem from earlier -# ... tp_conv = oeq.jax.TensorProductConv(problem, deterministic=False) Z = tp_conv.forward( X, Y, W, edge_index[0], edge_index[1] From aea13514d363a9991811e3c0403fab8938af2bbb Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 4 Jan 2026 13:38:58 -0800 Subject: [PATCH 100/116] Updated error handling --- openequivariance/openequivariance/__init__.py | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/openequivariance/openequivariance/__init__.py b/openequivariance/openequivariance/__init__.py index b1a13325..1603ec95 100644 --- a/openequivariance/openequivariance/__init__.py +++ b/openequivariance/openequivariance/__init__.py @@ -56,6 +56,14 @@ def _check_package_editable(): ] ) + from openequivariance._torch.extlib import ( + LINKED_LIBPYTHON, + LINKED_LIBPYTHON_ERROR, + BUILT_EXTENSION, + BUILT_EXTENSION_ERROR, + TORCH_COMPILE, + TORCH_COMPILE_ERROR, + ) def torch_ext_so_path(): """ @@ -72,8 +80,16 @@ def torch_ext_so_path(): try: import openequivariance_extjax import openequivariance.jax as jax -except Exception: - pass +except Exception as e: + error = e + class JAX_ERR: + def TensorProduct(*args, **kwargs): + raise error + + def TensorProductConv(*args, **kwargs): + raise error + + jax = JAX_ERR() __all__ = [ "TPProblem", From 903d597393bc8a6fd308d7d49662b3510d69d149 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 4 Jan 2026 13:45:26 -0800 Subject: [PATCH 101/116] Last bit of cleanup. --- .github/workflows/verify_extension_build.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/verify_extension_build.yml b/.github/workflows/verify_extension_build.yml index 73bd2eb1..f9367a32 100644 --- a/.github/workflows/verify_extension_build.yml +++ b/.github/workflows/verify_extension_build.yml @@ -37,7 +37,7 @@ jobs: tests/import_test.py::test_extension_built \ tests/import_test.py::test_torch_extension_built - - name: Test JAX extension build - run: | - pip install -e "./openequivariance[jax]" - pip install -e ""./openequivariance_extjax" --no-build-isolation + #- name: Test JAX extension build + # run: | + # pip install -e "./openequivariance[jax]" + # pip install -e ""./openequivariance_extjax" --no-build-isolation From 656887618775d902fedd58c312467ec37b2d9a5d Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 4 Jan 2026 13:53:55 -0800 Subject: [PATCH 102/116] Ruff. --- openequivariance/openequivariance/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/openequivariance/openequivariance/__init__.py b/openequivariance/openequivariance/__init__.py index 1603ec95..a842a7c9 100644 --- a/openequivariance/openequivariance/__init__.py +++ b/openequivariance/openequivariance/__init__.py @@ -65,6 +65,7 @@ def _check_package_editable(): TORCH_COMPILE_ERROR, ) + def torch_ext_so_path(): """ :returns: Path to a ``.so`` file that must be linked to use OpenEquivariance @@ -82,13 +83,14 @@ def torch_ext_so_path(): import openequivariance.jax as jax except Exception as e: error = e + class JAX_ERR: def TensorProduct(*args, **kwargs): raise error - + def TensorProductConv(*args, **kwargs): raise error - + jax = JAX_ERR() __all__ = [ From 24dbb1171f7ee2243cf31ea8a187936e87d3726d Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 5 Jan 2026 02:17:43 -0500 Subject: [PATCH 103/116] Things working for HIP, just need to branch. --- .../openequivariance/jax/extlib/__init__.py | 15 +++-- openequivariance_extjax/CMakeLists.txt | 10 ++-- openequivariance_extjax/src/libjax_tp_jit.cpp | 60 ++++++++++++------- 3 files changed, 52 insertions(+), 33 deletions(-) diff --git a/openequivariance/openequivariance/jax/extlib/__init__.py b/openequivariance/openequivariance/jax/extlib/__init__.py index 8719e848..899dab8d 100644 --- a/openequivariance/openequivariance/jax/extlib/__init__.py +++ b/openequivariance/openequivariance/jax/extlib/__init__.py @@ -2,15 +2,20 @@ import openequivariance_extjax as oeq_extjax +#def postprocess_kernel(kernel): +# """ +# Only CUDA for now, so no postprocessing. +# """ +# return kernel + def postprocess_kernel(kernel): - """ - Only CUDA for now, so no postprocessing. - """ + kernel = kernel.replace("__syncwarp();", "__threadfence_block();") + kernel = kernel.replace("__shfl_down_sync(FULL_MASK,", "__shfl_down(") + kernel = kernel.replace("atomicAdd", "unsafeAtomicAdd") return kernel - for name, target in oeq_extjax.registrations().items(): - jax.ffi.register_ffi_target(name, target, platform="CUDA") + jax.ffi.register_ffi_target(name, target, platform="ROCM") GPUTimer = oeq_extjax.GPUTimer DeviceProp = oeq_extjax.DeviceProp diff --git a/openequivariance_extjax/CMakeLists.txt b/openequivariance_extjax/CMakeLists.txt index 21f0e9d4..5f821247 100644 --- a/openequivariance_extjax/CMakeLists.txt +++ b/openequivariance_extjax/CMakeLists.txt @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.15...3.30) project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX) # TODO: Add HIP support find_package(Python 3.10 REQUIRED COMPONENTS Interpreter Development.Module) -find_package(CUDAToolkit REQUIRED) +find_package(hip REQUIRED) # Would like this to avoid build isolation, but breaks ABI compat. #include(ExternalProject) @@ -56,8 +56,8 @@ target_include_directories(openequivariance_extjax PUBLIC ${XLA_DIR} ${HEADER_DI set_target_properties(openequivariance_extjax PROPERTIES CUDA_STANDARD 17 POSITION_INDEPENDENT_CODE ON) target_compile_options(openequivariance_extjax PRIVATE -Wno-attributes -Wno-return-type) -get_target_property(CUDA_LIB_DIR CUDA::nvrtc IMPORTED_LOCATION) -get_filename_component(CUDA_LIB_DIR ${CUDA_LIB_DIR} DIRECTORY) +#get_target_property(CUDA_LIB_DIR CUDA::nvrtc IMPORTED_LOCATION) +#get_filename_component(CUDA_LIB_DIR ${CUDA_LIB_DIR} DIRECTORY) set_target_properties(openequivariance_extjax PROPERTIES BUILD_RPATH "${CUDA_LIB_DIR}" @@ -65,8 +65,6 @@ set_target_properties(openequivariance_extjax PROPERTIES ) target_link_libraries(openequivariance_extjax PRIVATE - CUDA::cudart - CUDA::cuda_driver - CUDA::nvrtc) + hiprtc) install(TARGETS openequivariance_extjax LIBRARY DESTINATION .) diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp index 7342eb01..af83f442 100644 --- a/openequivariance_extjax/src/libjax_tp_jit.cpp +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -5,8 +5,6 @@ #include #include #include -#include -#include #include "nanobind/nanobind.h" #include "xla/ffi/api/ffi.h" @@ -14,9 +12,12 @@ namespace nb = nanobind; namespace ffi = xla::ffi; -#define CUDA_BACKEND // Stick to CUDA for now +#define HIP_BACKEND #ifdef CUDA_BACKEND + #include + #include + #include "util/backend_cuda.hpp" #include "group_mm_cuda.hpp" using JITKernel = CUJITKernel; @@ -24,6 +25,18 @@ namespace ffi = xla::ffi; template using GroupMM = GroupMMCUDA; + using stream_t = cudaStream_t; +#endif + +#ifdef HIP_BACKEND + #include "util/backend_hip.hpp" + #include "group_mm_hip.hpp" + using JITKernel = HIPJITKernel; + using GPU_Allocator = HIP_Allocator; + + template + using GroupMM = GroupMMHIP; + using stream_t = hipStream_t; #endif #include "tensorproducts.hpp" @@ -49,16 +62,10 @@ std::string xla_dtype_to_string(xla::ffi::DataType dtype) { const std::unordered_map map = { {xla::ffi::DataType::INVALID, "INVALID"}, {xla::ffi::DataType::PRED, "PRED"}, - {xla::ffi::DataType::S1, "S1"}, - {xla::ffi::DataType::S2, "S2"}, - {xla::ffi::DataType::S4, "S4"}, {xla::ffi::DataType::S8, "S8"}, {xla::ffi::DataType::S16, "S16"}, {xla::ffi::DataType::S32, "S32"}, {xla::ffi::DataType::S64, "S64"}, - {xla::ffi::DataType::U1, "U1"}, - {xla::ffi::DataType::U2, "U2"}, - {xla::ffi::DataType::U4, "U4"}, {xla::ffi::DataType::U8, "U8"}, {xla::ffi::DataType::U16, "U16"}, {xla::ffi::DataType::U32, "U32"}, @@ -108,7 +115,7 @@ inline int byte_count(ffi::AnyBuffer &buffer) { } #ifdef CUDA_BACKEND -void zero_buffer(ffi::AnyBuffer &buffer, cudaStream_t stream) { +void zero_buffer(ffi::AnyBuffer &buffer, stream_t stream) { cudaMemsetAsync( data_ptr(buffer), 0, @@ -116,6 +123,15 @@ void zero_buffer(ffi::AnyBuffer &buffer, cudaStream_t stream) { stream); } #endif +#ifdef HIP_BACKEND +void zero_buffer(ffi::AnyBuffer &buffer, stream_t stream) { + std::ignore = hipMemsetAsync( + data_ptr(buffer), + 0, + buffer.element_count() * byte_count(buffer), + stream); +} +#endif struct KernelProp { int64_t L1_dim, L2_dim, L3_dim, weight_numel; @@ -286,7 +302,7 @@ ffi::Error tp_forward_impl( ffi::AnyBuffer L2_in, ffi::AnyBuffer W, ffi::Result L3_out, - cudaStream_t stream, + stream_t stream, std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, int64_t hash) { @@ -321,7 +337,7 @@ ffi::Error tp_backward_impl( ffi::Result L1_grad, ffi::Result L2_grad, ffi::Result W_grad, - cudaStream_t stream, + stream_t stream, std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, int64_t hash) { @@ -371,7 +387,7 @@ ffi::Error tp_double_backward_impl( ffi::Result L2_grad, ffi::Result W_grad, ffi::Result L3_dgrad, - cudaStream_t stream, + stream_t stream, std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, int64_t hash) { @@ -420,7 +436,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Arg() .Arg() .Ret() - .Ctx>() + .Ctx>() .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled @@ -435,7 +451,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Ret() .Ret() .Ret() - .Ctx>() + .Ctx>() .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); @@ -454,7 +470,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Ret() .Ret() .Ret() - .Ctx>() + .Ctx>() .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); @@ -469,7 +485,7 @@ ffi::Error conv_forward_impl( ffi::AnyBuffer workspace, ffi::AnyBuffer transpose_perm, ffi::Result L3_out, - cudaStream_t stream, + stream_t stream, std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, int64_t hash) { @@ -524,7 +540,7 @@ ffi::Error conv_backward_impl( ffi::AnyBuffer cols, ffi::AnyBuffer workspace, ffi::AnyBuffer transpose_perm, - cudaStream_t stream, + stream_t stream, std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, int64_t hash) { @@ -593,7 +609,7 @@ ffi::Error conv_double_backward_impl( ffi::AnyBuffer cols, ffi::AnyBuffer workspace, ffi::AnyBuffer transpose_perm, - cudaStream_t stream, + stream_t stream, std::string_view kernel, ffi::Dictionary forward_config, ffi::Dictionary backward_config, ffi::Dictionary double_backward_config, ffi::Dictionary kernel_prop, int64_t hash) { @@ -664,7 +680,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Arg() .Arg() .Ret() - .Ctx>() + .Ctx>() .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); @@ -683,7 +699,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Arg() .Arg() .Arg() - .Ctx>() + .Ctx>() .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); @@ -706,7 +722,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( .Arg() .Arg() .Arg() - .Ctx>() + .Ctx>() .Attr("kernel").Attr("forward_config").Attr("backward_config").Attr("double_backward_config").Attr("kernel_prop") .Attr("hash"), {xla::ffi::Traits::kCmdBufferCompatible}); From 8c5eadf6905539a834d8b20353ea9b109c6379d0 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 5 Jan 2026 02:51:24 -0500 Subject: [PATCH 104/116] Updated CMakeLists. --- openequivariance_extjax/CMakeLists.txt | 62 +++++++++++++------------- 1 file changed, 30 insertions(+), 32 deletions(-) diff --git a/openequivariance_extjax/CMakeLists.txt b/openequivariance_extjax/CMakeLists.txt index 5f821247..fe9ece0d 100644 --- a/openequivariance_extjax/CMakeLists.txt +++ b/openequivariance_extjax/CMakeLists.txt @@ -1,26 +1,7 @@ cmake_minimum_required(VERSION 3.15...3.30) -project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX) # TODO: Add HIP support +project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX) find_package(Python 3.10 REQUIRED COMPONENTS Interpreter Development.Module) -find_package(hip REQUIRED) - -# Would like this to avoid build isolation, but breaks ABI compat. -#include(ExternalProject) -#ExternalProject_Add( -# xla -# PREFIX ${CMAKE_BINARY_DIR}/xla -# GIT_REPOSITORY https://github.com/openxla/xla.git -# GIT_TAG d56e645fe72988e2b4464119300ca6f894f82598 -# GIT_SHALLOW TRUE -# GIT_PROGRESS TRUE -# CONFIGURE_COMMAND "" -# BUILD_COMMAND "" -# INSTALL_COMMAND "" -# LOG_DOWNLOAD ON -# ) -#ExternalProject_Get_Property(xla source_dir) -#set(XLA_DIR ${source_dir}) -#add_dependencies(openequivariance_extjax xla) execute_process( COMMAND "${Python_EXECUTABLE}" "-c" @@ -34,10 +15,18 @@ execute_process( OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE nanobind_ROOT ) message(STATUS "nanobind cmake directory: ${nanobind_ROOT}") -set(HEADER_DIR "src/extension") find_package(nanobind CONFIG REQUIRED) +if(DJAX_HIP) + message(STATUS "DJAX_HIP is set. Building with HIP backend.") + find_package(hip REQUIRED) +else() + message(STATUS "DJAX_HIP is not set (or zero). Building with CUDA backend.") + find_package(CUDAToolkit REQUIRED) +endif() + +set(HEADER_DIR "src/extension") set(OEQ_JAX_SOURCES src/libjax_tp_jit.cpp ) @@ -51,20 +40,29 @@ set(OEQ_JAX_HEADERS ) nanobind_add_module(openequivariance_extjax NB_STATIC ${OEQ_JAX_SOURCES} ${OEQ_JAX_HEADERS}) - target_include_directories(openequivariance_extjax PUBLIC ${XLA_DIR} ${HEADER_DIR}) -set_target_properties(openequivariance_extjax PROPERTIES CUDA_STANDARD 17 POSITION_INDEPENDENT_CODE ON) +set_target_properties(openequivariance_extjax PROPERTIES POSITION_INDEPENDENT_CODE ON) target_compile_options(openequivariance_extjax PRIVATE -Wno-attributes -Wno-return-type) -#get_target_property(CUDA_LIB_DIR CUDA::nvrtc IMPORTED_LOCATION) -#get_filename_component(CUDA_LIB_DIR ${CUDA_LIB_DIR} DIRECTORY) +if(DJAX_HIP) + target_link_libraries(openequivariance_extjax PRIVATE hiprtc) + target_compile_definitions(openequivariance_extjax PRIVATE DJAX_HIP=1) -set_target_properties(openequivariance_extjax PROPERTIES - BUILD_RPATH "${CUDA_LIB_DIR}" - INSTALL_RPATH "${CUDA_LIB_DIR}" -) +else() + set_target_properties(openequivariance_extjax PROPERTIES CUDA_STANDARD 17) + + get_target_property(CUDA_LIB_DIR CUDA::nvrtc IMPORTED_LOCATION) + get_filename_component(CUDA_LIB_DIR ${CUDA_LIB_DIR} DIRECTORY) + + set_target_properties(openequivariance_extjax PROPERTIES + BUILD_RPATH "${CUDA_LIB_DIR}" + INSTALL_RPATH "${CUDA_LIB_DIR}" + ) -target_link_libraries(openequivariance_extjax PRIVATE - hiprtc) + target_link_libraries(openequivariance_extjax PRIVATE + CUDA::cudart + CUDA::cuda_driver + CUDA::nvrtc) +endif() -install(TARGETS openequivariance_extjax LIBRARY DESTINATION .) +install(TARGETS openequivariance_extjax LIBRARY DESTINATION .) \ No newline at end of file From 39ac3be65edbe6745529abc6be7dc2c9151237da Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 5 Jan 2026 02:52:43 -0500 Subject: [PATCH 105/116] Added pyproject.toml define. --- openequivariance_extjax/pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/openequivariance_extjax/pyproject.toml b/openequivariance_extjax/pyproject.toml index 41137b61..702040d9 100644 --- a/openequivariance_extjax/pyproject.toml +++ b/openequivariance_extjax/pyproject.toml @@ -37,6 +37,8 @@ homepage = "https://passionlab.github.io/OpenEquivariance/" source = "https://github.com/PASSIONLab/OpenEquivariance" issues = "https://github.com/PASSIONLab/OpenEquivariance/issues" +[tool.scikit-build.cmake.define] +JAX_HIP = {env="JAX_HIP", default="0"} [tool.setuptools_scm] root = ".." From c3d4666374ac11658f05d9a865ebee1f291c869a Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 5 Jan 2026 03:00:29 -0500 Subject: [PATCH 106/116] Plumbed logic. --- .../openequivariance/jax/extlib/__init__.py | 24 +++++++++---------- openequivariance_extjax/CMakeLists.txt | 3 ++- openequivariance_extjax/src/libjax_tp_jit.cpp | 13 ++++++++-- 3 files changed, 25 insertions(+), 15 deletions(-) diff --git a/openequivariance/openequivariance/jax/extlib/__init__.py b/openequivariance/openequivariance/jax/extlib/__init__.py index 899dab8d..66a2af42 100644 --- a/openequivariance/openequivariance/jax/extlib/__init__.py +++ b/openequivariance/openequivariance/jax/extlib/__init__.py @@ -1,21 +1,21 @@ import jax import openequivariance_extjax as oeq_extjax - -#def postprocess_kernel(kernel): -# """ -# Only CUDA for now, so no postprocessing. -# """ -# return kernel - def postprocess_kernel(kernel): - kernel = kernel.replace("__syncwarp();", "__threadfence_block();") - kernel = kernel.replace("__shfl_down_sync(FULL_MASK,", "__shfl_down(") - kernel = kernel.replace("atomicAdd", "unsafeAtomicAdd") - return kernel + if oeq_extjax.is_hip(): + kernel = kernel.replace("__syncwarp();", "__threadfence_block();") + kernel = kernel.replace("__shfl_down_sync(FULL_MASK,", "__shfl_down(") + kernel = kernel.replace("atomicAdd", "unsafeAtomicAdd") + return kernel + else: + return kernel + +platform = "CUDA" +if oeq_extjax.is_hip(): + platform = "ROCM" for name, target in oeq_extjax.registrations().items(): - jax.ffi.register_ffi_target(name, target, platform="ROCM") + jax.ffi.register_ffi_target(name, target, platform=platform) GPUTimer = oeq_extjax.GPUTimer DeviceProp = oeq_extjax.DeviceProp diff --git a/openequivariance_extjax/CMakeLists.txt b/openequivariance_extjax/CMakeLists.txt index fe9ece0d..d33275d8 100644 --- a/openequivariance_extjax/CMakeLists.txt +++ b/openequivariance_extjax/CMakeLists.txt @@ -46,7 +46,7 @@ target_compile_options(openequivariance_extjax PRIVATE -Wno-attributes -Wno-retu if(DJAX_HIP) target_link_libraries(openequivariance_extjax PRIVATE hiprtc) - target_compile_definitions(openequivariance_extjax PRIVATE DJAX_HIP=1) + target_compile_definitions(openequivariance_extjax PRIVATE HIP_BACKEND=1) else() set_target_properties(openequivariance_extjax PROPERTIES CUDA_STANDARD 17) @@ -63,6 +63,7 @@ else() CUDA::cudart CUDA::cuda_driver CUDA::nvrtc) + target_compile_definitions(openequivariance_extjax PRIVATE CUDA_BACKEND=1) endif() install(TARGETS openequivariance_extjax LIBRARY DESTINATION .) \ No newline at end of file diff --git a/openequivariance_extjax/src/libjax_tp_jit.cpp b/openequivariance_extjax/src/libjax_tp_jit.cpp index af83f442..ae2035e8 100644 --- a/openequivariance_extjax/src/libjax_tp_jit.cpp +++ b/openequivariance_extjax/src/libjax_tp_jit.cpp @@ -12,8 +12,6 @@ namespace nb = nanobind; namespace ffi = xla::ffi; -#define HIP_BACKEND - #ifdef CUDA_BACKEND #include #include @@ -669,6 +667,16 @@ ffi::Error conv_double_backward_impl( return ffi::Error::Success(); } +bool is_hip() { +#ifdef HIP_BACKEND + return true; +#else + return false; +#endif +} + +// --------------------- FFI Bindings -------------------------- + XLA_FFI_DEFINE_HANDLER_SYMBOL( conv_forward, conv_forward_impl, ffi::Ffi::Bind() @@ -740,6 +748,7 @@ NB_MODULE(openequivariance_extjax, m) { registrations["conv_double_backward"] = nb::capsule(reinterpret_cast(conv_double_backward)); return registrations; }); + m.def("is_hip", &is_hip); nb::class_(m, "DeviceProp") .def(nb::init()) From 9e6c56f9c628e29016acd886e9a13720860771f3 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 5 Jan 2026 03:06:22 -0500 Subject: [PATCH 107/116] Made things compile with HIP. --- openequivariance_extjax/CMakeLists.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/openequivariance_extjax/CMakeLists.txt b/openequivariance_extjax/CMakeLists.txt index d33275d8..f7e97e44 100644 --- a/openequivariance_extjax/CMakeLists.txt +++ b/openequivariance_extjax/CMakeLists.txt @@ -18,11 +18,11 @@ message(STATUS "nanobind cmake directory: ${nanobind_ROOT}") find_package(nanobind CONFIG REQUIRED) -if(DJAX_HIP) - message(STATUS "DJAX_HIP is set. Building with HIP backend.") +if(JAX_HIP) + message(STATUS "JAX_HIP is set. Building with HIP backend.") find_package(hip REQUIRED) else() - message(STATUS "DJAX_HIP is not set (or zero). Building with CUDA backend.") + message(STATUS "JAX_HIP is not set (or zero). Building with CUDA backend.") find_package(CUDAToolkit REQUIRED) endif() @@ -44,7 +44,7 @@ target_include_directories(openequivariance_extjax PUBLIC ${XLA_DIR} ${HEADER_DI set_target_properties(openequivariance_extjax PROPERTIES POSITION_INDEPENDENT_CODE ON) target_compile_options(openequivariance_extjax PRIVATE -Wno-attributes -Wno-return-type) -if(DJAX_HIP) +if(JAX_HIP) target_link_libraries(openequivariance_extjax PRIVATE hiprtc) target_compile_definitions(openequivariance_extjax PRIVATE HIP_BACKEND=1) From bf4bb89c90f25837498a4cb55e75a11f60e2fa39 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 5 Jan 2026 03:09:21 -0500 Subject: [PATCH 108/116] Updated READMEs. --- README.md | 9 ++++++++- docs/installation.rst | 12 ++++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index f701905a..3e4cf6bc 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,8 @@ For detailed instructions on tests, benchmarks, MACE / Nequip, and our API, check out the [documentation](https://passionlab.github.io/OpenEquivariance). ⭐️ **JAX**: Our latest update brings -support for JAX. Install it with the following two commands +support for JAX. For NVIDIA GPUs, +install it with the following two commands (strictly in order): ``` bash @@ -39,6 +40,12 @@ pip install openequivariance[jax] pip install openequivariance_extjax --no-build-isolation ``` +For AMD GPUs: +``` bash +pip install openequivariance[jax] +JAX_HIP=1 pip install openequivariance_extjax --no-build-isolation +``` + See the section below for example usage and our [API page](https://passionlab.github.io/OpenEquivariance/api/) for more details. diff --git a/docs/installation.rst b/docs/installation.rst index d65a0995..4ab008d8 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -58,7 +58,7 @@ installation (or upgrade) with pip. JAX ------------------------------------------ -JAX support is currently limited to NVIDIA GPUs. Install it by +Install OpenEquivariance for NVIDIA GPUs by executing the following two commands in order: .. code-block:: bash @@ -66,6 +66,14 @@ executing the following two commands in order: pip install openequivariance[jax] pip install openequivariance_extjax --no-build-isolation + +For AMD GPUs, use + +.. code-block:: bash + + pip install openequivariance[jax] + JAX_HIP=1 pip install openequivariance_extjax --no-build-isolation + From there, set ``OEQ_NOTORCH=1`` to avoid a PyTorch import and test the package: .. code-block:: bash @@ -73,7 +81,7 @@ From there, set ``OEQ_NOTORCH=1`` to avoid a PyTorch import and test the package OEQ_NOTORCH=1 python -c "import openequivariance.jax" -You can get the nightly build with teh following command: +Likewise, you can get the nightly build with the following commands: .. code-block:: bash From abc409f3b0c6fda96e9be9beb58ae06b1eb53607 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 5 Jan 2026 03:18:14 -0500 Subject: [PATCH 109/116] Highlight AMD support in changelog. --- CHANGELOG.md | 3 ++- README.md | 4 ++-- docs/installation.rst | 9 +++++---- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 49aef9d5..e9bd29e1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,8 @@ ### v0.5.0 (2025-12-25) JAX support is now available in -OpenEquivariance for NVIDIA GPUs. See the +OpenEquivariance for BOTH NVIDIA and +AMD GPUs! See the [documentation](https://passionlab.github.io/OpenEquivariance/) and README.md for instructions on installation and usage. diff --git a/README.md b/README.md index 3e4cf6bc..e3fb7885 100644 --- a/README.md +++ b/README.md @@ -32,8 +32,8 @@ check out the [documentation](https://passionlab.github.io/OpenEquivariance). ⭐️ **JAX**: Our latest update brings support for JAX. For NVIDIA GPUs, -install it with the following two commands -(strictly in order): +install it (after installing JAX) +with the following two commands strictly in order: ``` bash pip install openequivariance[jax] diff --git a/docs/installation.rst b/docs/installation.rst index 4ab008d8..9b4d65d5 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -8,8 +8,8 @@ Installation (Torch and JAX) You need the following to install OpenEquivariance: - A Linux system equipped with an NVIDIA / AMD graphics card. -- Either PyTorch >= 2.4 (>= 2.8 for AOTI and export), or JAX with CUDA 12 support - or higher. +- Either PyTorch >= 2.4 (>= 2.8 for AOTI and export), or JAX>0.5.0 + with CUDA 12 support or higher. - GCC 9+ and the CUDA / HIP toolkit. The command ``c++ --version`` should return >= 9.0; see below for details on setting an alternate compiler. @@ -58,8 +58,9 @@ installation (or upgrade) with pip. JAX ------------------------------------------ -Install OpenEquivariance for NVIDIA GPUs by -executing the following two commands in order: +Before starting, ensure the appropriate JAX Python +package is installed in your environment. Then +run the following two commands stricly in order: .. code-block:: bash From 77ef6a9261adf1da777be2684ce8b15d90fabd81 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 5 Jan 2026 03:19:14 -0500 Subject: [PATCH 110/116] Ruff. --- openequivariance/openequivariance/jax/extlib/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/openequivariance/openequivariance/jax/extlib/__init__.py b/openequivariance/openequivariance/jax/extlib/__init__.py index 66a2af42..d0965502 100644 --- a/openequivariance/openequivariance/jax/extlib/__init__.py +++ b/openequivariance/openequivariance/jax/extlib/__init__.py @@ -1,6 +1,7 @@ import jax import openequivariance_extjax as oeq_extjax + def postprocess_kernel(kernel): if oeq_extjax.is_hip(): kernel = kernel.replace("__syncwarp();", "__threadfence_block();") @@ -10,9 +11,10 @@ def postprocess_kernel(kernel): else: return kernel -platform = "CUDA" + +platform = "CUDA" if oeq_extjax.is_hip(): - platform = "ROCM" + platform = "ROCM" for name, target in oeq_extjax.registrations().items(): jax.ffi.register_ffi_target(name, target, platform=platform) From 0eeec5847ab41d35dde6523793fb43bb2809bc27 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 5 Jan 2026 20:59:33 -0800 Subject: [PATCH 111/116] Updated documentation. --- .github/workflows/docs.yaml | 2 +- docs/api.rst | 4 +- docs/conf.py | 1 + docs/installation.rst | 167 +++++++++++++++++--------------- docs/requirements.txt | 4 + docs/tests_and_benchmarks.rst | 47 +++++---- openequivariance/pyproject.toml | 5 +- 7 files changed, 125 insertions(+), 105 deletions(-) create mode 100644 docs/requirements.txt diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yaml index 49e1cb04..fa6c7363 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yaml @@ -23,7 +23,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install sphinx furo + pip install -r docs/requirements.txt - name: Build website run: | sphinx-build -M dirhtml docs docs/_build diff --git a/docs/api.rst b/docs/api.rst index e268b160..c21b918f 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -30,9 +30,9 @@ PyTorch API :undoc-members: :exclude-members: name -.. autofunction:: openequivariance._torch_to_oeq_dtype +.. autofunction:: openequivariance.torch_to_oeq_dtype -.. autofunction:: openequivariance._torch_ext_so_path +.. autofunction:: openequivariance.torch_ext_so_path JAX API ------------------------ diff --git a/docs/conf.py b/docs/conf.py index df3b7636..adaeb2db 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -30,6 +30,7 @@ extensions = [ "sphinx.ext.autodoc", + "sphinx_inline_tabs" ] sys.path.insert(0, str(Path("../openequivariance").resolve())) diff --git a/docs/installation.rst b/docs/installation.rst index 9b4d65d5..d7188b77 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -1,4 +1,4 @@ -Installation (Torch and JAX) +Installation ============================== .. toctree:: @@ -14,122 +14,133 @@ You need the following to install OpenEquivariance: ``c++ --version`` should return >= 9.0; see below for details on setting an alternate compiler. -PyTorch ------------------------------------------- +.. tab:: PyTorch -Installation is one easy command, followed by import verification: + Installation is one easy command, followed by import verification: -.. code-block:: bash + .. code-block:: bash - pip install openequivariance - python -c "import openequivariance" + pip install openequivariance + python -c "import openequivariance" -The second line triggers a build of the C++ extension we use to compile -kernels, which can take a couple of minutes. Subsequent imports are -much faster since this extension is cached. + The second line triggers a build of the C++ extension we use to compile + kernels, which can take a couple of minutes. Subsequent imports are + much faster since this extension is cached. -To get the nightly build, run + To support ``torch.compile``, ``torch.export``, and + JITScript, OpenEquivariance needs to compile a C++ extension + tightly integrated with PyTorch. If you see a warning that + this extension could not be compiled, first check: -.. code-block:: bash + .. code-block:: bash - pip install git+https://github.com/PASSIONLab/OpenEquivariance#subdirectory=openequivariance + c++ --version + + To build the extension with an alternate compiler, set the + ``CC`` and ``CXX`` + environment variable and retry the import: -To support ``torch.compile``, ``torch.export``, and -JITScript, OpenEquivariance needs to compile a C++ extension -tightly integrated with PyTorch. If you see a warning that -this extension could not be compiled, first check: + .. code-block:: bash -.. code-block:: bash + export CC=/path/to/your/gcc + export CXX=/path/to/your/g++ + python -c "import openequivariance" - c++ --version - -To build the extension with an alternate compiler, set the -``CC`` and ``CXX`` -environment variable and retry the import: -.. code-block:: bash + These configuration steps are required only ONCE after + installation (or upgrade) with pip. - export CC=/path/to/your/gcc - export CXX=/path/to/your/g++ - python -c "import openequivariance" -These configuration steps are required only ONCE after -installation (or upgrade) with pip. +.. tab:: JAX NVIDIA GPUs -JAX ------------------------------------------- -Before starting, ensure the appropriate JAX Python -package is installed in your environment. Then -run the following two commands stricly in order: + First ensure the appropriate JAX Python + package is installed in your environment. Then + run the following two commands stricly in order: -.. code-block:: bash + .. code-block:: bash - pip install openequivariance[jax] - pip install openequivariance_extjax --no-build-isolation + pip install openequivariance[jax] + pip install openequivariance_extjax --no-build-isolation +.. tab:: JAX AMD GPUs -For AMD GPUs, use + Ensure that JAX is installed correctly with RocM support + before running, in order, -.. code-block:: bash + .. code-block:: bash - pip install openequivariance[jax] - JAX_HIP=1 pip install openequivariance_extjax --no-build-isolation + pip install openequivariance[jax] + JAX_HIP=1 pip install openequivariance_extjax --no-build-isolation -From there, set ``OEQ_NOTORCH=1`` to avoid a PyTorch import and test the package: -.. code-block:: bash +.. tab:: Nightly (PT) + + .. code-block:: bash + + pip install "git+https://github.com/PASSIONLab/OpenEquivariance#subdirectory=openequivariance" - OEQ_NOTORCH=1 - python -c "import openequivariance.jax" -Likewise, you can get the nightly build with the following commands: +.. tab:: Nightly (JAX) -.. code-block:: bash + .. code-block:: bash + + pip install "git+https://github.com/PASSIONLab/OpenEquivariance#subdirectory=openequivariance[jax]" + pip install "git+https://github.com/PASSIONLab/OpenEquivariance#subdirectory=openequivariance_extjax --no-build-isolation" + + # Use the command below for JAX+AMD + # JAX_HIP=1 pip install "git+https://github.com/PASSIONLab/OpenEquivariance#subdirectory=openequivariance_extjax --no-build-isolation" + + +If you're using JAX, set the environment variable +``OEQ_NOTORCH=1`` to avoid a PyTorch import: + +.. code-block:: bash + + export OEQ_NOTORCH=1 + python -c "import openequivariance.jax" - pip install "git+https://github.com/PASSIONLab/OpenEquivariance#subdirectory=openequivariance[jax]" - pip install "git+https://github.com/PASSIONLab/OpenEquivariance#subdirectory=openequivariance_extjax" Configurations on Major Platforms --------------------------------- OpenEquivariance has been tested on both supercomputers and lab clusters. -Here are some tested environment configuration files. If use OpenEquivariance -on a widely-used platform, send us a pull request to add your configuration! +Here are some tested environment configuration files. If you use OpenEquivariance +on a major cluster, send us a pull request to add your configuration! -NERSC Perlmutter (NVIDIA A100) -"""""""""""""""""""""""""""""" -.. code-block:: bash - :caption: env.sh (last updated June 2025) +.. tab:: NERSC Perlmutter (NVIDIA A100) - module load gcc - module load conda + .. code-block:: bash + :caption: env.sh (last updated June 2025) - # Deactivate any base environments - for i in $(seq ${CONDA_SHLVL}); do - conda deactivate - done + module load gcc + module load conda - conda activate + # Deactivate any base environments + for i in $(seq ${CONDA_SHLVL}); do + conda deactivate + done + conda activate -OLCF Frontier (AMD MI250x) -"""""""""""""""""""""""""" -You need to install a HIP-enabled verison of PyTorch to use our package. -To do this, follow the steps `here `_. +.. tab:: OLCF Frontier (AMD MI250x) -.. code-block:: bash - :caption: env.sh (last updated June 2025) + You need to install a HIP-enabled verison of PyTorch to use our package. + Follow the steps `here `_. + + + .. code-block:: bash + :caption: env.sh (last updated June 2025) - module load PrgEnv-gnu/8.6.0 - module load miniforge3/23.11.0-0 - module load rocm/6.4.0 - module load craype-accel-amd-gfx90a + module load PrgEnv-gnu/8.6.0 + module load miniforge3/23.11.0-0 + module load rocm/6.4.0 + module load craype-accel-amd-gfx90a - for i in $(seq ${CONDA_SHLVL}); do - conda deactivate - done + for i in $(seq ${CONDA_SHLVL}); do + conda deactivate + done - conda activate - export CC=cc - export CXX=CC \ No newline at end of file + conda activate + export CC=cc + export CXX=CC \ No newline at end of file diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 00000000..1cc76517 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,4 @@ +furo +sphinx +sphinx-inline-tabs +sphinx-autobuild \ No newline at end of file diff --git a/docs/tests_and_benchmarks.rst b/docs/tests_and_benchmarks.rst index 0799a033..f602ab44 100644 --- a/docs/tests_and_benchmarks.rst +++ b/docs/tests_and_benchmarks.rst @@ -12,33 +12,39 @@ these; we provide instructions below. We recommend you clone our repository and use an editable install to run tests and benchmarks. -You can still test our code with a non-editable install; just -download the test folder and install the non-editable package and the dependencies with: +Correctness +------------------------------ +To set up an editable install and run our tests, use the following code: -.. code-block:: bash +.. tab:: PyTorch - pip install openequivariance[dev,bench] + .. code-block:: bash -Correctness ------------------------------- -To set up the editable install and run the entire PyTorch testsuite, use: + git clone https://github.com/PASSIONLab/OpenEquivariance + cd OpenEquivariance + pip install -e "./openequivariance[dev]" + pytest tests/ -.. code-block:: bash +.. tab:: JAX - git clone https://github.com/PASSIONLab/OpenEquivariance - cd OpenEquivariance - pip install -e .[dev] - pytest + Note: To test correctness in JAX, we still require + an installation of PyTorch and e3nn in your environment. -Browse the ``tests`` directory to run specific components. + .. code-block:: bash -To test the JAX wrappers, follow the same steps above and make sure that -``openequivariance_extjax`` is installed without build isolation. Then run + git clone https://github.com/PASSIONLab/OpenEquivariance + cd OpenEquivariance + + pip install "./openequivariance[jax]" + pip install "./openequivariance[dev]" + pip install "./openequivariance_extjax" --no-build-isolation + + pytest --jax tests/example_test.py + pytest --jax tests/batch_test.py + pytest --jax tests/conv_test.py + +Browse the ``tests`` directory to run specific components. -.. code-block:: bash - pytest --jax tests/example_test.py - pytest --jax tests/batch_test.py - pytest --jax tests/conv_test.py Replicating our Benchmarks @@ -87,12 +93,13 @@ OpenEquivariance exhibits up to 2x speedup over FlashTP's fused kernels. List of GPUs Tested -------------------------------- -OpenEquivariance has been tested successfully the following GPUs. Submit a pull +OpenEquivariance runs successfully the following GPUs. Submit a pull request if you'd like to add your own! - NVIDIA V100 (V. Bharadwaj, LBNL Einsteinium, June 2025) - NVIDIA A100-SXM-40GB and A100-SXM-80GB (A. Glover, NERSC Perlmutter, June 2025) - NVIDIA A5000 (V. Bharadwaj, UCB SLICE, June 2025) +- NVIDIA T4 (V. Bharadwaj, Google Colab, Jan 2026) - NVIDIA H100 (L. Larsen, P1 DTU HPC, June 2025) - AMD MI250x (V. Bharadwaj, OLCF Frontier, June 2025) - AMD MI300x (V. Bharadwaj, AMD Cloud, February 2025) \ No newline at end of file diff --git a/openequivariance/pyproject.toml b/openequivariance/pyproject.toml index e2dd21de..a0ddd618 100644 --- a/openequivariance/pyproject.toml +++ b/openequivariance/pyproject.toml @@ -56,10 +56,7 @@ dev = [ "pytest-check", "pytest-subtests", "torch_geometric", - "cmake", - "furo", - "sphinx", - "sphinx-autobuild" + "cmake" ] jax = [ From 90b34055c183a061ec55e168f7914f643fa1e359 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 5 Jan 2026 21:00:53 -0800 Subject: [PATCH 112/116] Updated installation instructions. --- docs/installation.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/installation.rst b/docs/installation.rst index d7188b77..5ade5c0c 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -9,7 +9,7 @@ You need the following to install OpenEquivariance: - A Linux system equipped with an NVIDIA / AMD graphics card. - Either PyTorch >= 2.4 (>= 2.8 for AOTI and export), or JAX>0.5.0 - with CUDA 12 support or higher. + with CUDA or RocM support. - GCC 9+ and the CUDA / HIP toolkit. The command ``c++ --version`` should return >= 9.0; see below for details on setting an alternate compiler. From 4acd545055e3ae0fdd6b92a90507fd0d78eea027 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 5 Jan 2026 21:04:49 -0800 Subject: [PATCH 113/116] More ruff. --- docs/conf.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index adaeb2db..540cf37e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -28,10 +28,7 @@ html_theme = "furo" # html_static_path = ["_static"] -extensions = [ - "sphinx.ext.autodoc", - "sphinx_inline_tabs" -] +extensions = ["sphinx.ext.autodoc", "sphinx_inline_tabs"] sys.path.insert(0, str(Path("../openequivariance").resolve())) From 6f61ee90c5c3bb74ca1fe98c1f5409aef335c67a Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Thu, 15 Jan 2026 18:50:25 -0800 Subject: [PATCH 114/116] Added option for CI. --- openequivariance_extjax/CMakeLists.txt | 38 ++++++++++++++++++++++---- openequivariance_extjax/pyproject.toml | 1 + 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/openequivariance_extjax/CMakeLists.txt b/openequivariance_extjax/CMakeLists.txt index f7e97e44..25fec285 100644 --- a/openequivariance_extjax/CMakeLists.txt +++ b/openequivariance_extjax/CMakeLists.txt @@ -3,12 +3,35 @@ project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX) find_package(Python 3.10 REQUIRED COMPONENTS Interpreter Development.Module) -execute_process( - COMMAND "${Python_EXECUTABLE}" "-c" - "from jax import ffi; print(ffi.include_dir())" - OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR -) +# --- XLA CONFIGURATION --- +if(XLA_DIRECT_DOWNLOAD) + message(STATUS "XLA_DIRECT_DOWNLOAD is ON. Fetching XLA source...") + include(ExternalProject) + ExternalProject_Add( + xla + PREFIX ${CMAKE_BINARY_DIR}/xla + GIT_REPOSITORY https://github.com/openxla/xla.git + GIT_TAG main + GIT_SHALLOW TRUE + GIT_PROGRESS TRUE + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + LOG_DOWNLOAD ON + ) + ExternalProject_Get_Property(xla source_dir) + set(XLA_DIR ${source_dir}) +else() + message(STATUS "XLA_DIRECT_DOWNLOAD is OFF. Locating XLA via installed JAX...") + execute_process( + COMMAND "${Python_EXECUTABLE}" "-c" + "from jax import ffi; print(ffi.include_dir())" + OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR + ) +endif() + message(STATUS "XLA include directory: ${XLA_DIR}") +# ------------------------- execute_process( COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir @@ -44,6 +67,11 @@ target_include_directories(openequivariance_extjax PUBLIC ${XLA_DIR} ${HEADER_DI set_target_properties(openequivariance_extjax PROPERTIES POSITION_INDEPENDENT_CODE ON) target_compile_options(openequivariance_extjax PRIVATE -Wno-attributes -Wno-return-type) +# Ensure the module waits for XLA download if we are in direct download mode +if(XLA_DIRECT_DOWNLOAD) + add_dependencies(openequivariance_extjax xla) +endif() + if(JAX_HIP) target_link_libraries(openequivariance_extjax PRIVATE hiprtc) target_compile_definitions(openequivariance_extjax PRIVATE HIP_BACKEND=1) diff --git a/openequivariance_extjax/pyproject.toml b/openequivariance_extjax/pyproject.toml index 702040d9..47a21179 100644 --- a/openequivariance_extjax/pyproject.toml +++ b/openequivariance_extjax/pyproject.toml @@ -39,6 +39,7 @@ issues = "https://github.com/PASSIONLab/OpenEquivariance/issues" [tool.scikit-build.cmake.define] JAX_HIP = {env="JAX_HIP", default="0"} +XLA_DIRECT_DOWNLOAD = {env="XLA_DIRECT_DOWNLOAD", default="0"} [tool.setuptools_scm] root = ".." From cfe9b674dd6bdb9973948ad57e8bc898c2668513 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Thu, 15 Jan 2026 19:06:04 -0800 Subject: [PATCH 115/116] Ready to go. --- .github/workflows/release.yaml | 4 ---- .github/workflows/verify_extension_build.yml | 11 +++++------ README.md | 2 +- 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 2881158a..2f351dba 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -29,9 +29,7 @@ jobs: pypi-publish: name: Upload release to PyPI runs-on: ubuntu-latest - # build task to be completed first needs: build-oeq - # Specifying a GitHub environment is optional, but strongly encouraged environment: name: pypi url: https://pypi.org/p/openequivariance @@ -72,9 +70,7 @@ jobs: pypi-publish-extjax: name: Upload release to PyPI runs-on: ubuntu-latest - # build task to be completed first needs: build-oeq-extjax - # Specifying a GitHub environment is optional, but strongly encouraged environment: name: pypi url: https://pypi.org/p/openequivariance_extjax diff --git a/.github/workflows/verify_extension_build.yml b/.github/workflows/verify_extension_build.yml index f9367a32..8a097281 100644 --- a/.github/workflows/verify_extension_build.yml +++ b/.github/workflows/verify_extension_build.yml @@ -1,4 +1,4 @@ -name: OEQ CUDA C++ Extension Build Verification +name: OEQ C++ Extension Build Verification on: push: @@ -31,13 +31,12 @@ jobs: pip install -r .github/workflows/requirements_cuda_ci.txt pip install -e "./openequivariance" - - name: Test extension build via import + - name: Test CUDA extension build via import run: | pytest \ tests/import_test.py::test_extension_built \ tests/import_test.py::test_torch_extension_built - #- name: Test JAX extension build - # run: | - # pip install -e "./openequivariance[jax]" - # pip install -e ""./openequivariance_extjax" --no-build-isolation + - name: Test JAX extension build + run: | + pip install -e "./openequivariance_extjax" --no-build-isolation \ No newline at end of file diff --git a/README.md b/README.md index e3fb7885..c68e4fa9 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # OpenEquivariance -[![OEQ CUDA C++ Extension Build Verification](https://github.com/PASSIONLab/OpenEquivariance/actions/workflows/verify_extension_build.yml/badge.svg?event=push)](https://github.com/PASSIONLab/OpenEquivariance/actions/workflows/verify_extension_build.yml) +[![OEQ C++ Extension Build Verification](https://github.com/PASSIONLab/OpenEquivariance/actions/workflows/verify_extension_build.yml/badge.svg?event=push)](https://github.com/PASSIONLab/OpenEquivariance/actions/workflows/verify_extension_build.yml) [![License](https://img.shields.io/badge/License-BSD_3--Clause-blue.svg)](https://opensource.org/licenses/BSD-3-Clause) [[PyTorch Examples]](#pytorch-examples) From 09d62c5ee1de52541b614e13b59d6f7fb074f7b7 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Thu, 15 Jan 2026 19:09:31 -0800 Subject: [PATCH 116/116] Enabled direct download XLA. --- .github/workflows/verify_extension_build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/verify_extension_build.yml b/.github/workflows/verify_extension_build.yml index 8a097281..39888491 100644 --- a/.github/workflows/verify_extension_build.yml +++ b/.github/workflows/verify_extension_build.yml @@ -39,4 +39,4 @@ jobs: - name: Test JAX extension build run: | - pip install -e "./openequivariance_extjax" --no-build-isolation \ No newline at end of file + XLA_DIRECT_DOWNLOAD=1 pip install -e "./openequivariance_extjax" --no-build-isolation \ No newline at end of file