From b35483e689d4338205529b375aef27d41ee5771d Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 2 Feb 2026 23:03:19 -0800 Subject: [PATCH 01/15] Refactoring PyTorch tensor product. --- .../openequivariance/_torch/TensorProduct.py | 277 ++-------- .../extension/libtorch_tp_jit.cpp | 506 +++++++----------- 2 files changed, 231 insertions(+), 552 deletions(-) diff --git a/openequivariance/openequivariance/_torch/TensorProduct.py b/openequivariance/openequivariance/_torch/TensorProduct.py index 05ea54b5..2e421288 100644 --- a/openequivariance/openequivariance/_torch/TensorProduct.py +++ b/openequivariance/openequivariance/_torch/TensorProduct.py @@ -2,14 +2,13 @@ from openequivariance import TPProblem from openequivariance._torch import extlib import torch -import typing from openequivariance.core.utils import torch_to_oeq_dtype from openequivariance.benchmark.logging_utils import getLogger from openequivariance._torch.utils import reorder_torch from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin import numpy as np -from openequivariance._torch.extlib import DeviceBuffer +import json logger = getLogger() @@ -46,25 +45,23 @@ def _init_class(self): self.input_args["torch_op"], ) - internal_cls = None - if extlib.TORCH_COMPILE: - internal_cls = torch.classes.libtorch_tp_jit.TorchJITProduct - else: - internal_cls = extlib.JITTPImpl - - logger.info("Starting kernel compiler...") - self.internal = internal_cls( - self.jit_kernel, - vars(self.forward_schedule.launch_config), - vars(self.backward_schedule.launch_config), - vars(self.double_backward_schedule.launch_config), - self.kernelProp, + self.kernel = json.dumps( + { + "kernel": self.jit_kernel, + "forward_config": vars(self.forward_schedule.launch_config), + "backward_config": vars(self.backward_schedule.launch_config), + "double_backward_config": vars( + self.double_backward_schedule.launch_config + ), + "kernel_prop": self.kernelProp, + } ) - logger.info("Kernel compiled!") + self.hash = self.kernel.__hash__() logger.info(f"Kernel File Size: {len(self.jit_kernel) // 1024} KB") self.weight_numel = self.input_args["problem"].weight_numel self._setup_notorchbind() + if (not extlib.TORCH_COMPILE) or self.input_args["use_opaque"]: self.forward = self.forward_opaque @@ -124,170 +121,9 @@ def forward( """ return torch.ops.libtorch_tp_jit.jit_tp_forward(self.internal, x, y, W) - def _setup_notorchbind(self): - """ - In case TorchBind is not available (e.g. for torch.compile below PT2.8, etc.), - set up operations using custom ops. - """ - - @torch.library.custom_op( - f"openequivariance::tp_forward{self.tp_id}", - mutates_args=(), - device_types="cuda", - ) - def forward( - L1_in: torch.Tensor, L2_in: torch.Tensor, weights: torch.Tensor - ) -> torch.Tensor: - L1_in_c, L2_in_c, weights_c = ( - L1_in.contiguous(), - L2_in.contiguous(), - weights.contiguous(), - ) - L3_out = torch.empty( - (L1_in_c.shape[0], self.L3.dim), dtype=L1_in.dtype, device=L1_in.device - ) - self.forward_raw( - L1_in_c.shape[0], - L1_in_c.data_ptr(), - L2_in_c.data_ptr(), - L3_out.data_ptr(), - weights_c.data_ptr(), - ) - return L3_out - - @forward.register_fake - def _(L1_in, L2_in, weights): - return L1_in.new_empty(L1_in.shape[0], self.L3.dim) - - self.forward_opaque = forward - - # ---------------- Backward pass ----------------- - @torch.library.custom_op( - f"openequivariance::tp_grad_helper{self.tp_id}", - mutates_args=(), - device_types="cuda", - ) - def backward_helper( - L1_in: torch.Tensor, - L2_in: torch.Tensor, - weights: torch.Tensor, - L3_grad: torch.Tensor, - ) -> typing.List[torch.Tensor]: - L1_grad = torch.zeros_like(L1_in) - L2_grad = torch.zeros_like(L2_in) - weights_grad = torch.empty_like(weights) - - if self.config.shared_weights: - weights_grad[:] = 0.0 - - self.backward_raw( - L1_in.shape[0], - L1_in.contiguous().data_ptr(), - L1_grad.data_ptr(), - L2_in.contiguous().data_ptr(), - L2_grad.data_ptr(), - weights.contiguous().data_ptr(), - weights_grad.data_ptr(), - L3_grad.contiguous().data_ptr(), - ) - - return [L1_grad, L2_grad, weights_grad] - - @backward_helper.register_fake - def _(L1_in, L2_in, weights, L3_grad): - return [ - L1_in.new_empty(*L1_in.shape), - L2_in.new_empty(*L2_in.shape), - weights.new_empty(*weights.shape), - ] - - def setup_context(ctx, inputs, output): - ctx.L1_in, ctx.L2_in, ctx.weights = inputs - - def backward(ctx, grad_output): - result = backward_helper(ctx.L1_in, ctx.L2_in, ctx.weights, grad_output) - return result[0], result[1], result[2] - - self.forward_opaque.register_autograd(backward, setup_context=setup_context) - - def setup_context_double_backward(ctx, inputs, output): - ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad = inputs - - def double_backward(ctx, grad_output): - A, B, C, D = ctx.L1_in, ctx.L2_in, ctx.L3_grad, ctx.weights - E, F, G = grad_output[0], grad_output[1], grad_output[2] - - op1 = backward_helper(E, F, D, C) - op2 = backward_helper(A, B, G, C) - op3 = forward(E, B, D) - op4 = backward_helper(E, B, D, C) - op5 = backward_helper(A, F, D, C) - op6 = forward(A, F, D) - op7 = forward(A, B, G) - - return ( - op1[0] + op2[0], - op1[1] + op2[1], - (op4[2] + op5[2]), - (op3 + op6 + op7), - ) - - backward_helper.register_autograd( - double_backward, setup_context=setup_context_double_backward - ) @classmethod def register_torch_fakes(cls): - @torch._library.register_fake_class("libtorch_tp_jit::TorchJITProduct") - class TorchJITProduct: - def __init__( - self, - kernel_plaintext: str, - fwd_config: dict[str, int], - bwd_config: dict[str, int], - dbl_bwd_config: dict[str, int], - kernel_dims: dict[str, int], - ) -> None: - ( - self.kernel_plaintext, - self.fwd_config, - self.bwd_config, - self.dbl_bwd_config, - self.kernel_dims, - ) = ( - kernel_plaintext, - fwd_config, - bwd_config, - dbl_bwd_config, - kernel_dims, - ) - - @classmethod - def __obj_unflatten__(cls, flattened_product): - return cls(**dict(flattened_product)) - - def __len__(self): - return 0 - - def __setstate__(self, state): - self.kernel_plaintext = state["kernel_plaintext"] - self.fwd_config = state["fwd_config"] - self.bwd_config = state["bwd_config"] - self.dbl_bwd_config = state["dbl_bwd_config"] - self.kernel_dims = state["kernel_dims"] - - def exec_tensor_product_rawptr(*args, **kwargs): - pass - - def backward_rawptr(*args, **kwargs): - pass - - def L3_dim_getter(self): - return self.kernel_dims["L3_dim"] - - def irrep_dtype_getter(self): - return self.kernel_dims["irrep_dtype"] - @torch.library.register_fake("libtorch_tp_jit::jit_tp_forward") def fake_forward(jit, L1_in, L2_in, W): L3_dim = None @@ -307,11 +143,11 @@ def register_autograd(cls): backward_op = torch.ops.libtorch_tp_jit.jit_tp_backward def setup_context(ctx, inputs, output): - ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights = inputs + ctx.kernel, ctx.hash, ctx.L1_in, ctx.L2_in, ctx.weights = inputs def backward(ctx, grad_output): L1_grad, L2_grad, W_grad = backward_op( - ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, grad_output + ctx.kernel, ctx.hash, ctx.L1_in, ctx.L2_in, ctx.weights, grad_output ) return None, L1_grad, L2_grad, W_grad @@ -320,11 +156,11 @@ def backward(ctx, grad_output): ) def setup_context_double_backward(ctx, inputs, output): - ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad = inputs + ctx.kernel, ctx.hash, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad = inputs def double_backward(ctx, E, F, G): result = torch.ops.libtorch_tp_jit.jit_tp_double_backward( - ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad, E, F, G + ctx.kernel, ctx.hash, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad, E, F, G ) return None, result[0], result[1], result[2], result[3] @@ -353,31 +189,6 @@ def register_autocast(cls): def name(): return "LoopUnrollTP" - def forward_raw( - self, - batch: np.uint64, - L1_in: np.uint64, - L2_in: np.uint64, - L3_out: np.uint64, - weights: np.uint64, - ) -> None: - self.internal.exec_tensor_product_rawptr(batch, L1_in, L2_in, L3_out, weights) - - def backward_raw( - self, - batch_size: np.uint64, - L1_in: np.uint64, - L1_grad: np.uint64, - L2_in: np.uint64, - L2_grad: np.uint64, - weights: np.uint64, - weights_grad: np.uint64, - L3_grad: np.uint64, - ): - self.internal.backward_rawptr( - batch_size, L1_in, L1_grad, L2_in, L2_grad, weights, weights_grad, L3_grad - ) - def forward_cpu( self, L1_in: np.ndarray, @@ -389,19 +200,12 @@ def forward_cpu( weights, not self.config.shared_weights ) - batch = L1_in.shape[0] - L1_d = DeviceBuffer(L1_in) - L2_d = DeviceBuffer(L2_in) - L3_d = DeviceBuffer(L3_out) - weights_d = DeviceBuffer(weights_chunked) - self.internal.exec_tensor_product_rawptr( - batch, - L1_d.data_ptr(), - L2_d.data_ptr(), - L3_d.data_ptr(), - weights_d.data_ptr(), - ) - L3_d.copy_to_host() + torch_L1_in = torch.tensor(L1_in, device="cuda") + torch_L2_in = torch.tensor(L2_in, device="cuda") + torch_weights = torch.tensor(weights_chunked, device="cuda") + torch_L3_out = self.e3nn_tp(torch_L1_in, torch_L2_in, torch_weights) + + L3_out[:] = torch_L3_out.numpy(force=True) def backward_cpu( self, L1_in, L1_grad, L2_in, L2_grad, L3_grad, weights, weights_grad @@ -410,32 +214,19 @@ def backward_cpu( weights, not self.config.shared_weights ) - batch = L1_in.shape[0] - L1_d, L2_d, L3_d = ( - DeviceBuffer(L1_in), - DeviceBuffer(L2_in), - DeviceBuffer(L3_grad), - ) - L1_grad_d, L2_grad_d = DeviceBuffer(L1_grad), DeviceBuffer(L2_grad) - weights_d, weights_grad_d = ( - DeviceBuffer(weights_chunked), - DeviceBuffer(weights_grad), - ) + torch_L1_in = torch.tensor(L1_in, requires_grad=True, device="cuda") + torch_L2_in = torch.tensor(L2_in, requires_grad=True, device="cuda") + torch_weights = torch.tensor(weights_chunked, requires_grad=True, device="cuda") - self.internal.backward_rawptr( - batch, - L1_d.data_ptr(), - L1_grad_d.data_ptr(), - L2_d.data_ptr(), - L2_grad_d.data_ptr(), - weights_d.data_ptr(), - weights_grad_d.data_ptr(), - L3_d.data_ptr(), - ) + torch_out = self.e3nn_tp(torch_L1_in, torch_L2_in, torch_weights) + + torch_L3_grad_in = torch.tensor(L3_grad, device="cuda") + + torch_out.backward(gradient=torch_L3_grad_in) - L1_grad_d.copy_to_host() - L2_grad_d.copy_to_host() - weights_grad_d.copy_to_host() + L1_grad[:] = torch_L1_in.grad.numpy(force=True) + L2_grad[:] = torch_L2_in.grad.numpy(force=True) + weights_grad[:] = torch_weights.grad.numpy(force=True) weights_grad[:] = self.reorder_weights_to_e3nn( weights_grad, not self.config.shared_weights diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp index 18b8f65c..bfb978b9 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp @@ -3,9 +3,10 @@ #include #include #include +#include +#include -#include -#include +#include "json11/json11.hpp" #ifdef CUDA_BACKEND #include @@ -42,7 +43,7 @@ #include "convolution.hpp" using namespace std; -namespace py=pybind11; +using json = json11::Json; #include #include @@ -50,28 +51,15 @@ namespace py=pybind11; #include #include -using Map_t=torch::Dict; - -std::unordered_map to_map(const Map_t &map) { - std::unordered_map result; - for(auto it = map.begin(); it != map.end(); ++it) { - result[it->key()] = it->value(); - } - return result; -} +// --------------------- Utilities -------------------------- torch::Dtype enum_to_torch_dtype(int64_t i){ switch(i) { - case 1: - return torch::kFloat; - case 2: - return torch::kDouble; - case 3: - return torch::kInt; - case 4: - return torch::kLong; - case 5: - return torch::kUInt8; + case 1: return torch::kFloat; + case 2: return torch::kDouble; + case 3: return torch::kInt; + case 4: return torch::kLong; + case 5: return torch::kUInt8; } throw logic_error("Unsupported tensor datatype!"); } @@ -96,11 +84,23 @@ inline void* data_ptr(const torch::Tensor &tensor) { else if(tensor.dtype() == torch::kLong) return reinterpret_cast(tensor.data_ptr()); else if(tensor.dtype() == torch::kByte) - return reinterpret_cast(tensor.data_ptr()); + return reinterpret_cast(tensor.data_ptr()); // Replaces kUInt8 + else if(tensor.dtype() == torch::kInt) + return reinterpret_cast(tensor.data_ptr()); else throw logic_error("Unsupported tensor datatype!"); } +std::unordered_map parse_json_config(const json &j_obj) { + std::unordered_map result; + for (const auto &kv : j_obj.object_items()) { + result[kv.first] = static_cast(kv.second.number_value()); + } + return result; +} + +// --------------------- Compilation & Caching -------------------------- + struct KernelProp { int64_t L1_dim, L2_dim, L3_dim, weight_numel; bool shared_weights; @@ -112,7 +112,15 @@ struct KernelProp { torch::Dtype idx_dtype; torch::Dtype workspace_dtype; - KernelProp(Map_t &kernel_dims, bool is_convolution): + KernelProp() : + L1_dim(0), L2_dim(0), L3_dim(0), weight_numel(0), + shared_weights(false), + irrep_dtype(torch::kFloat), weight_dtype(torch::kFloat), + workspace_size(0), deterministic(false), + idx_dtype(torch::kInt), workspace_dtype(torch::kByte) {} + + KernelProp( + std::unordered_map &kernel_dims, bool is_convolution): L1_dim(kernel_dims.at("L1_dim")), L2_dim(kernel_dims.at("L2_dim")), L3_dim(kernel_dims.at("L3_dim")), @@ -126,81 +134,116 @@ struct KernelProp { deterministic = kernel_dims.at("deterministic"); idx_dtype = enum_to_torch_dtype(kernel_dims.at("idx_dtype")); } - } + } }; -class __attribute__ ((visibility ("default"))) TorchJITProduct : public torch::CustomClassHolder { -public: - Map_t fwd_dict, bwd_dict, dbl_bwd_dict, kernel_dims; - JITTPImpl internal; - KernelProp kernelProp; - int64_t L3_dim, irrep_dtype; - - TorchJITProduct(string kernel_plaintext, Map_t fwd_dict_i, Map_t bwd_dict_i, Map_t dbl_bwd_dict_i, Map_t kernel_dims_i) : - fwd_dict(fwd_dict_i.copy()), - bwd_dict(bwd_dict_i.copy()), - dbl_bwd_dict(dbl_bwd_dict_i.copy()), - kernel_dims(kernel_dims_i.copy()), - internal(kernel_plaintext, - to_map(fwd_dict_i), - to_map(bwd_dict_i), - to_map(dbl_bwd_dict_i), - to_map(kernel_dims_i) - ), - kernelProp(kernel_dims, false), - L3_dim(kernelProp.L3_dim), - irrep_dtype(kernel_dims_i.at("irrep_dtype")) - { } - - tuple< tuple, - tuple, - tuple, - tuple, - tuple> __obj_flatten__() { - return tuple(tuple("kernel_plaintext", internal.jit.kernel_plaintext), - tuple("fwd_config", fwd_dict), - tuple("bwd_config", bwd_dict), - tuple("dbl_bwd_config", dbl_bwd_dict), - tuple("kernel_dims", kernel_dims)); +// Global Caches +std::unordered_map>, + KernelProp + >> tp_cache; + +std::unordered_map>, + KernelProp + >> conv_cache; + +std::mutex mut; + +std::pair*, KernelProp> + compile_tp_with_caching(const torch::Tensor &json_bytes, + int64_t hash) { + { + const std::lock_guard lock(mut); + auto it = tp_cache.find(hash); + if (it == tp_cache.end()) { + // Cache Miss: Extract String + torch::Tensor cpu_tensor = json_bytes.to(torch::kCPU).contiguous(); + std::string json_payload( + reinterpret_cast(cpu_tensor.data_ptr()), + cpu_tensor.numel() + ); + + std::string err; + json root = json::parse(json_payload, err); + if (!err.empty()) throw std::runtime_error("JSON Parse Error: " + err); + + std::string kernel_src = root["kernel"].string_value(); + auto forward_cfg = parse_json_config(root["forward_config"]); + auto backward_cfg = parse_json_config(root["backward_config"]); + auto dbackward_cfg = parse_json_config(root["double_backward_config"]); + auto kernel_prop_map = parse_json_config(root["kernel_prop"]); + + auto jit_tp_impl = std::make_unique>( + kernel_src, + forward_cfg, + backward_cfg, + dbackward_cfg, + kernel_prop_map); + + tp_cache.insert({hash, + std::make_pair(std::move(jit_tp_impl), + KernelProp(kernel_prop_map, false))}); + it = tp_cache.find(hash); + } + return {it->second.first.get(), it->second.second}; } +} - void exec_tensor_product_device_rawptrs(int64_t num_batch, int64_t L1_in, int64_t L2_in, int64_t L3_out, int64_t weights) { - Stream stream = get_current_stream(); - internal.exec_tensor_product( - num_batch, - reinterpret_cast(L1_in), - reinterpret_cast(L2_in), - reinterpret_cast(L3_out), - reinterpret_cast(weights), - stream - ); - } - - void backward_device_rawptrs(int64_t num_batch, - int64_t L1_in, int64_t L1_grad, - int64_t L2_in, int64_t L2_grad, - int64_t weight, int64_t weight_grad, - int64_t L3_grad) { - Stream stream = get_current_stream(); - internal.backward(num_batch, - reinterpret_cast(L1_in), reinterpret_cast(L1_grad), - reinterpret_cast(L2_in), reinterpret_cast(L2_grad), - reinterpret_cast(weight), reinterpret_cast(weight_grad), - reinterpret_cast(L3_grad), stream - ); +std::pair*, KernelProp> + compile_conv_with_caching(const torch::Tensor &json_bytes, + int64_t hash) { + { + const std::lock_guard lock(mut); + auto it = conv_cache.find(hash); + if (it == conv_cache.end()) { + // Cache Miss: Extract String + torch::Tensor cpu_tensor = json_bytes.to(torch::kCPU).contiguous(); + std::string json_payload( + reinterpret_cast(cpu_tensor.data_ptr()), + cpu_tensor.numel() + ); + + std::string err; + json root = json::parse(json_payload, err); + if (!err.empty()) throw std::runtime_error("JSON Parse Error: " + err); + + std::string kernel_src = root["kernel"].string_value(); + auto forward_cfg = parse_json_config(root["forward_config"]); + auto backward_cfg = parse_json_config(root["backward_config"]); + auto dbackward_cfg = parse_json_config(root["double_backward_config"]); + auto kernel_prop_map = parse_json_config(root["kernel_prop"]); + + auto jit_conv_impl = std::make_unique>( + kernel_src, + forward_cfg, + backward_cfg, + dbackward_cfg, + kernel_prop_map); + + conv_cache.insert({hash, + std::make_pair(std::move(jit_conv_impl), + KernelProp(kernel_prop_map, true))}); + it = conv_cache.find(hash); + } + return {it->second.first.get(), it->second.second}; } -}; +} + +// --------------------- Tensor Products -------------------------- torch::Tensor jit_tp_forward( - const c10::intrusive_ptr &jit_instance, - const torch::Tensor &L1_in, - const torch::Tensor &L2_in, - const torch::Tensor &W) { + torch::Tensor json_bytes, int64_t hash, + torch::Tensor L1_in, + torch::Tensor L2_in, + torch::Tensor W) { + auto [jit_kernel, k] = compile_tp_with_caching(json_bytes, hash); Stream stream = get_current_stream(); const int64_t num_batch = L1_in.size(0); - const KernelProp &k = jit_instance->kernelProp; check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); @@ -216,7 +259,7 @@ torch::Tensor jit_tp_forward( at::Tensor L2_contig = L2_in.contiguous(); at::Tensor W_contig = W.contiguous(); - jit_instance->internal.exec_tensor_product( + jit_kernel->exec_tensor_product( num_batch, data_ptr(L1_contig), data_ptr(L2_contig), @@ -229,16 +272,16 @@ torch::Tensor jit_tp_forward( } tuple jit_tp_backward( - const c10::intrusive_ptr &jit_instance, - const torch::Tensor &L1_in, - const torch::Tensor &L2_in, - const torch::Tensor &W, - const torch::Tensor &L3_grad) { + torch::Tensor json_bytes, int64_t hash, + torch::Tensor L1_in, + torch::Tensor L2_in, + torch::Tensor W, + torch::Tensor L3_grad) { + auto [jit_kernel, k] = compile_tp_with_caching(json_bytes, hash); Stream stream = get_current_stream(); const int64_t num_batch = L1_in.size(0); - const KernelProp &k = jit_instance->kernelProp; check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); @@ -261,7 +304,7 @@ tuple jit_tp_backward( torch::Tensor W_contig = W.contiguous(); torch::Tensor L3_grad_contig = L3_grad.contiguous(); - jit_instance->internal.backward( + jit_kernel->backward( num_batch, data_ptr(L1_in_contig), data_ptr(L1_grad), data_ptr(L2_in_contig), data_ptr(L2_grad), @@ -274,19 +317,19 @@ tuple jit_tp_backward( } tuple jit_tp_double_backward( - const c10::intrusive_ptr &jit_instance, - const torch::Tensor &L1_in, - const torch::Tensor &L2_in, - const torch::Tensor &W, - const torch::Tensor &L3_grad, - const torch::Tensor &L1_dgrad, - const torch::Tensor &L2_dgrad, - const torch::Tensor &W_dgrad) { + torch::Tensor json_bytes, int64_t hash, + torch::Tensor L1_in, + torch::Tensor L2_in, + torch::Tensor W, + torch::Tensor L3_grad, + torch::Tensor L1_dgrad, + torch::Tensor L2_dgrad, + torch::Tensor W_dgrad) { + auto [jit_kernel, k] = compile_tp_with_caching(json_bytes, hash); Stream stream = get_current_stream(); const int64_t num_batch = L1_in.size(0); - const KernelProp &k = jit_instance->kernelProp; check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); @@ -321,7 +364,7 @@ tuple jit_tp_double_ TORCH_CHECK(W.dim() == 1); } - jit_instance->internal.double_backward( + jit_kernel->double_backward( num_batch, data_ptr(L1_in_contig), data_ptr(L2_in_contig), data_ptr(W_contig), data_ptr(L3_grad_contig), @@ -336,127 +379,23 @@ tuple jit_tp_double_ } -// =========================================================== - -class TorchJITConv : public torch::CustomClassHolder { -public: - Map_t fwd_dict, bwd_dict, dbl_bwd_dict, kernel_dims; - JITConvImpl internal; - KernelProp kernelProp; - int64_t L3_dim, irrep_dtype; - - TorchJITConv(string kernel_plaintext, Map_t fwd_dict_i, Map_t bwd_dict_i, Map_t dbl_bwd_dict_i, Map_t kernel_dims_i) : - fwd_dict(fwd_dict_i.copy()), - bwd_dict(bwd_dict_i.copy()), - dbl_bwd_dict(bwd_dict_i.copy()), - kernel_dims(kernel_dims_i.copy()), - internal(kernel_plaintext, - to_map(fwd_dict_i), - to_map(bwd_dict_i), - to_map(dbl_bwd_dict_i), - to_map(kernel_dims_i) - ), - kernelProp(kernel_dims, true), - L3_dim(kernelProp.L3_dim), - irrep_dtype(kernel_dims_i.at("irrep_dtype")) - { } - - tuple, - tuple, - tuple, - tuple, - tuple> __obj_flatten__() { - return tuple(tuple("kernel_plaintext", internal.jit.kernel_plaintext), - tuple("fwd_config", fwd_dict), - tuple("bwd_config", bwd_dict), - tuple("dbl_bwd_config", dbl_bwd_dict), - tuple("kernel_dims", kernel_dims)); - } - - void exec_conv_rawptrs( - int64_t L1_in, int64_t L2_in, int64_t weights, int64_t L3_out, - int64_t rows, int64_t cols, - int64_t nnz, int64_t node_count, - int64_t workspace) { - Stream stream = get_current_stream(); - internal.exec_conv( - reinterpret_cast(L1_in), - reinterpret_cast(L2_in), - reinterpret_cast(weights), - reinterpret_cast(L3_out), - reinterpret_cast(rows), - reinterpret_cast(cols), - nnz, node_count, - reinterpret_cast(workspace), - stream); - } - void backward_rawptrs( - int64_t L1_in, int64_t L1_grad, - int64_t L2_in, int64_t L2_grad, - int64_t weight, int64_t weight_grad, - int64_t L3_grad, - int64_t rows, int64_t cols, - int64_t nnz, int64_t node_count, - int64_t workspace, - int64_t transpose_perm) { - Stream stream = get_current_stream(); - internal.backward( - reinterpret_cast(L1_in), reinterpret_cast(L1_grad), - reinterpret_cast(L2_in), reinterpret_cast(L2_grad), - reinterpret_cast(weight), reinterpret_cast(weight_grad), - reinterpret_cast(L3_grad), - reinterpret_cast(rows), - reinterpret_cast(cols), - nnz, node_count, - reinterpret_cast(workspace), - reinterpret_cast(transpose_perm), - stream); - } - void double_backward_rawptrs( - int64_t L1_in, int64_t L2_in, int64_t W, int64_t L3_grad, - int64_t L1_dgrad, int64_t L2_dgrad, int64_t w_dgrad, - int64_t L1_grad, int64_t L2_grad, int64_t W_grad, int64_t L3_dgrad, - int64_t rows, int64_t cols, - int64_t nnz, int64_t node_count, - int64_t wspace, int64_t transpose_perm) { - - Stream stream = get_current_stream(); - internal.double_backward( - reinterpret_cast(L1_in), - reinterpret_cast(L2_in), - reinterpret_cast(W), - reinterpret_cast(L3_grad), - reinterpret_cast(L1_dgrad), - reinterpret_cast(L2_dgrad), - reinterpret_cast(w_dgrad), - reinterpret_cast(L1_grad), - reinterpret_cast(L2_grad), - reinterpret_cast(W_grad), - reinterpret_cast(L3_dgrad), - reinterpret_cast(rows), - reinterpret_cast(cols), - nnz, node_count, - reinterpret_cast(wspace), - reinterpret_cast(transpose_perm), - stream); - } -}; +// ========================= Convolution ================================== torch::Tensor jit_conv_forward( - const c10::intrusive_ptr &jit_instance, - const torch::Tensor &L1_in, - const torch::Tensor &L2_in, - const torch::Tensor &W, - const torch::Tensor &rows, - const torch::Tensor &cols, - const torch::Tensor &workspace, - const torch::Tensor &transpose_perm) { - + torch::Tensor json_bytes, int64_t hash, + torch::Tensor L1_in, + torch::Tensor L2_in, + torch::Tensor W, + torch::Tensor rows, + torch::Tensor cols, + torch::Tensor workspace, + torch::Tensor transpose_perm) { + + auto [jit_kernel, k] = compile_conv_with_caching(json_bytes, hash); Stream stream = get_current_stream(); const int64_t nnz = rows.size(0); const int64_t node_count = L1_in.size(0); - const KernelProp &k = jit_instance->kernelProp; check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in"); check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in"); @@ -483,7 +422,7 @@ torch::Tensor jit_conv_forward( torch::Tensor cols_contig = cols.contiguous(); torch::Tensor workspace_contig = workspace.contiguous(); - jit_instance->internal.exec_conv( + jit_kernel->exec_conv( data_ptr(L1_contig), data_ptr(L2_contig), data_ptr(W_contig), @@ -498,21 +437,21 @@ torch::Tensor jit_conv_forward( } tuple jit_conv_backward( - const c10::intrusive_ptr &jit_instance, - const torch::Tensor &L1_in, - const torch::Tensor &L2_in, - const torch::Tensor &W, - const torch::Tensor &L3_grad, - const torch::Tensor &rows, - const torch::Tensor &cols, - const torch::Tensor &workspace, - const torch::Tensor &transpose_perm) { + torch::Tensor json_bytes, int64_t hash, + torch::Tensor L1_in, + torch::Tensor L2_in, + torch::Tensor W, + torch::Tensor L3_grad, + torch::Tensor rows, + torch::Tensor cols, + torch::Tensor workspace, + torch::Tensor transpose_perm) { + auto [jit_kernel, k] = compile_conv_with_caching(json_bytes, hash); Stream stream = get_current_stream(); const int64_t nnz = rows.size(0); const int64_t node_count = L1_in.size(0); - const KernelProp &k = jit_instance->kernelProp; check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in"); check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in"); @@ -549,7 +488,7 @@ tuple jit_conv_backward( if(k.shared_weights) W_grad.zero_(); - jit_instance->internal.backward( + jit_kernel->backward( data_ptr(L1_in_contig), data_ptr(L1_grad), data_ptr(L2_in_contig), data_ptr(L2_grad), data_ptr(W_contig), data_ptr(W_grad), @@ -564,24 +503,24 @@ tuple jit_conv_backward( } tuple jit_conv_double_backward( - const c10::intrusive_ptr &jit_instance, - const torch::Tensor &L1_in, - const torch::Tensor &L2_in, - const torch::Tensor &W, - const torch::Tensor &L3_grad, - const torch::Tensor &L1_dgrad, - const torch::Tensor &L2_dgrad, - const torch::Tensor &W_dgrad, - const torch::Tensor &rows, - const torch::Tensor &cols, - const torch::Tensor &workspace, - const torch::Tensor &transpose_perm) { + torch::Tensor json_bytes, int64_t hash, + torch::Tensor L1_in, + torch::Tensor L2_in, + torch::Tensor W, + torch::Tensor L3_grad, + torch::Tensor L1_dgrad, + torch::Tensor L2_dgrad, + torch::Tensor W_dgrad, + torch::Tensor rows, + torch::Tensor cols, + torch::Tensor workspace, + torch::Tensor transpose_perm) { + auto [jit_kernel, k] = compile_conv_with_caching(json_bytes, hash); Stream stream = get_current_stream(); const int64_t nnz = rows.size(0); const int64_t node_count = L1_in.size(0); - const KernelProp &k = jit_instance->kernelProp; check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in"); check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in"); @@ -628,7 +567,7 @@ tuple jit_conv_doubl if(k.shared_weights) W_grad.zero_(); - jit_instance->internal.double_backward( + jit_kernel->double_backward( data_ptr(L1_in_contig), data_ptr(L2_in_contig), data_ptr(W_contig), data_ptr(L3_grad_contig), data_ptr(L1_dgrad_contig), data_ptr(L2_dgrad_contig), @@ -646,68 +585,6 @@ tuple jit_conv_doubl // =========================================================== -TORCH_LIBRARY_FRAGMENT(libtorch_tp_jit, m) { - m.class_("TorchJITProduct") - .def(torch::init()) - .def("__obj_flatten__", &TorchJITProduct::__obj_flatten__) - .def("exec_tensor_product_rawptr", &TorchJITProduct::exec_tensor_product_device_rawptrs) - .def("backward_rawptr", &TorchJITProduct::backward_device_rawptrs) - .def("__len__", [](const c10::intrusive_ptr& test) -> int64_t { - return 0; - }) - .def_readonly("L3_dim", &TorchJITProduct::L3_dim) - .def_readonly("irrep_dtype", &TorchJITProduct::irrep_dtype) - .def("__eq__", [](const c10::IValue & self, const c10::IValue& other) -> bool { - return self.is(other); - }) - .def_pickle( - // __getstate__ - [](const c10::intrusive_ptr& self) - -> tuple { - return tuple(self->internal.jit.kernel_plaintext, self->fwd_dict, self->bwd_dict, self->dbl_bwd_dict, self->kernel_dims); - }, - // __setstate__ - [](tuple state) - -> c10::intrusive_ptr { - return c10::make_intrusive(get<0>(state), get<1>(state), get<2>(state), get<3>(state), get<4>(state)); - }); - - m.def("jit_tp_forward(__torch__.torch.classes.libtorch_tp_jit.TorchJITProduct jit, Tensor L1_in, Tensor L2_in, Tensor W) -> Tensor"); - m.def("jit_tp_backward(__torch__.torch.classes.libtorch_tp_jit.TorchJITProduct jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad) -> (Tensor, Tensor, Tensor)"); - m.def("jit_tp_double_backward(__torch__.torch.classes.libtorch_tp_jit.TorchJITProduct jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad) -> (Tensor, Tensor, Tensor, Tensor)"); - - - m.class_("TorchJITConv") - .def(torch::init()) - .def("__obj_flatten__", &TorchJITConv::__obj_flatten__) - .def("exec_conv_rawptrs", &TorchJITConv::exec_conv_rawptrs) - .def("backward_rawptrs", &TorchJITConv::backward_rawptrs) - .def("double_backward_rawptrs", &TorchJITConv::double_backward_rawptrs) - .def("__len__", [](const c10::intrusive_ptr& test) -> int64_t { - return 0; - }) - .def_readonly("L3_dim", &TorchJITConv::L3_dim) - .def_readonly("irrep_dtype", &TorchJITConv::irrep_dtype) - .def("__eq__", [](const c10::IValue & self, const c10::IValue& other) -> bool { - return self.is(other); - }) - .def_pickle( - // __getstate__ - [](const c10::intrusive_ptr& self) - -> tuple { - return tuple(self->internal.jit.kernel_plaintext, self->fwd_dict, self->bwd_dict, self->dbl_bwd_dict, self->kernel_dims); - }, - // __setstate__ - [](tuple state) - -> c10::intrusive_ptr { - return c10::make_intrusive(get<0>(state), get<1>(state), get<2>(state), get<3>(state), get<4>(state)); - }); - - m.def("jit_conv_forward(__torch__.torch.classes.libtorch_tp_jit.TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> Tensor"); - m.def("jit_conv_backward(__torch__.torch.classes.libtorch_tp_jit.TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor)"); - m.def("jit_conv_double_backward(__torch__.torch.classes.libtorch_tp_jit.TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor, Tensor)"); -}; - TORCH_LIBRARY_IMPL(libtorch_tp_jit, CUDA, m) { m.impl("jit_tp_forward", &jit_tp_forward); m.impl("jit_tp_backward", &jit_tp_backward); @@ -718,4 +595,15 @@ TORCH_LIBRARY_IMPL(libtorch_tp_jit, CUDA, m) { m.impl("jit_conv_double_backward", &jit_conv_double_backward); }; +// Define headers for the library (without implementations) +TORCH_LIBRARY(libtorch_tp_jit, m) { + m.def("jit_tp_forward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W) -> Tensor"); + m.def("jit_tp_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad) -> (Tensor, Tensor, Tensor)"); + m.def("jit_tp_double_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad) -> (Tensor, Tensor, Tensor, Tensor)"); + + m.def("jit_conv_forward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> Tensor"); + m.def("jit_conv_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor)"); + m.def("jit_conv_double_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor, Tensor)"); +} + PYBIND11_MODULE(libtorch_tp_jit, m) {} \ No newline at end of file From fa58732c70a7a6ee59c1d96149f9a3155292ecc2 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Mon, 2 Feb 2026 23:27:24 -0800 Subject: [PATCH 02/15] Tensor product refactored successfully. --- .../openequivariance/_torch/TensorProduct.py | 10 ++-- .../_torch/TensorProductConv.py | 55 ------------------- .../_torch/extlib/__init__.py | 2 +- .../openequivariance/_torch/utils.py | 6 ++ 4 files changed, 13 insertions(+), 60 deletions(-) diff --git a/openequivariance/openequivariance/_torch/TensorProduct.py b/openequivariance/openequivariance/_torch/TensorProduct.py index 2e421288..cd52b057 100644 --- a/openequivariance/openequivariance/_torch/TensorProduct.py +++ b/openequivariance/openequivariance/_torch/TensorProduct.py @@ -4,7 +4,7 @@ import torch from openequivariance.core.utils import torch_to_oeq_dtype from openequivariance.benchmark.logging_utils import getLogger -from openequivariance._torch.utils import reorder_torch +from openequivariance._torch.utils import reorder_torch, string_to_tensor from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin import numpy as np @@ -45,7 +45,7 @@ def _init_class(self): self.input_args["torch_op"], ) - self.kernel = json.dumps( + kernel_string = json.dumps( { "kernel": self.jit_kernel, "forward_config": vars(self.forward_schedule.launch_config), @@ -56,13 +56,15 @@ def _init_class(self): "kernel_prop": self.kernelProp, } ) + + self.kernel= string_to_tensor(kernel_string) self.hash = self.kernel.__hash__() logger.info(f"Kernel File Size: {len(self.jit_kernel) // 1024} KB") self.weight_numel = self.input_args["problem"].weight_numel - self._setup_notorchbind() if (not extlib.TORCH_COMPILE) or self.input_args["use_opaque"]: + print(extlib.TORCH_COMPILE_ERROR) self.forward = self.forward_opaque def to(self, *args, **kwargs): @@ -119,7 +121,7 @@ def forward( :return: Tensor of shape ``[batch_size, problem.irreps_out.dim()]``, datatype ``problem.irrep_dtype``. """ - return torch.ops.libtorch_tp_jit.jit_tp_forward(self.internal, x, y, W) + return torch.ops.libtorch_tp_jit.jit_tp_forward(self.kernel, self.hash, x, y, W) @classmethod diff --git a/openequivariance/openequivariance/_torch/TensorProductConv.py b/openequivariance/openequivariance/_torch/TensorProductConv.py index f30c943c..ddce986d 100644 --- a/openequivariance/openequivariance/_torch/TensorProductConv.py +++ b/openequivariance/openequivariance/_torch/TensorProductConv.py @@ -439,61 +439,6 @@ def register_torch_fakes(cls): global torch import torch - @torch._library.register_fake_class("libtorch_tp_jit::TorchJITConv") - class TorchJITConv: - def __init__( - self, - kernel_plaintext: str, - fwd_config: dict[str, int], - bwd_config: dict[str, int], - dbl_bwd_config: dict[str, int], - kernel_dims: dict[str, int], - ) -> None: - ( - self.kernel_plaintext, - self.fwd_config, - self.bwd_config, - self.dbl_bwd_config, - self.kernel_dims, - ) = ( - kernel_plaintext, - fwd_config, - bwd_config, - dbl_bwd_config, - kernel_dims, - ) - - @classmethod - def __obj_unflatten__(cls, flattened_product): - return cls(**dict(flattened_product)) - - def __len__(self): - return 0 - - def __setstate__(self, state): - ( - self.kernel_plaintext, - self.fwd_config, - self.bwd_config, - self.dbl_bwd_config, - self.kernel_dims, - ) = state - - def exec_conv_rawptrs(*args, **kwargs): - pass - - def backward_rawptrs(*args, **kwargs): - pass - - def double_backward_rawptrs(*args, **kwargs): - pass - - def L3_dim_getter(self): - return self.kernel_dims["L3_dim"] - - def irrep_dtype_getter(self): - return self.kernel_dims["irrep_dtype"] - @torch.library.register_fake("libtorch_tp_jit::jit_conv_forward") def fake_forward( jit, L1_in, L2_in, W, rows, cols, workspace_buffer, sender_perm diff --git a/openequivariance/openequivariance/_torch/extlib/__init__.py b/openequivariance/openequivariance/_torch/extlib/__init__.py index 72440872..97b6c991 100644 --- a/openequivariance/openequivariance/_torch/extlib/__init__.py +++ b/openequivariance/openequivariance/_torch/extlib/__init__.py @@ -51,7 +51,7 @@ extra_cflags = ["-O3"] generic_sources = ["generic_module.cpp"] - torch_sources = ["libtorch_tp_jit.cpp"] + torch_sources = ["libtorch_tp_jit.cpp", "json11/json11.cpp"] include_dirs, extra_link_args = (["util"], ["-Wl,--no-as-needed"]) diff --git a/openequivariance/openequivariance/_torch/utils.py b/openequivariance/openequivariance/_torch/utils.py index 7538fb27..cd5e15c4 100644 --- a/openequivariance/openequivariance/_torch/utils.py +++ b/openequivariance/openequivariance/_torch/utils.py @@ -1,4 +1,5 @@ import torch +import numpy as np from types import MappingProxyType from openequivariance.core.utils import DTypeEnum @@ -66,3 +67,8 @@ def reorder_torch(schedule, weights_in, direction, has_batch_dim): DTypeEnum.UINT8: torch.uint8, } ) + +def string_to_tensor(text: str) -> torch.Tensor: + bytes_data = text.encode('utf-8') + np_bytes = np.frombuffer(bytes_data, dtype=np.uint8) + return torch.from_numpy(np_bytes).clone() From bf431a0af06ad7498f05461462e341737642cbf6 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Tue, 3 Feb 2026 21:40:22 -0800 Subject: [PATCH 03/15] Eliminated opaque bindings, batch tests working. --- .../openequivariance/_torch/TensorProduct.py | 26 ++---- .../openequivariance/core/LoopUnrollConv.py | 18 +++- .../openequivariance/core/LoopUnrollTP.py | 20 ++++- .../openequivariance/core/utils.py | 4 + .../extension/libtorch_tp_jit.cpp | 8 -- .../openequivariance/jax/TensorProduct.py | 14 +-- .../openequivariance/jax/TensorProductConv.py | 14 +-- tests/batch_test.py | 8 -- tests/export_test.py | 87 ------------------- 9 files changed, 44 insertions(+), 155 deletions(-) diff --git a/openequivariance/openequivariance/_torch/TensorProduct.py b/openequivariance/openequivariance/_torch/TensorProduct.py index cd52b057..4b26d550 100644 --- a/openequivariance/openequivariance/_torch/TensorProduct.py +++ b/openequivariance/openequivariance/_torch/TensorProduct.py @@ -8,7 +8,6 @@ from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin import numpy as np -import json logger = getLogger() @@ -45,22 +44,7 @@ def _init_class(self): self.input_args["torch_op"], ) - kernel_string = json.dumps( - { - "kernel": self.jit_kernel, - "forward_config": vars(self.forward_schedule.launch_config), - "backward_config": vars(self.backward_schedule.launch_config), - "double_backward_config": vars( - self.double_backward_schedule.launch_config - ), - "kernel_prop": self.kernelProp, - } - ) - - self.kernel= string_to_tensor(kernel_string) - self.hash = self.kernel.__hash__() - logger.info(f"Kernel File Size: {len(self.jit_kernel) // 1024} KB") - + self.kernel= string_to_tensor(self.kernel_string) self.weight_numel = self.input_args["problem"].weight_numel if (not extlib.TORCH_COMPILE) or self.input_args["use_opaque"]: @@ -151,7 +135,7 @@ def backward(ctx, grad_output): L1_grad, L2_grad, W_grad = backward_op( ctx.kernel, ctx.hash, ctx.L1_in, ctx.L2_in, ctx.weights, grad_output ) - return None, L1_grad, L2_grad, W_grad + return None, None, L1_grad, L2_grad, W_grad torch.library.register_autograd( "libtorch_tp_jit::jit_tp_forward", backward, setup_context=setup_context @@ -164,7 +148,7 @@ def double_backward(ctx, E, F, G): result = torch.ops.libtorch_tp_jit.jit_tp_double_backward( ctx.kernel, ctx.hash, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad, E, F, G ) - return None, result[0], result[1], result[2], result[3] + return None, None, result[0], result[1], result[2], result[3] torch.library.register_autograd( "libtorch_tp_jit::jit_tp_backward", @@ -205,7 +189,7 @@ def forward_cpu( torch_L1_in = torch.tensor(L1_in, device="cuda") torch_L2_in = torch.tensor(L2_in, device="cuda") torch_weights = torch.tensor(weights_chunked, device="cuda") - torch_L3_out = self.e3nn_tp(torch_L1_in, torch_L2_in, torch_weights) + torch_L3_out = self.forward(torch_L1_in, torch_L2_in, torch_weights) L3_out[:] = torch_L3_out.numpy(force=True) @@ -220,7 +204,7 @@ def backward_cpu( torch_L2_in = torch.tensor(L2_in, requires_grad=True, device="cuda") torch_weights = torch.tensor(weights_chunked, requires_grad=True, device="cuda") - torch_out = self.e3nn_tp(torch_L1_in, torch_L2_in, torch_weights) + torch_out = self.forward(torch_L1_in, torch_L2_in, torch_weights) torch_L3_grad_in = torch.tensor(L3_grad, device="cuda") diff --git a/openequivariance/openequivariance/core/LoopUnrollConv.py b/openequivariance/openequivariance/core/LoopUnrollConv.py index 35a9bc3e..e633bd4f 100644 --- a/openequivariance/openequivariance/core/LoopUnrollConv.py +++ b/openequivariance/openequivariance/core/LoopUnrollConv.py @@ -1,4 +1,5 @@ import numpy as np +import json from openequivariance.core.ConvolutionBase import ConvolutionBase from openequivariance.core.ComputationSchedule import ( @@ -8,8 +9,7 @@ from openequivariance.core.utils import dtype_to_enum from openequivariance.templates.jinja_utils import get_jinja_environment -from openequivariance.core.utils import filter_and_analyze_problem - +from openequivariance.core.utils import filter_and_analyze_problem, dtype_to_enum, hash_str_64 class LoopUnrollConv(ConvolutionBase): def __init__( @@ -203,5 +203,15 @@ def generate_double_backward_schedule(warps_per_block): ) self.jit_kernel = postprocess_kernel(self.jit_kernel) - # with open("scratch.txt", "w") as f: - # f.write(self.jit_kernel) + self.kernel_string = json.dumps( + { + "kernel": self.jit_kernel, + "forward_config": vars(self.forward_schedule.launch_config), + "backward_config": vars(self.backward_schedule.launch_config), + "double_backward_config": vars( + self.double_backward_schedule.launch_config + ), + "kernel_prop": self.kernel_prop, + } + ) + self.hash = hash_str_64(self.kernel_string) \ No newline at end of file diff --git a/openequivariance/openequivariance/core/LoopUnrollTP.py b/openequivariance/openequivariance/core/LoopUnrollTP.py index 12ad4536..54aab56c 100644 --- a/openequivariance/openequivariance/core/LoopUnrollTP.py +++ b/openequivariance/openequivariance/core/LoopUnrollTP.py @@ -1,15 +1,18 @@ import numpy as np +import json from openequivariance.templates.jinja_utils import get_jinja_environment from openequivariance.core.ComputationSchedule import ComputationSchedule from openequivariance.core.TensorProductBase import TensorProductBase -from openequivariance.core.utils import dtype_to_enum +from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.core.utils import dtype_to_enum, hash_str_64 from openequivariance.core.utils import ( filter_and_analyze_problem, count_cg_non_zero, ) +logger = getLogger() class LoopUnrollTP(TensorProductBase): def __init__(self, config, dp, postprocess_kernel, torch_op): @@ -106,6 +109,21 @@ def generate_double_backward_schedule(warps_per_block): "idx_dtype": 0, } + self.kernel_string = json.dumps( + { + "kernel": self.jit_kernel, + "forward_config": vars(self.forward_schedule.launch_config), + "backward_config": vars(self.backward_schedule.launch_config), + "double_backward_config": vars( + self.double_backward_schedule.launch_config + ), + "kernel_prop": self.kernelProp, + } + ) + self.hash = hash_str_64(self.kernel_string) + logger.info(f"Kernel File Size: {len(self.jit_kernel) // 1024} KB") + + def calculate_flops_forward(self, batch_size: int) -> dict: if self.is_uvw: return super().calculate_flops_forward(batch_size) diff --git a/openequivariance/openequivariance/core/utils.py b/openequivariance/openequivariance/core/utils.py index 50f35bd4..5e2b1901 100644 --- a/openequivariance/openequivariance/core/utils.py +++ b/openequivariance/openequivariance/core/utils.py @@ -7,6 +7,7 @@ import json import tempfile +import hashlib from enum import IntEnum @@ -199,3 +200,6 @@ def benchmark(func, num_warmup, num_iter, mode="gpu_time", kernel_names=[]): time_millis[i] = kernel_time return time_millis + +def hash_str_64(s: str) -> int: + return int.from_bytes(hashlib.sha256(s.encode()).digest()[:7], 'big') \ No newline at end of file diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp index bfb978b9..a312201d 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp @@ -51,8 +51,6 @@ using json = json11::Json; #include #include -// --------------------- Utilities -------------------------- - torch::Dtype enum_to_torch_dtype(int64_t i){ switch(i) { case 1: return torch::kFloat; @@ -99,8 +97,6 @@ std::unordered_map parse_json_config(const json &j_obj) { return result; } -// --------------------- Compilation & Caching -------------------------- - struct KernelProp { int64_t L1_dim, L2_dim, L3_dim, weight_numel; bool shared_weights; @@ -137,7 +133,6 @@ struct KernelProp { } }; -// Global Caches std::unordered_map>, @@ -159,7 +154,6 @@ std::pair*, KernelProp> const std::lock_guard lock(mut); auto it = tp_cache.find(hash); if (it == tp_cache.end()) { - // Cache Miss: Extract String torch::Tensor cpu_tensor = json_bytes.to(torch::kCPU).contiguous(); std::string json_payload( reinterpret_cast(cpu_tensor.data_ptr()), @@ -199,7 +193,6 @@ std::pair*, KernelProp> const std::lock_guard lock(mut); auto it = conv_cache.find(hash); if (it == conv_cache.end()) { - // Cache Miss: Extract String torch::Tensor cpu_tensor = json_bytes.to(torch::kCPU).contiguous(); std::string json_payload( reinterpret_cast(cpu_tensor.data_ptr()), @@ -595,7 +588,6 @@ TORCH_LIBRARY_IMPL(libtorch_tp_jit, CUDA, m) { m.impl("jit_conv_double_backward", &jit_conv_double_backward); }; -// Define headers for the library (without implementations) TORCH_LIBRARY(libtorch_tp_jit, m) { m.def("jit_tp_forward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W) -> Tensor"); m.def("jit_tp_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad) -> (Tensor, Tensor, Tensor)"); diff --git a/openequivariance/openequivariance/jax/TensorProduct.py b/openequivariance/openequivariance/jax/TensorProduct.py index bf7a1445..4f273b8a 100644 --- a/openequivariance/openequivariance/jax/TensorProduct.py +++ b/openequivariance/openequivariance/jax/TensorProduct.py @@ -19,19 +19,7 @@ def __init__(self, problem: TPProblem): dp = extlib.DeviceProp(0) super().__init__(problem, dp, extlib.postprocess_kernel, torch_op=False) - self.kernel = json.dumps( - { - "kernel": self.jit_kernel, - "forward_config": vars(self.forward_schedule.launch_config), - "backward_config": vars(self.backward_schedule.launch_config), - "double_backward_config": vars( - self.double_backward_schedule.launch_config - ), - "kernel_prop": self.kernelProp, - } - ) - self.hash = self.kernel.__hash__() - + self.kernel = self.kernel_string self.weight_numel = problem.weight_numel self.L3_dim = self.config.irreps_out.dim diff --git a/openequivariance/openequivariance/jax/TensorProductConv.py b/openequivariance/openequivariance/jax/TensorProductConv.py index 52102294..84fc8a30 100644 --- a/openequivariance/openequivariance/jax/TensorProductConv.py +++ b/openequivariance/openequivariance/jax/TensorProductConv.py @@ -49,19 +49,7 @@ def __init__( kahan=kahan, ) - self.kernel = json.dumps( - { - "kernel": self.jit_kernel, - "forward_config": vars(self.forward_schedule.launch_config), - "backward_config": vars(self.backward_schedule.launch_config), - "double_backward_config": vars( - self.double_backward_schedule.launch_config - ), - "kernel_prop": self.kernel_prop, - } - ) - self.hash = self.kernel.__hash__() - + self.kernel = self.kernel_string self.weight_numel = config.weight_numel self.L3_dim = self.config.irreps_out.dim diff --git a/tests/batch_test.py b/tests/batch_test.py index f32f7b51..788950ab 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -253,14 +253,6 @@ def problem(self, request, dtype): return problem -class TestTorchbindDisable(TestProductionModels): - @pytest.fixture(scope="class") - def extra_tp_constructor_args(self, with_jax): - if with_jax: - pytest.skip("N/A for JAX") - return {"use_opaque": True} - - class TestTorchTo(TPCorrectness): problems = [mace_problems()[0]] diff --git a/tests/export_test.py b/tests/export_test.py index 0fd23b2b..32cf4ab8 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -98,20 +98,6 @@ def test_torch_load(tp_and_inputs): assert torch.allclose(original_result, reloaded_result, atol=1e-5) -def test_jitscript(tp_and_inputs): - tp, inputs = tp_and_inputs - uncompiled_result = tp.forward(*inputs) - - scripted_tp = torch.jit.script(tp) - loaded_tp = None - with tempfile.NamedTemporaryFile(suffix=".pt") as tmp_file: - scripted_tp.save(tmp_file.name) - loaded_tp = torch.jit.load(tmp_file.name) - - compiled_result = loaded_tp(*inputs) - assert torch.allclose(uncompiled_result, compiled_result, atol=1e-5) - - def test_compile(tp_and_inputs): tp, inputs = tp_and_inputs uncompiled_result = tp.forward(*inputs) @@ -154,76 +140,3 @@ def test_aoti(tp_and_inputs): aoti_result = aoti_model(*inputs) assert torch.allclose(uncompiled_result, aoti_result, atol=1e-5) - - -def test_jitscript_cpp_interface(problem_and_irreps): - assert oeq.LINKED_LIBPYTHON, oeq.LINKED_LIBPYTHON_ERROR - problem, X_ir, Y_ir, _ = problem_and_irreps - cmake_prefix_path = torch.utils.cmake_prefix_path - torch_ext_so_path = oeq.torch_ext_so_path() - - oeq_tp = oeq.TensorProduct(problem).to("cuda") - scripted_oeq = torch.jit.script(oeq_tp) - - e3nn_tp = E3NNTensorProduct(problem).e3nn_tp.to("cuda") - scripted_e3nn = torch.jit.script(e3nn_tp) - - batch_size = 1000 - - with ( - tempfile.TemporaryDirectory() as tmpdir, - tempfile.NamedTemporaryFile(suffix=".pt") as oeq_file, - tempfile.NamedTemporaryFile(suffix=".pt") as e3nn_file, - ): - scripted_oeq.save(oeq_file.name) - scripted_e3nn.save(e3nn_file.name) - - test_path = importlib.resources.files("openequivariance") / "extension" / "test" - build_dir = os.path.join(tmpdir, "build") - os.makedirs(build_dir, exist_ok=True) - - for item in test_path.iterdir(): - shutil.copy(item, tmpdir) - - try: - subprocess.run( - [ - "cmake", - "..", - "-DCMAKE_BUILD_TYPE=Release", - "-DCMAKE_PREFIX_PATH=" + cmake_prefix_path, - "-DOEQ_EXTLIB=" + torch_ext_so_path, - ], - cwd=build_dir, - check=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - - subprocess.run( - ["make"], - cwd=build_dir, - check=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - - subprocess.run( - [ - "./load_jitscript", - e3nn_file.name, - oeq_file.name, - str(X_ir.dim), - str(Y_ir.dim), - str(problem.weight_numel), - str(batch_size), - ], - cwd=build_dir, - check=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - except subprocess.CalledProcessError as e: - print(e.stdout.decode(), file=sys.stderr) - print(e.stderr.decode(), file=sys.stderr) - assert False From c9381937bbdd43592e23e553c11875d85b11d059 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Tue, 3 Feb 2026 22:16:15 -0800 Subject: [PATCH 04/15] Ready to run convolution tests. --- .../openequivariance/_torch/TensorProduct.py | 2 +- .../_torch/TensorProductConv.py | 409 +++--------------- tests/export_test.py | 5 +- 3 files changed, 74 insertions(+), 342 deletions(-) diff --git a/openequivariance/openequivariance/_torch/TensorProduct.py b/openequivariance/openequivariance/_torch/TensorProduct.py index 4b26d550..88101746 100644 --- a/openequivariance/openequivariance/_torch/TensorProduct.py +++ b/openequivariance/openequivariance/_torch/TensorProduct.py @@ -22,7 +22,7 @@ class TensorProduct(torch.nn.Module, LoopUnrollTP, NumpyDoubleBackwardMixin): * The provided tensor product specification is unsupported. :param problem: Specification of the tensor product. - :param use_opaque: If ``True``, uses an opaque forward pass that cannot be symbolically traced. *Default*: ``False``. + :param use_opaque: This parameter is deprecated. """ def __init__(self, problem: TPProblem, torch_op=True, use_opaque=False): diff --git a/openequivariance/openequivariance/_torch/TensorProductConv.py b/openequivariance/openequivariance/_torch/TensorProductConv.py index ddce986d..71350176 100644 --- a/openequivariance/openequivariance/_torch/TensorProductConv.py +++ b/openequivariance/openequivariance/_torch/TensorProductConv.py @@ -1,11 +1,10 @@ -from typing import Optional, List +from typing import Optional import numpy as np import torch import openequivariance._torch.extlib as extlib from openequivariance._torch.extlib import ( - JITConvImpl, postprocess_kernel, DeviceProp, ) @@ -18,12 +17,10 @@ from openequivariance._torch.TensorProduct import TensorProduct from openequivariance import TPProblem from openequivariance.core.utils import torch_to_oeq_dtype -from openequivariance._torch.utils import enum_to_torch_dtype -from openequivariance._torch.utils import reorder_torch +from openequivariance._torch.utils import enum_to_torch_dtype, reorder_torch, string_to_tensor from openequivariance.benchmark.logging_utils import getLogger from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixinConv -from openequivariance._torch.extlib import DeviceBuffer logger = getLogger() @@ -49,7 +46,7 @@ class TensorProductConv(torch.nn.Module, LoopUnrollConv, NumpyDoubleBackwardMixi fixup-based algorithm. `Default`: ``False``. :param kahan: If ``True``, uses Kahan summation to improve accuracy during aggregation. To use this option, the input tensors must be in float32 precision AND you must set ``deterministic=True``. *Default*: ``False``. - :param use_opaque: If ``True``, uses an opaque forward pass that cannot be symbolically traced. *Default*: ``False``. + :param use_opaque: This parameter is deprecated. """ def __init__( @@ -85,27 +82,10 @@ def _init_class(self): ) self.allocate_workspace(self.workspace_size) - if extlib.TORCH_COMPILE: - internal_cls = torch.classes.libtorch_tp_jit.TorchJITConv - else: - internal_cls = JITConvImpl - - logger.info("Starting kernel compiler...") - self.internal = internal_cls( - self.jit_kernel, - vars(self.forward_schedule.launch_config), - vars(self.backward_schedule.launch_config), - vars(self.double_backward_schedule.launch_config), - self.kernel_prop, - ) - logger.info("Kernel compiled!") - self.dummy_transpose_perm = torch.zeros(1, dtype=torch.int64, device="cuda") self.weight_numel = self.config.weight_numel - self._setup_notorchbind() + self.kernel= string_to_tensor(self.kernel_string) - if (not extlib.TORCH_COMPILE) or self.input_args["use_opaque"]: - self.forward = self.forward_opaque def to(self, *args, **kwargs): r""" @@ -163,27 +143,19 @@ def forward( :return: Tensor of shape ``[|V|, problem.irreps_out.dim()]``, datatype ``problem.irrep_dtype``. """ if sender_perm is None: - return torch.ops.libtorch_tp_jit.jit_conv_forward( - self.internal, - X, - Y, - W, - rows, - cols, - self.workspace_buffer, - self.dummy_transpose_perm, - ) - else: - return torch.ops.libtorch_tp_jit.jit_conv_forward( - self.internal, - X, - Y, - W, - rows, - cols, - self.workspace_buffer, - sender_perm, - ) + sender_perm = self.dummy_transpose_perm + + return torch.ops.libtorch_tp_jit.jit_conv_forward( + self.kernel, + self.hash, + X, + Y, + W, + rows, + cols, + self.workspace_buffer, + sender_perm, + ) def allocate_workspace(self, size_bytes): self.workspace_size = size_bytes @@ -196,230 +168,6 @@ def allocate_workspace(self, size_bytes): self.workspace_ptr = self.workspace_buffer.data_ptr() logger.info(f"Convolution requires {size_bytes // 1000000}MB of workspace.") - def _setup_notorchbind(self): - @torch.library.custom_op( - f"openequivariance::conv_forward{self.conv_id}", - mutates_args=(), - device_types="cuda", - ) - def forward( - L1_in: torch.Tensor, - L2_in: torch.Tensor, - weights: torch.Tensor, - rows: torch.Tensor, - cols: torch.Tensor, - transpose_perm: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - L1_in_c, L2_in_c, weights_c = ( - L1_in.contiguous(), - L2_in.contiguous(), - weights.contiguous(), - ) - L3_out = torch.zeros( - (L1_in_c.shape[0], self.L3.dim), dtype=L1_in.dtype, device="cuda" - ) - - self.internal.exec_conv_rawptrs( - L1_in_c.data_ptr(), - L2_in_c.data_ptr(), - weights_c.data_ptr(), - L3_out.data_ptr(), - rows.contiguous().data_ptr(), - cols.contiguous().data_ptr(), - rows.shape[0], - L1_in.shape[0], - self.workspace_ptr, - ) - - return L3_out - - @forward.register_fake - def _(L1_in, L2_in, weights, rows, cols, transpose_perm=None): - return L1_in.new_empty(L1_in.shape[0], self.L3.dim) - - self.forward_opaque = forward - - @torch.library.custom_op( - f"openequivariance::conv_backward{self.conv_id}", - mutates_args=(), - device_types="cuda", - ) - def backward_helper( - L1_in: torch.Tensor, - L2_in: torch.Tensor, - weights: torch.Tensor, - L3_grad: torch.Tensor, - rows: torch.Tensor, - cols: torch.Tensor, - transpose_perm: Optional[torch.Tensor] = None, - ) -> List[torch.Tensor]: - L1_grad = torch.zeros_like(L1_in) - L2_grad = torch.zeros_like(L2_in) - weights_grad = torch.empty_like(weights) - - if self.config.shared_weights: - weights_grad[:] = 0.0 - - transpose_perm_ptr = 0 - if transpose_perm is not None: - transpose_perm_ptr = transpose_perm.data_ptr() - - self.internal.backward_rawptrs( - L1_in.contiguous().data_ptr(), - L1_grad.data_ptr(), - L2_in.contiguous().data_ptr(), - L2_grad.data_ptr(), - weights.contiguous().data_ptr(), - weights_grad.data_ptr(), - L3_grad.contiguous().data_ptr(), - rows.contiguous().data_ptr(), - cols.contiguous().data_ptr(), - rows.shape[0], - L1_in.shape[0], - self.workspace_ptr, - transpose_perm_ptr, - ) - - return [L1_grad, L2_grad, weights_grad] - - @backward_helper.register_fake - def _(L1_in, L2_in, weights, L3_grad, rows, cols, transpose_perm=None): - return [ - L1_in.new_empty(*L1_in.shape), - L2_in.new_empty(*L2_in.shape), - weights.new_empty(*weights.shape), - ] - - def setup_context(ctx, inputs, output): - ( - ctx.L1_in, - ctx.L2_in, - ctx.weights, - ctx.rows, - ctx.cols, - ctx.transpose_perm, - ) = inputs - - def backward(ctx, grad_output): - result = backward_helper( - ctx.L1_in, - ctx.L2_in, - ctx.weights, - grad_output, - ctx.rows, - ctx.cols, - ctx.transpose_perm, - ) - return result[0], result[1], result[2], None, None, None - - self.forward_opaque.register_autograd(backward, setup_context=setup_context) - - def setup_context_double_backward(ctx, inputs, output): - ( - ctx.L1_in, - ctx.L2_in, - ctx.weights, - ctx.L3_grad, - ctx.rows, - ctx.cols, - ctx.transpose_perm, - ) = inputs - - @torch.library.custom_op( - f"openequivariance::conv_double_backward{self.conv_id}", - mutates_args=(), - device_types="cuda", - ) - def double_backward_helper( - L1_in: torch.Tensor, - L2_in: torch.Tensor, - W: torch.Tensor, - L3_grad: torch.Tensor, - L1_dgrad: torch.Tensor, - L2_dgrad: torch.Tensor, - w_dgrad: torch.Tensor, - rows: torch.Tensor, - cols: torch.Tensor, - transpose_perm: Optional[torch.Tensor] = None, - ) -> List[torch.Tensor]: - L1_grad = torch.zeros_like(L1_in) - L2_grad = torch.zeros_like(L2_in) - W_grad = torch.empty_like(W) - L3_dgrad = torch.zeros_like(L3_grad) - - if self.config.shared_weights: - W_grad[:] = 0.0 - - transpose_perm_ptr = 0 - if transpose_perm is not None: - transpose_perm_ptr = transpose_perm.data_ptr() - - self.internal.double_backward_rawptrs( - L1_in.contiguous().data_ptr(), - L2_in.contiguous().data_ptr(), - W.contiguous().data_ptr(), - L3_grad.contiguous().data_ptr(), - L1_dgrad.contiguous().data_ptr(), - L2_dgrad.contiguous().data_ptr(), - w_dgrad.contiguous().data_ptr(), - L1_grad.data_ptr(), - L2_grad.data_ptr(), - W_grad.data_ptr(), - L3_dgrad.data_ptr(), - rows.contiguous().data_ptr(), - cols.contiguous().data_ptr(), - rows.shape[0], - L1_in.shape[0], - self.workspace_ptr, - transpose_perm_ptr, - ) - return [L1_grad, L2_grad, W_grad, L3_dgrad] - - @double_backward_helper.register_fake - def _( - L1_in, - L2_in, - W, - L3_grad, - L1_dgrad, - L2_dgrad, - w_dgrad, - rows, - cols, - transpose_perm=None, - ): - return [ - L1_in.new_empty(*L1_in.shape), - L2_in.new_empty(*L2_in.shape), - W.new_empty(*W.shape), - L3_grad.new_empty(*L3_grad.shape), - ] - - def double_backward(ctx, grad_output): - L1_dgrad, L2_dgrad, w_dgrad = grad_output[0], grad_output[1], grad_output[2] - - L1_grad, L2_grad, W_grad, L3_dgrad = double_backward_helper( - ctx.L1_in, - ctx.L2_in, - ctx.weights, - ctx.L3_grad, - L1_dgrad, - L2_dgrad, - w_dgrad, - ctx.rows, - ctx.cols, - ctx.transpose_perm, - ) - - if ctx.transpose_perm is None: - return L1_grad, L2_grad, W_grad, L3_dgrad, None, None - else: - return L1_grad, L2_grad, W_grad, L3_dgrad, None, None, None - - backward_helper.register_autograd( - double_backward, setup_context=setup_context_double_backward - ) - def reorder_weights_from_e3nn(self, weights, has_batch_dim=True): return reorder_torch( self.forward_schedule, weights, "forward", not self.config.shared_weights @@ -441,7 +189,7 @@ def register_torch_fakes(cls): @torch.library.register_fake("libtorch_tp_jit::jit_conv_forward") def fake_forward( - jit, L1_in, L2_in, W, rows, cols, workspace_buffer, sender_perm + kernel, hash, L1_in, L2_in, W, rows, cols, workspace_buffer, sender_perm ): L3_dim, irrep_dtype = None, None if hasattr(jit, "wrapped_obj"): @@ -460,13 +208,13 @@ def fake_forward( @torch.library.register_fake("libtorch_tp_jit::jit_conv_backward") def fake_backward( - jit, L1_in, L2_in, W, L3_grad, rows, cols, workspace_buffer, sender_perm + kernel, hash, L1_in, L2_in, W, L3_grad, rows, cols, workspace_buffer, sender_perm ): return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W) @torch.library.register_fake("libtorch_tp_jit::jit_conv_double_backward") def fake_double_backward( - jit, + kernel, hash, L1_in, L2_in, W, @@ -493,7 +241,8 @@ def register_autograd(cls): def setup_context(ctx, inputs, output): ( - ctx.jit, + ctx.kernel, + ctx.hash, ctx.L1_in, ctx.L2_in, ctx.W, @@ -505,7 +254,8 @@ def setup_context(ctx, inputs, output): def backward(ctx, grad_output): L1_grad, L2_grad, W_grad = backward_op( - ctx.jit, + ctx.kernel, + ctx.hash, ctx.L1_in, ctx.L2_in, ctx.W, @@ -515,7 +265,7 @@ def backward(ctx, grad_output): ctx.workspace_buffer, ctx.sender_perm, ) - return None, L1_grad, L2_grad, W_grad, None, None, None, None + return None, None, L1_grad, L2_grad, W_grad, None, None, None, None torch.library.register_autograd( "libtorch_tp_jit::jit_conv_forward", backward, setup_context=setup_context @@ -523,7 +273,8 @@ def backward(ctx, grad_output): def setup_context_double_backward(ctx, inputs, output): ( - ctx.jit, + ctx.kernel, + ctx.hash, ctx.L1_in, ctx.L2_in, ctx.W, @@ -537,7 +288,8 @@ def setup_context_double_backward(ctx, inputs, output): def double_backward(ctx, E, F, G): result = double_backward_op( - ctx.jit, + ctx.kernel, + ctx.hash, ctx.L1_in, ctx.L2_in, ctx.W, @@ -551,6 +303,7 @@ def double_backward(ctx, E, F, G): ctx.sender_perm, ) return ( + None, None, result[0], result[1], @@ -570,9 +323,6 @@ def double_backward(ctx, E, F, G): @classmethod def register_autocast(cls): - global torch - import torch - torch.library.register_autocast( "libtorch_tp_jit::jit_conv_forward", "cuda", torch.float32 ) @@ -591,29 +341,26 @@ def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph): weights, not self.config.shared_weights ) - L1_d, L2_d, weights_d = ( - DeviceBuffer(L1_in), - DeviceBuffer(L2_in), - DeviceBuffer(weights_chunked), - ) - L3_d = DeviceBuffer(L3_out) - - rows_d = DeviceBuffer(graph.rows) - cols_d = DeviceBuffer(graph.cols) - - self.internal.exec_conv_rawptrs( - L1_d.data_ptr(), - L2_d.data_ptr(), - weights_d.data_ptr(), - L3_d.data_ptr(), - rows_d.data_ptr(), - cols_d.data_ptr(), - graph.nnz, - graph.node_count, - self.workspace_ptr, - ) + torch_L1_in = torch.tensor(L1_in, device="cuda") + torch_L2_in = torch.tensor(L2_in, device="cuda") + torch_weights = torch.tensor(weights_chunked, device="cuda") + torch_rows = torch.tensor(graph.rows, device="cuda") + torch_cols = torch.tensor(graph.cols, device="cuda") - L3_d.copy_to_host() + if self.deterministic: + torch_sender_perm = torch.tensor(graph.sender_perm, device="cuda") + else: + torch_sender_perm = None + + result = self.forward( + torch_L1_in, + torch_L2_in, + torch_weights, + torch_rows, + torch_cols, + torch_sender_perm, + ) + L3_out[:] = result.numpy(force=True) def backward_cpu( self, L1_in, L1_grad, L2_in, L2_grad, weights, weights_grad, L3_grad, graph @@ -625,42 +372,32 @@ def backward_cpu( weights, not self.config.shared_weights ) - L1_d = DeviceBuffer(L1_in) - L2_d = DeviceBuffer(L2_in) - weights_d = DeviceBuffer(weights_chunked) - L3_d = DeviceBuffer(L3_grad) - rows_d = DeviceBuffer(graph.rows) - cols_d = DeviceBuffer(graph.cols) - - L1_grad_d = DeviceBuffer(L1_grad) - L2_grad_d = DeviceBuffer(L2_grad) - weights_grad_d = DeviceBuffer(weights_grad) + torch_L1_in = torch.tensor(L1_in, requires_grad=True, device="cuda") + torch_L2_in = torch.tensor(L2_in, requires_grad=True, device="cuda") + torch_weights = torch.tensor( + weights_chunked, requires_grad=True, device="cuda" + ) + torch_L3_grad = torch.tensor(L3_grad, device="cuda") + torch_rows = torch.tensor(graph.rows, device="cuda") + torch_cols = torch.tensor(graph.cols, device="cuda") - transpose_perm_d = None - transpose_perm_ptr = 0 if self.deterministic: - transpose_perm_d = DeviceBuffer(graph.transpose_perm) - transpose_perm_ptr = transpose_perm_d.data_ptr() - - self.internal.backward_rawptrs( - L1_d.data_ptr(), - L1_grad_d.data_ptr(), - L2_d.data_ptr(), - L2_grad_d.data_ptr(), - weights_d.data_ptr(), - weights_grad_d.data_ptr(), - L3_d.data_ptr(), - rows_d.data_ptr(), - cols_d.data_ptr(), - graph.nnz, - graph.node_count, - self.workspace_ptr, - transpose_perm_ptr, + torch_sender_perm = torch.tensor(graph.sender_perm, device="cuda") + else: + torch_sender_perm = None + + torch_out = self.forward( + torch_L1_in, + torch_L2_in, + torch_weights, + torch_rows, + torch_cols, + torch_sender_perm, ) - - L1_grad_d.copy_to_host() - L2_grad_d.copy_to_host() - weights_grad_d.copy_to_host() + torch_out.backward(gradient=torch_L3_grad) + L1_grad[:] = torch_L1_in.grad.numpy(force=True) + L2_grad[:] = torch_L2_in.grad.numpy(force=True) + weights_grad[:] = torch_weights.grad.numpy(force=True) weights_grad[:] = self.reorder_weights_to_e3nn( weights_grad, not self.config.shared_weights @@ -709,8 +446,6 @@ def name(): class TensorProductConvScatterSum(ConvolutionBase): def __init__(self, config, *, torch_op=True): assert torch_op - global torch - import torch super().__init__(config, torch_op=torch_op, deterministic=False) diff --git a/tests/export_test.py b/tests/export_test.py index 32cf4ab8..0e868ee3 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -129,10 +129,7 @@ def test_aoti(tp_and_inputs): ) except Exception as e: err_msg = ( - "AOTI compile_and_package failed. NOTE: OpenEquivariance only supports AOTI for " - + "PyTorch version >= 2.8.0.dev20250410+cu126 due to incomplete TorchBind support " - + "in prior versions. " - + f"{e}" + f"AOTI compile_and_package failed. Error: {e}" ) assert False, err_msg From 8d75dfdad4b1e55c364c0d82250403578e440cd8 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Tue, 3 Feb 2026 23:03:24 -0800 Subject: [PATCH 05/15] Ready to fix fakes. --- openequivariance/openequivariance/_torch/TensorProduct.py | 3 --- openequivariance/openequivariance/_torch/TensorProductConv.py | 4 ++-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/openequivariance/openequivariance/_torch/TensorProduct.py b/openequivariance/openequivariance/_torch/TensorProduct.py index 88101746..d1336b1a 100644 --- a/openequivariance/openequivariance/_torch/TensorProduct.py +++ b/openequivariance/openequivariance/_torch/TensorProduct.py @@ -47,9 +47,6 @@ def _init_class(self): self.kernel= string_to_tensor(self.kernel_string) self.weight_numel = self.input_args["problem"].weight_numel - if (not extlib.TORCH_COMPILE) or self.input_args["use_opaque"]: - print(extlib.TORCH_COMPILE_ERROR) - self.forward = self.forward_opaque def to(self, *args, **kwargs): r""" diff --git a/openequivariance/openequivariance/_torch/TensorProductConv.py b/openequivariance/openequivariance/_torch/TensorProductConv.py index 71350176..fd2bacee 100644 --- a/openequivariance/openequivariance/_torch/TensorProductConv.py +++ b/openequivariance/openequivariance/_torch/TensorProductConv.py @@ -348,7 +348,7 @@ def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph): torch_cols = torch.tensor(graph.cols, device="cuda") if self.deterministic: - torch_sender_perm = torch.tensor(graph.sender_perm, device="cuda") + torch_sender_perm = torch.tensor(graph.transpose_perm, device="cuda") else: torch_sender_perm = None @@ -382,7 +382,7 @@ def backward_cpu( torch_cols = torch.tensor(graph.cols, device="cuda") if self.deterministic: - torch_sender_perm = torch.tensor(graph.sender_perm, device="cuda") + torch_sender_perm = torch.tensor(graph.transpose_perm, device="cuda") else: torch_sender_perm = None From 6f3adfae06b3c989ff49bd2b388dee3fee57fe3e Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Tue, 3 Feb 2026 23:06:42 -0800 Subject: [PATCH 06/15] DeviceBuffer is gone. --- .../_torch/TensorProductConv.py | 9 ++-- .../_torch/extlib/__init__.py | 4 -- .../extension/generic_module.cpp | 7 --- .../extension/libtorch_tp_jit.cpp | 1 - .../extension/util/buffer.hpp | 45 ------------------- openequivariance_extjax/CMakeLists.txt | 1 - 6 files changed, 3 insertions(+), 64 deletions(-) delete mode 100644 openequivariance/openequivariance/extension/util/buffer.hpp diff --git a/openequivariance/openequivariance/_torch/TensorProductConv.py b/openequivariance/openequivariance/_torch/TensorProductConv.py index fd2bacee..c67ca3e9 100644 --- a/openequivariance/openequivariance/_torch/TensorProductConv.py +++ b/openequivariance/openequivariance/_torch/TensorProductConv.py @@ -159,12 +159,9 @@ def forward( def allocate_workspace(self, size_bytes): self.workspace_size = size_bytes - if self.torch_op: - self.workspace_buffer = torch.zeros( - size_bytes, dtype=torch.uint8, device="cuda" - ) - else: - self.workspace_buffer = extlib.DeviceBuffer(size_bytes) + self.workspace_buffer = torch.zeros( + size_bytes, dtype=torch.uint8, device="cuda" + ) self.workspace_ptr = self.workspace_buffer.data_ptr() logger.info(f"Convolution requires {size_bytes // 1000000}MB of workspace.") diff --git a/openequivariance/openequivariance/_torch/extlib/__init__.py b/openequivariance/openequivariance/_torch/extlib/__init__.py index 97b6c991..38bc211f 100644 --- a/openequivariance/openequivariance/_torch/extlib/__init__.py +++ b/openequivariance/openequivariance/_torch/extlib/__init__.py @@ -154,7 +154,6 @@ def torch_ext_so_path(): GroupMM_F32, GroupMM_F64, DeviceProp, - DeviceBuffer, GPUTimer, ) else: @@ -174,8 +173,5 @@ def GroupMM_F64(*args, **kwargs): def DeviceProp(*args, **kwargs): _raise_import_error_helper("DeviceProp") - def DeviceBuffer(*args, **kwargs): - _raise_import_error_helper("DeviceBuffer") - def GPUTimer(*args, **kwargs): _raise_import_error_helper("GPUTimer") diff --git a/openequivariance/openequivariance/extension/generic_module.cpp b/openequivariance/openequivariance/extension/generic_module.cpp index fc94eec9..2b83414e 100644 --- a/openequivariance/openequivariance/extension/generic_module.cpp +++ b/openequivariance/openequivariance/extension/generic_module.cpp @@ -24,7 +24,6 @@ using GroupMM = GroupMMHIP; #endif -#include "buffer.hpp" #include "tensorproducts.hpp" #include "convolution.hpp" @@ -68,12 +67,6 @@ PYBIND11_MODULE(generic_module, m) { .def_readonly("multiprocessorCount", &DeviceProp::multiprocessorCount) .def_readonly("maxSharedMemPerBlock", &DeviceProp::maxSharedMemPerBlock); - py::class_>(m, "DeviceBuffer") - .def(py::init()) - .def(py::init()) - .def("copy_to_host", &PyDeviceBuffer::copy_to_host) - .def("data_ptr", &PyDeviceBuffer::data_ptr); - py::class_(m, "GPUTimer") .def(py::init<>()) .def("start", &GPUTimer::start) diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp index a312201d..b0edf170 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp @@ -38,7 +38,6 @@ } #endif -#include "buffer.hpp" #include "tensorproducts.hpp" #include "convolution.hpp" diff --git a/openequivariance/openequivariance/extension/util/buffer.hpp b/openequivariance/openequivariance/extension/util/buffer.hpp deleted file mode 100644 index 95dc8319..00000000 --- a/openequivariance/openequivariance/extension/util/buffer.hpp +++ /dev/null @@ -1,45 +0,0 @@ -#pragma once -#include -#include - -using namespace std; -namespace py = pybind11; - -template -class PyDeviceBuffer { -public: - char* host_ptr; - char* device_ptr; - size_t size; - - PyDeviceBuffer(uint64_t size) { - this->size = size; - device_ptr = static_cast(ALLOC_T::gpu_alloc(size)); - host_ptr = nullptr; - } - - PyDeviceBuffer(py::buffer host_data) { - const py::buffer_info &info = host_data.request(); - host_ptr = static_cast(info.ptr); - size = 1; - for(int64_t i = 0; i < info.ndim; i++) { - size *= info.shape[i]; - } - size *= info.itemsize; - - device_ptr = static_cast(ALLOC_T::gpu_alloc(size)); - ALLOC_T::copy_host_to_device(host_ptr, device_ptr, size); - } - - ~PyDeviceBuffer() { - ALLOC_T::gpu_free(static_cast(device_ptr)); - } - - void copy_to_host() { - ALLOC_T::copy_device_to_host(host_ptr, device_ptr, size); - } - - uint64_t data_ptr() { - return reinterpret_cast(device_ptr); - } -}; \ No newline at end of file diff --git a/openequivariance_extjax/CMakeLists.txt b/openequivariance_extjax/CMakeLists.txt index 91617d94..90eafe6c 100644 --- a/openequivariance_extjax/CMakeLists.txt +++ b/openequivariance_extjax/CMakeLists.txt @@ -60,7 +60,6 @@ set(OEQ_JAX_HEADERS ${HEADER_DIR}/tensorproducts.hpp ${HEADER_DIR}/util/backend_cuda.hpp ${HEADER_DIR}/util/backend_hip.hpp - ${HEADER_DIR}/util/buffer.hpp ${HEADER_DIR}/json11/json11.hpp ) From 1650d652c59d587e51977f516a067e1f45ab33ad Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Tue, 3 Feb 2026 23:29:48 -0800 Subject: [PATCH 07/15] Fixed the fakes. --- .../openequivariance/_torch/TensorProduct.py | 138 ++++---- .../_torch/TensorProductConv.py | 304 +++++++++--------- .../openequivariance/core/LoopUnrollTP.py | 4 +- 3 files changed, 222 insertions(+), 224 deletions(-) diff --git a/openequivariance/openequivariance/_torch/TensorProduct.py b/openequivariance/openequivariance/_torch/TensorProduct.py index d1336b1a..bce42bdd 100644 --- a/openequivariance/openequivariance/_torch/TensorProduct.py +++ b/openequivariance/openequivariance/_torch/TensorProduct.py @@ -8,6 +8,7 @@ from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin import numpy as np +import json logger = getLogger() @@ -105,69 +106,6 @@ def forward( return torch.ops.libtorch_tp_jit.jit_tp_forward(self.kernel, self.hash, x, y, W) - @classmethod - def register_torch_fakes(cls): - @torch.library.register_fake("libtorch_tp_jit::jit_tp_forward") - def fake_forward(jit, L1_in, L2_in, W): - L3_dim = None - if hasattr(jit, "wrapped_obj"): - L3_dim = jit.wrapped_obj.kernel_dims["L3_dim"] - else: - L3_dim = jit.L3_dim - - return L1_in.new_empty(L1_in.shape[0], L3_dim) - - @torch.library.register_fake("libtorch_tp_jit::jit_tp_backward") - def fake_backward(jit, L1_in, L2_in, W, L3_grad): - return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W) - - @classmethod - def register_autograd(cls): - backward_op = torch.ops.libtorch_tp_jit.jit_tp_backward - - def setup_context(ctx, inputs, output): - ctx.kernel, ctx.hash, ctx.L1_in, ctx.L2_in, ctx.weights = inputs - - def backward(ctx, grad_output): - L1_grad, L2_grad, W_grad = backward_op( - ctx.kernel, ctx.hash, ctx.L1_in, ctx.L2_in, ctx.weights, grad_output - ) - return None, None, L1_grad, L2_grad, W_grad - - torch.library.register_autograd( - "libtorch_tp_jit::jit_tp_forward", backward, setup_context=setup_context - ) - - def setup_context_double_backward(ctx, inputs, output): - ctx.kernel, ctx.hash, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad = inputs - - def double_backward(ctx, E, F, G): - result = torch.ops.libtorch_tp_jit.jit_tp_double_backward( - ctx.kernel, ctx.hash, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad, E, F, G - ) - return None, None, result[0], result[1], result[2], result[3] - - torch.library.register_autograd( - "libtorch_tp_jit::jit_tp_backward", - double_backward, - setup_context=setup_context_double_backward, - ) - - @classmethod - def register_autocast(cls): - global torch - import torch - - torch.library.register_autocast( - "libtorch_tp_jit::jit_tp_forward", "cuda", torch.float32 - ) - torch.library.register_autocast( - "libtorch_tp_jit::jit_tp_backward", "cuda", torch.float32 - ) - torch.library.register_autocast( - "libtorch_tp_jit::jit_tp_double_backward", "cuda", torch.float32 - ) - @staticmethod def name(): return "LoopUnrollTP" @@ -216,7 +154,73 @@ def backward_cpu( ) -if extlib.TORCH_COMPILE: - TensorProduct.register_torch_fakes() - TensorProduct.register_autograd() - TensorProduct.register_autocast() +def register_torch_fakes(): + @torch.library.register_fake("libtorch_tp_jit::jit_tp_forward") + def fake_forward(kernel, hash, L1_in, L2_in, W): + info = json.loads(kernel) + L3_dim = info["kernel_prop"]["L3_dim"] + + return L1_in.new_empty(L1_in.shape[0], L3_dim) + + @torch.library.register_fake("libtorch_tp_jit::jit_tp_backward") + def fake_backward(kernel, hash, L1_in, L2_in, W, L3_grad): + return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W) + + @torch.library.register_fake("libtorch_tp_jit::jit_tp_double_backward") + def fake_double_backward(kernel, hash, L1_in, L2_in, W, L3_grad, E, F, G): + return ( + torch.empty_like(L1_in), + torch.empty_like(L2_in), + torch.empty_like(W), + torch.empty_like(L3_grad), + ) + + +def register_autograd(): + backward_op = torch.ops.libtorch_tp_jit.jit_tp_backward + + def setup_context(ctx, inputs, output): + ctx.kernel, ctx.hash, ctx.L1_in, ctx.L2_in, ctx.weights = inputs + + def backward(ctx, grad_output): + L1_grad, L2_grad, W_grad = backward_op( + ctx.kernel, ctx.hash, ctx.L1_in, ctx.L2_in, ctx.weights, grad_output + ) + return None, None, L1_grad, L2_grad, W_grad + + torch.library.register_autograd( + "libtorch_tp_jit::jit_tp_forward", backward, setup_context=setup_context + ) + + def setup_context_double_backward(ctx, inputs, output): + ctx.kernel, ctx.hash, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad = inputs + + def double_backward(ctx, E, F, G): + result = torch.ops.libtorch_tp_jit.jit_tp_double_backward( + ctx.kernel, ctx.hash, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad, E, F, G + ) + return None, None, result[0], result[1], result[2], result[3] + + torch.library.register_autograd( + "libtorch_tp_jit::jit_tp_backward", + double_backward, + setup_context=setup_context_double_backward, + ) + +def register_autocast(): + global torch + import torch + + torch.library.register_autocast( + "libtorch_tp_jit::jit_tp_forward", "cuda", torch.float32 + ) + torch.library.register_autocast( + "libtorch_tp_jit::jit_tp_backward", "cuda", torch.float32 + ) + torch.library.register_autocast( + "libtorch_tp_jit::jit_tp_double_backward", "cuda", torch.float32 + ) + +register_torch_fakes() +register_autograd() +register_autocast() \ No newline at end of file diff --git a/openequivariance/openequivariance/_torch/TensorProductConv.py b/openequivariance/openequivariance/_torch/TensorProductConv.py index c67ca3e9..d5f9cd3c 100644 --- a/openequivariance/openequivariance/_torch/TensorProductConv.py +++ b/openequivariance/openequivariance/_torch/TensorProductConv.py @@ -2,6 +2,7 @@ import numpy as np import torch +import json import openequivariance._torch.extlib as extlib from openequivariance._torch.extlib import ( @@ -179,157 +180,6 @@ def reorder_weights_to_e3nn(self, weights, has_batch_dim=True): def name(): return "LoopUnrollConv" - @classmethod - def register_torch_fakes(cls): - global torch - import torch - - @torch.library.register_fake("libtorch_tp_jit::jit_conv_forward") - def fake_forward( - kernel, hash, L1_in, L2_in, W, rows, cols, workspace_buffer, sender_perm - ): - L3_dim, irrep_dtype = None, None - if hasattr(jit, "wrapped_obj"): - L3_dim = jit.wrapped_obj.kernel_dims["L3_dim"] - irrep_dtype = jit.wrapped_obj.kernel_dims["irrep_dtype"] - else: - L3_dim = jit.L3_dim - irrep_dtype = jit.irrep_dtype - - return torch.empty( - L1_in.shape[0], - L3_dim, - device="cuda", - dtype=enum_to_torch_dtype[irrep_dtype], - ) - - @torch.library.register_fake("libtorch_tp_jit::jit_conv_backward") - def fake_backward( - kernel, hash, L1_in, L2_in, W, L3_grad, rows, cols, workspace_buffer, sender_perm - ): - return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W) - - @torch.library.register_fake("libtorch_tp_jit::jit_conv_double_backward") - def fake_double_backward( - kernel, hash, - L1_in, - L2_in, - W, - L3_grad, - L1_dgrad, - L2_dgrad, - w_dgrad, - rows, - cols, - workspace_buffer, - transpose_perm=None, - ): - return [ - L1_in.new_empty(*L1_in.shape), - L2_in.new_empty(*L2_in.shape), - W.new_empty(*W.shape), - L3_grad.new_empty(*L3_grad.shape), - ] - - @classmethod - def register_autograd(cls): - backward_op = torch.ops.libtorch_tp_jit.jit_conv_backward - double_backward_op = torch.ops.libtorch_tp_jit.jit_conv_double_backward - - def setup_context(ctx, inputs, output): - ( - ctx.kernel, - ctx.hash, - ctx.L1_in, - ctx.L2_in, - ctx.W, - ctx.rows, - ctx.cols, - ctx.workspace_buffer, - ctx.sender_perm, - ) = inputs - - def backward(ctx, grad_output): - L1_grad, L2_grad, W_grad = backward_op( - ctx.kernel, - ctx.hash, - ctx.L1_in, - ctx.L2_in, - ctx.W, - grad_output, - ctx.rows, - ctx.cols, - ctx.workspace_buffer, - ctx.sender_perm, - ) - return None, None, L1_grad, L2_grad, W_grad, None, None, None, None - - torch.library.register_autograd( - "libtorch_tp_jit::jit_conv_forward", backward, setup_context=setup_context - ) - - def setup_context_double_backward(ctx, inputs, output): - ( - ctx.kernel, - ctx.hash, - ctx.L1_in, - ctx.L2_in, - ctx.W, - ctx.grad_output, - ctx.rows, - ctx.cols, - ctx.workspace_buffer, - ctx.sender_perm, - ) = inputs - ctx.inputs = inputs - - def double_backward(ctx, E, F, G): - result = double_backward_op( - ctx.kernel, - ctx.hash, - ctx.L1_in, - ctx.L2_in, - ctx.W, - ctx.grad_output, - E, - F, - G, - ctx.rows, - ctx.cols, - ctx.workspace_buffer, - ctx.sender_perm, - ) - return ( - None, - None, - result[0], - result[1], - result[2], - result[3], - None, - None, - None, - None, - ) - - torch.library.register_autograd( - "libtorch_tp_jit::jit_conv_backward", - double_backward, - setup_context=setup_context_double_backward, - ) - - @classmethod - def register_autocast(cls): - torch.library.register_autocast( - "libtorch_tp_jit::jit_conv_forward", "cuda", torch.float32 - ) - torch.library.register_autocast( - "libtorch_tp_jit::jit_conv_backward", "cuda", torch.float32 - ) - torch.library.register_autocast( - "libtorch_tp_jit::jit_conv_double_backward", "cuda", torch.float32 - ) - def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph): assert graph.rows.dtype == self.idx_dtype assert graph.cols.dtype == self.idx_dtype @@ -403,10 +253,154 @@ def backward_cpu( return L1_grad, L2_grad, weights_grad -if extlib.TORCH_COMPILE: - TensorProductConv.register_torch_fakes() - TensorProductConv.register_autograd() - TensorProductConv.register_autocast() +def register_torch_fakes(): + global torch + import torch + + @torch.library.register_fake("libtorch_tp_jit::jit_conv_forward") + def fake_forward( + kernel, hash, L1_in, L2_in, W, rows, cols, workspace_buffer, sender_perm + ): + info = json.loads(kernel) + L3_dim = info["kernel_prop"]["L3_dim"] + irrep_dtype = info["kernel_prop"]["irrep_dtype"] + + return torch.empty( + L1_in.shape[0], + L3_dim, + device="cuda", + dtype=enum_to_torch_dtype[irrep_dtype], + ) + + @torch.library.register_fake("libtorch_tp_jit::jit_conv_backward") + def fake_backward( + kernel, hash, L1_in, L2_in, W, L3_grad, rows, cols, workspace_buffer, sender_perm + ): + return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W) + + @torch.library.register_fake("libtorch_tp_jit::jit_conv_double_backward") + def fake_double_backward( + kernel, hash, + L1_in, + L2_in, + W, + L3_grad, + L1_dgrad, + L2_dgrad, + w_dgrad, + rows, + cols, + workspace_buffer, + transpose_perm=None, + ): + return [ + L1_in.new_empty(*L1_in.shape), + L2_in.new_empty(*L2_in.shape), + W.new_empty(*W.shape), + L3_grad.new_empty(*L3_grad.shape), + ] + +def register_autograd(): + backward_op = torch.ops.libtorch_tp_jit.jit_conv_backward + double_backward_op = torch.ops.libtorch_tp_jit.jit_conv_double_backward + + def setup_context(ctx, inputs, output): + ( + ctx.kernel, + ctx.hash, + ctx.L1_in, + ctx.L2_in, + ctx.W, + ctx.rows, + ctx.cols, + ctx.workspace_buffer, + ctx.sender_perm, + ) = inputs + + def backward(ctx, grad_output): + L1_grad, L2_grad, W_grad = backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_in, + ctx.L2_in, + ctx.W, + grad_output, + ctx.rows, + ctx.cols, + ctx.workspace_buffer, + ctx.sender_perm, + ) + return None, None, L1_grad, L2_grad, W_grad, None, None, None, None + + torch.library.register_autograd( + "libtorch_tp_jit::jit_conv_forward", backward, setup_context=setup_context + ) + + def setup_context_double_backward(ctx, inputs, output): + ( + ctx.kernel, + ctx.hash, + ctx.L1_in, + ctx.L2_in, + ctx.W, + ctx.grad_output, + ctx.rows, + ctx.cols, + ctx.workspace_buffer, + ctx.sender_perm, + ) = inputs + ctx.inputs = inputs + + def double_backward(ctx, E, F, G): + result = double_backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_in, + ctx.L2_in, + ctx.W, + ctx.grad_output, + E, + F, + G, + ctx.rows, + ctx.cols, + ctx.workspace_buffer, + ctx.sender_perm, + ) + return ( + None, + None, + result[0], + result[1], + result[2], + result[3], + None, + None, + None, + None, + ) + + torch.library.register_autograd( + "libtorch_tp_jit::jit_conv_backward", + double_backward, + setup_context=setup_context_double_backward, + ) + +def register_autocast(): + torch.library.register_autocast( + "libtorch_tp_jit::jit_conv_forward", "cuda", torch.float32 + ) + torch.library.register_autocast( + "libtorch_tp_jit::jit_conv_backward", "cuda", torch.float32 + ) + torch.library.register_autocast( + "libtorch_tp_jit::jit_conv_double_backward", "cuda", torch.float32 + ) + + +register_torch_fakes() +register_autograd() +register_autocast() # ================================================================== diff --git a/openequivariance/openequivariance/core/LoopUnrollTP.py b/openequivariance/openequivariance/core/LoopUnrollTP.py index 54aab56c..30efcbbb 100644 --- a/openequivariance/openequivariance/core/LoopUnrollTP.py +++ b/openequivariance/openequivariance/core/LoopUnrollTP.py @@ -94,7 +94,7 @@ def generate_double_backward_schedule(warps_per_block): ) ) - self.kernelProp = { + self.kernel_prop = { "L1_dim": self.L1.dim, "L2_dim": self.L2.dim, "L3_dim": self.L3.dim, @@ -117,7 +117,7 @@ def generate_double_backward_schedule(warps_per_block): "double_backward_config": vars( self.double_backward_schedule.launch_config ), - "kernel_prop": self.kernelProp, + "kernel_prop": self.kernel_prop, } ) self.hash = hash_str_64(self.kernel_string) From c49877602d0ca8290799ddf3d40552f804d68f23 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 4 Feb 2026 21:50:41 -0800 Subject: [PATCH 08/15] Fixing compile tests. --- .../openequivariance/_torch/TensorProduct.py | 15 ++++++--------- .../openequivariance/_torch/TensorProductConv.py | 12 +++++------- openequivariance/openequivariance/_torch/utils.py | 8 +++++++- .../extension/libtorch_tp_jit.cpp | 11 +++++++---- 4 files changed, 25 insertions(+), 21 deletions(-) diff --git a/openequivariance/openequivariance/_torch/TensorProduct.py b/openequivariance/openequivariance/_torch/TensorProduct.py index bce42bdd..3b7591b7 100644 --- a/openequivariance/openequivariance/_torch/TensorProduct.py +++ b/openequivariance/openequivariance/_torch/TensorProduct.py @@ -4,7 +4,7 @@ import torch from openequivariance.core.utils import torch_to_oeq_dtype from openequivariance.benchmark.logging_utils import getLogger -from openequivariance._torch.utils import reorder_torch, string_to_tensor +from openequivariance._torch.utils import reorder_torch, string_to_tensor, tensor_to_string from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin import numpy as np @@ -103,7 +103,7 @@ def forward( :return: Tensor of shape ``[batch_size, problem.irreps_out.dim()]``, datatype ``problem.irrep_dtype``. """ - return torch.ops.libtorch_tp_jit.jit_tp_forward(self.kernel, self.hash, x, y, W) + return torch.ops.libtorch_tp_jit.jit_tp_forward(self.kernel, self.hash, x, y, W, self.L3.dim) @staticmethod @@ -156,15 +156,12 @@ def backward_cpu( def register_torch_fakes(): @torch.library.register_fake("libtorch_tp_jit::jit_tp_forward") - def fake_forward(kernel, hash, L1_in, L2_in, W): - info = json.loads(kernel) - L3_dim = info["kernel_prop"]["L3_dim"] - + def fake_forward(kernel, hash, L1_in, L2_in, W, L3_dim): return L1_in.new_empty(L1_in.shape[0], L3_dim) @torch.library.register_fake("libtorch_tp_jit::jit_tp_backward") def fake_backward(kernel, hash, L1_in, L2_in, W, L3_grad): - return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W) + return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W) @torch.library.register_fake("libtorch_tp_jit::jit_tp_double_backward") def fake_double_backward(kernel, hash, L1_in, L2_in, W, L3_grad, E, F, G): @@ -180,13 +177,13 @@ def register_autograd(): backward_op = torch.ops.libtorch_tp_jit.jit_tp_backward def setup_context(ctx, inputs, output): - ctx.kernel, ctx.hash, ctx.L1_in, ctx.L2_in, ctx.weights = inputs + ctx.kernel, ctx.hash, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_dim = inputs def backward(ctx, grad_output): L1_grad, L2_grad, W_grad = backward_op( ctx.kernel, ctx.hash, ctx.L1_in, ctx.L2_in, ctx.weights, grad_output ) - return None, None, L1_grad, L2_grad, W_grad + return None, None, L1_grad, L2_grad, W_grad, None torch.library.register_autograd( "libtorch_tp_jit::jit_tp_forward", backward, setup_context=setup_context diff --git a/openequivariance/openequivariance/_torch/TensorProductConv.py b/openequivariance/openequivariance/_torch/TensorProductConv.py index d5f9cd3c..0082ab68 100644 --- a/openequivariance/openequivariance/_torch/TensorProductConv.py +++ b/openequivariance/openequivariance/_torch/TensorProductConv.py @@ -152,6 +152,7 @@ def forward( X, Y, W, + self.L3.dim, rows, cols, self.workspace_buffer, @@ -259,17 +260,13 @@ def register_torch_fakes(): @torch.library.register_fake("libtorch_tp_jit::jit_conv_forward") def fake_forward( - kernel, hash, L1_in, L2_in, W, rows, cols, workspace_buffer, sender_perm + kernel, hash, L1_in, L2_in, W, L3_dim, rows, cols, workspace_buffer, sender_perm ): - info = json.loads(kernel) - L3_dim = info["kernel_prop"]["L3_dim"] - irrep_dtype = info["kernel_prop"]["irrep_dtype"] - return torch.empty( L1_in.shape[0], L3_dim, device="cuda", - dtype=enum_to_torch_dtype[irrep_dtype], + dtype=L1_in.dtype ) @torch.library.register_fake("libtorch_tp_jit::jit_conv_backward") @@ -311,6 +308,7 @@ def setup_context(ctx, inputs, output): ctx.L1_in, ctx.L2_in, ctx.W, + ctx.L3_dim, ctx.rows, ctx.cols, ctx.workspace_buffer, @@ -330,7 +328,7 @@ def backward(ctx, grad_output): ctx.workspace_buffer, ctx.sender_perm, ) - return None, None, L1_grad, L2_grad, W_grad, None, None, None, None + return None, None, L1_grad, L2_grad, W_grad, None, None, None, None, None torch.library.register_autograd( "libtorch_tp_jit::jit_conv_forward", backward, setup_context=setup_context diff --git a/openequivariance/openequivariance/_torch/utils.py b/openequivariance/openequivariance/_torch/utils.py index cd5e15c4..aa83980e 100644 --- a/openequivariance/openequivariance/_torch/utils.py +++ b/openequivariance/openequivariance/_torch/utils.py @@ -71,4 +71,10 @@ def reorder_torch(schedule, weights_in, direction, has_batch_dim): def string_to_tensor(text: str) -> torch.Tensor: bytes_data = text.encode('utf-8') np_bytes = np.frombuffer(bytes_data, dtype=np.uint8) - return torch.from_numpy(np_bytes).clone() + result = torch.from_numpy(np_bytes).clone() + result.requires_grad = False + return result + +def tensor_to_string(tensor: torch.Tensor) -> str: + bytes_data = tensor.numpy().tobytes() + return bytes_data.decode('utf-8') diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp index b0edf170..6216909f 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp @@ -7,6 +7,7 @@ #include #include "json11/json11.hpp" +#include #ifdef CUDA_BACKEND #include @@ -230,7 +231,8 @@ torch::Tensor jit_tp_forward( torch::Tensor json_bytes, int64_t hash, torch::Tensor L1_in, torch::Tensor L2_in, - torch::Tensor W) { + torch::Tensor W, + int64_t L3_dim) { auto [jit_kernel, k] = compile_tp_with_caching(json_bytes, hash); Stream stream = get_current_stream(); @@ -378,6 +380,7 @@ torch::Tensor jit_conv_forward( torch::Tensor L1_in, torch::Tensor L2_in, torch::Tensor W, + int64_t L3_dim, torch::Tensor rows, torch::Tensor cols, torch::Tensor workspace, @@ -588,13 +591,13 @@ TORCH_LIBRARY_IMPL(libtorch_tp_jit, CUDA, m) { }; TORCH_LIBRARY(libtorch_tp_jit, m) { - m.def("jit_tp_forward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W) -> Tensor"); + m.def("jit_tp_forward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, int L3_dim) -> Tensor"); m.def("jit_tp_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad) -> (Tensor, Tensor, Tensor)"); m.def("jit_tp_double_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad) -> (Tensor, Tensor, Tensor, Tensor)"); - m.def("jit_conv_forward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> Tensor"); + m.def("jit_conv_forward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, int L3_dim, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> Tensor"); m.def("jit_conv_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor)"); m.def("jit_conv_double_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor, Tensor)"); -} +}; PYBIND11_MODULE(libtorch_tp_jit, m) {} \ No newline at end of file From 8ee25737059427a25397bd514de323d4cad7291e Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 4 Feb 2026 23:16:36 -0800 Subject: [PATCH 09/15] Fixed JAX error. --- openequivariance/openequivariance/_torch/TensorProduct.py | 5 +++-- .../openequivariance/_torch/TensorProductConv.py | 5 +++-- openequivariance/openequivariance/_torch/utils.py | 8 ++------ 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/openequivariance/openequivariance/_torch/TensorProduct.py b/openequivariance/openequivariance/_torch/TensorProduct.py index 3b7591b7..ca335291 100644 --- a/openequivariance/openequivariance/_torch/TensorProduct.py +++ b/openequivariance/openequivariance/_torch/TensorProduct.py @@ -4,7 +4,7 @@ import torch from openequivariance.core.utils import torch_to_oeq_dtype from openequivariance.benchmark.logging_utils import getLogger -from openequivariance._torch.utils import reorder_torch, string_to_tensor, tensor_to_string +from openequivariance._torch.utils import reorder_torch, string_to_tensor from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin import numpy as np @@ -45,6 +45,7 @@ def _init_class(self): self.input_args["torch_op"], ) + self.L3_dim = self.kernel_prop["L3_dim"] self.kernel= string_to_tensor(self.kernel_string) self.weight_numel = self.input_args["problem"].weight_numel @@ -103,7 +104,7 @@ def forward( :return: Tensor of shape ``[batch_size, problem.irreps_out.dim()]``, datatype ``problem.irrep_dtype``. """ - return torch.ops.libtorch_tp_jit.jit_tp_forward(self.kernel, self.hash, x, y, W, self.L3.dim) + return torch.ops.libtorch_tp_jit.jit_tp_forward(self.kernel, self.hash, x, y, W, self.L3_dim) @staticmethod diff --git a/openequivariance/openequivariance/_torch/TensorProductConv.py b/openequivariance/openequivariance/_torch/TensorProductConv.py index 0082ab68..5257eb22 100644 --- a/openequivariance/openequivariance/_torch/TensorProductConv.py +++ b/openequivariance/openequivariance/_torch/TensorProductConv.py @@ -83,10 +83,11 @@ def _init_class(self): ) self.allocate_workspace(self.workspace_size) + self.dummy_transpose_perm = torch.zeros(1, dtype=torch.int64, device="cuda") self.weight_numel = self.config.weight_numel self.kernel= string_to_tensor(self.kernel_string) - + self.L3_dim = self.kernel_prop["L3_dim"] def to(self, *args, **kwargs): r""" @@ -152,7 +153,7 @@ def forward( X, Y, W, - self.L3.dim, + self.L3_dim, rows, cols, self.workspace_buffer, diff --git a/openequivariance/openequivariance/_torch/utils.py b/openequivariance/openequivariance/_torch/utils.py index aa83980e..1a980206 100644 --- a/openequivariance/openequivariance/_torch/utils.py +++ b/openequivariance/openequivariance/_torch/utils.py @@ -71,10 +71,6 @@ def reorder_torch(schedule, weights_in, direction, has_batch_dim): def string_to_tensor(text: str) -> torch.Tensor: bytes_data = text.encode('utf-8') np_bytes = np.frombuffer(bytes_data, dtype=np.uint8) - result = torch.from_numpy(np_bytes).clone() + result = torch.tensor(np_bytes) result.requires_grad = False - return result - -def tensor_to_string(tensor: torch.Tensor) -> str: - bytes_data = tensor.numpy().tobytes() - return bytes_data.decode('utf-8') + return result \ No newline at end of file From 4fbb518f741934f2e171ca922c9a0d7cf42ce991 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 4 Feb 2026 23:21:02 -0800 Subject: [PATCH 10/15] Fixed compile issues. --- openequivariance/openequivariance/_torch/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openequivariance/openequivariance/_torch/utils.py b/openequivariance/openequivariance/_torch/utils.py index 1a980206..31ffa330 100644 --- a/openequivariance/openequivariance/_torch/utils.py +++ b/openequivariance/openequivariance/_torch/utils.py @@ -71,6 +71,6 @@ def reorder_torch(schedule, weights_in, direction, has_batch_dim): def string_to_tensor(text: str) -> torch.Tensor: bytes_data = text.encode('utf-8') np_bytes = np.frombuffer(bytes_data, dtype=np.uint8) - result = torch.tensor(np_bytes) + result = torch.tensor(np_bytes, device='cpu') result.requires_grad = False return result \ No newline at end of file From 1fb0f85c772b37bdccdf1880f9ea68542382d867 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 4 Feb 2026 23:27:52 -0800 Subject: [PATCH 11/15] Ruff. --- .../openequivariance/_torch/TensorProduct.py | 29 +++++++++----- .../_torch/TensorProductConv.py | 40 +++++++++++-------- .../openequivariance/_torch/utils.py | 7 ++-- .../openequivariance/core/LoopUnrollConv.py | 10 +++-- .../openequivariance/core/LoopUnrollTP.py | 2 +- .../openequivariance/core/utils.py | 3 +- .../openequivariance/jax/TensorProduct.py | 1 - .../openequivariance/jax/TensorProductConv.py | 1 - tests/export_test.py | 11 +---- 9 files changed, 57 insertions(+), 47 deletions(-) diff --git a/openequivariance/openequivariance/_torch/TensorProduct.py b/openequivariance/openequivariance/_torch/TensorProduct.py index ca335291..3885604f 100644 --- a/openequivariance/openequivariance/_torch/TensorProduct.py +++ b/openequivariance/openequivariance/_torch/TensorProduct.py @@ -4,11 +4,10 @@ import torch from openequivariance.core.utils import torch_to_oeq_dtype from openequivariance.benchmark.logging_utils import getLogger -from openequivariance._torch.utils import reorder_torch, string_to_tensor +from openequivariance._torch.utils import reorder_torch, string_to_tensor from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin import numpy as np -import json logger = getLogger() @@ -23,7 +22,7 @@ class TensorProduct(torch.nn.Module, LoopUnrollTP, NumpyDoubleBackwardMixin): * The provided tensor product specification is unsupported. :param problem: Specification of the tensor product. - :param use_opaque: This parameter is deprecated. + :param use_opaque: This parameter is deprecated. """ def __init__(self, problem: TPProblem, torch_op=True, use_opaque=False): @@ -46,10 +45,9 @@ def _init_class(self): ) self.L3_dim = self.kernel_prop["L3_dim"] - self.kernel= string_to_tensor(self.kernel_string) + self.kernel = string_to_tensor(self.kernel_string) self.weight_numel = self.input_args["problem"].weight_numel - def to(self, *args, **kwargs): r""" See `torch.nn.Module.to() `_. @@ -104,8 +102,9 @@ def forward( :return: Tensor of shape ``[batch_size, problem.irreps_out.dim()]``, datatype ``problem.irrep_dtype``. """ - return torch.ops.libtorch_tp_jit.jit_tp_forward(self.kernel, self.hash, x, y, W, self.L3_dim) - + return torch.ops.libtorch_tp_jit.jit_tp_forward( + self.kernel, self.hash, x, y, W, self.L3_dim + ) @staticmethod def name(): @@ -162,7 +161,7 @@ def fake_forward(kernel, hash, L1_in, L2_in, W, L3_dim): @torch.library.register_fake("libtorch_tp_jit::jit_tp_backward") def fake_backward(kernel, hash, L1_in, L2_in, W, L3_grad): - return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W) + return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W) @torch.library.register_fake("libtorch_tp_jit::jit_tp_double_backward") def fake_double_backward(kernel, hash, L1_in, L2_in, W, L3_grad, E, F, G): @@ -195,7 +194,15 @@ def setup_context_double_backward(ctx, inputs, output): def double_backward(ctx, E, F, G): result = torch.ops.libtorch_tp_jit.jit_tp_double_backward( - ctx.kernel, ctx.hash, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad, E, F, G + ctx.kernel, + ctx.hash, + ctx.L1_in, + ctx.L2_in, + ctx.weights, + ctx.L3_grad, + E, + F, + G, ) return None, None, result[0], result[1], result[2], result[3] @@ -205,6 +212,7 @@ def double_backward(ctx, E, F, G): setup_context=setup_context_double_backward, ) + def register_autocast(): global torch import torch @@ -219,6 +227,7 @@ def register_autocast(): "libtorch_tp_jit::jit_tp_double_backward", "cuda", torch.float32 ) + register_torch_fakes() register_autograd() -register_autocast() \ No newline at end of file +register_autocast() diff --git a/openequivariance/openequivariance/_torch/TensorProductConv.py b/openequivariance/openequivariance/_torch/TensorProductConv.py index 5257eb22..5788f2f0 100644 --- a/openequivariance/openequivariance/_torch/TensorProductConv.py +++ b/openequivariance/openequivariance/_torch/TensorProductConv.py @@ -2,9 +2,7 @@ import numpy as np import torch -import json -import openequivariance._torch.extlib as extlib from openequivariance._torch.extlib import ( postprocess_kernel, DeviceProp, @@ -18,7 +16,10 @@ from openequivariance._torch.TensorProduct import TensorProduct from openequivariance import TPProblem from openequivariance.core.utils import torch_to_oeq_dtype -from openequivariance._torch.utils import enum_to_torch_dtype, reorder_torch, string_to_tensor +from openequivariance._torch.utils import ( + reorder_torch, + string_to_tensor, +) from openequivariance.benchmark.logging_utils import getLogger from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixinConv @@ -47,7 +48,7 @@ class TensorProductConv(torch.nn.Module, LoopUnrollConv, NumpyDoubleBackwardMixi fixup-based algorithm. `Default`: ``False``. :param kahan: If ``True``, uses Kahan summation to improve accuracy during aggregation. To use this option, the input tensors must be in float32 precision AND you must set ``deterministic=True``. *Default*: ``False``. - :param use_opaque: This parameter is deprecated. + :param use_opaque: This parameter is deprecated. """ def __init__( @@ -86,7 +87,7 @@ def _init_class(self): self.dummy_transpose_perm = torch.zeros(1, dtype=torch.int64, device="cuda") self.weight_numel = self.config.weight_numel - self.kernel= string_to_tensor(self.kernel_string) + self.kernel = string_to_tensor(self.kernel_string) self.L3_dim = self.kernel_prop["L3_dim"] def to(self, *args, **kwargs): @@ -223,9 +224,7 @@ def backward_cpu( torch_L1_in = torch.tensor(L1_in, requires_grad=True, device="cuda") torch_L2_in = torch.tensor(L2_in, requires_grad=True, device="cuda") - torch_weights = torch.tensor( - weights_chunked, requires_grad=True, device="cuda" - ) + torch_weights = torch.tensor(weights_chunked, requires_grad=True, device="cuda") torch_L3_grad = torch.tensor(L3_grad, device="cuda") torch_rows = torch.tensor(graph.rows, device="cuda") torch_cols = torch.tensor(graph.cols, device="cuda") @@ -234,7 +233,7 @@ def backward_cpu( torch_sender_perm = torch.tensor(graph.transpose_perm, device="cuda") else: torch_sender_perm = None - + torch_out = self.forward( torch_L1_in, torch_L2_in, @@ -263,22 +262,27 @@ def register_torch_fakes(): def fake_forward( kernel, hash, L1_in, L2_in, W, L3_dim, rows, cols, workspace_buffer, sender_perm ): - return torch.empty( - L1_in.shape[0], - L3_dim, - device="cuda", - dtype=L1_in.dtype - ) + return torch.empty(L1_in.shape[0], L3_dim, device="cuda", dtype=L1_in.dtype) @torch.library.register_fake("libtorch_tp_jit::jit_conv_backward") def fake_backward( - kernel, hash, L1_in, L2_in, W, L3_grad, rows, cols, workspace_buffer, sender_perm + kernel, + hash, + L1_in, + L2_in, + W, + L3_grad, + rows, + cols, + workspace_buffer, + sender_perm, ): return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W) @torch.library.register_fake("libtorch_tp_jit::jit_conv_double_backward") def fake_double_backward( - kernel, hash, + kernel, + hash, L1_in, L2_in, W, @@ -298,6 +302,7 @@ def fake_double_backward( L3_grad.new_empty(*L3_grad.shape), ] + def register_autograd(): backward_op = torch.ops.libtorch_tp_jit.jit_conv_backward double_backward_op = torch.ops.libtorch_tp_jit.jit_conv_double_backward @@ -385,6 +390,7 @@ def double_backward(ctx, E, F, G): setup_context=setup_context_double_backward, ) + def register_autocast(): torch.library.register_autocast( "libtorch_tp_jit::jit_conv_forward", "cuda", torch.float32 diff --git a/openequivariance/openequivariance/_torch/utils.py b/openequivariance/openequivariance/_torch/utils.py index 31ffa330..74d5a010 100644 --- a/openequivariance/openequivariance/_torch/utils.py +++ b/openequivariance/openequivariance/_torch/utils.py @@ -68,9 +68,10 @@ def reorder_torch(schedule, weights_in, direction, has_batch_dim): } ) + def string_to_tensor(text: str) -> torch.Tensor: - bytes_data = text.encode('utf-8') + bytes_data = text.encode("utf-8") np_bytes = np.frombuffer(bytes_data, dtype=np.uint8) - result = torch.tensor(np_bytes, device='cpu') + result = torch.tensor(np_bytes, device="cpu") result.requires_grad = False - return result \ No newline at end of file + return result diff --git a/openequivariance/openequivariance/core/LoopUnrollConv.py b/openequivariance/openequivariance/core/LoopUnrollConv.py index e633bd4f..ca8b4bdd 100644 --- a/openequivariance/openequivariance/core/LoopUnrollConv.py +++ b/openequivariance/openequivariance/core/LoopUnrollConv.py @@ -7,9 +7,13 @@ SMEMCapacityException, ) -from openequivariance.core.utils import dtype_to_enum from openequivariance.templates.jinja_utils import get_jinja_environment -from openequivariance.core.utils import filter_and_analyze_problem, dtype_to_enum, hash_str_64 +from openequivariance.core.utils import ( + filter_and_analyze_problem, + dtype_to_enum, + hash_str_64, +) + class LoopUnrollConv(ConvolutionBase): def __init__( @@ -214,4 +218,4 @@ def generate_double_backward_schedule(warps_per_block): "kernel_prop": self.kernel_prop, } ) - self.hash = hash_str_64(self.kernel_string) \ No newline at end of file + self.hash = hash_str_64(self.kernel_string) diff --git a/openequivariance/openequivariance/core/LoopUnrollTP.py b/openequivariance/openequivariance/core/LoopUnrollTP.py index 30efcbbb..41354e5f 100644 --- a/openequivariance/openequivariance/core/LoopUnrollTP.py +++ b/openequivariance/openequivariance/core/LoopUnrollTP.py @@ -14,6 +14,7 @@ logger = getLogger() + class LoopUnrollTP(TensorProductBase): def __init__(self, config, dp, postprocess_kernel, torch_op): super().__init__(config, torch_op=torch_op) @@ -123,7 +124,6 @@ def generate_double_backward_schedule(warps_per_block): self.hash = hash_str_64(self.kernel_string) logger.info(f"Kernel File Size: {len(self.jit_kernel) // 1024} KB") - def calculate_flops_forward(self, batch_size: int) -> dict: if self.is_uvw: return super().calculate_flops_forward(batch_size) diff --git a/openequivariance/openequivariance/core/utils.py b/openequivariance/openequivariance/core/utils.py index 5e2b1901..1950013d 100644 --- a/openequivariance/openequivariance/core/utils.py +++ b/openequivariance/openequivariance/core/utils.py @@ -201,5 +201,6 @@ def benchmark(func, num_warmup, num_iter, mode="gpu_time", kernel_names=[]): return time_millis + def hash_str_64(s: str) -> int: - return int.from_bytes(hashlib.sha256(s.encode()).digest()[:7], 'big') \ No newline at end of file + return int.from_bytes(hashlib.sha256(s.encode()).digest()[:7], "big") diff --git a/openequivariance/openequivariance/jax/TensorProduct.py b/openequivariance/openequivariance/jax/TensorProduct.py index 4f273b8a..84d75e10 100644 --- a/openequivariance/openequivariance/jax/TensorProduct.py +++ b/openequivariance/openequivariance/jax/TensorProduct.py @@ -5,7 +5,6 @@ from openequivariance.core.LoopUnrollTP import LoopUnrollTP from openequivariance.jax.utils import reorder_jax from openequivariance.jax.jvp.tp_prim import tp_fwd_p -import json class TensorProduct(LoopUnrollTP): diff --git a/openequivariance/openequivariance/jax/TensorProductConv.py b/openequivariance/openequivariance/jax/TensorProductConv.py index 84fc8a30..ce36e0c2 100644 --- a/openequivariance/openequivariance/jax/TensorProductConv.py +++ b/openequivariance/openequivariance/jax/TensorProductConv.py @@ -1,5 +1,4 @@ import jax -import json import jax.numpy as jnp import numpy as np from typing import Optional diff --git a/tests/export_test.py b/tests/export_test.py index 0e868ee3..efdaf865 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -1,17 +1,10 @@ -import shutil import torch import pytest import tempfile -import subprocess -import os -import sys import numpy as np import openequivariance as oeq from torch_geometric import EdgeIndex -import importlib.resources - -from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct @pytest.fixture(scope="session") @@ -128,9 +121,7 @@ def test_aoti(tp_and_inputs): exported_tp, package_path=tmp_file.name ) except Exception as e: - err_msg = ( - f"AOTI compile_and_package failed. Error: {e}" - ) + err_msg = f"AOTI compile_and_package failed. Error: {e}" assert False, err_msg aoti_model = torch._inductor.aoti_load_package(output_path) From 046c0714b0dc35fd37603b0f623cec7ef1f5a018 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 8 Feb 2026 21:38:02 -0800 Subject: [PATCH 12/15] Eliminated raw pointer functions and generic classes. --- .../_torch/extlib/__init__.py | 8 -- .../extension/convolution.hpp | 84 ------------------- .../extension/generic_module.cpp | 20 ----- .../extension/tensorproducts.hpp | 28 +------ 4 files changed, 1 insertion(+), 139 deletions(-) diff --git a/openequivariance/openequivariance/_torch/extlib/__init__.py b/openequivariance/openequivariance/_torch/extlib/__init__.py index 38bc211f..a7b4b865 100644 --- a/openequivariance/openequivariance/_torch/extlib/__init__.py +++ b/openequivariance/openequivariance/_torch/extlib/__init__.py @@ -149,8 +149,6 @@ def torch_ext_so_path(): if BUILT_EXTENSION: from generic_module import ( - JITTPImpl, - JITConvImpl, GroupMM_F32, GroupMM_F64, DeviceProp, @@ -158,12 +156,6 @@ def torch_ext_so_path(): ) else: - def JITTPImpl(*args, **kwargs): - _raise_import_error_helper("JITTPImpl") - - def JITConvImpl(*args, **kwargs): - _raise_import_error_helper("JITConvImpl") - def GroupMM_F32(*args, **kwargs): _raise_import_error_helper("GroupMM_F32") diff --git a/openequivariance/openequivariance/extension/convolution.hpp b/openequivariance/openequivariance/extension/convolution.hpp index 92aa6880..83ad58b4 100644 --- a/openequivariance/openequivariance/extension/convolution.hpp +++ b/openequivariance/openequivariance/extension/convolution.hpp @@ -176,88 +176,4 @@ class __attribute__ ((visibility ("default"))) JITConvImpl { } ~JITConvImpl() = default; - - // Integer pointer versions of the functions above - - void exec_conv_rawptrs( - uint64_t L1_in, - uint64_t L2_in, - uint64_t weights, - uint64_t L3_out, - uint64_t rows, - uint64_t cols, - uint64_t nnz, - uint64_t node_count, - uint64_t workspace) { - - exec_conv( - reinterpret_cast(L1_in), - reinterpret_cast(L2_in), - reinterpret_cast(weights), - reinterpret_cast(L3_out), - reinterpret_cast(rows), - reinterpret_cast(cols), - nnz, - node_count, - reinterpret_cast(workspace), - 0 // Default Stream - ); - } - - void backward_rawptrs( - uint64_t L1_in, uint64_t L1_grad, - uint64_t L2_in, uint64_t L2_grad, - uint64_t weight, uint64_t weight_grad, - uint64_t L3_grad, - uint64_t rows, uint64_t cols, - uint64_t nnz, uint64_t node_count, - uint64_t workspace, uint64_t inverse_perm) { - - backward( - reinterpret_cast(L1_in), - reinterpret_cast(L1_grad), - reinterpret_cast(L2_in), - reinterpret_cast(L2_grad), - reinterpret_cast(weight), - reinterpret_cast(weight_grad), - reinterpret_cast(L3_grad), - reinterpret_cast(rows), - reinterpret_cast(cols), - nnz, - node_count, - reinterpret_cast(workspace), - reinterpret_cast(inverse_perm), - 0 // Default Stream - ); - } - - void double_backward_rawptrs( - uint64_t L1_in, uint64_t L2_in, uint64_t W, uint64_t L3_grad, - uint64_t L1_dgrad, uint64_t L2_dgrad, uint64_t w_dgrad, - uint64_t L1_grad, uint64_t L2_grad, uint64_t W_grad, uint64_t L3_dgrad, - uint64_t rows, uint64_t cols, - uint64_t nnz, uint64_t node_count, - uint64_t wspace, uint64_t transpose_perm) { - - double_backward( - reinterpret_cast(L1_in), - reinterpret_cast(L2_in), - reinterpret_cast(W), - reinterpret_cast(L3_grad), - reinterpret_cast(L1_dgrad), - reinterpret_cast(L2_dgrad), - reinterpret_cast(w_dgrad), - reinterpret_cast(L1_grad), - reinterpret_cast(L2_grad), - reinterpret_cast(W_grad), - reinterpret_cast(L3_dgrad), - reinterpret_cast(rows), - reinterpret_cast(cols), - nnz, - node_count, - reinterpret_cast(wspace), - reinterpret_cast(transpose_perm), - 0 - ); - } }; \ No newline at end of file diff --git a/openequivariance/openequivariance/extension/generic_module.cpp b/openequivariance/openequivariance/extension/generic_module.cpp index 2b83414e..b0996991 100644 --- a/openequivariance/openequivariance/extension/generic_module.cpp +++ b/openequivariance/openequivariance/extension/generic_module.cpp @@ -31,26 +31,6 @@ using namespace std; namespace py=pybind11; PYBIND11_MODULE(generic_module, m) { - //=========== Batch tensor products ========= - py::class_>(m, "JITTPImpl") - .def(py::init< std::string, - std::unordered_map, - std::unordered_map, - std::unordered_map, - std::unordered_map>()) - .def("exec_tensor_product_rawptr", &JITTPImpl::exec_tensor_product_device_rawptrs) - .def("backward_rawptr", &JITTPImpl::backward_device_rawptrs); - - py::class_>(m, "JITConvImpl") - .def(py::init< std::string, - std::unordered_map, - std::unordered_map, - std::unordered_map, - std::unordered_map>()) - .def("exec_conv_rawptrs", &JITConvImpl::exec_conv_rawptrs) - .def("backward_rawptrs", &JITConvImpl::backward_rawptrs) - .def("double_backward_rawptrs", &JITConvImpl::double_backward_rawptrs); - py::class_>(m, "GroupMM_F32") .def(py::init()) .def("group_gemm", &GroupMM::group_gemm_intptr); diff --git a/openequivariance/openequivariance/extension/tensorproducts.hpp b/openequivariance/openequivariance/extension/tensorproducts.hpp index ee8def66..b4b4d84b 100644 --- a/openequivariance/openequivariance/extension/tensorproducts.hpp +++ b/openequivariance/openequivariance/extension/tensorproducts.hpp @@ -109,31 +109,5 @@ class __attribute__ ((visibility ("default"))) JITTPImpl { jit.execute(3, args, with_stream(double_backward_config_ref, stream)); } - ~JITTPImpl() = default; - - // Integer pointer versions of the functions above - void exec_tensor_product_device_rawptrs(uint64_t num_products, - uint64_t L1_in, uint64_t L2_in, uint64_t L3_out, uint64_t weights) { - exec_tensor_product(num_products, - reinterpret_cast(L1_in), - reinterpret_cast(L2_in), - reinterpret_cast(L3_out), - reinterpret_cast(weights), - 0 // Default Stream - ); - } - - void backward_device_rawptrs(uint64_t num_products, - uint64_t L1_in, uint64_t L1_grad, - uint64_t L2_in, uint64_t L2_grad, - uint64_t weight, uint64_t weight_grad, - uint64_t L3_grad) { - - backward(num_products, - reinterpret_cast(L1_in), reinterpret_cast(L1_grad), - reinterpret_cast(L2_in), reinterpret_cast(L2_grad), - reinterpret_cast(weight), reinterpret_cast(weight_grad), - reinterpret_cast(L3_grad), 0 // Null = Default Stream - ); - } + ~JITTPImpl() = default; }; \ No newline at end of file From a930094a0c568784ab958baf0b0069132e6ef47a Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Sun, 8 Feb 2026 21:44:35 -0800 Subject: [PATCH 13/15] Ruff. --- openequivariance/openequivariance/core/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/openequivariance/openequivariance/core/utils.py b/openequivariance/openequivariance/core/utils.py index 5e67a207..1950013d 100644 --- a/openequivariance/openequivariance/core/utils.py +++ b/openequivariance/openequivariance/core/utils.py @@ -10,7 +10,6 @@ import hashlib from enum import IntEnum -import hashlib class DTypeEnum(IntEnum): From 53f80696421da1eb3761608dcfb4313996281e88 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 11 Feb 2026 20:35:30 -0800 Subject: [PATCH 14/15] Fixed benchmark.py. --- .../openequivariance/benchmark/ConvBenchmarkSuite.py | 3 ++- .../openequivariance/benchmark/TestBenchmarkSuite.py | 7 +++++-- .../openequivariance/benchmark/benchmark_utils.py | 12 ++++++++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py b/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py index 499a33eb..debcc65b 100644 --- a/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py +++ b/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py @@ -8,6 +8,7 @@ import openequivariance as oeq from openequivariance.benchmark.logging_utils import getLogger from openequivariance.core.ConvolutionBase import CoordGraph +from openequivariance.benchmark.benchmark_utils import NpEncoder logger = getLogger() @@ -145,7 +146,7 @@ def run( f"{output_folder}/{self.exp_count}_{impl.name()}_{graph.name}.json" ) with open(fname, "w") as f: - json.dump(result, f, indent=2) + json.dump(result, f, indent=2, cls=NpEncoder) self.exp_count += 1 logger.info(f"Finished {tc_name}, graph {graph.name}") diff --git a/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py b/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py index 119c866c..37d20c46 100644 --- a/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py +++ b/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py @@ -21,6 +21,7 @@ benchmark_forward, benchmark_backward, benchmark_double_backward, + NpEncoder, ) logger = getLogger() @@ -235,10 +236,12 @@ def run( fname = pathlib.Path(f"{output_folder}/{test_ID}_{impl.name()}.json") - pretty_result = json.dumps(obj=result, indent=2).replace("\\n", "\n") + pretty_result = json.dumps(obj=result, indent=2, cls=NpEncoder).replace( + "\\n", "\n" + ) logger.debug(pretty_result) with open(fname, "w") as f: - json.dump(result, f, indent=2) + json.dump(result, f, indent=2, cls=NpEncoder) self.results.append(result) logger.info(f"Finished Test ID: {test_ID}") diff --git a/openequivariance/openequivariance/benchmark/benchmark_utils.py b/openequivariance/openequivariance/benchmark/benchmark_utils.py index 377df3d6..68dc6f9f 100644 --- a/openequivariance/openequivariance/benchmark/benchmark_utils.py +++ b/openequivariance/openequivariance/benchmark/benchmark_utils.py @@ -1,3 +1,4 @@ +import json import numpy as np from openequivariance.benchmark.random_buffer_utils import ( @@ -290,3 +291,14 @@ def benchmark_double_backward( ) return result + + +class NpEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, np.integer): + return int(obj) + if isinstance(obj, np.floating): + return float(obj) + if isinstance(obj, np.ndarray): + return obj.tolist() + return super(NpEncoder, self).default(obj) From 21b8ae14746eb3c0e1b5e51864062abf5d858e33 Mon Sep 17 00:00:00 2001 From: Vivek Bharadwaj Date: Wed, 11 Feb 2026 20:46:14 -0800 Subject: [PATCH 15/15] Updated JAX extension version due to backend change. --- openequivariance_extjax/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openequivariance_extjax/pyproject.toml b/openequivariance_extjax/pyproject.toml index def67e12..74f9627d 100644 --- a/openequivariance_extjax/pyproject.toml +++ b/openequivariance_extjax/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build" [project] name = "openequivariance_extjax" -version = "0.2.0" +version = "0.2.1" authors = [ { name="Austin Glover" }, { name="Vivek Bharadwaj" },