diff --git a/openequivariance/openequivariance/_torch/TensorProduct.py b/openequivariance/openequivariance/_torch/TensorProduct.py index 05ea54b5..3885604f 100644 --- a/openequivariance/openequivariance/_torch/TensorProduct.py +++ b/openequivariance/openequivariance/_torch/TensorProduct.py @@ -2,14 +2,12 @@ from openequivariance import TPProblem from openequivariance._torch import extlib import torch -import typing from openequivariance.core.utils import torch_to_oeq_dtype from openequivariance.benchmark.logging_utils import getLogger -from openequivariance._torch.utils import reorder_torch +from openequivariance._torch.utils import reorder_torch, string_to_tensor from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin import numpy as np -from openequivariance._torch.extlib import DeviceBuffer logger = getLogger() @@ -24,7 +22,7 @@ class TensorProduct(torch.nn.Module, LoopUnrollTP, NumpyDoubleBackwardMixin): * The provided tensor product specification is unsupported. :param problem: Specification of the tensor product. - :param use_opaque: If ``True``, uses an opaque forward pass that cannot be symbolically traced. *Default*: ``False``. + :param use_opaque: This parameter is deprecated. """ def __init__(self, problem: TPProblem, torch_op=True, use_opaque=False): @@ -46,27 +44,9 @@ def _init_class(self): self.input_args["torch_op"], ) - internal_cls = None - if extlib.TORCH_COMPILE: - internal_cls = torch.classes.libtorch_tp_jit.TorchJITProduct - else: - internal_cls = extlib.JITTPImpl - - logger.info("Starting kernel compiler...") - self.internal = internal_cls( - self.jit_kernel, - vars(self.forward_schedule.launch_config), - vars(self.backward_schedule.launch_config), - vars(self.double_backward_schedule.launch_config), - self.kernelProp, - ) - logger.info("Kernel compiled!") - logger.info(f"Kernel File Size: {len(self.jit_kernel) // 1024} KB") - + self.L3_dim = self.kernel_prop["L3_dim"] + self.kernel = string_to_tensor(self.kernel_string) self.weight_numel = self.input_args["problem"].weight_numel - self._setup_notorchbind() - if (not extlib.TORCH_COMPILE) or self.input_args["use_opaque"]: - self.forward = self.forward_opaque def to(self, *args, **kwargs): r""" @@ -122,262 +102,14 @@ def forward( :return: Tensor of shape ``[batch_size, problem.irreps_out.dim()]``, datatype ``problem.irrep_dtype``. """ - return torch.ops.libtorch_tp_jit.jit_tp_forward(self.internal, x, y, W) - - def _setup_notorchbind(self): - """ - In case TorchBind is not available (e.g. for torch.compile below PT2.8, etc.), - set up operations using custom ops. - """ - - @torch.library.custom_op( - f"openequivariance::tp_forward{self.tp_id}", - mutates_args=(), - device_types="cuda", - ) - def forward( - L1_in: torch.Tensor, L2_in: torch.Tensor, weights: torch.Tensor - ) -> torch.Tensor: - L1_in_c, L2_in_c, weights_c = ( - L1_in.contiguous(), - L2_in.contiguous(), - weights.contiguous(), - ) - L3_out = torch.empty( - (L1_in_c.shape[0], self.L3.dim), dtype=L1_in.dtype, device=L1_in.device - ) - self.forward_raw( - L1_in_c.shape[0], - L1_in_c.data_ptr(), - L2_in_c.data_ptr(), - L3_out.data_ptr(), - weights_c.data_ptr(), - ) - return L3_out - - @forward.register_fake - def _(L1_in, L2_in, weights): - return L1_in.new_empty(L1_in.shape[0], self.L3.dim) - - self.forward_opaque = forward - - # ---------------- Backward pass ----------------- - @torch.library.custom_op( - f"openequivariance::tp_grad_helper{self.tp_id}", - mutates_args=(), - device_types="cuda", - ) - def backward_helper( - L1_in: torch.Tensor, - L2_in: torch.Tensor, - weights: torch.Tensor, - L3_grad: torch.Tensor, - ) -> typing.List[torch.Tensor]: - L1_grad = torch.zeros_like(L1_in) - L2_grad = torch.zeros_like(L2_in) - weights_grad = torch.empty_like(weights) - - if self.config.shared_weights: - weights_grad[:] = 0.0 - - self.backward_raw( - L1_in.shape[0], - L1_in.contiguous().data_ptr(), - L1_grad.data_ptr(), - L2_in.contiguous().data_ptr(), - L2_grad.data_ptr(), - weights.contiguous().data_ptr(), - weights_grad.data_ptr(), - L3_grad.contiguous().data_ptr(), - ) - - return [L1_grad, L2_grad, weights_grad] - - @backward_helper.register_fake - def _(L1_in, L2_in, weights, L3_grad): - return [ - L1_in.new_empty(*L1_in.shape), - L2_in.new_empty(*L2_in.shape), - weights.new_empty(*weights.shape), - ] - - def setup_context(ctx, inputs, output): - ctx.L1_in, ctx.L2_in, ctx.weights = inputs - - def backward(ctx, grad_output): - result = backward_helper(ctx.L1_in, ctx.L2_in, ctx.weights, grad_output) - return result[0], result[1], result[2] - - self.forward_opaque.register_autograd(backward, setup_context=setup_context) - - def setup_context_double_backward(ctx, inputs, output): - ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad = inputs - - def double_backward(ctx, grad_output): - A, B, C, D = ctx.L1_in, ctx.L2_in, ctx.L3_grad, ctx.weights - E, F, G = grad_output[0], grad_output[1], grad_output[2] - - op1 = backward_helper(E, F, D, C) - op2 = backward_helper(A, B, G, C) - op3 = forward(E, B, D) - op4 = backward_helper(E, B, D, C) - op5 = backward_helper(A, F, D, C) - op6 = forward(A, F, D) - op7 = forward(A, B, G) - - return ( - op1[0] + op2[0], - op1[1] + op2[1], - (op4[2] + op5[2]), - (op3 + op6 + op7), - ) - - backward_helper.register_autograd( - double_backward, setup_context=setup_context_double_backward - ) - - @classmethod - def register_torch_fakes(cls): - @torch._library.register_fake_class("libtorch_tp_jit::TorchJITProduct") - class TorchJITProduct: - def __init__( - self, - kernel_plaintext: str, - fwd_config: dict[str, int], - bwd_config: dict[str, int], - dbl_bwd_config: dict[str, int], - kernel_dims: dict[str, int], - ) -> None: - ( - self.kernel_plaintext, - self.fwd_config, - self.bwd_config, - self.dbl_bwd_config, - self.kernel_dims, - ) = ( - kernel_plaintext, - fwd_config, - bwd_config, - dbl_bwd_config, - kernel_dims, - ) - - @classmethod - def __obj_unflatten__(cls, flattened_product): - return cls(**dict(flattened_product)) - - def __len__(self): - return 0 - - def __setstate__(self, state): - self.kernel_plaintext = state["kernel_plaintext"] - self.fwd_config = state["fwd_config"] - self.bwd_config = state["bwd_config"] - self.dbl_bwd_config = state["dbl_bwd_config"] - self.kernel_dims = state["kernel_dims"] - - def exec_tensor_product_rawptr(*args, **kwargs): - pass - - def backward_rawptr(*args, **kwargs): - pass - - def L3_dim_getter(self): - return self.kernel_dims["L3_dim"] - - def irrep_dtype_getter(self): - return self.kernel_dims["irrep_dtype"] - - @torch.library.register_fake("libtorch_tp_jit::jit_tp_forward") - def fake_forward(jit, L1_in, L2_in, W): - L3_dim = None - if hasattr(jit, "wrapped_obj"): - L3_dim = jit.wrapped_obj.kernel_dims["L3_dim"] - else: - L3_dim = jit.L3_dim - - return L1_in.new_empty(L1_in.shape[0], L3_dim) - - @torch.library.register_fake("libtorch_tp_jit::jit_tp_backward") - def fake_backward(jit, L1_in, L2_in, W, L3_grad): - return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W) - - @classmethod - def register_autograd(cls): - backward_op = torch.ops.libtorch_tp_jit.jit_tp_backward - - def setup_context(ctx, inputs, output): - ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights = inputs - - def backward(ctx, grad_output): - L1_grad, L2_grad, W_grad = backward_op( - ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, grad_output - ) - return None, L1_grad, L2_grad, W_grad - - torch.library.register_autograd( - "libtorch_tp_jit::jit_tp_forward", backward, setup_context=setup_context - ) - - def setup_context_double_backward(ctx, inputs, output): - ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad = inputs - - def double_backward(ctx, E, F, G): - result = torch.ops.libtorch_tp_jit.jit_tp_double_backward( - ctx.jit, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad, E, F, G - ) - return None, result[0], result[1], result[2], result[3] - - torch.library.register_autograd( - "libtorch_tp_jit::jit_tp_backward", - double_backward, - setup_context=setup_context_double_backward, - ) - - @classmethod - def register_autocast(cls): - global torch - import torch - - torch.library.register_autocast( - "libtorch_tp_jit::jit_tp_forward", "cuda", torch.float32 - ) - torch.library.register_autocast( - "libtorch_tp_jit::jit_tp_backward", "cuda", torch.float32 - ) - torch.library.register_autocast( - "libtorch_tp_jit::jit_tp_double_backward", "cuda", torch.float32 + return torch.ops.libtorch_tp_jit.jit_tp_forward( + self.kernel, self.hash, x, y, W, self.L3_dim ) @staticmethod def name(): return "LoopUnrollTP" - def forward_raw( - self, - batch: np.uint64, - L1_in: np.uint64, - L2_in: np.uint64, - L3_out: np.uint64, - weights: np.uint64, - ) -> None: - self.internal.exec_tensor_product_rawptr(batch, L1_in, L2_in, L3_out, weights) - - def backward_raw( - self, - batch_size: np.uint64, - L1_in: np.uint64, - L1_grad: np.uint64, - L2_in: np.uint64, - L2_grad: np.uint64, - weights: np.uint64, - weights_grad: np.uint64, - L3_grad: np.uint64, - ): - self.internal.backward_rawptr( - batch_size, L1_in, L1_grad, L2_in, L2_grad, weights, weights_grad, L3_grad - ) - def forward_cpu( self, L1_in: np.ndarray, @@ -389,19 +121,12 @@ def forward_cpu( weights, not self.config.shared_weights ) - batch = L1_in.shape[0] - L1_d = DeviceBuffer(L1_in) - L2_d = DeviceBuffer(L2_in) - L3_d = DeviceBuffer(L3_out) - weights_d = DeviceBuffer(weights_chunked) - self.internal.exec_tensor_product_rawptr( - batch, - L1_d.data_ptr(), - L2_d.data_ptr(), - L3_d.data_ptr(), - weights_d.data_ptr(), - ) - L3_d.copy_to_host() + torch_L1_in = torch.tensor(L1_in, device="cuda") + torch_L2_in = torch.tensor(L2_in, device="cuda") + torch_weights = torch.tensor(weights_chunked, device="cuda") + torch_L3_out = self.forward(torch_L1_in, torch_L2_in, torch_weights) + + L3_out[:] = torch_L3_out.numpy(force=True) def backward_cpu( self, L1_in, L1_grad, L2_in, L2_grad, L3_grad, weights, weights_grad @@ -410,39 +135,99 @@ def backward_cpu( weights, not self.config.shared_weights ) - batch = L1_in.shape[0] - L1_d, L2_d, L3_d = ( - DeviceBuffer(L1_in), - DeviceBuffer(L2_in), - DeviceBuffer(L3_grad), + torch_L1_in = torch.tensor(L1_in, requires_grad=True, device="cuda") + torch_L2_in = torch.tensor(L2_in, requires_grad=True, device="cuda") + torch_weights = torch.tensor(weights_chunked, requires_grad=True, device="cuda") + + torch_out = self.forward(torch_L1_in, torch_L2_in, torch_weights) + + torch_L3_grad_in = torch.tensor(L3_grad, device="cuda") + + torch_out.backward(gradient=torch_L3_grad_in) + + L1_grad[:] = torch_L1_in.grad.numpy(force=True) + L2_grad[:] = torch_L2_in.grad.numpy(force=True) + weights_grad[:] = torch_weights.grad.numpy(force=True) + + weights_grad[:] = self.reorder_weights_to_e3nn( + weights_grad, not self.config.shared_weights ) - L1_grad_d, L2_grad_d = DeviceBuffer(L1_grad), DeviceBuffer(L2_grad) - weights_d, weights_grad_d = ( - DeviceBuffer(weights_chunked), - DeviceBuffer(weights_grad), + + +def register_torch_fakes(): + @torch.library.register_fake("libtorch_tp_jit::jit_tp_forward") + def fake_forward(kernel, hash, L1_in, L2_in, W, L3_dim): + return L1_in.new_empty(L1_in.shape[0], L3_dim) + + @torch.library.register_fake("libtorch_tp_jit::jit_tp_backward") + def fake_backward(kernel, hash, L1_in, L2_in, W, L3_grad): + return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W) + + @torch.library.register_fake("libtorch_tp_jit::jit_tp_double_backward") + def fake_double_backward(kernel, hash, L1_in, L2_in, W, L3_grad, E, F, G): + return ( + torch.empty_like(L1_in), + torch.empty_like(L2_in), + torch.empty_like(W), + torch.empty_like(L3_grad), ) - self.internal.backward_rawptr( - batch, - L1_d.data_ptr(), - L1_grad_d.data_ptr(), - L2_d.data_ptr(), - L2_grad_d.data_ptr(), - weights_d.data_ptr(), - weights_grad_d.data_ptr(), - L3_d.data_ptr(), + +def register_autograd(): + backward_op = torch.ops.libtorch_tp_jit.jit_tp_backward + + def setup_context(ctx, inputs, output): + ctx.kernel, ctx.hash, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_dim = inputs + + def backward(ctx, grad_output): + L1_grad, L2_grad, W_grad = backward_op( + ctx.kernel, ctx.hash, ctx.L1_in, ctx.L2_in, ctx.weights, grad_output ) + return None, None, L1_grad, L2_grad, W_grad, None - L1_grad_d.copy_to_host() - L2_grad_d.copy_to_host() - weights_grad_d.copy_to_host() + torch.library.register_autograd( + "libtorch_tp_jit::jit_tp_forward", backward, setup_context=setup_context + ) - weights_grad[:] = self.reorder_weights_to_e3nn( - weights_grad, not self.config.shared_weights + def setup_context_double_backward(ctx, inputs, output): + ctx.kernel, ctx.hash, ctx.L1_in, ctx.L2_in, ctx.weights, ctx.L3_grad = inputs + + def double_backward(ctx, E, F, G): + result = torch.ops.libtorch_tp_jit.jit_tp_double_backward( + ctx.kernel, + ctx.hash, + ctx.L1_in, + ctx.L2_in, + ctx.weights, + ctx.L3_grad, + E, + F, + G, ) + return None, None, result[0], result[1], result[2], result[3] + + torch.library.register_autograd( + "libtorch_tp_jit::jit_tp_backward", + double_backward, + setup_context=setup_context_double_backward, + ) + + +def register_autocast(): + global torch + import torch + + torch.library.register_autocast( + "libtorch_tp_jit::jit_tp_forward", "cuda", torch.float32 + ) + torch.library.register_autocast( + "libtorch_tp_jit::jit_tp_backward", "cuda", torch.float32 + ) + torch.library.register_autocast( + "libtorch_tp_jit::jit_tp_double_backward", "cuda", torch.float32 + ) -if extlib.TORCH_COMPILE: - TensorProduct.register_torch_fakes() - TensorProduct.register_autograd() - TensorProduct.register_autocast() +register_torch_fakes() +register_autograd() +register_autocast() diff --git a/openequivariance/openequivariance/_torch/TensorProductConv.py b/openequivariance/openequivariance/_torch/TensorProductConv.py index f30c943c..5788f2f0 100644 --- a/openequivariance/openequivariance/_torch/TensorProductConv.py +++ b/openequivariance/openequivariance/_torch/TensorProductConv.py @@ -1,11 +1,9 @@ -from typing import Optional, List +from typing import Optional import numpy as np import torch -import openequivariance._torch.extlib as extlib from openequivariance._torch.extlib import ( - JITConvImpl, postprocess_kernel, DeviceProp, ) @@ -18,12 +16,13 @@ from openequivariance._torch.TensorProduct import TensorProduct from openequivariance import TPProblem from openequivariance.core.utils import torch_to_oeq_dtype -from openequivariance._torch.utils import enum_to_torch_dtype -from openequivariance._torch.utils import reorder_torch +from openequivariance._torch.utils import ( + reorder_torch, + string_to_tensor, +) from openequivariance.benchmark.logging_utils import getLogger from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixinConv -from openequivariance._torch.extlib import DeviceBuffer logger = getLogger() @@ -49,7 +48,7 @@ class TensorProductConv(torch.nn.Module, LoopUnrollConv, NumpyDoubleBackwardMixi fixup-based algorithm. `Default`: ``False``. :param kahan: If ``True``, uses Kahan summation to improve accuracy during aggregation. To use this option, the input tensors must be in float32 precision AND you must set ``deterministic=True``. *Default*: ``False``. - :param use_opaque: If ``True``, uses an opaque forward pass that cannot be symbolically traced. *Default*: ``False``. + :param use_opaque: This parameter is deprecated. """ def __init__( @@ -85,27 +84,11 @@ def _init_class(self): ) self.allocate_workspace(self.workspace_size) - if extlib.TORCH_COMPILE: - internal_cls = torch.classes.libtorch_tp_jit.TorchJITConv - else: - internal_cls = JITConvImpl - - logger.info("Starting kernel compiler...") - self.internal = internal_cls( - self.jit_kernel, - vars(self.forward_schedule.launch_config), - vars(self.backward_schedule.launch_config), - vars(self.double_backward_schedule.launch_config), - self.kernel_prop, - ) - logger.info("Kernel compiled!") self.dummy_transpose_perm = torch.zeros(1, dtype=torch.int64, device="cuda") self.weight_numel = self.config.weight_numel - self._setup_notorchbind() - - if (not extlib.TORCH_COMPILE) or self.input_args["use_opaque"]: - self.forward = self.forward_opaque + self.kernel = string_to_tensor(self.kernel_string) + self.L3_dim = self.kernel_prop["L3_dim"] def to(self, *args, **kwargs): r""" @@ -163,263 +146,29 @@ def forward( :return: Tensor of shape ``[|V|, problem.irreps_out.dim()]``, datatype ``problem.irrep_dtype``. """ if sender_perm is None: - return torch.ops.libtorch_tp_jit.jit_conv_forward( - self.internal, - X, - Y, - W, - rows, - cols, - self.workspace_buffer, - self.dummy_transpose_perm, - ) - else: - return torch.ops.libtorch_tp_jit.jit_conv_forward( - self.internal, - X, - Y, - W, - rows, - cols, - self.workspace_buffer, - sender_perm, - ) - - def allocate_workspace(self, size_bytes): - self.workspace_size = size_bytes - if self.torch_op: - self.workspace_buffer = torch.zeros( - size_bytes, dtype=torch.uint8, device="cuda" - ) - else: - self.workspace_buffer = extlib.DeviceBuffer(size_bytes) - self.workspace_ptr = self.workspace_buffer.data_ptr() - logger.info(f"Convolution requires {size_bytes // 1000000}MB of workspace.") + sender_perm = self.dummy_transpose_perm - def _setup_notorchbind(self): - @torch.library.custom_op( - f"openequivariance::conv_forward{self.conv_id}", - mutates_args=(), - device_types="cuda", - ) - def forward( - L1_in: torch.Tensor, - L2_in: torch.Tensor, - weights: torch.Tensor, - rows: torch.Tensor, - cols: torch.Tensor, - transpose_perm: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - L1_in_c, L2_in_c, weights_c = ( - L1_in.contiguous(), - L2_in.contiguous(), - weights.contiguous(), - ) - L3_out = torch.zeros( - (L1_in_c.shape[0], self.L3.dim), dtype=L1_in.dtype, device="cuda" - ) - - self.internal.exec_conv_rawptrs( - L1_in_c.data_ptr(), - L2_in_c.data_ptr(), - weights_c.data_ptr(), - L3_out.data_ptr(), - rows.contiguous().data_ptr(), - cols.contiguous().data_ptr(), - rows.shape[0], - L1_in.shape[0], - self.workspace_ptr, - ) - - return L3_out - - @forward.register_fake - def _(L1_in, L2_in, weights, rows, cols, transpose_perm=None): - return L1_in.new_empty(L1_in.shape[0], self.L3.dim) - - self.forward_opaque = forward - - @torch.library.custom_op( - f"openequivariance::conv_backward{self.conv_id}", - mutates_args=(), - device_types="cuda", - ) - def backward_helper( - L1_in: torch.Tensor, - L2_in: torch.Tensor, - weights: torch.Tensor, - L3_grad: torch.Tensor, - rows: torch.Tensor, - cols: torch.Tensor, - transpose_perm: Optional[torch.Tensor] = None, - ) -> List[torch.Tensor]: - L1_grad = torch.zeros_like(L1_in) - L2_grad = torch.zeros_like(L2_in) - weights_grad = torch.empty_like(weights) - - if self.config.shared_weights: - weights_grad[:] = 0.0 - - transpose_perm_ptr = 0 - if transpose_perm is not None: - transpose_perm_ptr = transpose_perm.data_ptr() - - self.internal.backward_rawptrs( - L1_in.contiguous().data_ptr(), - L1_grad.data_ptr(), - L2_in.contiguous().data_ptr(), - L2_grad.data_ptr(), - weights.contiguous().data_ptr(), - weights_grad.data_ptr(), - L3_grad.contiguous().data_ptr(), - rows.contiguous().data_ptr(), - cols.contiguous().data_ptr(), - rows.shape[0], - L1_in.shape[0], - self.workspace_ptr, - transpose_perm_ptr, - ) - - return [L1_grad, L2_grad, weights_grad] - - @backward_helper.register_fake - def _(L1_in, L2_in, weights, L3_grad, rows, cols, transpose_perm=None): - return [ - L1_in.new_empty(*L1_in.shape), - L2_in.new_empty(*L2_in.shape), - weights.new_empty(*weights.shape), - ] - - def setup_context(ctx, inputs, output): - ( - ctx.L1_in, - ctx.L2_in, - ctx.weights, - ctx.rows, - ctx.cols, - ctx.transpose_perm, - ) = inputs - - def backward(ctx, grad_output): - result = backward_helper( - ctx.L1_in, - ctx.L2_in, - ctx.weights, - grad_output, - ctx.rows, - ctx.cols, - ctx.transpose_perm, - ) - return result[0], result[1], result[2], None, None, None - - self.forward_opaque.register_autograd(backward, setup_context=setup_context) - - def setup_context_double_backward(ctx, inputs, output): - ( - ctx.L1_in, - ctx.L2_in, - ctx.weights, - ctx.L3_grad, - ctx.rows, - ctx.cols, - ctx.transpose_perm, - ) = inputs - - @torch.library.custom_op( - f"openequivariance::conv_double_backward{self.conv_id}", - mutates_args=(), - device_types="cuda", - ) - def double_backward_helper( - L1_in: torch.Tensor, - L2_in: torch.Tensor, - W: torch.Tensor, - L3_grad: torch.Tensor, - L1_dgrad: torch.Tensor, - L2_dgrad: torch.Tensor, - w_dgrad: torch.Tensor, - rows: torch.Tensor, - cols: torch.Tensor, - transpose_perm: Optional[torch.Tensor] = None, - ) -> List[torch.Tensor]: - L1_grad = torch.zeros_like(L1_in) - L2_grad = torch.zeros_like(L2_in) - W_grad = torch.empty_like(W) - L3_dgrad = torch.zeros_like(L3_grad) - - if self.config.shared_weights: - W_grad[:] = 0.0 - - transpose_perm_ptr = 0 - if transpose_perm is not None: - transpose_perm_ptr = transpose_perm.data_ptr() - - self.internal.double_backward_rawptrs( - L1_in.contiguous().data_ptr(), - L2_in.contiguous().data_ptr(), - W.contiguous().data_ptr(), - L3_grad.contiguous().data_ptr(), - L1_dgrad.contiguous().data_ptr(), - L2_dgrad.contiguous().data_ptr(), - w_dgrad.contiguous().data_ptr(), - L1_grad.data_ptr(), - L2_grad.data_ptr(), - W_grad.data_ptr(), - L3_dgrad.data_ptr(), - rows.contiguous().data_ptr(), - cols.contiguous().data_ptr(), - rows.shape[0], - L1_in.shape[0], - self.workspace_ptr, - transpose_perm_ptr, - ) - return [L1_grad, L2_grad, W_grad, L3_dgrad] - - @double_backward_helper.register_fake - def _( - L1_in, - L2_in, + return torch.ops.libtorch_tp_jit.jit_conv_forward( + self.kernel, + self.hash, + X, + Y, W, - L3_grad, - L1_dgrad, - L2_dgrad, - w_dgrad, + self.L3_dim, rows, cols, - transpose_perm=None, - ): - return [ - L1_in.new_empty(*L1_in.shape), - L2_in.new_empty(*L2_in.shape), - W.new_empty(*W.shape), - L3_grad.new_empty(*L3_grad.shape), - ] - - def double_backward(ctx, grad_output): - L1_dgrad, L2_dgrad, w_dgrad = grad_output[0], grad_output[1], grad_output[2] - - L1_grad, L2_grad, W_grad, L3_dgrad = double_backward_helper( - ctx.L1_in, - ctx.L2_in, - ctx.weights, - ctx.L3_grad, - L1_dgrad, - L2_dgrad, - w_dgrad, - ctx.rows, - ctx.cols, - ctx.transpose_perm, - ) - - if ctx.transpose_perm is None: - return L1_grad, L2_grad, W_grad, L3_dgrad, None, None - else: - return L1_grad, L2_grad, W_grad, L3_dgrad, None, None, None - - backward_helper.register_autograd( - double_backward, setup_context=setup_context_double_backward + self.workspace_buffer, + sender_perm, ) + def allocate_workspace(self, size_bytes): + self.workspace_size = size_bytes + self.workspace_buffer = torch.zeros( + size_bytes, dtype=torch.uint8, device="cuda" + ) + self.workspace_ptr = self.workspace_buffer.data_ptr() + logger.info(f"Convolution requires {size_bytes // 1000000}MB of workspace.") + def reorder_weights_from_e3nn(self, weights, has_batch_dim=True): return reorder_torch( self.forward_schedule, weights, "forward", not self.config.shared_weights @@ -434,210 +183,6 @@ def reorder_weights_to_e3nn(self, weights, has_batch_dim=True): def name(): return "LoopUnrollConv" - @classmethod - def register_torch_fakes(cls): - global torch - import torch - - @torch._library.register_fake_class("libtorch_tp_jit::TorchJITConv") - class TorchJITConv: - def __init__( - self, - kernel_plaintext: str, - fwd_config: dict[str, int], - bwd_config: dict[str, int], - dbl_bwd_config: dict[str, int], - kernel_dims: dict[str, int], - ) -> None: - ( - self.kernel_plaintext, - self.fwd_config, - self.bwd_config, - self.dbl_bwd_config, - self.kernel_dims, - ) = ( - kernel_plaintext, - fwd_config, - bwd_config, - dbl_bwd_config, - kernel_dims, - ) - - @classmethod - def __obj_unflatten__(cls, flattened_product): - return cls(**dict(flattened_product)) - - def __len__(self): - return 0 - - def __setstate__(self, state): - ( - self.kernel_plaintext, - self.fwd_config, - self.bwd_config, - self.dbl_bwd_config, - self.kernel_dims, - ) = state - - def exec_conv_rawptrs(*args, **kwargs): - pass - - def backward_rawptrs(*args, **kwargs): - pass - - def double_backward_rawptrs(*args, **kwargs): - pass - - def L3_dim_getter(self): - return self.kernel_dims["L3_dim"] - - def irrep_dtype_getter(self): - return self.kernel_dims["irrep_dtype"] - - @torch.library.register_fake("libtorch_tp_jit::jit_conv_forward") - def fake_forward( - jit, L1_in, L2_in, W, rows, cols, workspace_buffer, sender_perm - ): - L3_dim, irrep_dtype = None, None - if hasattr(jit, "wrapped_obj"): - L3_dim = jit.wrapped_obj.kernel_dims["L3_dim"] - irrep_dtype = jit.wrapped_obj.kernel_dims["irrep_dtype"] - else: - L3_dim = jit.L3_dim - irrep_dtype = jit.irrep_dtype - - return torch.empty( - L1_in.shape[0], - L3_dim, - device="cuda", - dtype=enum_to_torch_dtype[irrep_dtype], - ) - - @torch.library.register_fake("libtorch_tp_jit::jit_conv_backward") - def fake_backward( - jit, L1_in, L2_in, W, L3_grad, rows, cols, workspace_buffer, sender_perm - ): - return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W) - - @torch.library.register_fake("libtorch_tp_jit::jit_conv_double_backward") - def fake_double_backward( - jit, - L1_in, - L2_in, - W, - L3_grad, - L1_dgrad, - L2_dgrad, - w_dgrad, - rows, - cols, - workspace_buffer, - transpose_perm=None, - ): - return [ - L1_in.new_empty(*L1_in.shape), - L2_in.new_empty(*L2_in.shape), - W.new_empty(*W.shape), - L3_grad.new_empty(*L3_grad.shape), - ] - - @classmethod - def register_autograd(cls): - backward_op = torch.ops.libtorch_tp_jit.jit_conv_backward - double_backward_op = torch.ops.libtorch_tp_jit.jit_conv_double_backward - - def setup_context(ctx, inputs, output): - ( - ctx.jit, - ctx.L1_in, - ctx.L2_in, - ctx.W, - ctx.rows, - ctx.cols, - ctx.workspace_buffer, - ctx.sender_perm, - ) = inputs - - def backward(ctx, grad_output): - L1_grad, L2_grad, W_grad = backward_op( - ctx.jit, - ctx.L1_in, - ctx.L2_in, - ctx.W, - grad_output, - ctx.rows, - ctx.cols, - ctx.workspace_buffer, - ctx.sender_perm, - ) - return None, L1_grad, L2_grad, W_grad, None, None, None, None - - torch.library.register_autograd( - "libtorch_tp_jit::jit_conv_forward", backward, setup_context=setup_context - ) - - def setup_context_double_backward(ctx, inputs, output): - ( - ctx.jit, - ctx.L1_in, - ctx.L2_in, - ctx.W, - ctx.grad_output, - ctx.rows, - ctx.cols, - ctx.workspace_buffer, - ctx.sender_perm, - ) = inputs - ctx.inputs = inputs - - def double_backward(ctx, E, F, G): - result = double_backward_op( - ctx.jit, - ctx.L1_in, - ctx.L2_in, - ctx.W, - ctx.grad_output, - E, - F, - G, - ctx.rows, - ctx.cols, - ctx.workspace_buffer, - ctx.sender_perm, - ) - return ( - None, - result[0], - result[1], - result[2], - result[3], - None, - None, - None, - None, - ) - - torch.library.register_autograd( - "libtorch_tp_jit::jit_conv_backward", - double_backward, - setup_context=setup_context_double_backward, - ) - - @classmethod - def register_autocast(cls): - global torch - import torch - - torch.library.register_autocast( - "libtorch_tp_jit::jit_conv_forward", "cuda", torch.float32 - ) - torch.library.register_autocast( - "libtorch_tp_jit::jit_conv_backward", "cuda", torch.float32 - ) - torch.library.register_autocast( - "libtorch_tp_jit::jit_conv_double_backward", "cuda", torch.float32 - ) - def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph): assert graph.rows.dtype == self.idx_dtype assert graph.cols.dtype == self.idx_dtype @@ -646,29 +191,26 @@ def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph): weights, not self.config.shared_weights ) - L1_d, L2_d, weights_d = ( - DeviceBuffer(L1_in), - DeviceBuffer(L2_in), - DeviceBuffer(weights_chunked), - ) - L3_d = DeviceBuffer(L3_out) - - rows_d = DeviceBuffer(graph.rows) - cols_d = DeviceBuffer(graph.cols) - - self.internal.exec_conv_rawptrs( - L1_d.data_ptr(), - L2_d.data_ptr(), - weights_d.data_ptr(), - L3_d.data_ptr(), - rows_d.data_ptr(), - cols_d.data_ptr(), - graph.nnz, - graph.node_count, - self.workspace_ptr, - ) + torch_L1_in = torch.tensor(L1_in, device="cuda") + torch_L2_in = torch.tensor(L2_in, device="cuda") + torch_weights = torch.tensor(weights_chunked, device="cuda") + torch_rows = torch.tensor(graph.rows, device="cuda") + torch_cols = torch.tensor(graph.cols, device="cuda") - L3_d.copy_to_host() + if self.deterministic: + torch_sender_perm = torch.tensor(graph.transpose_perm, device="cuda") + else: + torch_sender_perm = None + + result = self.forward( + torch_L1_in, + torch_L2_in, + torch_weights, + torch_rows, + torch_cols, + torch_sender_perm, + ) + L3_out[:] = result.numpy(force=True) def backward_cpu( self, L1_in, L1_grad, L2_in, L2_grad, weights, weights_grad, L3_grad, graph @@ -680,42 +222,30 @@ def backward_cpu( weights, not self.config.shared_weights ) - L1_d = DeviceBuffer(L1_in) - L2_d = DeviceBuffer(L2_in) - weights_d = DeviceBuffer(weights_chunked) - L3_d = DeviceBuffer(L3_grad) - rows_d = DeviceBuffer(graph.rows) - cols_d = DeviceBuffer(graph.cols) - - L1_grad_d = DeviceBuffer(L1_grad) - L2_grad_d = DeviceBuffer(L2_grad) - weights_grad_d = DeviceBuffer(weights_grad) + torch_L1_in = torch.tensor(L1_in, requires_grad=True, device="cuda") + torch_L2_in = torch.tensor(L2_in, requires_grad=True, device="cuda") + torch_weights = torch.tensor(weights_chunked, requires_grad=True, device="cuda") + torch_L3_grad = torch.tensor(L3_grad, device="cuda") + torch_rows = torch.tensor(graph.rows, device="cuda") + torch_cols = torch.tensor(graph.cols, device="cuda") - transpose_perm_d = None - transpose_perm_ptr = 0 if self.deterministic: - transpose_perm_d = DeviceBuffer(graph.transpose_perm) - transpose_perm_ptr = transpose_perm_d.data_ptr() - - self.internal.backward_rawptrs( - L1_d.data_ptr(), - L1_grad_d.data_ptr(), - L2_d.data_ptr(), - L2_grad_d.data_ptr(), - weights_d.data_ptr(), - weights_grad_d.data_ptr(), - L3_d.data_ptr(), - rows_d.data_ptr(), - cols_d.data_ptr(), - graph.nnz, - graph.node_count, - self.workspace_ptr, - transpose_perm_ptr, + torch_sender_perm = torch.tensor(graph.transpose_perm, device="cuda") + else: + torch_sender_perm = None + + torch_out = self.forward( + torch_L1_in, + torch_L2_in, + torch_weights, + torch_rows, + torch_cols, + torch_sender_perm, ) - - L1_grad_d.copy_to_host() - L2_grad_d.copy_to_host() - weights_grad_d.copy_to_host() + torch_out.backward(gradient=torch_L3_grad) + L1_grad[:] = torch_L1_in.grad.numpy(force=True) + L2_grad[:] = torch_L2_in.grad.numpy(force=True) + weights_grad[:] = torch_weights.grad.numpy(force=True) weights_grad[:] = self.reorder_weights_to_e3nn( weights_grad, not self.config.shared_weights @@ -724,10 +254,158 @@ def backward_cpu( return L1_grad, L2_grad, weights_grad -if extlib.TORCH_COMPILE: - TensorProductConv.register_torch_fakes() - TensorProductConv.register_autograd() - TensorProductConv.register_autocast() +def register_torch_fakes(): + global torch + import torch + + @torch.library.register_fake("libtorch_tp_jit::jit_conv_forward") + def fake_forward( + kernel, hash, L1_in, L2_in, W, L3_dim, rows, cols, workspace_buffer, sender_perm + ): + return torch.empty(L1_in.shape[0], L3_dim, device="cuda", dtype=L1_in.dtype) + + @torch.library.register_fake("libtorch_tp_jit::jit_conv_backward") + def fake_backward( + kernel, + hash, + L1_in, + L2_in, + W, + L3_grad, + rows, + cols, + workspace_buffer, + sender_perm, + ): + return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W) + + @torch.library.register_fake("libtorch_tp_jit::jit_conv_double_backward") + def fake_double_backward( + kernel, + hash, + L1_in, + L2_in, + W, + L3_grad, + L1_dgrad, + L2_dgrad, + w_dgrad, + rows, + cols, + workspace_buffer, + transpose_perm=None, + ): + return [ + L1_in.new_empty(*L1_in.shape), + L2_in.new_empty(*L2_in.shape), + W.new_empty(*W.shape), + L3_grad.new_empty(*L3_grad.shape), + ] + + +def register_autograd(): + backward_op = torch.ops.libtorch_tp_jit.jit_conv_backward + double_backward_op = torch.ops.libtorch_tp_jit.jit_conv_double_backward + + def setup_context(ctx, inputs, output): + ( + ctx.kernel, + ctx.hash, + ctx.L1_in, + ctx.L2_in, + ctx.W, + ctx.L3_dim, + ctx.rows, + ctx.cols, + ctx.workspace_buffer, + ctx.sender_perm, + ) = inputs + + def backward(ctx, grad_output): + L1_grad, L2_grad, W_grad = backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_in, + ctx.L2_in, + ctx.W, + grad_output, + ctx.rows, + ctx.cols, + ctx.workspace_buffer, + ctx.sender_perm, + ) + return None, None, L1_grad, L2_grad, W_grad, None, None, None, None, None + + torch.library.register_autograd( + "libtorch_tp_jit::jit_conv_forward", backward, setup_context=setup_context + ) + + def setup_context_double_backward(ctx, inputs, output): + ( + ctx.kernel, + ctx.hash, + ctx.L1_in, + ctx.L2_in, + ctx.W, + ctx.grad_output, + ctx.rows, + ctx.cols, + ctx.workspace_buffer, + ctx.sender_perm, + ) = inputs + ctx.inputs = inputs + + def double_backward(ctx, E, F, G): + result = double_backward_op( + ctx.kernel, + ctx.hash, + ctx.L1_in, + ctx.L2_in, + ctx.W, + ctx.grad_output, + E, + F, + G, + ctx.rows, + ctx.cols, + ctx.workspace_buffer, + ctx.sender_perm, + ) + return ( + None, + None, + result[0], + result[1], + result[2], + result[3], + None, + None, + None, + None, + ) + + torch.library.register_autograd( + "libtorch_tp_jit::jit_conv_backward", + double_backward, + setup_context=setup_context_double_backward, + ) + + +def register_autocast(): + torch.library.register_autocast( + "libtorch_tp_jit::jit_conv_forward", "cuda", torch.float32 + ) + torch.library.register_autocast( + "libtorch_tp_jit::jit_conv_backward", "cuda", torch.float32 + ) + torch.library.register_autocast( + "libtorch_tp_jit::jit_conv_double_backward", "cuda", torch.float32 + ) + + +register_torch_fakes() +register_autograd() +register_autocast() # ================================================================== @@ -764,8 +442,6 @@ def name(): class TensorProductConvScatterSum(ConvolutionBase): def __init__(self, config, *, torch_op=True): assert torch_op - global torch - import torch super().__init__(config, torch_op=torch_op, deterministic=False) diff --git a/openequivariance/openequivariance/_torch/extlib/__init__.py b/openequivariance/openequivariance/_torch/extlib/__init__.py index 72440872..a7b4b865 100644 --- a/openequivariance/openequivariance/_torch/extlib/__init__.py +++ b/openequivariance/openequivariance/_torch/extlib/__init__.py @@ -51,7 +51,7 @@ extra_cflags = ["-O3"] generic_sources = ["generic_module.cpp"] - torch_sources = ["libtorch_tp_jit.cpp"] + torch_sources = ["libtorch_tp_jit.cpp", "json11/json11.cpp"] include_dirs, extra_link_args = (["util"], ["-Wl,--no-as-needed"]) @@ -149,22 +149,13 @@ def torch_ext_so_path(): if BUILT_EXTENSION: from generic_module import ( - JITTPImpl, - JITConvImpl, GroupMM_F32, GroupMM_F64, DeviceProp, - DeviceBuffer, GPUTimer, ) else: - def JITTPImpl(*args, **kwargs): - _raise_import_error_helper("JITTPImpl") - - def JITConvImpl(*args, **kwargs): - _raise_import_error_helper("JITConvImpl") - def GroupMM_F32(*args, **kwargs): _raise_import_error_helper("GroupMM_F32") @@ -174,8 +165,5 @@ def GroupMM_F64(*args, **kwargs): def DeviceProp(*args, **kwargs): _raise_import_error_helper("DeviceProp") - def DeviceBuffer(*args, **kwargs): - _raise_import_error_helper("DeviceBuffer") - def GPUTimer(*args, **kwargs): _raise_import_error_helper("GPUTimer") diff --git a/openequivariance/openequivariance/_torch/utils.py b/openequivariance/openequivariance/_torch/utils.py index 7538fb27..74d5a010 100644 --- a/openequivariance/openequivariance/_torch/utils.py +++ b/openequivariance/openequivariance/_torch/utils.py @@ -1,4 +1,5 @@ import torch +import numpy as np from types import MappingProxyType from openequivariance.core.utils import DTypeEnum @@ -66,3 +67,11 @@ def reorder_torch(schedule, weights_in, direction, has_batch_dim): DTypeEnum.UINT8: torch.uint8, } ) + + +def string_to_tensor(text: str) -> torch.Tensor: + bytes_data = text.encode("utf-8") + np_bytes = np.frombuffer(bytes_data, dtype=np.uint8) + result = torch.tensor(np_bytes, device="cpu") + result.requires_grad = False + return result diff --git a/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py b/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py index 499a33eb..debcc65b 100644 --- a/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py +++ b/openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py @@ -8,6 +8,7 @@ import openequivariance as oeq from openequivariance.benchmark.logging_utils import getLogger from openequivariance.core.ConvolutionBase import CoordGraph +from openequivariance.benchmark.benchmark_utils import NpEncoder logger = getLogger() @@ -145,7 +146,7 @@ def run( f"{output_folder}/{self.exp_count}_{impl.name()}_{graph.name}.json" ) with open(fname, "w") as f: - json.dump(result, f, indent=2) + json.dump(result, f, indent=2, cls=NpEncoder) self.exp_count += 1 logger.info(f"Finished {tc_name}, graph {graph.name}") diff --git a/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py b/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py index 119c866c..37d20c46 100644 --- a/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py +++ b/openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py @@ -21,6 +21,7 @@ benchmark_forward, benchmark_backward, benchmark_double_backward, + NpEncoder, ) logger = getLogger() @@ -235,10 +236,12 @@ def run( fname = pathlib.Path(f"{output_folder}/{test_ID}_{impl.name()}.json") - pretty_result = json.dumps(obj=result, indent=2).replace("\\n", "\n") + pretty_result = json.dumps(obj=result, indent=2, cls=NpEncoder).replace( + "\\n", "\n" + ) logger.debug(pretty_result) with open(fname, "w") as f: - json.dump(result, f, indent=2) + json.dump(result, f, indent=2, cls=NpEncoder) self.results.append(result) logger.info(f"Finished Test ID: {test_ID}") diff --git a/openequivariance/openequivariance/benchmark/benchmark_utils.py b/openequivariance/openequivariance/benchmark/benchmark_utils.py index 377df3d6..68dc6f9f 100644 --- a/openequivariance/openequivariance/benchmark/benchmark_utils.py +++ b/openequivariance/openequivariance/benchmark/benchmark_utils.py @@ -1,3 +1,4 @@ +import json import numpy as np from openequivariance.benchmark.random_buffer_utils import ( @@ -290,3 +291,14 @@ def benchmark_double_backward( ) return result + + +class NpEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, np.integer): + return int(obj) + if isinstance(obj, np.floating): + return float(obj) + if isinstance(obj, np.ndarray): + return obj.tolist() + return super(NpEncoder, self).default(obj) diff --git a/openequivariance/openequivariance/core/LoopUnrollConv.py b/openequivariance/openequivariance/core/LoopUnrollConv.py index 35a9bc3e..ca8b4bdd 100644 --- a/openequivariance/openequivariance/core/LoopUnrollConv.py +++ b/openequivariance/openequivariance/core/LoopUnrollConv.py @@ -1,4 +1,5 @@ import numpy as np +import json from openequivariance.core.ConvolutionBase import ConvolutionBase from openequivariance.core.ComputationSchedule import ( @@ -6,9 +7,12 @@ SMEMCapacityException, ) -from openequivariance.core.utils import dtype_to_enum from openequivariance.templates.jinja_utils import get_jinja_environment -from openequivariance.core.utils import filter_and_analyze_problem +from openequivariance.core.utils import ( + filter_and_analyze_problem, + dtype_to_enum, + hash_str_64, +) class LoopUnrollConv(ConvolutionBase): @@ -203,5 +207,15 @@ def generate_double_backward_schedule(warps_per_block): ) self.jit_kernel = postprocess_kernel(self.jit_kernel) - # with open("scratch.txt", "w") as f: - # f.write(self.jit_kernel) + self.kernel_string = json.dumps( + { + "kernel": self.jit_kernel, + "forward_config": vars(self.forward_schedule.launch_config), + "backward_config": vars(self.backward_schedule.launch_config), + "double_backward_config": vars( + self.double_backward_schedule.launch_config + ), + "kernel_prop": self.kernel_prop, + } + ) + self.hash = hash_str_64(self.kernel_string) diff --git a/openequivariance/openequivariance/core/LoopUnrollTP.py b/openequivariance/openequivariance/core/LoopUnrollTP.py index 12ad4536..41354e5f 100644 --- a/openequivariance/openequivariance/core/LoopUnrollTP.py +++ b/openequivariance/openequivariance/core/LoopUnrollTP.py @@ -1,15 +1,19 @@ import numpy as np +import json from openequivariance.templates.jinja_utils import get_jinja_environment from openequivariance.core.ComputationSchedule import ComputationSchedule from openequivariance.core.TensorProductBase import TensorProductBase -from openequivariance.core.utils import dtype_to_enum +from openequivariance.benchmark.logging_utils import getLogger +from openequivariance.core.utils import dtype_to_enum, hash_str_64 from openequivariance.core.utils import ( filter_and_analyze_problem, count_cg_non_zero, ) +logger = getLogger() + class LoopUnrollTP(TensorProductBase): def __init__(self, config, dp, postprocess_kernel, torch_op): @@ -91,7 +95,7 @@ def generate_double_backward_schedule(warps_per_block): ) ) - self.kernelProp = { + self.kernel_prop = { "L1_dim": self.L1.dim, "L2_dim": self.L2.dim, "L3_dim": self.L3.dim, @@ -106,6 +110,20 @@ def generate_double_backward_schedule(warps_per_block): "idx_dtype": 0, } + self.kernel_string = json.dumps( + { + "kernel": self.jit_kernel, + "forward_config": vars(self.forward_schedule.launch_config), + "backward_config": vars(self.backward_schedule.launch_config), + "double_backward_config": vars( + self.double_backward_schedule.launch_config + ), + "kernel_prop": self.kernel_prop, + } + ) + self.hash = hash_str_64(self.kernel_string) + logger.info(f"Kernel File Size: {len(self.jit_kernel) // 1024} KB") + def calculate_flops_forward(self, batch_size: int) -> dict: if self.is_uvw: return super().calculate_flops_forward(batch_size) diff --git a/openequivariance/openequivariance/core/utils.py b/openequivariance/openequivariance/core/utils.py index f3aa466c..1950013d 100644 --- a/openequivariance/openequivariance/core/utils.py +++ b/openequivariance/openequivariance/core/utils.py @@ -7,9 +7,9 @@ import json import tempfile +import hashlib from enum import IntEnum -import hashlib class DTypeEnum(IntEnum): diff --git a/openequivariance/openequivariance/extension/convolution.hpp b/openequivariance/openequivariance/extension/convolution.hpp index 92aa6880..83ad58b4 100644 --- a/openequivariance/openequivariance/extension/convolution.hpp +++ b/openequivariance/openequivariance/extension/convolution.hpp @@ -176,88 +176,4 @@ class __attribute__ ((visibility ("default"))) JITConvImpl { } ~JITConvImpl() = default; - - // Integer pointer versions of the functions above - - void exec_conv_rawptrs( - uint64_t L1_in, - uint64_t L2_in, - uint64_t weights, - uint64_t L3_out, - uint64_t rows, - uint64_t cols, - uint64_t nnz, - uint64_t node_count, - uint64_t workspace) { - - exec_conv( - reinterpret_cast(L1_in), - reinterpret_cast(L2_in), - reinterpret_cast(weights), - reinterpret_cast(L3_out), - reinterpret_cast(rows), - reinterpret_cast(cols), - nnz, - node_count, - reinterpret_cast(workspace), - 0 // Default Stream - ); - } - - void backward_rawptrs( - uint64_t L1_in, uint64_t L1_grad, - uint64_t L2_in, uint64_t L2_grad, - uint64_t weight, uint64_t weight_grad, - uint64_t L3_grad, - uint64_t rows, uint64_t cols, - uint64_t nnz, uint64_t node_count, - uint64_t workspace, uint64_t inverse_perm) { - - backward( - reinterpret_cast(L1_in), - reinterpret_cast(L1_grad), - reinterpret_cast(L2_in), - reinterpret_cast(L2_grad), - reinterpret_cast(weight), - reinterpret_cast(weight_grad), - reinterpret_cast(L3_grad), - reinterpret_cast(rows), - reinterpret_cast(cols), - nnz, - node_count, - reinterpret_cast(workspace), - reinterpret_cast(inverse_perm), - 0 // Default Stream - ); - } - - void double_backward_rawptrs( - uint64_t L1_in, uint64_t L2_in, uint64_t W, uint64_t L3_grad, - uint64_t L1_dgrad, uint64_t L2_dgrad, uint64_t w_dgrad, - uint64_t L1_grad, uint64_t L2_grad, uint64_t W_grad, uint64_t L3_dgrad, - uint64_t rows, uint64_t cols, - uint64_t nnz, uint64_t node_count, - uint64_t wspace, uint64_t transpose_perm) { - - double_backward( - reinterpret_cast(L1_in), - reinterpret_cast(L2_in), - reinterpret_cast(W), - reinterpret_cast(L3_grad), - reinterpret_cast(L1_dgrad), - reinterpret_cast(L2_dgrad), - reinterpret_cast(w_dgrad), - reinterpret_cast(L1_grad), - reinterpret_cast(L2_grad), - reinterpret_cast(W_grad), - reinterpret_cast(L3_dgrad), - reinterpret_cast(rows), - reinterpret_cast(cols), - nnz, - node_count, - reinterpret_cast(wspace), - reinterpret_cast(transpose_perm), - 0 - ); - } }; \ No newline at end of file diff --git a/openequivariance/openequivariance/extension/generic_module.cpp b/openequivariance/openequivariance/extension/generic_module.cpp index fc94eec9..b0996991 100644 --- a/openequivariance/openequivariance/extension/generic_module.cpp +++ b/openequivariance/openequivariance/extension/generic_module.cpp @@ -24,7 +24,6 @@ using GroupMM = GroupMMHIP; #endif -#include "buffer.hpp" #include "tensorproducts.hpp" #include "convolution.hpp" @@ -32,26 +31,6 @@ using namespace std; namespace py=pybind11; PYBIND11_MODULE(generic_module, m) { - //=========== Batch tensor products ========= - py::class_>(m, "JITTPImpl") - .def(py::init< std::string, - std::unordered_map, - std::unordered_map, - std::unordered_map, - std::unordered_map>()) - .def("exec_tensor_product_rawptr", &JITTPImpl::exec_tensor_product_device_rawptrs) - .def("backward_rawptr", &JITTPImpl::backward_device_rawptrs); - - py::class_>(m, "JITConvImpl") - .def(py::init< std::string, - std::unordered_map, - std::unordered_map, - std::unordered_map, - std::unordered_map>()) - .def("exec_conv_rawptrs", &JITConvImpl::exec_conv_rawptrs) - .def("backward_rawptrs", &JITConvImpl::backward_rawptrs) - .def("double_backward_rawptrs", &JITConvImpl::double_backward_rawptrs); - py::class_>(m, "GroupMM_F32") .def(py::init()) .def("group_gemm", &GroupMM::group_gemm_intptr); @@ -68,12 +47,6 @@ PYBIND11_MODULE(generic_module, m) { .def_readonly("multiprocessorCount", &DeviceProp::multiprocessorCount) .def_readonly("maxSharedMemPerBlock", &DeviceProp::maxSharedMemPerBlock); - py::class_>(m, "DeviceBuffer") - .def(py::init()) - .def(py::init()) - .def("copy_to_host", &PyDeviceBuffer::copy_to_host) - .def("data_ptr", &PyDeviceBuffer::data_ptr); - py::class_(m, "GPUTimer") .def(py::init<>()) .def("start", &GPUTimer::start) diff --git a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp index 18b8f65c..6216909f 100644 --- a/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp +++ b/openequivariance/openequivariance/extension/libtorch_tp_jit.cpp @@ -3,9 +3,11 @@ #include #include #include +#include +#include -#include -#include +#include "json11/json11.hpp" +#include #ifdef CUDA_BACKEND #include @@ -37,12 +39,11 @@ } #endif -#include "buffer.hpp" #include "tensorproducts.hpp" #include "convolution.hpp" using namespace std; -namespace py=pybind11; +using json = json11::Json; #include #include @@ -50,28 +51,13 @@ namespace py=pybind11; #include #include -using Map_t=torch::Dict; - -std::unordered_map to_map(const Map_t &map) { - std::unordered_map result; - for(auto it = map.begin(); it != map.end(); ++it) { - result[it->key()] = it->value(); - } - return result; -} - torch::Dtype enum_to_torch_dtype(int64_t i){ switch(i) { - case 1: - return torch::kFloat; - case 2: - return torch::kDouble; - case 3: - return torch::kInt; - case 4: - return torch::kLong; - case 5: - return torch::kUInt8; + case 1: return torch::kFloat; + case 2: return torch::kDouble; + case 3: return torch::kInt; + case 4: return torch::kLong; + case 5: return torch::kUInt8; } throw logic_error("Unsupported tensor datatype!"); } @@ -96,11 +82,21 @@ inline void* data_ptr(const torch::Tensor &tensor) { else if(tensor.dtype() == torch::kLong) return reinterpret_cast(tensor.data_ptr()); else if(tensor.dtype() == torch::kByte) - return reinterpret_cast(tensor.data_ptr()); + return reinterpret_cast(tensor.data_ptr()); // Replaces kUInt8 + else if(tensor.dtype() == torch::kInt) + return reinterpret_cast(tensor.data_ptr()); else throw logic_error("Unsupported tensor datatype!"); } +std::unordered_map parse_json_config(const json &j_obj) { + std::unordered_map result; + for (const auto &kv : j_obj.object_items()) { + result[kv.first] = static_cast(kv.second.number_value()); + } + return result; +} + struct KernelProp { int64_t L1_dim, L2_dim, L3_dim, weight_numel; bool shared_weights; @@ -112,7 +108,15 @@ struct KernelProp { torch::Dtype idx_dtype; torch::Dtype workspace_dtype; - KernelProp(Map_t &kernel_dims, bool is_convolution): + KernelProp() : + L1_dim(0), L2_dim(0), L3_dim(0), weight_numel(0), + shared_weights(false), + irrep_dtype(torch::kFloat), weight_dtype(torch::kFloat), + workspace_size(0), deterministic(false), + idx_dtype(torch::kInt), workspace_dtype(torch::kByte) {} + + KernelProp( + std::unordered_map &kernel_dims, bool is_convolution): L1_dim(kernel_dims.at("L1_dim")), L2_dim(kernel_dims.at("L2_dim")), L3_dim(kernel_dims.at("L3_dim")), @@ -126,81 +130,114 @@ struct KernelProp { deterministic = kernel_dims.at("deterministic"); idx_dtype = enum_to_torch_dtype(kernel_dims.at("idx_dtype")); } - } + } }; -class __attribute__ ((visibility ("default"))) TorchJITProduct : public torch::CustomClassHolder { -public: - Map_t fwd_dict, bwd_dict, dbl_bwd_dict, kernel_dims; - JITTPImpl internal; - KernelProp kernelProp; - int64_t L3_dim, irrep_dtype; - - TorchJITProduct(string kernel_plaintext, Map_t fwd_dict_i, Map_t bwd_dict_i, Map_t dbl_bwd_dict_i, Map_t kernel_dims_i) : - fwd_dict(fwd_dict_i.copy()), - bwd_dict(bwd_dict_i.copy()), - dbl_bwd_dict(dbl_bwd_dict_i.copy()), - kernel_dims(kernel_dims_i.copy()), - internal(kernel_plaintext, - to_map(fwd_dict_i), - to_map(bwd_dict_i), - to_map(dbl_bwd_dict_i), - to_map(kernel_dims_i) - ), - kernelProp(kernel_dims, false), - L3_dim(kernelProp.L3_dim), - irrep_dtype(kernel_dims_i.at("irrep_dtype")) - { } - - tuple< tuple, - tuple, - tuple, - tuple, - tuple> __obj_flatten__() { - return tuple(tuple("kernel_plaintext", internal.jit.kernel_plaintext), - tuple("fwd_config", fwd_dict), - tuple("bwd_config", bwd_dict), - tuple("dbl_bwd_config", dbl_bwd_dict), - tuple("kernel_dims", kernel_dims)); +std::unordered_map>, + KernelProp + >> tp_cache; + +std::unordered_map>, + KernelProp + >> conv_cache; + +std::mutex mut; + +std::pair*, KernelProp> + compile_tp_with_caching(const torch::Tensor &json_bytes, + int64_t hash) { + { + const std::lock_guard lock(mut); + auto it = tp_cache.find(hash); + if (it == tp_cache.end()) { + torch::Tensor cpu_tensor = json_bytes.to(torch::kCPU).contiguous(); + std::string json_payload( + reinterpret_cast(cpu_tensor.data_ptr()), + cpu_tensor.numel() + ); + + std::string err; + json root = json::parse(json_payload, err); + if (!err.empty()) throw std::runtime_error("JSON Parse Error: " + err); + + std::string kernel_src = root["kernel"].string_value(); + auto forward_cfg = parse_json_config(root["forward_config"]); + auto backward_cfg = parse_json_config(root["backward_config"]); + auto dbackward_cfg = parse_json_config(root["double_backward_config"]); + auto kernel_prop_map = parse_json_config(root["kernel_prop"]); + + auto jit_tp_impl = std::make_unique>( + kernel_src, + forward_cfg, + backward_cfg, + dbackward_cfg, + kernel_prop_map); + + tp_cache.insert({hash, + std::make_pair(std::move(jit_tp_impl), + KernelProp(kernel_prop_map, false))}); + it = tp_cache.find(hash); + } + return {it->second.first.get(), it->second.second}; } +} - void exec_tensor_product_device_rawptrs(int64_t num_batch, int64_t L1_in, int64_t L2_in, int64_t L3_out, int64_t weights) { - Stream stream = get_current_stream(); - internal.exec_tensor_product( - num_batch, - reinterpret_cast(L1_in), - reinterpret_cast(L2_in), - reinterpret_cast(L3_out), - reinterpret_cast(weights), - stream - ); - } - - void backward_device_rawptrs(int64_t num_batch, - int64_t L1_in, int64_t L1_grad, - int64_t L2_in, int64_t L2_grad, - int64_t weight, int64_t weight_grad, - int64_t L3_grad) { - Stream stream = get_current_stream(); - internal.backward(num_batch, - reinterpret_cast(L1_in), reinterpret_cast(L1_grad), - reinterpret_cast(L2_in), reinterpret_cast(L2_grad), - reinterpret_cast(weight), reinterpret_cast(weight_grad), - reinterpret_cast(L3_grad), stream - ); +std::pair*, KernelProp> + compile_conv_with_caching(const torch::Tensor &json_bytes, + int64_t hash) { + { + const std::lock_guard lock(mut); + auto it = conv_cache.find(hash); + if (it == conv_cache.end()) { + torch::Tensor cpu_tensor = json_bytes.to(torch::kCPU).contiguous(); + std::string json_payload( + reinterpret_cast(cpu_tensor.data_ptr()), + cpu_tensor.numel() + ); + + std::string err; + json root = json::parse(json_payload, err); + if (!err.empty()) throw std::runtime_error("JSON Parse Error: " + err); + + std::string kernel_src = root["kernel"].string_value(); + auto forward_cfg = parse_json_config(root["forward_config"]); + auto backward_cfg = parse_json_config(root["backward_config"]); + auto dbackward_cfg = parse_json_config(root["double_backward_config"]); + auto kernel_prop_map = parse_json_config(root["kernel_prop"]); + + auto jit_conv_impl = std::make_unique>( + kernel_src, + forward_cfg, + backward_cfg, + dbackward_cfg, + kernel_prop_map); + + conv_cache.insert({hash, + std::make_pair(std::move(jit_conv_impl), + KernelProp(kernel_prop_map, true))}); + it = conv_cache.find(hash); + } + return {it->second.first.get(), it->second.second}; } -}; +} + +// --------------------- Tensor Products -------------------------- torch::Tensor jit_tp_forward( - const c10::intrusive_ptr &jit_instance, - const torch::Tensor &L1_in, - const torch::Tensor &L2_in, - const torch::Tensor &W) { + torch::Tensor json_bytes, int64_t hash, + torch::Tensor L1_in, + torch::Tensor L2_in, + torch::Tensor W, + int64_t L3_dim) { + auto [jit_kernel, k] = compile_tp_with_caching(json_bytes, hash); Stream stream = get_current_stream(); const int64_t num_batch = L1_in.size(0); - const KernelProp &k = jit_instance->kernelProp; check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); @@ -216,7 +253,7 @@ torch::Tensor jit_tp_forward( at::Tensor L2_contig = L2_in.contiguous(); at::Tensor W_contig = W.contiguous(); - jit_instance->internal.exec_tensor_product( + jit_kernel->exec_tensor_product( num_batch, data_ptr(L1_contig), data_ptr(L2_contig), @@ -229,16 +266,16 @@ torch::Tensor jit_tp_forward( } tuple jit_tp_backward( - const c10::intrusive_ptr &jit_instance, - const torch::Tensor &L1_in, - const torch::Tensor &L2_in, - const torch::Tensor &W, - const torch::Tensor &L3_grad) { + torch::Tensor json_bytes, int64_t hash, + torch::Tensor L1_in, + torch::Tensor L2_in, + torch::Tensor W, + torch::Tensor L3_grad) { + auto [jit_kernel, k] = compile_tp_with_caching(json_bytes, hash); Stream stream = get_current_stream(); const int64_t num_batch = L1_in.size(0); - const KernelProp &k = jit_instance->kernelProp; check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); @@ -261,7 +298,7 @@ tuple jit_tp_backward( torch::Tensor W_contig = W.contiguous(); torch::Tensor L3_grad_contig = L3_grad.contiguous(); - jit_instance->internal.backward( + jit_kernel->backward( num_batch, data_ptr(L1_in_contig), data_ptr(L1_grad), data_ptr(L2_in_contig), data_ptr(L2_grad), @@ -274,19 +311,19 @@ tuple jit_tp_backward( } tuple jit_tp_double_backward( - const c10::intrusive_ptr &jit_instance, - const torch::Tensor &L1_in, - const torch::Tensor &L2_in, - const torch::Tensor &W, - const torch::Tensor &L3_grad, - const torch::Tensor &L1_dgrad, - const torch::Tensor &L2_dgrad, - const torch::Tensor &W_dgrad) { + torch::Tensor json_bytes, int64_t hash, + torch::Tensor L1_in, + torch::Tensor L2_in, + torch::Tensor W, + torch::Tensor L3_grad, + torch::Tensor L1_dgrad, + torch::Tensor L2_dgrad, + torch::Tensor W_dgrad) { + auto [jit_kernel, k] = compile_tp_with_caching(json_bytes, hash); Stream stream = get_current_stream(); const int64_t num_batch = L1_in.size(0); - const KernelProp &k = jit_instance->kernelProp; check_tensor(L1_in, {num_batch, k.L1_dim}, k.irrep_dtype, "L1_in"); check_tensor(L2_in, {num_batch, k.L2_dim}, k.irrep_dtype, "L2_in"); @@ -321,7 +358,7 @@ tuple jit_tp_double_ TORCH_CHECK(W.dim() == 1); } - jit_instance->internal.double_backward( + jit_kernel->double_backward( num_batch, data_ptr(L1_in_contig), data_ptr(L2_in_contig), data_ptr(W_contig), data_ptr(L3_grad_contig), @@ -336,127 +373,24 @@ tuple jit_tp_double_ } -// =========================================================== - -class TorchJITConv : public torch::CustomClassHolder { -public: - Map_t fwd_dict, bwd_dict, dbl_bwd_dict, kernel_dims; - JITConvImpl internal; - KernelProp kernelProp; - int64_t L3_dim, irrep_dtype; - - TorchJITConv(string kernel_plaintext, Map_t fwd_dict_i, Map_t bwd_dict_i, Map_t dbl_bwd_dict_i, Map_t kernel_dims_i) : - fwd_dict(fwd_dict_i.copy()), - bwd_dict(bwd_dict_i.copy()), - dbl_bwd_dict(bwd_dict_i.copy()), - kernel_dims(kernel_dims_i.copy()), - internal(kernel_plaintext, - to_map(fwd_dict_i), - to_map(bwd_dict_i), - to_map(dbl_bwd_dict_i), - to_map(kernel_dims_i) - ), - kernelProp(kernel_dims, true), - L3_dim(kernelProp.L3_dim), - irrep_dtype(kernel_dims_i.at("irrep_dtype")) - { } - - tuple, - tuple, - tuple, - tuple, - tuple> __obj_flatten__() { - return tuple(tuple("kernel_plaintext", internal.jit.kernel_plaintext), - tuple("fwd_config", fwd_dict), - tuple("bwd_config", bwd_dict), - tuple("dbl_bwd_config", dbl_bwd_dict), - tuple("kernel_dims", kernel_dims)); - } - - void exec_conv_rawptrs( - int64_t L1_in, int64_t L2_in, int64_t weights, int64_t L3_out, - int64_t rows, int64_t cols, - int64_t nnz, int64_t node_count, - int64_t workspace) { - Stream stream = get_current_stream(); - internal.exec_conv( - reinterpret_cast(L1_in), - reinterpret_cast(L2_in), - reinterpret_cast(weights), - reinterpret_cast(L3_out), - reinterpret_cast(rows), - reinterpret_cast(cols), - nnz, node_count, - reinterpret_cast(workspace), - stream); - } - void backward_rawptrs( - int64_t L1_in, int64_t L1_grad, - int64_t L2_in, int64_t L2_grad, - int64_t weight, int64_t weight_grad, - int64_t L3_grad, - int64_t rows, int64_t cols, - int64_t nnz, int64_t node_count, - int64_t workspace, - int64_t transpose_perm) { - Stream stream = get_current_stream(); - internal.backward( - reinterpret_cast(L1_in), reinterpret_cast(L1_grad), - reinterpret_cast(L2_in), reinterpret_cast(L2_grad), - reinterpret_cast(weight), reinterpret_cast(weight_grad), - reinterpret_cast(L3_grad), - reinterpret_cast(rows), - reinterpret_cast(cols), - nnz, node_count, - reinterpret_cast(workspace), - reinterpret_cast(transpose_perm), - stream); - } - void double_backward_rawptrs( - int64_t L1_in, int64_t L2_in, int64_t W, int64_t L3_grad, - int64_t L1_dgrad, int64_t L2_dgrad, int64_t w_dgrad, - int64_t L1_grad, int64_t L2_grad, int64_t W_grad, int64_t L3_dgrad, - int64_t rows, int64_t cols, - int64_t nnz, int64_t node_count, - int64_t wspace, int64_t transpose_perm) { - - Stream stream = get_current_stream(); - internal.double_backward( - reinterpret_cast(L1_in), - reinterpret_cast(L2_in), - reinterpret_cast(W), - reinterpret_cast(L3_grad), - reinterpret_cast(L1_dgrad), - reinterpret_cast(L2_dgrad), - reinterpret_cast(w_dgrad), - reinterpret_cast(L1_grad), - reinterpret_cast(L2_grad), - reinterpret_cast(W_grad), - reinterpret_cast(L3_dgrad), - reinterpret_cast(rows), - reinterpret_cast(cols), - nnz, node_count, - reinterpret_cast(wspace), - reinterpret_cast(transpose_perm), - stream); - } -}; +// ========================= Convolution ================================== torch::Tensor jit_conv_forward( - const c10::intrusive_ptr &jit_instance, - const torch::Tensor &L1_in, - const torch::Tensor &L2_in, - const torch::Tensor &W, - const torch::Tensor &rows, - const torch::Tensor &cols, - const torch::Tensor &workspace, - const torch::Tensor &transpose_perm) { - + torch::Tensor json_bytes, int64_t hash, + torch::Tensor L1_in, + torch::Tensor L2_in, + torch::Tensor W, + int64_t L3_dim, + torch::Tensor rows, + torch::Tensor cols, + torch::Tensor workspace, + torch::Tensor transpose_perm) { + + auto [jit_kernel, k] = compile_conv_with_caching(json_bytes, hash); Stream stream = get_current_stream(); const int64_t nnz = rows.size(0); const int64_t node_count = L1_in.size(0); - const KernelProp &k = jit_instance->kernelProp; check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in"); check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in"); @@ -483,7 +417,7 @@ torch::Tensor jit_conv_forward( torch::Tensor cols_contig = cols.contiguous(); torch::Tensor workspace_contig = workspace.contiguous(); - jit_instance->internal.exec_conv( + jit_kernel->exec_conv( data_ptr(L1_contig), data_ptr(L2_contig), data_ptr(W_contig), @@ -498,21 +432,21 @@ torch::Tensor jit_conv_forward( } tuple jit_conv_backward( - const c10::intrusive_ptr &jit_instance, - const torch::Tensor &L1_in, - const torch::Tensor &L2_in, - const torch::Tensor &W, - const torch::Tensor &L3_grad, - const torch::Tensor &rows, - const torch::Tensor &cols, - const torch::Tensor &workspace, - const torch::Tensor &transpose_perm) { + torch::Tensor json_bytes, int64_t hash, + torch::Tensor L1_in, + torch::Tensor L2_in, + torch::Tensor W, + torch::Tensor L3_grad, + torch::Tensor rows, + torch::Tensor cols, + torch::Tensor workspace, + torch::Tensor transpose_perm) { + auto [jit_kernel, k] = compile_conv_with_caching(json_bytes, hash); Stream stream = get_current_stream(); const int64_t nnz = rows.size(0); const int64_t node_count = L1_in.size(0); - const KernelProp &k = jit_instance->kernelProp; check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in"); check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in"); @@ -549,7 +483,7 @@ tuple jit_conv_backward( if(k.shared_weights) W_grad.zero_(); - jit_instance->internal.backward( + jit_kernel->backward( data_ptr(L1_in_contig), data_ptr(L1_grad), data_ptr(L2_in_contig), data_ptr(L2_grad), data_ptr(W_contig), data_ptr(W_grad), @@ -564,24 +498,24 @@ tuple jit_conv_backward( } tuple jit_conv_double_backward( - const c10::intrusive_ptr &jit_instance, - const torch::Tensor &L1_in, - const torch::Tensor &L2_in, - const torch::Tensor &W, - const torch::Tensor &L3_grad, - const torch::Tensor &L1_dgrad, - const torch::Tensor &L2_dgrad, - const torch::Tensor &W_dgrad, - const torch::Tensor &rows, - const torch::Tensor &cols, - const torch::Tensor &workspace, - const torch::Tensor &transpose_perm) { + torch::Tensor json_bytes, int64_t hash, + torch::Tensor L1_in, + torch::Tensor L2_in, + torch::Tensor W, + torch::Tensor L3_grad, + torch::Tensor L1_dgrad, + torch::Tensor L2_dgrad, + torch::Tensor W_dgrad, + torch::Tensor rows, + torch::Tensor cols, + torch::Tensor workspace, + torch::Tensor transpose_perm) { + auto [jit_kernel, k] = compile_conv_with_caching(json_bytes, hash); Stream stream = get_current_stream(); const int64_t nnz = rows.size(0); const int64_t node_count = L1_in.size(0); - const KernelProp &k = jit_instance->kernelProp; check_tensor(L1_in, {node_count, k.L1_dim}, k.irrep_dtype, "L1_in"); check_tensor(L2_in, {nnz, k.L2_dim}, k.irrep_dtype, "L2_in"); @@ -628,7 +562,7 @@ tuple jit_conv_doubl if(k.shared_weights) W_grad.zero_(); - jit_instance->internal.double_backward( + jit_kernel->double_backward( data_ptr(L1_in_contig), data_ptr(L2_in_contig), data_ptr(W_contig), data_ptr(L3_grad_contig), data_ptr(L1_dgrad_contig), data_ptr(L2_dgrad_contig), @@ -646,68 +580,6 @@ tuple jit_conv_doubl // =========================================================== -TORCH_LIBRARY_FRAGMENT(libtorch_tp_jit, m) { - m.class_("TorchJITProduct") - .def(torch::init()) - .def("__obj_flatten__", &TorchJITProduct::__obj_flatten__) - .def("exec_tensor_product_rawptr", &TorchJITProduct::exec_tensor_product_device_rawptrs) - .def("backward_rawptr", &TorchJITProduct::backward_device_rawptrs) - .def("__len__", [](const c10::intrusive_ptr& test) -> int64_t { - return 0; - }) - .def_readonly("L3_dim", &TorchJITProduct::L3_dim) - .def_readonly("irrep_dtype", &TorchJITProduct::irrep_dtype) - .def("__eq__", [](const c10::IValue & self, const c10::IValue& other) -> bool { - return self.is(other); - }) - .def_pickle( - // __getstate__ - [](const c10::intrusive_ptr& self) - -> tuple { - return tuple(self->internal.jit.kernel_plaintext, self->fwd_dict, self->bwd_dict, self->dbl_bwd_dict, self->kernel_dims); - }, - // __setstate__ - [](tuple state) - -> c10::intrusive_ptr { - return c10::make_intrusive(get<0>(state), get<1>(state), get<2>(state), get<3>(state), get<4>(state)); - }); - - m.def("jit_tp_forward(__torch__.torch.classes.libtorch_tp_jit.TorchJITProduct jit, Tensor L1_in, Tensor L2_in, Tensor W) -> Tensor"); - m.def("jit_tp_backward(__torch__.torch.classes.libtorch_tp_jit.TorchJITProduct jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad) -> (Tensor, Tensor, Tensor)"); - m.def("jit_tp_double_backward(__torch__.torch.classes.libtorch_tp_jit.TorchJITProduct jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad) -> (Tensor, Tensor, Tensor, Tensor)"); - - - m.class_("TorchJITConv") - .def(torch::init()) - .def("__obj_flatten__", &TorchJITConv::__obj_flatten__) - .def("exec_conv_rawptrs", &TorchJITConv::exec_conv_rawptrs) - .def("backward_rawptrs", &TorchJITConv::backward_rawptrs) - .def("double_backward_rawptrs", &TorchJITConv::double_backward_rawptrs) - .def("__len__", [](const c10::intrusive_ptr& test) -> int64_t { - return 0; - }) - .def_readonly("L3_dim", &TorchJITConv::L3_dim) - .def_readonly("irrep_dtype", &TorchJITConv::irrep_dtype) - .def("__eq__", [](const c10::IValue & self, const c10::IValue& other) -> bool { - return self.is(other); - }) - .def_pickle( - // __getstate__ - [](const c10::intrusive_ptr& self) - -> tuple { - return tuple(self->internal.jit.kernel_plaintext, self->fwd_dict, self->bwd_dict, self->dbl_bwd_dict, self->kernel_dims); - }, - // __setstate__ - [](tuple state) - -> c10::intrusive_ptr { - return c10::make_intrusive(get<0>(state), get<1>(state), get<2>(state), get<3>(state), get<4>(state)); - }); - - m.def("jit_conv_forward(__torch__.torch.classes.libtorch_tp_jit.TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> Tensor"); - m.def("jit_conv_backward(__torch__.torch.classes.libtorch_tp_jit.TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor)"); - m.def("jit_conv_double_backward(__torch__.torch.classes.libtorch_tp_jit.TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor, Tensor)"); -}; - TORCH_LIBRARY_IMPL(libtorch_tp_jit, CUDA, m) { m.impl("jit_tp_forward", &jit_tp_forward); m.impl("jit_tp_backward", &jit_tp_backward); @@ -718,4 +590,14 @@ TORCH_LIBRARY_IMPL(libtorch_tp_jit, CUDA, m) { m.impl("jit_conv_double_backward", &jit_conv_double_backward); }; +TORCH_LIBRARY(libtorch_tp_jit, m) { + m.def("jit_tp_forward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, int L3_dim) -> Tensor"); + m.def("jit_tp_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad) -> (Tensor, Tensor, Tensor)"); + m.def("jit_tp_double_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad) -> (Tensor, Tensor, Tensor, Tensor)"); + + m.def("jit_conv_forward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, int L3_dim, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> Tensor"); + m.def("jit_conv_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor)"); + m.def("jit_conv_double_backward(Tensor json_bytes, int hash, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor, Tensor)"); +}; + PYBIND11_MODULE(libtorch_tp_jit, m) {} \ No newline at end of file diff --git a/openequivariance/openequivariance/extension/tensorproducts.hpp b/openequivariance/openequivariance/extension/tensorproducts.hpp index ee8def66..b4b4d84b 100644 --- a/openequivariance/openequivariance/extension/tensorproducts.hpp +++ b/openequivariance/openequivariance/extension/tensorproducts.hpp @@ -109,31 +109,5 @@ class __attribute__ ((visibility ("default"))) JITTPImpl { jit.execute(3, args, with_stream(double_backward_config_ref, stream)); } - ~JITTPImpl() = default; - - // Integer pointer versions of the functions above - void exec_tensor_product_device_rawptrs(uint64_t num_products, - uint64_t L1_in, uint64_t L2_in, uint64_t L3_out, uint64_t weights) { - exec_tensor_product(num_products, - reinterpret_cast(L1_in), - reinterpret_cast(L2_in), - reinterpret_cast(L3_out), - reinterpret_cast(weights), - 0 // Default Stream - ); - } - - void backward_device_rawptrs(uint64_t num_products, - uint64_t L1_in, uint64_t L1_grad, - uint64_t L2_in, uint64_t L2_grad, - uint64_t weight, uint64_t weight_grad, - uint64_t L3_grad) { - - backward(num_products, - reinterpret_cast(L1_in), reinterpret_cast(L1_grad), - reinterpret_cast(L2_in), reinterpret_cast(L2_grad), - reinterpret_cast(weight), reinterpret_cast(weight_grad), - reinterpret_cast(L3_grad), 0 // Null = Default Stream - ); - } + ~JITTPImpl() = default; }; \ No newline at end of file diff --git a/openequivariance/openequivariance/extension/util/buffer.hpp b/openequivariance/openequivariance/extension/util/buffer.hpp deleted file mode 100644 index 95dc8319..00000000 --- a/openequivariance/openequivariance/extension/util/buffer.hpp +++ /dev/null @@ -1,45 +0,0 @@ -#pragma once -#include -#include - -using namespace std; -namespace py = pybind11; - -template -class PyDeviceBuffer { -public: - char* host_ptr; - char* device_ptr; - size_t size; - - PyDeviceBuffer(uint64_t size) { - this->size = size; - device_ptr = static_cast(ALLOC_T::gpu_alloc(size)); - host_ptr = nullptr; - } - - PyDeviceBuffer(py::buffer host_data) { - const py::buffer_info &info = host_data.request(); - host_ptr = static_cast(info.ptr); - size = 1; - for(int64_t i = 0; i < info.ndim; i++) { - size *= info.shape[i]; - } - size *= info.itemsize; - - device_ptr = static_cast(ALLOC_T::gpu_alloc(size)); - ALLOC_T::copy_host_to_device(host_ptr, device_ptr, size); - } - - ~PyDeviceBuffer() { - ALLOC_T::gpu_free(static_cast(device_ptr)); - } - - void copy_to_host() { - ALLOC_T::copy_device_to_host(host_ptr, device_ptr, size); - } - - uint64_t data_ptr() { - return reinterpret_cast(device_ptr); - } -}; \ No newline at end of file diff --git a/openequivariance/openequivariance/jax/TensorProduct.py b/openequivariance/openequivariance/jax/TensorProduct.py index a140674b..84d75e10 100644 --- a/openequivariance/openequivariance/jax/TensorProduct.py +++ b/openequivariance/openequivariance/jax/TensorProduct.py @@ -3,10 +3,8 @@ from openequivariance.jax import extlib from openequivariance.core.e3nn_lite import TPProblem from openequivariance.core.LoopUnrollTP import LoopUnrollTP -from openequivariance.core.utils import hash_str_64 from openequivariance.jax.utils import reorder_jax from openequivariance.jax.jvp.tp_prim import tp_fwd_p -import json class TensorProduct(LoopUnrollTP): @@ -20,19 +18,7 @@ def __init__(self, problem: TPProblem): dp = extlib.DeviceProp(0) super().__init__(problem, dp, extlib.postprocess_kernel, torch_op=False) - self.kernel = json.dumps( - { - "kernel": self.jit_kernel, - "forward_config": vars(self.forward_schedule.launch_config), - "backward_config": vars(self.backward_schedule.launch_config), - "double_backward_config": vars( - self.double_backward_schedule.launch_config - ), - "kernel_prop": self.kernelProp, - } - ) - self.hash = hash_str_64(self.kernel) - + self.kernel = self.kernel_string self.weight_numel = problem.weight_numel self.L3_dim = self.config.irreps_out.dim diff --git a/openequivariance/openequivariance/jax/TensorProductConv.py b/openequivariance/openequivariance/jax/TensorProductConv.py index 2be0fe7a..c14637a1 100644 --- a/openequivariance/openequivariance/jax/TensorProductConv.py +++ b/openequivariance/openequivariance/jax/TensorProductConv.py @@ -1,12 +1,10 @@ import jax -import json import jax.numpy as jnp import numpy as np from typing import Optional from openequivariance.jax import extlib -from openequivariance.core.utils import hash_str_64 from openequivariance.core.e3nn_lite import TPProblem from openequivariance.core.LoopUnrollConv import LoopUnrollConv from openequivariance.jax.utils import reorder_jax @@ -51,19 +49,7 @@ def __init__( kahan=kahan, ) - self.kernel = json.dumps( - { - "kernel": self.jit_kernel, - "forward_config": vars(self.forward_schedule.launch_config), - "backward_config": vars(self.backward_schedule.launch_config), - "double_backward_config": vars( - self.double_backward_schedule.launch_config - ), - "kernel_prop": self.kernel_prop, - } - ) - self.hash = hash_str_64(self.kernel) - + self.kernel = self.kernel_string self.weight_numel = config.weight_numel self.L3_dim = self.config.irreps_out.dim diff --git a/openequivariance_extjax/CMakeLists.txt b/openequivariance_extjax/CMakeLists.txt index 91617d94..90eafe6c 100644 --- a/openequivariance_extjax/CMakeLists.txt +++ b/openequivariance_extjax/CMakeLists.txt @@ -60,7 +60,6 @@ set(OEQ_JAX_HEADERS ${HEADER_DIR}/tensorproducts.hpp ${HEADER_DIR}/util/backend_cuda.hpp ${HEADER_DIR}/util/backend_hip.hpp - ${HEADER_DIR}/util/buffer.hpp ${HEADER_DIR}/json11/json11.hpp ) diff --git a/openequivariance_extjax/pyproject.toml b/openequivariance_extjax/pyproject.toml index def67e12..74f9627d 100644 --- a/openequivariance_extjax/pyproject.toml +++ b/openequivariance_extjax/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build" [project] name = "openequivariance_extjax" -version = "0.2.0" +version = "0.2.1" authors = [ { name="Austin Glover" }, { name="Vivek Bharadwaj" }, diff --git a/tests/batch_test.py b/tests/batch_test.py index f32f7b51..788950ab 100644 --- a/tests/batch_test.py +++ b/tests/batch_test.py @@ -253,14 +253,6 @@ def problem(self, request, dtype): return problem -class TestTorchbindDisable(TestProductionModels): - @pytest.fixture(scope="class") - def extra_tp_constructor_args(self, with_jax): - if with_jax: - pytest.skip("N/A for JAX") - return {"use_opaque": True} - - class TestTorchTo(TPCorrectness): problems = [mace_problems()[0]] diff --git a/tests/export_test.py b/tests/export_test.py index 0fd23b2b..efdaf865 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -1,17 +1,10 @@ -import shutil import torch import pytest import tempfile -import subprocess -import os -import sys import numpy as np import openequivariance as oeq from torch_geometric import EdgeIndex -import importlib.resources - -from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct @pytest.fixture(scope="session") @@ -98,20 +91,6 @@ def test_torch_load(tp_and_inputs): assert torch.allclose(original_result, reloaded_result, atol=1e-5) -def test_jitscript(tp_and_inputs): - tp, inputs = tp_and_inputs - uncompiled_result = tp.forward(*inputs) - - scripted_tp = torch.jit.script(tp) - loaded_tp = None - with tempfile.NamedTemporaryFile(suffix=".pt") as tmp_file: - scripted_tp.save(tmp_file.name) - loaded_tp = torch.jit.load(tmp_file.name) - - compiled_result = loaded_tp(*inputs) - assert torch.allclose(uncompiled_result, compiled_result, atol=1e-5) - - def test_compile(tp_and_inputs): tp, inputs = tp_and_inputs uncompiled_result = tp.forward(*inputs) @@ -142,88 +121,10 @@ def test_aoti(tp_and_inputs): exported_tp, package_path=tmp_file.name ) except Exception as e: - err_msg = ( - "AOTI compile_and_package failed. NOTE: OpenEquivariance only supports AOTI for " - + "PyTorch version >= 2.8.0.dev20250410+cu126 due to incomplete TorchBind support " - + "in prior versions. " - + f"{e}" - ) + err_msg = f"AOTI compile_and_package failed. Error: {e}" assert False, err_msg aoti_model = torch._inductor.aoti_load_package(output_path) aoti_result = aoti_model(*inputs) assert torch.allclose(uncompiled_result, aoti_result, atol=1e-5) - - -def test_jitscript_cpp_interface(problem_and_irreps): - assert oeq.LINKED_LIBPYTHON, oeq.LINKED_LIBPYTHON_ERROR - problem, X_ir, Y_ir, _ = problem_and_irreps - cmake_prefix_path = torch.utils.cmake_prefix_path - torch_ext_so_path = oeq.torch_ext_so_path() - - oeq_tp = oeq.TensorProduct(problem).to("cuda") - scripted_oeq = torch.jit.script(oeq_tp) - - e3nn_tp = E3NNTensorProduct(problem).e3nn_tp.to("cuda") - scripted_e3nn = torch.jit.script(e3nn_tp) - - batch_size = 1000 - - with ( - tempfile.TemporaryDirectory() as tmpdir, - tempfile.NamedTemporaryFile(suffix=".pt") as oeq_file, - tempfile.NamedTemporaryFile(suffix=".pt") as e3nn_file, - ): - scripted_oeq.save(oeq_file.name) - scripted_e3nn.save(e3nn_file.name) - - test_path = importlib.resources.files("openequivariance") / "extension" / "test" - build_dir = os.path.join(tmpdir, "build") - os.makedirs(build_dir, exist_ok=True) - - for item in test_path.iterdir(): - shutil.copy(item, tmpdir) - - try: - subprocess.run( - [ - "cmake", - "..", - "-DCMAKE_BUILD_TYPE=Release", - "-DCMAKE_PREFIX_PATH=" + cmake_prefix_path, - "-DOEQ_EXTLIB=" + torch_ext_so_path, - ], - cwd=build_dir, - check=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - - subprocess.run( - ["make"], - cwd=build_dir, - check=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - - subprocess.run( - [ - "./load_jitscript", - e3nn_file.name, - oeq_file.name, - str(X_ir.dim), - str(Y_ir.dim), - str(problem.weight_numel), - str(batch_size), - ], - cwd=build_dir, - check=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - except subprocess.CalledProcessError as e: - print(e.stdout.decode(), file=sys.stderr) - print(e.stderr.decode(), file=sys.stderr) - assert False