From 2d80dcfa110d50ce62adb0c9c2a4130328e54bdf Mon Sep 17 00:00:00 2001 From: Ameya Daigavane Date: Fri, 5 Sep 2025 13:28:36 -0400 Subject: [PATCH 1/8] Add fused conv. --- src/e3tools/nn/__init__.py | 2 +- src/e3tools/nn/_conv.py | 126 ++++++++++++++++++++++++++++++++++++- tests/test_fused_conv.py | 47 ++++++++++++++ 3 files changed, 172 insertions(+), 3 deletions(-) create mode 100644 tests/test_fused_conv.py diff --git a/src/e3tools/nn/__init__.py b/src/e3tools/nn/__init__.py index 5612fd5..679300c 100644 --- a/src/e3tools/nn/__init__.py +++ b/src/e3tools/nn/__init__.py @@ -1,4 +1,4 @@ -from ._conv import Conv, ConvBlock, ExperimentalConv, SeparableConv, SeparableConvBlock +from ._conv import Conv, ConvBlock, ExperimentalConv, SeparableConv, SeparableConvBlock, FusedConv from ._linear import Linear from ._gate import Gate, Gated, GateWrapper from ._interaction import LinearSelfInteraction diff --git a/src/e3tools/nn/_conv.py b/src/e3tools/nn/_conv.py index 08aa501..1b0a53f 100644 --- a/src/e3tools/nn/_conv.py +++ b/src/e3tools/nn/_conv.py @@ -13,6 +13,128 @@ from ._mlp import ScalarMLP from ._tensor_product import ExperimentalTensorProduct, SeparableTensorProduct +try: + from openequivariance import ( + TensorProductConv, + TPProblem, + ) + openequivariance_available = True +except ImportError as e: + error_msg = str(e) + openequivariance_available = False + + +class FusedConv(nn.Module): + """ + Fused version of equivariant convolution layer with OpenEquivariance kernels. + + ref: https://arxiv.org/abs/1802.08219 + ref: https://arxiv.org/abs/2501.13986 + """ + + def __init__( + self, + irreps_in: Union[str, e3nn.o3.Irreps], + irreps_out: Union[str, e3nn.o3.Irreps], + irreps_sh: Union[str, e3nn.o3.Irreps], + edge_attr_dim: int, + radial_nn: Optional[Callable[..., nn.Module]] = None, + tensor_product: Optional[Callable[..., nn.Module]] = None, + ): + """ + Parameters + ---------- + irreps_in: e3nn.o3.Irreps + Input node feature irreps + irreps_out: e3nn.o3.Irreps + Ouput node feature irreps + irreps_sh: e3nn.o3.Irreps + Edge spherical harmonic irreps + edge_attr_dim: int + Dimension of scalar edge attributes to be passed to radial_nn + radial_nn: Optional[Callable[..., nn.Module]] + Factory function for radial nn used to generate tensor product weights. + Should be callable as radial_nn(in_features, out_features) + if `None` then + ``` + functools.partial( + e3tools.nn.ScalarMLP, + hidden_features=[edge_attr_dim], + activation_layer=nn.SiLU, + ) + ``` + is used. + tensor_product: Optional[Callable[..., nn.Module]] + Factory function for tensor product used to mix input node + representations with edge spherical harmonics. + Should be callable as `tensor_product(irreps_in, irreps_sh, irreps_out)` + and return an object with `weight_numel` property defined + If `None` then + ``` + functools.partial( + e3nn.o3.FullyConnectedTensorProduct + shared_weights=False, + internal_weights=False, + ) + ``` + is used. + """ + + super().__init__() + + self.irreps_in = e3nn.o3.Irreps(irreps_in) + self.irreps_out = e3nn.o3.Irreps(irreps_out) + self.irreps_sh = e3nn.o3.Irreps(irreps_sh) + + if tensor_product is None: + tensor_product = functools.partial( + e3nn.o3.FullyConnectedTensorProduct, + shared_weights=False, + internal_weights=False, + ) + + self.tp = tensor_product(irreps_in, irreps_sh, irreps_out) + + if not openequivariance_available: + raise ImportError(f"OpenEquivariance could not be imported:\n{error_msg}") + + tpp = TPProblem( + input_irreps, + sh_irreps, + tp_irreps, + instructions_tp, + shared_weights=False, + internal_weights=False, + ) + self.fused_tp = TensorProductConv( + tpp, torch_op=True, deterministic=False, use_opaque=False + ) + + def forward(self, node_attr: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor, edge_sh: torch.Tensor) -> torch.Tensor: + """ + Computes the forward pass of the equivariant convolution. + + Let N be the number of nodes, and E be the number of edges + + Parameters + ---------- + node_attr: [N, irreps_in.dim] + edge_index: [2, E] + edge_attr: [E, edge_attr_dim] + edge_sh: [E, irreps_sh.dim] + + Returns + ------- + out: [N, irreps_out.dim] + """ + N = node_attr.shape[0] + + src, dst = edge_index + messages_agg = self.tp_conv(node_attr, edge_sh, self.radial_nn(edge_attr), src, dst) + num_neighbors = scatter(torch.ones_like(dst), dst, dim=0, dim_size=N, reduce="sum") + out = messages_agg / num_neighbors.clamp_min(1).unsqueeze(-1) + return out + class Conv(nn.Module): """ @@ -92,10 +214,10 @@ def __init__( self.radial_nn = radial_nn(edge_attr_dim, self.tp.weight_numel) - def apply_per_edge(self, node_attr_src, edge_attr, edge_sh): + def apply_per_edge(self, node_attr_src: torch.Tensor, edge_attr: torch.Tensor, edge_sh: torch.Tensor) -> torch.Tensor: return self.tp(node_attr_src, edge_sh, self.radial_nn(edge_attr)) - def forward(self, node_attr, edge_index, edge_attr, edge_sh): + def forward(self, node_attr: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor, edge_sh: torch.Tensor) -> torch.Tensor: """ Computes the forward pass of the equivariant convolution. diff --git a/tests/test_fused_conv.py b/tests/test_fused_conv.py new file mode 100644 index 0000000..3c9915a --- /dev/null +++ b/tests/test_fused_conv.py @@ -0,0 +1,47 @@ +from typing import Tuple +import functools + +import pytest +import torch +import e3nn + +from e3tools.nn import Conv, FusedConv +from e3tools import radius_graph + +torch.set_default_dtype(torch.float64) + + +@pytest.mark.parametrize("seed", [0, 1, 2]) +def test_fused_conv(seed): + torch.manual_seed(seed) + + N = 20 + edge_attr_dim = 10 + max_radius = 1.3 + irreps_in = e3nn.o3.Irreps("10x0e + 10x1o + 10x2e") + irreps_sh = irreps_in.spherical_harmonics(2) + + layer = Conv(irreps_in, irreps_in, irreps_sh, edge_attr_dim=edge_attr_dim) + fused_layer = FusedConv(irreps_in, irreps_in, irreps_sh, edge_attr_dim=edge_attr_dim) + + pos = torch.randn(N, 3) + node_attr = layer.irreps_in.randn(N, -1) + + edge_index = radius_graph(pos, max_radius) + edge_vec = pos[edge_index[0]] - pos[edge_index[1]] + edge_length = (edge_vec).norm(dim=1) + edge_attr = e3nn.math.soft_one_hot_linspace( + edge_length, + start=0.0, + end=max_radius, + number=edge_attr_dim, + basis="smooth_finite", + cutoff=True, + ) + edge_sh = e3nn.o3.spherical_harmonics( + layer.irreps_sh, edge_vec, True, normalization="component" + ) + out = layer(node_attr, edge_index, edge_attr, edge_sh) + out_fused = fused_layer(node_attr, edge_index, edge_attr, edge_sh) + assert torch.allclose(out, out_fused, atol=1e-10) + From bef10046f65f4cb11a9d787639f9cdead1fa885d Mon Sep 17 00:00:00 2001 From: Ameya Daigavane Date: Fri, 5 Sep 2025 14:28:02 -0400 Subject: [PATCH 2/8] Add failing test for FusedConv. --- src/e3tools/nn/__init__.py | 1 + src/e3tools/nn/_conv.py | 37 +++++++++++++++++++++++-------------- tests/test_fused_conv.py | 26 +++++++++++++++++++------- 3 files changed, 43 insertions(+), 21 deletions(-) diff --git a/src/e3tools/nn/__init__.py b/src/e3tools/nn/__init__.py index 679300c..1dca385 100644 --- a/src/e3tools/nn/__init__.py +++ b/src/e3tools/nn/__init__.py @@ -19,6 +19,7 @@ "ExperimentalConv", "ExperimentalTensorProduct", "ExtractIrreps", + "FusedConv", "Gate", "GateWrapper", "Gated", diff --git a/src/e3tools/nn/_conv.py b/src/e3tools/nn/_conv.py index 1b0a53f..d523d48 100644 --- a/src/e3tools/nn/_conv.py +++ b/src/e3tools/nn/_conv.py @@ -14,10 +14,7 @@ from ._tensor_product import ExperimentalTensorProduct, SeparableTensorProduct try: - from openequivariance import ( - TensorProductConv, - TPProblem, - ) + import openequivariance as oeq openequivariance_available = True except ImportError as e: error_msg = str(e) @@ -93,21 +90,32 @@ def __init__( internal_weights=False, ) + self.tp = tensor_product(irreps_in, irreps_sh, irreps_out) + if radial_nn is None: + radial_nn = functools.partial( + ScalarMLP, + hidden_features=[edge_attr_dim], + activation_layer=nn.SiLU, + ) + + self.radial_nn = radial_nn(edge_attr_dim, self.tp.weight_numel) if not openequivariance_available: raise ImportError(f"OpenEquivariance could not be imported:\n{error_msg}") - tpp = TPProblem( - input_irreps, - sh_irreps, - tp_irreps, - instructions_tp, + # Remove path shape from instructions. + oeq_instructions = [instruction[:6] for instruction in self.tp.instructions] + oeq_tpp = oeq.TPProblem( + self.tp.irreps_in1, + self.tp.irreps_in2, + self.tp.irreps_out, + oeq_instructions, shared_weights=False, internal_weights=False, ) - self.fused_tp = TensorProductConv( - tpp, torch_op=True, deterministic=False, use_opaque=False + self.fused_tp = oeq.TensorProductConv( + oeq_tpp, torch_op=True, deterministic=False, use_opaque=False ) def forward(self, node_attr: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor, edge_sh: torch.Tensor) -> torch.Tensor: @@ -130,9 +138,10 @@ def forward(self, node_attr: torch.Tensor, edge_index: torch.Tensor, edge_attr: N = node_attr.shape[0] src, dst = edge_index - messages_agg = self.tp_conv(node_attr, edge_sh, self.radial_nn(edge_attr), src, dst) - num_neighbors = scatter(torch.ones_like(dst), dst, dim=0, dim_size=N, reduce="sum") - out = messages_agg / num_neighbors.clamp_min(1).unsqueeze(-1) + radial_attr = self.radial_nn(edge_attr) + messages_agg = self.fused_tp(node_attr, edge_sh, radial_attr, dst, src) + num_neighbors = scatter(torch.ones_like(src), src, dim=0, dim_size=N, reduce="sum") + out = messages_agg return out diff --git a/tests/test_fused_conv.py b/tests/test_fused_conv.py index 3c9915a..5a02492 100644 --- a/tests/test_fused_conv.py +++ b/tests/test_fused_conv.py @@ -8,21 +8,28 @@ from e3tools.nn import Conv, FusedConv from e3tools import radius_graph -torch.set_default_dtype(torch.float64) - @pytest.mark.parametrize("seed", [0, 1, 2]) def test_fused_conv(seed): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + torch.manual_seed(seed) + torch.set_default_device("cuda") N = 20 - edge_attr_dim = 10 - max_radius = 1.3 - irreps_in = e3nn.o3.Irreps("10x0e + 10x1o + 10x2e") + edge_attr_dim = 1 + max_radius = 100 + irreps_in = e3nn.o3.Irreps("1x0e + 1x1o + 1x2e + 1x3o") irreps_sh = irreps_in.spherical_harmonics(2) - layer = Conv(irreps_in, irreps_in, irreps_sh, edge_attr_dim=edge_attr_dim) - fused_layer = FusedConv(irreps_in, irreps_in, irreps_sh, edge_attr_dim=edge_attr_dim) + def fake_nn(edge_attr_dim, num_elements): + def fn(edge_attr): + return torch.ones((edge_attr.shape[0], num_elements), device=edge_attr.device) + return fn + + layer = Conv(irreps_in=irreps_in, irreps_out=irreps_in, irreps_sh=irreps_sh, radial_nn=fake_nn, edge_attr_dim=edge_attr_dim) + fused_layer = FusedConv(irreps_in=irreps_in, irreps_out=irreps_in, irreps_sh=irreps_sh, radial_nn=fake_nn, edge_attr_dim=edge_attr_dim) pos = torch.randn(N, 3) node_attr = layer.irreps_in.randn(N, -1) @@ -41,7 +48,12 @@ def test_fused_conv(seed): edge_sh = e3nn.o3.spherical_harmonics( layer.irreps_sh, edge_vec, True, normalization="component" ) + out = layer(node_attr, edge_index, edge_attr, edge_sh) out_fused = fused_layer(node_attr, edge_index, edge_attr, edge_sh) + print(out[0] / out_fused[0]) + print(((out[0, 0] / out_fused[0, 0]) / (out[1, 1] / out_fused[1, 1])) ** 4) + print(((out[0, 0] / out_fused[0, 0]) / (out[4, 4] / out_fused[4, 4])) ** 4) + print(((out[0, 0] / out_fused[0, 0]) / (out[9, 9] / out_fused[9, 9])) ** 4) assert torch.allclose(out, out_fused, atol=1e-10) From 71874939ed3a4680c4644d1be8e362278888aa3a Mon Sep 17 00:00:00 2001 From: Ameya Daigavane Date: Fri, 5 Sep 2025 15:26:09 -0400 Subject: [PATCH 3/8] Tests are working. --- src/e3tools/nn/__init__.py | 15 ++++++--- src/e3tools/nn/_conv.py | 48 ++++++++++++++++++++------- src/e3tools/nn/_tensor_product.py | 55 ++++++++++++++++--------------- tests/test_fused_conv.py | 42 +++++++++++++---------- 4 files changed, 100 insertions(+), 60 deletions(-) diff --git a/src/e3tools/nn/__init__.py b/src/e3tools/nn/__init__.py index 1dca385..2b4f423 100644 --- a/src/e3tools/nn/__init__.py +++ b/src/e3tools/nn/__init__.py @@ -1,11 +1,18 @@ -from ._conv import Conv, ConvBlock, ExperimentalConv, SeparableConv, SeparableConvBlock, FusedConv +from ._conv import ( + Conv, + ConvBlock, + SeparableConv, + SeparableConvBlock, + FusedConv, + FusedDepthwiseConv, +) from ._linear import Linear from ._gate import Gate, Gated, GateWrapper from ._interaction import LinearSelfInteraction from ._layer_norm import LayerNorm from ._mlp import EquivariantMLP, ScalarMLP from ._axis_to_mul import AxisToMul, MulToAxis -from ._tensor_product import ExperimentalTensorProduct, SeparableTensorProduct +from ._tensor_product import SeparableTensorProduct, DepthwiseTensorProduct from ._transformer import Attention, MultiheadAttention, TransformerBlock from ._extract_irreps import ExtractIrreps from ._scaling import ScaleIrreps @@ -15,11 +22,11 @@ "AxisToMul", "Conv", "ConvBlock", + "DepthwiseTensorProduct", "EquivariantMLP", - "ExperimentalConv", - "ExperimentalTensorProduct", "ExtractIrreps", "FusedConv", + "FusedDepthwiseConv", "Gate", "GateWrapper", "Gated", diff --git a/src/e3tools/nn/_conv.py b/src/e3tools/nn/_conv.py index d523d48..6c390fb 100644 --- a/src/e3tools/nn/_conv.py +++ b/src/e3tools/nn/_conv.py @@ -11,10 +11,11 @@ from ._gate import Gated from ._interaction import LinearSelfInteraction from ._mlp import ScalarMLP -from ._tensor_product import ExperimentalTensorProduct, SeparableTensorProduct +from ._tensor_product import SeparableTensorProduct, DepthwiseTensorProduct try: import openequivariance as oeq + openequivariance_available = True except ImportError as e: error_msg = str(e) @@ -90,7 +91,6 @@ def __init__( internal_weights=False, ) - self.tp = tensor_product(irreps_in, irreps_sh, irreps_out) if radial_nn is None: radial_nn = functools.partial( @@ -104,8 +104,8 @@ def __init__( if not openequivariance_available: raise ImportError(f"OpenEquivariance could not be imported:\n{error_msg}") - # Remove path shape from instructions. - oeq_instructions = [instruction[:6] for instruction in self.tp.instructions] + # Remove path weight and path shape from instructions. + oeq_instructions = [instruction[:5] for instruction in self.tp.instructions] oeq_tpp = oeq.TPProblem( self.tp.irreps_in1, self.tp.irreps_in2, @@ -118,7 +118,13 @@ def __init__( oeq_tpp, torch_op=True, deterministic=False, use_opaque=False ) - def forward(self, node_attr: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor, edge_sh: torch.Tensor) -> torch.Tensor: + def forward( + self, + node_attr: torch.Tensor, + edge_index: torch.Tensor, + edge_attr: torch.Tensor, + edge_sh: torch.Tensor, + ) -> torch.Tensor: """ Computes the forward pass of the equivariant convolution. @@ -140,8 +146,10 @@ def forward(self, node_attr: torch.Tensor, edge_index: torch.Tensor, edge_attr: src, dst = edge_index radial_attr = self.radial_nn(edge_attr) messages_agg = self.fused_tp(node_attr, edge_sh, radial_attr, dst, src) - num_neighbors = scatter(torch.ones_like(src), src, dim=0, dim_size=N, reduce="sum") - out = messages_agg + num_neighbors = scatter( + torch.ones_like(src), src, dim=0, dim_size=N, reduce="sum" + ) + out = messages_agg / num_neighbors.clamp_min(1).unsqueeze(1) return out @@ -223,10 +231,21 @@ def __init__( self.radial_nn = radial_nn(edge_attr_dim, self.tp.weight_numel) - def apply_per_edge(self, node_attr_src: torch.Tensor, edge_attr: torch.Tensor, edge_sh: torch.Tensor) -> torch.Tensor: + def apply_per_edge( + self, + node_attr_src: torch.Tensor, + edge_attr: torch.Tensor, + edge_sh: torch.Tensor, + ) -> torch.Tensor: return self.tp(node_attr_src, edge_sh, self.radial_nn(edge_attr)) - def forward(self, node_attr: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor, edge_sh: torch.Tensor) -> torch.Tensor: + def forward( + self, + node_attr: torch.Tensor, + edge_index: torch.Tensor, + edge_attr: torch.Tensor, + edge_sh: torch.Tensor, + ) -> torch.Tensor: """ Computes the forward pass of the equivariant convolution. @@ -268,12 +287,19 @@ def __init__(self, *args, **kwargs): ) -class ExperimentalConv(Conv): +class FusedDepthwiseConv(FusedConv): + """ + Equivariant convolution layer using separable tensor product + + ref: https://arxiv.org/abs/1802.08219 + ref: https://arxiv.org/abs/2206.11990 + """ + def __init__(self, *args, **kwargs): super().__init__( *args, **kwargs, - tensor_product=ExperimentalTensorProduct, + tensor_product=DepthwiseTensorProduct, ) diff --git a/src/e3tools/nn/_tensor_product.py b/src/e3tools/nn/_tensor_product.py index 5d54ebd..60f6e6d 100644 --- a/src/e3tools/nn/_tensor_product.py +++ b/src/e3tools/nn/_tensor_product.py @@ -2,14 +2,15 @@ import e3nn import e3nn.o3 +import torch from torch import nn from ._linear import Linear -class SeparableTensorProduct(nn.Module): +class DepthwiseTensorProduct(nn.Module): """ - Tensor product factored into depthwise and pointwise components + Depthwise tensor product ref: https://arxiv.org/abs/2206.11990 ref: https://github.com/atomicarchitects/equiformer/blob/a4360ada2d213ba7b4d884335d3dc54a92b7a371/nets/graph_attention_transformer.py#L157 @@ -24,7 +25,7 @@ def __init__( super().__init__() self.irreps_in1 = e3nn.o3.Irreps(irreps_in1) self.irreps_in2 = e3nn.o3.Irreps(irreps_in2) - self.irreps_out = e3nn.o3.Irreps(irreps_out) + irreps_out = e3nn.o3.Irreps(irreps_out) irreps_out_dtp = [] instructions_dtp = [] @@ -32,15 +33,14 @@ def __init__( for i, (mul, ir_in1) in enumerate(self.irreps_in1): for j, (_, ir_in2) in enumerate(self.irreps_in2): for ir_out in ir_in1 * ir_in2: - if ir_out in self.irreps_out or ir_out == e3nn.o3.Irrep(0, 1): + if ir_out in irreps_out or ir_out == e3nn.o3.Irrep(0, 1): k = len(irreps_out_dtp) irreps_out_dtp.append((mul, ir_out)) instructions_dtp.append((i, j, k, "uvu", True)) irreps_out_dtp = e3nn.o3.Irreps(irreps_out_dtp) - # depth wise - self.dtp = e3nn.o3.TensorProduct( + self.tp = e3nn.o3.TensorProduct( irreps_in1, irreps_in2, irreps_out_dtp, @@ -48,21 +48,23 @@ def __init__( internal_weights=False, shared_weights=False, ) - - # point wise - self.lin = Linear(irreps_out_dtp, self.irreps_out) - - self.weight_numel = self.dtp.weight_numel - - def forward(self, x, y, weight): - out = self.dtp(x, y, weight) - out = self.lin(out) + self.irreps_out = self.tp.irreps_out + self.weight_numel = self.tp.weight_numel + self.instructions = self.tp.instructions + + def forward( + self, x: torch.Tensor, y: torch.Tensor, weight: torch.Tensor + ) -> torch.Tensor: + out = self.tp(x, y, weight) return out -class ExperimentalTensorProduct(nn.Module): +class SeparableTensorProduct(nn.Module): """ - Compileable tensor product + Tensor product factored into depthwise and pointwise components + + ref: https://arxiv.org/abs/2206.11990 + ref: https://github.com/atomicarchitects/equiformer/blob/a4360ada2d213ba7b4d884335d3dc54a92b7a371/nets/graph_attention_transformer.py#L157 """ def __init__( @@ -76,18 +78,17 @@ def __init__( self.irreps_in2 = e3nn.o3.Irreps(irreps_in2) self.irreps_out = e3nn.o3.Irreps(irreps_out) - self.tp = e3nn.o3.FullTensorProductv2(self.irreps_in1, self.irreps_in2) - - self.lin = Linear( - self.tp.irreps_out, - self.irreps_out, - internal_weights=False, - shared_weights=False, + # Depthwise and pointwise + self.dtp = DepthwiseTensorProduct( + self.irreps_in1, self.irreps_in2, self.irreps_out ) + self.lin = Linear(self.dtp.irreps_out, self.irreps_out) - self.weight_numel = self.lin.weight_numel + # For book-keeping. + self.instructions = self.dtp.instructions + self.weight_numel = self.dtp.weight_numel def forward(self, x, y, weight): - out = self.tp(x, y) - out = self.lin(out, weight) + out = self.dtp(x, y, weight) + out = self.lin(out) return out diff --git a/tests/test_fused_conv.py b/tests/test_fused_conv.py index 5a02492..6493abb 100644 --- a/tests/test_fused_conv.py +++ b/tests/test_fused_conv.py @@ -1,16 +1,25 @@ -from typing import Tuple import functools import pytest import torch +from torch import nn import e3nn -from e3tools.nn import Conv, FusedConv +from e3tools.nn import Conv, FusedConv, DepthwiseTensorProduct, ScalarMLP from e3tools import radius_graph +TENSOR_PRODUCTS = [ + functools.partial( + e3nn.o3.FullyConnectedTensorProduct, + shared_weights=False, + internal_weights=False, + ), + DepthwiseTensorProduct +] -@pytest.mark.parametrize("seed", [0, 1, 2]) -def test_fused_conv(seed): +@pytest.mark.parametrize("tensor_product", TENSOR_PRODUCTS) +@pytest.mark.parametrize("seed", [0, 1]) +def test_fused_conv(tensor_product, seed): if not torch.cuda.is_available(): pytest.skip("CUDA is not available") @@ -18,18 +27,18 @@ def test_fused_conv(seed): torch.set_default_device("cuda") N = 20 - edge_attr_dim = 1 - max_radius = 100 - irreps_in = e3nn.o3.Irreps("1x0e + 1x1o + 1x2e + 1x3o") + edge_attr_dim = 10 + max_radius = 1.0 + irreps_in = e3nn.o3.Irreps("10x0e + 4x1o + 1x2e") irreps_sh = irreps_in.spherical_harmonics(2) - def fake_nn(edge_attr_dim, num_elements): - def fn(edge_attr): - return torch.ones((edge_attr.shape[0], num_elements), device=edge_attr.device) - return fn + tp = tensor_product(irreps_in, irreps_sh, irreps_in) + common_radial_nn = ScalarMLP(in_features=edge_attr_dim, out_features=tp.weight_numel, hidden_features=[edge_attr_dim], activation_layer=nn.SiLU) + def radial_nn(edge_attr_dim: int, num_elements: int) -> nn.Module: + return common_radial_nn - layer = Conv(irreps_in=irreps_in, irreps_out=irreps_in, irreps_sh=irreps_sh, radial_nn=fake_nn, edge_attr_dim=edge_attr_dim) - fused_layer = FusedConv(irreps_in=irreps_in, irreps_out=irreps_in, irreps_sh=irreps_sh, radial_nn=fake_nn, edge_attr_dim=edge_attr_dim) + layer = Conv(irreps_in=irreps_in, irreps_out=irreps_in, irreps_sh=irreps_sh, radial_nn=radial_nn, edge_attr_dim=edge_attr_dim, tensor_product=tensor_product) + fused_layer = FusedConv(irreps_in=irreps_in, irreps_out=irreps_in, irreps_sh=irreps_sh, radial_nn=radial_nn, edge_attr_dim=edge_attr_dim, tensor_product=tensor_product) pos = torch.randn(N, 3) node_attr = layer.irreps_in.randn(N, -1) @@ -51,9 +60,6 @@ def fn(edge_attr): out = layer(node_attr, edge_index, edge_attr, edge_sh) out_fused = fused_layer(node_attr, edge_index, edge_attr, edge_sh) - print(out[0] / out_fused[0]) - print(((out[0, 0] / out_fused[0, 0]) / (out[1, 1] / out_fused[1, 1])) ** 4) - print(((out[0, 0] / out_fused[0, 0]) / (out[4, 4] / out_fused[4, 4])) ** 4) - print(((out[0, 0] / out_fused[0, 0]) / (out[9, 9] / out_fused[9, 9])) ** 4) - assert torch.allclose(out, out_fused, atol=1e-10) + + assert torch.allclose(out, out_fused, rtol=1e-3) From bdf51da865dac5aa4e701e89bd8e91287b9b151e Mon Sep 17 00:00:00 2001 From: Ameya Daigavane Date: Fri, 5 Sep 2025 15:28:53 -0400 Subject: [PATCH 4/8] Formatting. --- tests/test_fused_conv.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/tests/test_fused_conv.py b/tests/test_fused_conv.py index 6493abb..bfbe68e 100644 --- a/tests/test_fused_conv.py +++ b/tests/test_fused_conv.py @@ -14,9 +14,10 @@ shared_weights=False, internal_weights=False, ), - DepthwiseTensorProduct + DepthwiseTensorProduct, ] + @pytest.mark.parametrize("tensor_product", TENSOR_PRODUCTS) @pytest.mark.parametrize("seed", [0, 1]) def test_fused_conv(tensor_product, seed): @@ -33,12 +34,32 @@ def test_fused_conv(tensor_product, seed): irreps_sh = irreps_in.spherical_harmonics(2) tp = tensor_product(irreps_in, irreps_sh, irreps_in) - common_radial_nn = ScalarMLP(in_features=edge_attr_dim, out_features=tp.weight_numel, hidden_features=[edge_attr_dim], activation_layer=nn.SiLU) + common_radial_nn = ScalarMLP( + in_features=edge_attr_dim, + out_features=tp.weight_numel, + hidden_features=[edge_attr_dim], + activation_layer=nn.SiLU, + ) + def radial_nn(edge_attr_dim: int, num_elements: int) -> nn.Module: return common_radial_nn - layer = Conv(irreps_in=irreps_in, irreps_out=irreps_in, irreps_sh=irreps_sh, radial_nn=radial_nn, edge_attr_dim=edge_attr_dim, tensor_product=tensor_product) - fused_layer = FusedConv(irreps_in=irreps_in, irreps_out=irreps_in, irreps_sh=irreps_sh, radial_nn=radial_nn, edge_attr_dim=edge_attr_dim, tensor_product=tensor_product) + layer = Conv( + irreps_in=irreps_in, + irreps_out=irreps_in, + irreps_sh=irreps_sh, + radial_nn=radial_nn, + edge_attr_dim=edge_attr_dim, + tensor_product=tensor_product, + ) + fused_layer = FusedConv( + irreps_in=irreps_in, + irreps_out=irreps_in, + irreps_sh=irreps_sh, + radial_nn=radial_nn, + edge_attr_dim=edge_attr_dim, + tensor_product=tensor_product, + ) pos = torch.randn(N, 3) node_attr = layer.irreps_in.randn(N, -1) @@ -62,4 +83,3 @@ def radial_nn(edge_attr_dim: int, num_elements: int) -> nn.Module: out_fused = fused_layer(node_attr, edge_index, edge_attr, edge_sh) assert torch.allclose(out, out_fused, rtol=1e-3) - From e307630f8d3cbba4b7b783c49938e855cae1ad34 Mon Sep 17 00:00:00 2001 From: Ameya Daigavane Date: Fri, 5 Sep 2025 15:31:53 -0400 Subject: [PATCH 5/8] Remove old ExperimentalTensorProduct --- tests/test_equivariance.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_equivariance.py b/tests/test_equivariance.py index cdf597b..7f2b777 100644 --- a/tests/test_equivariance.py +++ b/tests/test_equivariance.py @@ -10,7 +10,6 @@ Conv, ConvBlock, EquivariantMLP, - ExperimentalConv, Gated, LayerNorm, MultiheadAttention, @@ -21,7 +20,7 @@ torch.set_default_dtype(torch.float64) -CONV_LAYERS = [Conv, SeparableConv, ExperimentalConv] +CONV_LAYERS = [Conv, SeparableConv] def apply_layer_rotation(layer: torch.nn.Module) -> Tuple[torch.Tensor, torch.Tensor]: From 1611e899ebb1dd793547f5f2541fd6078b6ddb16 Mon Sep 17 00:00:00 2001 From: Joseph Kleinhenz Date: Fri, 5 Sep 2025 19:47:41 -0700 Subject: [PATCH 6/8] add openequivariance extra --- pyproject.toml | 3 + uv.lock | 165 +++++++++++++++++++++++++++++++++---------------- 2 files changed, 115 insertions(+), 53 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9644b4c..0ff26f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,9 @@ dev = [ "ruff", "ipykernel>=6.30.1", ] +openequivariance = [ + "openequivariance>=0.4.1", +] [tool.ruff.lint] ignore = ["F722"] diff --git a/uv.lock b/uv.lock index 37c4d3d..d575811 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.12'", @@ -350,6 +350,9 @@ dev = [ { name = "sphinx", version = "8.1.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "sphinx", version = "8.2.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, ] +openequivariance = [ + { name = "openequivariance" }, +] [package.metadata] requires-dist = [ @@ -359,13 +362,14 @@ requires-dist = [ { name = "ipykernel", marker = "extra == 'dev'", specifier = ">=6.30.1" }, { name = "ipython", marker = "extra == 'dev'", specifier = ">=8.34.0" }, { name = "jaxtyping", specifier = ">=0.2.38" }, + { name = "openequivariance", marker = "extra == 'openequivariance'", specifier = ">=0.4.1" }, { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.3.4" }, { name = "ruff", marker = "extra == 'dev'" }, { name = "setuptools", specifier = ">=78.1.0" }, { name = "sphinx", marker = "extra == 'dev'" }, { name = "torch", specifier = ">=2.5.1" }, ] -provides-extras = ["dev"] +provides-extras = ["dev", "openequivariance"] [[package]] name = "einops" @@ -887,6 +891,32 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl", hash = "sha256:df5d4365b724cf81b8c6a7312509d0c22386097011ad1abe274afd5e9d3bbc5f", size = 1723263, upload-time = "2024-10-21T12:39:36.247Z" }, ] +[[package]] +name = "ninja" +version = "1.13.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/43/73/79a0b22fc731989c708068427579e840a6cf4e937fe7ae5c5d0b7356ac22/ninja-1.13.0.tar.gz", hash = "sha256:4a40ce995ded54d9dc24f8ea37ff3bf62ad192b547f6c7126e7e25045e76f978", size = 242558, upload-time = "2025-08-11T15:10:19.421Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3c/74/d02409ed2aa865e051b7edda22ad416a39d81a84980f544f8de717cab133/ninja-1.13.0-py3-none-macosx_10_9_universal2.whl", hash = "sha256:fa2a8bfc62e31b08f83127d1613d10821775a0eb334197154c4d6067b7068ff1", size = 310125, upload-time = "2025-08-11T15:09:50.971Z" }, + { url = "https://files.pythonhosted.org/packages/8e/de/6e1cd6b84b412ac1ef327b76f0641aeb5dcc01e9d3f9eee0286d0c34fd93/ninja-1.13.0-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3d00c692fb717fd511abeb44b8c5d00340c36938c12d6538ba989fe764e79630", size = 177467, upload-time = "2025-08-11T15:09:52.767Z" }, + { url = "https://files.pythonhosted.org/packages/c8/83/49320fb6e58ae3c079381e333575fdbcf1cca3506ee160a2dcce775046fa/ninja-1.13.0-py3-none-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:be7f478ff9f96a128b599a964fc60a6a87b9fa332ee1bd44fa243ac88d50291c", size = 187834, upload-time = "2025-08-11T15:09:54.115Z" }, + { url = "https://files.pythonhosted.org/packages/56/c7/ba22748fb59f7f896b609cd3e568d28a0a367a6d953c24c461fe04fc4433/ninja-1.13.0-py3-none-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:60056592cf495e9a6a4bea3cd178903056ecb0943e4de45a2ea825edb6dc8d3e", size = 202736, upload-time = "2025-08-11T15:09:55.745Z" }, + { url = "https://files.pythonhosted.org/packages/79/22/d1de07632b78ac8e6b785f41fa9aad7a978ec8c0a1bf15772def36d77aac/ninja-1.13.0-py3-none-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:1c97223cdda0417f414bf864cfb73b72d8777e57ebb279c5f6de368de0062988", size = 179034, upload-time = "2025-08-11T15:09:57.394Z" }, + { url = "https://files.pythonhosted.org/packages/ed/de/0e6edf44d6a04dabd0318a519125ed0415ce437ad5a1ec9b9be03d9048cf/ninja-1.13.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fb46acf6b93b8dd0322adc3a4945452a4e774b75b91293bafcc7b7f8e6517dfa", size = 180716, upload-time = "2025-08-11T15:09:58.696Z" }, + { url = "https://files.pythonhosted.org/packages/54/28/938b562f9057aaa4d6bfbeaa05e81899a47aebb3ba6751e36c027a7f5ff7/ninja-1.13.0-py3-none-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:4be9c1b082d244b1ad7ef41eb8ab088aae8c109a9f3f0b3e56a252d3e00f42c1", size = 146843, upload-time = "2025-08-11T15:10:00.046Z" }, + { url = "https://files.pythonhosted.org/packages/2a/fb/d06a3838de4f8ab866e44ee52a797b5491df823901c54943b2adb0389fbb/ninja-1.13.0-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:6739d3352073341ad284246f81339a384eec091d9851a886dfa5b00a6d48b3e2", size = 154402, upload-time = "2025-08-11T15:10:01.657Z" }, + { url = "https://files.pythonhosted.org/packages/31/bf/0d7808af695ceddc763cf251b84a9892cd7f51622dc8b4c89d5012779f06/ninja-1.13.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:11be2d22027bde06f14c343f01d31446747dbb51e72d00decca2eb99be911e2f", size = 552388, upload-time = "2025-08-11T15:10:03.349Z" }, + { url = "https://files.pythonhosted.org/packages/9d/70/c99d0c2c809f992752453cce312848abb3b1607e56d4cd1b6cded317351a/ninja-1.13.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:aa45b4037b313c2f698bc13306239b8b93b4680eb47e287773156ac9e9304714", size = 472501, upload-time = "2025-08-11T15:10:04.735Z" }, + { url = "https://files.pythonhosted.org/packages/9f/43/c217b1153f0e499652f5e0766da8523ce3480f0a951039c7af115e224d55/ninja-1.13.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:5f8e1e8a1a30835eeb51db05cf5a67151ad37542f5a4af2a438e9490915e5b72", size = 638280, upload-time = "2025-08-11T15:10:06.512Z" }, + { url = "https://files.pythonhosted.org/packages/8c/45/9151bba2c8d0ae2b6260f71696330590de5850e5574b7b5694dce6023e20/ninja-1.13.0-py3-none-musllinux_1_2_ppc64le.whl", hash = "sha256:3d7d7779d12cb20c6d054c61b702139fd23a7a964ec8f2c823f1ab1b084150db", size = 642420, upload-time = "2025-08-11T15:10:08.35Z" }, + { url = "https://files.pythonhosted.org/packages/3c/fb/95752eb635bb8ad27d101d71bef15bc63049de23f299e312878fc21cb2da/ninja-1.13.0-py3-none-musllinux_1_2_riscv64.whl", hash = "sha256:d741a5e6754e0bda767e3274a0f0deeef4807f1fec6c0d7921a0244018926ae5", size = 585106, upload-time = "2025-08-11T15:10:09.818Z" }, + { url = "https://files.pythonhosted.org/packages/c1/31/aa56a1a286703800c0cbe39fb4e82811c277772dc8cd084f442dd8e2938a/ninja-1.13.0-py3-none-musllinux_1_2_s390x.whl", hash = "sha256:e8bad11f8a00b64137e9b315b137d8bb6cbf3086fbdc43bf1f90fd33324d2e96", size = 707138, upload-time = "2025-08-11T15:10:11.366Z" }, + { url = "https://files.pythonhosted.org/packages/34/6f/5f5a54a1041af945130abdb2b8529cbef0cdcbbf9bcf3f4195378319d29a/ninja-1.13.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b4f2a072db3c0f944c32793e91532d8948d20d9ab83da9c0c7c15b5768072200", size = 581758, upload-time = "2025-08-11T15:10:13.295Z" }, + { url = "https://files.pythonhosted.org/packages/95/97/51359c77527d45943fe7a94d00a3843b81162e6c4244b3579fe8fc54cb9c/ninja-1.13.0-py3-none-win32.whl", hash = "sha256:8cfbb80b4a53456ae8a39f90ae3d7a2129f45ea164f43fadfa15dc38c4aef1c9", size = 267201, upload-time = "2025-08-11T15:10:15.158Z" }, + { url = "https://files.pythonhosted.org/packages/29/45/c0adfbfb0b5895aa18cec400c535b4f7ff3e52536e0403602fc1a23f7de9/ninja-1.13.0-py3-none-win_amd64.whl", hash = "sha256:fb8ee8719f8af47fed145cced4a85f0755dd55d45b2bddaf7431fa89803c5f3e", size = 309975, upload-time = "2025-08-11T15:10:16.697Z" }, + { url = "https://files.pythonhosted.org/packages/df/93/a7b983643d1253bb223234b5b226e69de6cda02b76cdca7770f684b795f5/ninja-1.13.0-py3-none-win_arm64.whl", hash = "sha256:3c0b40b1f0bba764644385319028650087b4c1b18cdfa6f45cb39a3669b81aa9", size = 290806, upload-time = "2025-08-11T15:10:18.018Z" }, +] + [[package]] name = "numpy" version = "2.2.3" @@ -951,69 +981,77 @@ wheels = [ [[package]] name = "nvidia-cublas-cu12" -version = "12.4.5.8" +version = "12.8.4.1" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ae/71/1c91302526c45ab494c23f61c7a84aa568b8c1f9d196efa5993957faf906/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b", size = 363438805, upload-time = "2024-04-03T20:57:06.025Z" }, + { url = "https://files.pythonhosted.org/packages/dc/61/e24b560ab2e2eaeb3c839129175fb330dfcfc29e5203196e5541a4c44682/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142", size = 594346921, upload-time = "2025-03-07T01:44:31.254Z" }, ] [[package]] name = "nvidia-cuda-cupti-cu12" -version = "12.4.127" +version = "12.8.90" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/67/42/f4f60238e8194a3106d06a058d494b18e006c10bb2b915655bd9f6ea4cb1/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb", size = 13813957, upload-time = "2024-04-03T20:55:01.564Z" }, + { url = "https://files.pythonhosted.org/packages/f8/02/2adcaa145158bf1a8295d83591d22e4103dbfd821bcaf6f3f53151ca4ffa/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182", size = 10248621, upload-time = "2025-03-07T01:40:21.213Z" }, ] [[package]] name = "nvidia-cuda-nvrtc-cu12" -version = "12.4.127" +version = "12.8.93" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2c/14/91ae57cd4db3f9ef7aa99f4019cfa8d54cb4caa7e00975df6467e9725a9f/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a178759ebb095827bd30ef56598ec182b85547f1508941a3d560eb7ea1fbf338", size = 24640306, upload-time = "2024-04-03T20:56:01.463Z" }, + { url = "https://files.pythonhosted.org/packages/05/6b/32f747947df2da6994e999492ab306a903659555dddc0fbdeb9d71f75e52/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a7756528852ef889772a84c6cd89d41dfa74667e24cca16bb31f8f061e3e9994", size = 88040029, upload-time = "2025-03-07T01:42:13.562Z" }, ] [[package]] name = "nvidia-cuda-runtime-cu12" -version = "12.4.127" +version = "12.8.90" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ea/27/1795d86fe88ef397885f2e580ac37628ed058a92ed2c39dc8eac3adf0619/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5", size = 883737, upload-time = "2024-04-03T20:54:51.355Z" }, + { url = "https://files.pythonhosted.org/packages/0d/9b/a997b638fcd068ad6e4d53b8551a7d30fe8b404d6f1804abf1df69838932/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90", size = 954765, upload-time = "2025-03-07T01:40:01.615Z" }, ] [[package]] name = "nvidia-cudnn-cu12" -version = "9.1.0.70" +version = "9.10.2.21" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nvidia-cublas-cu12" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741, upload-time = "2024-04-22T15:24:15.253Z" }, + { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" }, ] [[package]] name = "nvidia-cufft-cu12" -version = "11.2.1.3" +version = "11.3.3.83" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nvidia-nvjitlink-cu12" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/27/94/3266821f65b92b3138631e9c8e7fe1fb513804ac934485a8d05776e1dd43/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f083fc24912aa410be21fa16d157fed2055dab1cc4b6934a0e03cba69eb242b9", size = 211459117, upload-time = "2024-04-03T20:57:40.402Z" }, + { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, +] + +[[package]] +name = "nvidia-cufile-cu12" +version = "1.13.1.3" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bb/fe/1bcba1dfbfb8d01be8d93f07bfc502c93fa23afa6fd5ab3fc7c1df71038a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1d069003be650e131b21c932ec3d8969c1715379251f8d23a1860554b1cb24fc", size = 1197834, upload-time = "2025-03-07T01:45:50.723Z" }, ] [[package]] name = "nvidia-curand-cu12" -version = "10.3.5.147" +version = "10.3.9.90" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8a/6d/44ad094874c6f1b9c654f8ed939590bdc408349f137f9b98a3a23ccec411/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl", hash = "sha256:a88f583d4e0bb643c49743469964103aa59f7f708d862c3ddb0fc07f851e3b8b", size = 56305206, upload-time = "2024-04-03T20:58:08.722Z" }, + { url = "https://files.pythonhosted.org/packages/fb/aa/6584b56dc84ebe9cf93226a5cde4d99080c8e90ab40f0c27bda7a0f29aa1/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:b32331d4f4df5d6eefa0554c565b626c7216f87a06a4f56fab27c3b68a830ec9", size = 63619976, upload-time = "2025-03-07T01:46:23.323Z" }, ] [[package]] name = "nvidia-cusolver-cu12" -version = "11.6.1.9" +version = "11.7.3.90" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nvidia-cublas-cu12" }, @@ -1021,51 +1059,63 @@ dependencies = [ { name = "nvidia-nvjitlink-cu12" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/3a/e1/5b9089a4b2a4790dfdea8b3a006052cfecff58139d5a4e34cb1a51df8d6f/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl", hash = "sha256:19e33fa442bcfd085b3086c4ebf7e8debc07cfe01e11513cc6d332fd918ac260", size = 127936057, upload-time = "2024-04-03T20:58:28.735Z" }, + { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, ] [[package]] name = "nvidia-cusparse-cu12" -version = "12.3.1.170" +version = "12.5.8.93" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nvidia-nvjitlink-cu12" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/db/f7/97a9ea26ed4bbbfc2d470994b8b4f338ef663be97b8f677519ac195e113d/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl", hash = "sha256:ea4f11a2904e2a8dc4b1833cc1b5181cde564edd0d5cd33e3c168eff2d1863f1", size = 207454763, upload-time = "2024-04-03T20:58:59.995Z" }, + { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, ] [[package]] name = "nvidia-cusparselt-cu12" -version = "0.6.2" +version = "0.7.1" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/78/a8/bcbb63b53a4b1234feeafb65544ee55495e1bb37ec31b999b963cbccfd1d/nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_x86_64.whl", hash = "sha256:df2c24502fd76ebafe7457dbc4716b2fec071aabaed4fb7691a201cde03704d9", size = 150057751, upload-time = "2024-07-23T02:35:53.074Z" }, + { url = "https://files.pythonhosted.org/packages/56/79/12978b96bd44274fe38b5dde5cfb660b1d114f70a65ef962bcbbed99b549/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f1bb701d6b930d5a7cea44c19ceb973311500847f81b634d802b7b539dc55623", size = 287193691, upload-time = "2025-02-26T00:15:44.104Z" }, ] [[package]] name = "nvidia-nccl-cu12" -version = "2.21.5" +version = "2.27.3" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/df/99/12cd266d6233f47d00daf3a72739872bdc10267d0383508b0b9c84a18bb6/nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0", size = 188654414, upload-time = "2024-04-03T15:32:57.427Z" }, + { url = "https://files.pythonhosted.org/packages/5c/5b/4e4fff7bad39adf89f735f2bc87248c81db71205b62bcc0d5ca5b606b3c3/nvidia_nccl_cu12-2.27.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adf27ccf4238253e0b826bce3ff5fa532d65fc42322c8bfdfaf28024c0fbe039", size = 322364134, upload-time = "2025-06-03T21:58:04.013Z" }, ] [[package]] name = "nvidia-nvjitlink-cu12" -version = "12.4.127" +version = "12.8.93" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ff/ff/847841bacfbefc97a00036e0fce5a0f086b640756dc38caea5e1bb002655/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57", size = 21066810, upload-time = "2024-04-03T20:59:46.957Z" }, + { url = "https://files.pythonhosted.org/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88", size = 39254836, upload-time = "2025-03-07T01:49:55.661Z" }, ] [[package]] name = "nvidia-nvtx-cu12" -version = "12.4.127" +version = "12.8.90" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/87/20/199b8713428322a2f22b722c62b8cc278cc53dffa9705d744484b5035ee9/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:781e950d9b9f60d8241ccea575b32f5105a5baf4c2351cab5256a24869f12a1a", size = 99144, upload-time = "2024-04-03T20:56:12.406Z" }, + { url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954, upload-time = "2025-03-07T01:42:44.131Z" }, +] + +[[package]] +name = "openequivariance" +version = "0.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jinja2" }, + { name = "ninja" }, + { name = "numpy" }, + { name = "torch" }, ] +sdist = { url = "https://files.pythonhosted.org/packages/43/15/b60363bef3df8cb59c295c3ca1cc2ded58874f659d3401b02d328def66c0/openequivariance-0.4.1.tar.gz", hash = "sha256:4122d541312c78170a3cefff1a674fe03f5a78c583f13efebec7b835b7d5d086", size = 113058, upload-time = "2025-09-04T22:20:24.646Z" } [[package]] name = "opt-einsum" @@ -1653,14 +1703,14 @@ wheels = [ [[package]] name = "sympy" -version = "1.13.1" +version = "1.14.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "mpmath" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ca/99/5a5b6f19ff9f083671ddf7b9632028436167cd3d33e11015754e41b249a4/sympy-1.13.1.tar.gz", hash = "sha256:9cebf7e04ff162015ce31c9c6c9144daa34a93bd082f54fd8f12deca4f47515f", size = 7533040, upload-time = "2024-07-19T09:26:51.238Z" } +sdist = { url = "https://files.pythonhosted.org/packages/83/d3/803453b36afefb7c2bb238361cd4ae6125a569b4db67cd9e79846ba2d68c/sympy-1.14.0.tar.gz", hash = "sha256:d3d3fe8df1e5a0b42f0e7bdf50541697dbe7d23746e894990c030e2b05e72517", size = 7793921, upload-time = "2025-04-27T18:05:01.611Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b2/fe/81695a1aa331a842b582453b605175f419fe8540355886031328089d840a/sympy-1.13.1-py3-none-any.whl", hash = "sha256:db36cdc64bf61b9b24578b6f7bab1ecdd2452cf008f34faa33776680c26d66f8", size = 6189177, upload-time = "2024-07-19T09:26:48.863Z" }, + { url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353, upload-time = "2025-04-27T18:04:59.103Z" }, ] [[package]] @@ -1722,7 +1772,7 @@ wheels = [ [[package]] name = "torch" -version = "2.6.0" +version = "2.8.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, @@ -1735,6 +1785,7 @@ dependencies = [ { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufile-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, @@ -1748,22 +1799,26 @@ dependencies = [ { name = "typing-extensions" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/37/81/aa9ab58ec10264c1abe62c8b73f5086c3c558885d6beecebf699f0dbeaeb/torch-2.6.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:6860df13d9911ac158f4c44031609700e1eba07916fff62e21e6ffa0a9e01961", size = 766685561, upload-time = "2025-01-29T16:19:12.12Z" }, - { url = "https://files.pythonhosted.org/packages/86/86/e661e229df2f5bfc6eab4c97deb1286d598bbeff31ab0cdb99b3c0d53c6f/torch-2.6.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:c4f103a49830ce4c7561ef4434cc7926e5a5fe4e5eb100c19ab36ea1e2b634ab", size = 95751887, upload-time = "2025-01-29T16:27:50.77Z" }, - { url = "https://files.pythonhosted.org/packages/20/e0/5cb2f8493571f0a5a7273cd7078f191ac252a402b5fb9cb6091f14879109/torch-2.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:56eeaf2ecac90da5d9e35f7f35eb286da82673ec3c582e310a8d1631a1c02341", size = 204165139, upload-time = "2025-01-29T16:27:11.63Z" }, - { url = "https://files.pythonhosted.org/packages/e5/16/ea1b7842413a7b8a5aaa5e99e8eaf3da3183cc3ab345ad025a07ff636301/torch-2.6.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:09e06f9949e1a0518c5b09fe95295bc9661f219d9ecb6f9893e5123e10696628", size = 66520221, upload-time = "2025-01-29T16:22:18.862Z" }, - { url = "https://files.pythonhosted.org/packages/78/a9/97cbbc97002fff0de394a2da2cdfa859481fdca36996d7bd845d50aa9d8d/torch-2.6.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:7979834102cd5b7a43cc64e87f2f3b14bd0e1458f06e9f88ffa386d07c7446e1", size = 766715424, upload-time = "2025-01-29T16:25:15.874Z" }, - { url = "https://files.pythonhosted.org/packages/6d/fa/134ce8f8a7ea07f09588c9cc2cea0d69249efab977707cf67669431dcf5c/torch-2.6.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:ccbd0320411fe1a3b3fec7b4d3185aa7d0c52adac94480ab024b5c8f74a0bf1d", size = 95759416, upload-time = "2025-01-29T16:27:38.429Z" }, - { url = "https://files.pythonhosted.org/packages/11/c5/2370d96b31eb1841c3a0883a492c15278a6718ccad61bb6a649c80d1d9eb/torch-2.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:46763dcb051180ce1ed23d1891d9b1598e07d051ce4c9d14307029809c4d64f7", size = 204164970, upload-time = "2025-01-29T16:26:16.182Z" }, - { url = "https://files.pythonhosted.org/packages/0b/fa/f33a4148c6fb46ca2a3f8de39c24d473822d5774d652b66ed9b1214da5f7/torch-2.6.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:94fc63b3b4bedd327af588696559f68c264440e2503cc9e6954019473d74ae21", size = 66530713, upload-time = "2025-01-29T16:26:38.881Z" }, - { url = "https://files.pythonhosted.org/packages/e5/35/0c52d708144c2deb595cd22819a609f78fdd699b95ff6f0ebcd456e3c7c1/torch-2.6.0-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:2bb8987f3bb1ef2675897034402373ddfc8f5ef0e156e2d8cfc47cacafdda4a9", size = 766624563, upload-time = "2025-01-29T16:23:19.084Z" }, - { url = "https://files.pythonhosted.org/packages/01/d6/455ab3fbb2c61c71c8842753b566012e1ed111e7a4c82e0e1c20d0c76b62/torch-2.6.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:b789069020c5588c70d5c2158ac0aa23fd24a028f34a8b4fcb8fcb4d7efcf5fb", size = 95607867, upload-time = "2025-01-29T16:25:55.649Z" }, - { url = "https://files.pythonhosted.org/packages/18/cf/ae99bd066571656185be0d88ee70abc58467b76f2f7c8bfeb48735a71fe6/torch-2.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:7e1448426d0ba3620408218b50aa6ada88aeae34f7a239ba5431f6c8774b1239", size = 204120469, upload-time = "2025-01-29T16:24:01.821Z" }, - { url = "https://files.pythonhosted.org/packages/81/b4/605ae4173aa37fb5aa14605d100ff31f4f5d49f617928c9f486bb3aaec08/torch-2.6.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:9a610afe216a85a8b9bc9f8365ed561535c93e804c2a317ef7fabcc5deda0989", size = 66532538, upload-time = "2025-01-29T16:24:18.976Z" }, - { url = "https://files.pythonhosted.org/packages/24/85/ead1349fc30fe5a32cadd947c91bda4a62fbfd7f8c34ee61f6398d38fb48/torch-2.6.0-cp313-cp313-manylinux1_x86_64.whl", hash = "sha256:4874a73507a300a5d089ceaff616a569e7bb7c613c56f37f63ec3ffac65259cf", size = 766626191, upload-time = "2025-01-29T16:17:26.26Z" }, - { url = "https://files.pythonhosted.org/packages/dd/b0/26f06f9428b250d856f6d512413e9e800b78625f63801cbba13957432036/torch-2.6.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:a0d5e1b9874c1a6c25556840ab8920569a7a4137afa8a63a32cee0bc7d89bd4b", size = 95611439, upload-time = "2025-01-29T16:21:21.061Z" }, - { url = "https://files.pythonhosted.org/packages/c2/9c/fc5224e9770c83faed3a087112d73147cd7c7bfb7557dcf9ad87e1dda163/torch-2.6.0-cp313-cp313-win_amd64.whl", hash = "sha256:510c73251bee9ba02ae1cb6c9d4ee0907b3ce6020e62784e2d7598e0cfa4d6cc", size = 204126475, upload-time = "2025-01-29T16:21:55.394Z" }, - { url = "https://files.pythonhosted.org/packages/88/8b/d60c0491ab63634763be1537ad488694d316ddc4a20eaadd639cedc53971/torch-2.6.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:ff96f4038f8af9f7ec4231710ed4549da1bdebad95923953a25045dcf6fd87e2", size = 66536783, upload-time = "2025-01-29T16:22:08.559Z" }, + { url = "https://files.pythonhosted.org/packages/63/28/110f7274254f1b8476c561dada127173f994afa2b1ffc044efb773c15650/torch-2.8.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:0be92c08b44009d4131d1ff7a8060d10bafdb7ddcb7359ef8d8c5169007ea905", size = 102052793, upload-time = "2025-08-06T14:53:15.852Z" }, + { url = "https://files.pythonhosted.org/packages/70/1c/58da560016f81c339ae14ab16c98153d51c941544ae568da3cb5b1ceb572/torch-2.8.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:89aa9ee820bb39d4d72b794345cccef106b574508dd17dbec457949678c76011", size = 888025420, upload-time = "2025-08-06T14:54:18.014Z" }, + { url = "https://files.pythonhosted.org/packages/70/87/f69752d0dd4ba8218c390f0438130c166fa264a33b7025adb5014b92192c/torch-2.8.0-cp310-cp310-win_amd64.whl", hash = "sha256:e8e5bf982e87e2b59d932769938b698858c64cc53753894be25629bdf5cf2f46", size = 241363614, upload-time = "2025-08-06T14:53:31.496Z" }, + { url = "https://files.pythonhosted.org/packages/ef/d6/e6d4c57e61c2b2175d3aafbfb779926a2cfd7c32eeda7c543925dceec923/torch-2.8.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:a3f16a58a9a800f589b26d47ee15aca3acf065546137fc2af039876135f4c760", size = 73611154, upload-time = "2025-08-06T14:53:10.919Z" }, + { url = "https://files.pythonhosted.org/packages/8f/c4/3e7a3887eba14e815e614db70b3b529112d1513d9dae6f4d43e373360b7f/torch-2.8.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:220a06fd7af8b653c35d359dfe1aaf32f65aa85befa342629f716acb134b9710", size = 102073391, upload-time = "2025-08-06T14:53:20.937Z" }, + { url = "https://files.pythonhosted.org/packages/5a/63/4fdc45a0304536e75a5e1b1bbfb1b56dd0e2743c48ee83ca729f7ce44162/torch-2.8.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:c12fa219f51a933d5f80eeb3a7a5d0cbe9168c0a14bbb4055f1979431660879b", size = 888063640, upload-time = "2025-08-06T14:55:05.325Z" }, + { url = "https://files.pythonhosted.org/packages/84/57/2f64161769610cf6b1c5ed782bd8a780e18a3c9d48931319f2887fa9d0b1/torch-2.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:8c7ef765e27551b2fbfc0f41bcf270e1292d9bf79f8e0724848b1682be6e80aa", size = 241366752, upload-time = "2025-08-06T14:53:38.692Z" }, + { url = "https://files.pythonhosted.org/packages/a4/5e/05a5c46085d9b97e928f3f037081d3d2b87fb4b4195030fc099aaec5effc/torch-2.8.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:5ae0524688fb6707c57a530c2325e13bb0090b745ba7b4a2cd6a3ce262572916", size = 73621174, upload-time = "2025-08-06T14:53:25.44Z" }, + { url = "https://files.pythonhosted.org/packages/49/0c/2fd4df0d83a495bb5e54dca4474c4ec5f9c62db185421563deeb5dabf609/torch-2.8.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:e2fab4153768d433f8ed9279c8133a114a034a61e77a3a104dcdf54388838705", size = 101906089, upload-time = "2025-08-06T14:53:52.631Z" }, + { url = "https://files.pythonhosted.org/packages/99/a8/6acf48d48838fb8fe480597d98a0668c2beb02ee4755cc136de92a0a956f/torch-2.8.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:b2aca0939fb7e4d842561febbd4ffda67a8e958ff725c1c27e244e85e982173c", size = 887913624, upload-time = "2025-08-06T14:56:44.33Z" }, + { url = "https://files.pythonhosted.org/packages/af/8a/5c87f08e3abd825c7dfecef5a0f1d9aa5df5dd0e3fd1fa2f490a8e512402/torch-2.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:2f4ac52f0130275d7517b03a33d2493bab3693c83dcfadf4f81688ea82147d2e", size = 241326087, upload-time = "2025-08-06T14:53:46.503Z" }, + { url = "https://files.pythonhosted.org/packages/be/66/5c9a321b325aaecb92d4d1855421e3a055abd77903b7dab6575ca07796db/torch-2.8.0-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:619c2869db3ada2c0105487ba21b5008defcc472d23f8b80ed91ac4a380283b0", size = 73630478, upload-time = "2025-08-06T14:53:57.144Z" }, + { url = "https://files.pythonhosted.org/packages/10/4e/469ced5a0603245d6a19a556e9053300033f9c5baccf43a3d25ba73e189e/torch-2.8.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:2b2f96814e0345f5a5aed9bf9734efa913678ed19caf6dc2cddb7930672d6128", size = 101936856, upload-time = "2025-08-06T14:54:01.526Z" }, + { url = "https://files.pythonhosted.org/packages/16/82/3948e54c01b2109238357c6f86242e6ecbf0c63a1af46906772902f82057/torch-2.8.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:65616ca8ec6f43245e1f5f296603e33923f4c30f93d65e103d9e50c25b35150b", size = 887922844, upload-time = "2025-08-06T14:55:50.78Z" }, + { url = "https://files.pythonhosted.org/packages/e3/54/941ea0a860f2717d86a811adf0c2cd01b3983bdd460d0803053c4e0b8649/torch-2.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:659df54119ae03e83a800addc125856effda88b016dfc54d9f65215c3975be16", size = 241330968, upload-time = "2025-08-06T14:54:45.293Z" }, + { url = "https://files.pythonhosted.org/packages/de/69/8b7b13bba430f5e21d77708b616f767683629fc4f8037564a177d20f90ed/torch-2.8.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:1a62a1ec4b0498930e2543535cf70b1bef8c777713de7ceb84cd79115f553767", size = 73915128, upload-time = "2025-08-06T14:54:34.769Z" }, + { url = "https://files.pythonhosted.org/packages/15/0e/8a800e093b7f7430dbaefa80075aee9158ec22e4c4fc3c1a66e4fb96cb4f/torch-2.8.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:83c13411a26fac3d101fe8035a6b0476ae606deb8688e904e796a3534c197def", size = 102020139, upload-time = "2025-08-06T14:54:39.047Z" }, + { url = "https://files.pythonhosted.org/packages/4a/15/5e488ca0bc6162c86a33b58642bc577c84ded17c7b72d97e49b5833e2d73/torch-2.8.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:8f0a9d617a66509ded240add3754e462430a6c1fc5589f86c17b433dd808f97a", size = 887990692, upload-time = "2025-08-06T14:56:18.286Z" }, + { url = "https://files.pythonhosted.org/packages/b4/a8/6a04e4b54472fc5dba7ca2341ab219e529f3c07b6941059fbf18dccac31f/torch-2.8.0-cp313-cp313t-win_amd64.whl", hash = "sha256:a7242b86f42be98ac674b88a4988643b9bc6145437ec8f048fea23f72feb5eca", size = 241603453, upload-time = "2025-08-06T14:55:22.945Z" }, + { url = "https://files.pythonhosted.org/packages/04/6e/650bb7f28f771af0cb791b02348db8b7f5f64f40f6829ee82aa6ce99aabe/torch-2.8.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:7b677e17f5a3e69fdef7eb3b9da72622f8d322692930297e4ccb52fefc6c8211", size = 73632395, upload-time = "2025-08-06T14:55:28.645Z" }, ] [[package]] @@ -1796,13 +1851,17 @@ wheels = [ [[package]] name = "triton" -version = "3.2.0" +version = "3.4.0" source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "setuptools" }, +] wheels = [ - { url = "https://files.pythonhosted.org/packages/01/65/3ffa90e158a2c82f0716eee8d26a725d241549b7d7aaf7e4f44ac03ebd89/triton-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b3e54983cd51875855da7c68ec05c05cf8bb08df361b1d5b69e05e40b0c9bd62", size = 253090354, upload-time = "2025-01-22T19:12:21.872Z" }, - { url = "https://files.pythonhosted.org/packages/a7/2e/757d2280d4fefe7d33af7615124e7e298ae7b8e3bc4446cdb8e88b0f9bab/triton-3.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8009a1fb093ee8546495e96731336a33fb8856a38e45bb4ab6affd6dbc3ba220", size = 253157636, upload-time = "2025-01-22T19:12:51.322Z" }, - { url = "https://files.pythonhosted.org/packages/06/00/59500052cb1cf8cf5316be93598946bc451f14072c6ff256904428eaf03c/triton-3.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d9b215efc1c26fa7eefb9a157915c92d52e000d2bf83e5f69704047e63f125c", size = 253159365, upload-time = "2025-01-22T19:13:24.648Z" }, - { url = "https://files.pythonhosted.org/packages/c7/30/37a3384d1e2e9320331baca41e835e90a3767303642c7a80d4510152cbcf/triton-3.2.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e5dfa23ba84541d7c0a531dfce76d8bcd19159d50a4a8b14ad01e91734a5c1b0", size = 253154278, upload-time = "2025-01-22T19:13:54.221Z" }, + { url = "https://files.pythonhosted.org/packages/62/ee/0ee5f64a87eeda19bbad9bc54ae5ca5b98186ed00055281fd40fb4beb10e/triton-3.4.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7ff2785de9bc02f500e085420273bb5cc9c9bb767584a4aa28d6e360cec70128", size = 155430069, upload-time = "2025-07-30T19:58:21.715Z" }, + { url = "https://files.pythonhosted.org/packages/7d/39/43325b3b651d50187e591eefa22e236b2981afcebaefd4f2fc0ea99df191/triton-3.4.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b70f5e6a41e52e48cfc087436c8a28c17ff98db369447bcaff3b887a3ab4467", size = 155531138, upload-time = "2025-07-30T19:58:29.908Z" }, + { url = "https://files.pythonhosted.org/packages/d0/66/b1eb52839f563623d185f0927eb3530ee4d5ffe9d377cdaf5346b306689e/triton-3.4.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:31c1d84a5c0ec2c0f8e8a072d7fd150cab84a9c239eaddc6706c081bfae4eb04", size = 155560068, upload-time = "2025-07-30T19:58:37.081Z" }, + { url = "https://files.pythonhosted.org/packages/30/7b/0a685684ed5322d2af0bddefed7906674f67974aa88b0fae6e82e3b766f6/triton-3.4.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:00be2964616f4c619193cb0d1b29a99bd4b001d7dc333816073f92cf2a8ccdeb", size = 155569223, upload-time = "2025-07-30T19:58:44.017Z" }, + { url = "https://files.pythonhosted.org/packages/20/63/8cb444ad5cdb25d999b7d647abac25af0ee37d292afc009940c05b82dda0/triton-3.4.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7936b18a3499ed62059414d7df563e6c163c5e16c3773678a3ee3d417865035d", size = 155659780, upload-time = "2025-07-30T19:58:51.171Z" }, ] [[package]] From 4bd64ad6a8a37f0af05661777c62aae7f6bc6a31 Mon Sep 17 00:00:00 2001 From: Joseph Kleinhenz Date: Fri, 5 Sep 2025 19:48:47 -0700 Subject: [PATCH 7/8] don't modify default dtype for all tests --- src/e3tools/utils/__init__.py | 3 + src/e3tools/utils/_default_dtype_manager.py | 13 ++ tests/test_equivariance.py | 213 ++++++++++---------- 3 files changed, 127 insertions(+), 102 deletions(-) create mode 100644 src/e3tools/utils/__init__.py create mode 100644 src/e3tools/utils/_default_dtype_manager.py diff --git a/src/e3tools/utils/__init__.py b/src/e3tools/utils/__init__.py new file mode 100644 index 0000000..445c668 --- /dev/null +++ b/src/e3tools/utils/__init__.py @@ -0,0 +1,3 @@ +from ._default_dtype_manager import default_dtype_manager + +__all__ = ["default_dtype_manager"] diff --git a/src/e3tools/utils/_default_dtype_manager.py b/src/e3tools/utils/_default_dtype_manager.py new file mode 100644 index 0000000..f9c0126 --- /dev/null +++ b/src/e3tools/utils/_default_dtype_manager.py @@ -0,0 +1,13 @@ +from contextlib import contextmanager + +import torch + + +@contextmanager +def default_dtype_manager(dtype): + original_dtype = torch.get_default_dtype() + try: + torch.set_default_dtype(dtype) + yield + finally: + torch.set_default_dtype(original_dtype) diff --git a/tests/test_equivariance.py b/tests/test_equivariance.py index 7f2b777..9ab0d60 100644 --- a/tests/test_equivariance.py +++ b/tests/test_equivariance.py @@ -17,8 +17,7 @@ TransformerBlock, ) from e3tools import radius_graph - -torch.set_default_dtype(torch.float64) +from e3tools.utils import default_dtype_manager CONV_LAYERS = [Conv, SeparableConv] @@ -65,144 +64,154 @@ def apply_layer_rotation(layer: torch.nn.Module) -> Tuple[torch.Tensor, torch.Te @pytest.mark.parametrize("conv", CONV_LAYERS) def test_conv(conv): - irreps_in = e3nn.o3.Irreps("10x0e + 10x1o + 10x2e") - irreps_sh = irreps_in.spherical_harmonics(2) - edge_attr_dim = 10 + with default_dtype_manager(torch.float64): + irreps_in = e3nn.o3.Irreps("10x0e + 10x1o + 10x2e") + irreps_sh = irreps_in.spherical_harmonics(2) + edge_attr_dim = 10 - layer = conv(irreps_in, irreps_in, irreps_sh, edge_attr_dim=edge_attr_dim) + layer = conv(irreps_in, irreps_in, irreps_sh, edge_attr_dim=edge_attr_dim) - out_1, out_2 = apply_layer_rotation(layer) - assert torch.allclose(out_1, out_2, atol=1e-10) + out_1, out_2 = apply_layer_rotation(layer) + assert torch.allclose(out_1, out_2, atol=1e-10) @pytest.mark.parametrize("conv", CONV_LAYERS) def test_gated_conv(conv): - irreps_in = e3nn.o3.Irreps("10x0e + 10x1o + 10x2e") - irreps_sh = irreps_in.spherical_harmonics(2) - edge_attr_dim = 10 + with default_dtype_manager(torch.float64): + irreps_in = e3nn.o3.Irreps("10x0e + 10x1o + 10x2e") + irreps_sh = irreps_in.spherical_harmonics(2) + edge_attr_dim = 10 - wrapped = functools.partial(conv, irreps_sh=irreps_sh, edge_attr_dim=edge_attr_dim) + wrapped = functools.partial( + conv, irreps_sh=irreps_sh, edge_attr_dim=edge_attr_dim + ) - layer = Gated(wrapped, irreps_in=irreps_in, irreps_out=irreps_in) + layer = Gated(wrapped, irreps_in=irreps_in, irreps_out=irreps_in) - out_1, out_2 = apply_layer_rotation(layer) - assert torch.allclose(out_1, out_2, atol=1e-10) + out_1, out_2 = apply_layer_rotation(layer) + assert torch.allclose(out_1, out_2, atol=1e-10) @pytest.mark.parametrize("conv", CONV_LAYERS) def test_conv_block(conv): - irreps_in = e3nn.o3.Irreps("10x0e + 10x1o + 10x2e") - irreps_sh = irreps_in.spherical_harmonics(2) - edge_attr_dim = 10 + with default_dtype_manager(torch.float64): + irreps_in = e3nn.o3.Irreps("10x0e + 10x1o + 10x2e") + irreps_sh = irreps_in.spherical_harmonics(2) + edge_attr_dim = 10 - layer = ConvBlock( - irreps_in=irreps_in, - irreps_out=irreps_in, - irreps_sh=irreps_sh, - edge_attr_dim=edge_attr_dim, - conv=conv, - ) + layer = ConvBlock( + irreps_in=irreps_in, + irreps_out=irreps_in, + irreps_sh=irreps_sh, + edge_attr_dim=edge_attr_dim, + conv=conv, + ) - out_1, out_2 = apply_layer_rotation(layer) - assert torch.allclose(out_1, out_2, atol=1e-10) + out_1, out_2 = apply_layer_rotation(layer) + assert torch.allclose(out_1, out_2, atol=1e-10) @pytest.mark.parametrize("conv", CONV_LAYERS) def test_attention(conv): - irreps_in = e3nn.o3.Irreps("10x0e + 10x1o + 10x2e") - irreps_out = irreps_in - irreps_sh = irreps_in.spherical_harmonics(2) - irreps_key = irreps_in - irreps_query = irreps_in - edge_attr_dim = 10 - - layer = Attention( - irreps_in, - irreps_out, - irreps_sh, - irreps_query, - irreps_key, - edge_attr_dim, - conv=conv, - ) - - out_1, out_2 = apply_layer_rotation(layer) - assert torch.allclose(out_1, out_2, atol=1e-10) + with default_dtype_manager(torch.float64): + irreps_in = e3nn.o3.Irreps("10x0e + 10x1o + 10x2e") + irreps_out = irreps_in + irreps_sh = irreps_in.spherical_harmonics(2) + irreps_key = irreps_in + irreps_query = irreps_in + edge_attr_dim = 10 + + layer = Attention( + irreps_in, + irreps_out, + irreps_sh, + irreps_query, + irreps_key, + edge_attr_dim, + conv=conv, + ) + + out_1, out_2 = apply_layer_rotation(layer) + assert torch.allclose(out_1, out_2, atol=1e-10) @pytest.mark.parametrize("conv", [Conv, SeparableConv]) def test_multihead_attention(conv): - irreps_in = e3nn.o3.Irreps("10x0e + 10x1o + 10x2e") - irreps_out = irreps_in - irreps_sh = irreps_in.spherical_harmonics(2) - irreps_key = irreps_in - irreps_query = irreps_in - edge_attr_dim = 10 - num_heads = 2 - - layer = MultiheadAttention( - irreps_in, - irreps_out, - irreps_sh, - irreps_query, - irreps_key, - edge_attr_dim, - num_heads, - conv=conv, - ) - - out_1, out_2 = apply_layer_rotation(layer) - assert torch.allclose(out_1, out_2, atol=1e-10) + with default_dtype_manager(torch.float64): + irreps_in = e3nn.o3.Irreps("10x0e + 10x1o + 10x2e") + irreps_out = irreps_in + irreps_sh = irreps_in.spherical_harmonics(2) + irreps_key = irreps_in + irreps_query = irreps_in + edge_attr_dim = 10 + num_heads = 2 + + layer = MultiheadAttention( + irreps_in, + irreps_out, + irreps_sh, + irreps_query, + irreps_key, + edge_attr_dim, + num_heads, + conv=conv, + ) + + out_1, out_2 = apply_layer_rotation(layer) + assert torch.allclose(out_1, out_2, atol=1e-10) def test_layer_norm(): - irreps = e3nn.o3.Irreps("10x0e + 10x1o + 10x2e") + with default_dtype_manager(torch.float64): + irreps = e3nn.o3.Irreps("10x0e + 10x1o + 10x2e") - layer = LayerNorm(irreps) - rot = e3nn.o3.rand_matrix() - D = irreps.D_from_matrix(rot) + layer = LayerNorm(irreps) + rot = e3nn.o3.rand_matrix() + D = irreps.D_from_matrix(rot) - x = irreps.randn(10, -1) + x = irreps.randn(10, -1) - out_1 = layer(x @ D.T) - out_2 = layer(x) @ D.T + out_1 = layer(x @ D.T) + out_2 = layer(x) @ D.T - assert torch.allclose(out_1, out_2, atol=1e-10) + assert torch.allclose(out_1, out_2, atol=1e-10) def test_equivariant_mlp(): - irreps = e3nn.o3.Irreps("10x0e + 10x1o + 10x2e") - irreps_hidden = e3nn.o3.Irreps([(4 * mul, ir) for mul, ir in irreps]) + with default_dtype_manager(torch.float64): + irreps = e3nn.o3.Irreps("10x0e + 10x1o + 10x2e") + irreps_hidden = e3nn.o3.Irreps([(4 * mul, ir) for mul, ir in irreps]) - layer = EquivariantMLP( - irreps, irreps, [irreps_hidden, irreps_hidden], norm_layer=LayerNorm - ) + layer = EquivariantMLP( + irreps, irreps, [irreps_hidden, irreps_hidden], norm_layer=LayerNorm + ) - rot = e3nn.o3.rand_matrix() - D = irreps.D_from_matrix(rot) + rot = e3nn.o3.rand_matrix() + D = irreps.D_from_matrix(rot) - x = irreps.randn(10, -1) + x = irreps.randn(10, -1) - out_1 = layer(x @ D.T) - out_2 = layer(x) @ D.T + out_1 = layer(x @ D.T) + out_2 = layer(x) @ D.T - assert torch.allclose(out_1, out_2, atol=1e-10) + assert torch.allclose(out_1, out_2, atol=1e-10) def test_transformer(): - irreps_in = e3nn.o3.Irreps("10x0e + 10x1o + 10x2e") - irreps_out = irreps_in - irreps_sh = irreps_in.spherical_harmonics(2) - edge_attr_dim = 10 - num_heads = 2 - - layer = TransformerBlock( - irreps_in=irreps_in, - irreps_out=irreps_out, - irreps_sh=irreps_sh, - edge_attr_dim=edge_attr_dim, - num_heads=num_heads, - ) - - out_1, out_2 = apply_layer_rotation(layer) - assert torch.allclose(out_1, out_2, atol=1e-10) + with default_dtype_manager(torch.float64): + irreps_in = e3nn.o3.Irreps("10x0e + 10x1o + 10x2e") + irreps_out = irreps_in + irreps_sh = irreps_in.spherical_harmonics(2) + edge_attr_dim = 10 + num_heads = 2 + + layer = TransformerBlock( + irreps_in=irreps_in, + irreps_out=irreps_out, + irreps_sh=irreps_sh, + edge_attr_dim=edge_attr_dim, + num_heads=num_heads, + ) + + out_1, out_2 = apply_layer_rotation(layer) + assert torch.allclose(out_1, out_2, atol=1e-10) From 5b5b8b7a4f760bdd0c89ab4dd8cbdd9f22ea9714 Mon Sep 17 00:00:00 2001 From: Joseph Kleinhenz Date: Fri, 5 Sep 2025 20:21:23 -0700 Subject: [PATCH 8/8] assert torch.allclose -> torch.testing.assert_close --- tests/test_attention.py | 14 ++++++++------ tests/test_equivariance.py | 16 ++++++++-------- tests/test_extract_irreps.py | 10 +++++----- tests/test_fused_conv.py | 2 +- tests/test_layer_norm.py | 13 +++++++++---- tests/test_pack_unpack.py | 6 +++--- tests/test_scaling.py | 4 ++-- 7 files changed, 36 insertions(+), 29 deletions(-) diff --git a/tests/test_attention.py b/tests/test_attention.py index 73ca0cc..cf45624 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -87,7 +87,7 @@ def test_compile(model_name, causal, request, irreps_in, irreps_sh, edge_attr_di out = model(node_attr, edge_index, edge_attr, edge_sh, mask=mask) compiled_out = compiled_model(node_attr, edge_index, edge_attr, edge_sh, mask=mask) - assert torch.allclose(out, compiled_out) + torch.testing.assert_close(out, compiled_out) @pytest.mark.parametrize("model_name", ["singlehead_attention", "multihead_attention"]) @@ -124,20 +124,22 @@ def test_causal_vs_non_causal_attention( causal_mask = edge_index[0] <= edge_index[1] non_causal_out = model(node_attr, edge_index, edge_attr, edge_sh, mask=None) causal_out = model(node_attr, edge_index, edge_attr, edge_sh, mask=causal_mask) - assert torch.allclose(non_causal_out, causal_out) + torch.testing.assert_close(non_causal_out, causal_out) # Check that the outputs are the same for the nodes that do not have any causal edges. edge_index = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 0]]) causal_mask = edge_index[0] <= edge_index[1] non_causal_out = model(node_attr, edge_index, edge_attr, edge_sh, mask=None) causal_out = model(node_attr, edge_index, edge_attr, edge_sh, mask=causal_mask) - assert not torch.allclose(non_causal_out[:1], causal_out[:1]) - assert torch.allclose(non_causal_out[1:], causal_out[1:]) + with pytest.raises(AssertionError): + torch.testing.assert_close(non_causal_out[:1], causal_out[:1]) + torch.testing.assert_close(non_causal_out[1:], causal_out[1:]) # Check that the outputs are the same for the nodes that do not have any causal edges. edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 4]]) causal_mask = edge_index[0] <= edge_index[1] non_causal_out = model(node_attr, edge_index, edge_attr, edge_sh, mask=None) causal_out = model(node_attr, edge_index, edge_attr, edge_sh, mask=causal_mask) - assert not torch.allclose(non_causal_out[:2], causal_out[:2]) - assert torch.allclose(non_causal_out[2:], causal_out[2:]) + with pytest.raises(AssertionError): + torch.testing.assert_close(non_causal_out[:2], causal_out[:2]) + torch.testing.assert_close(non_causal_out[2:], causal_out[2:]) diff --git a/tests/test_equivariance.py b/tests/test_equivariance.py index 9ab0d60..17419e6 100644 --- a/tests/test_equivariance.py +++ b/tests/test_equivariance.py @@ -72,7 +72,7 @@ def test_conv(conv): layer = conv(irreps_in, irreps_in, irreps_sh, edge_attr_dim=edge_attr_dim) out_1, out_2 = apply_layer_rotation(layer) - assert torch.allclose(out_1, out_2, atol=1e-10) + torch.testing.assert_close(out_1, out_2, atol=1e-10, rtol=1e-10) @pytest.mark.parametrize("conv", CONV_LAYERS) @@ -89,7 +89,7 @@ def test_gated_conv(conv): layer = Gated(wrapped, irreps_in=irreps_in, irreps_out=irreps_in) out_1, out_2 = apply_layer_rotation(layer) - assert torch.allclose(out_1, out_2, atol=1e-10) + torch.testing.assert_close(out_1, out_2, atol=1e-10, rtol=1e-10) @pytest.mark.parametrize("conv", CONV_LAYERS) @@ -108,7 +108,7 @@ def test_conv_block(conv): ) out_1, out_2 = apply_layer_rotation(layer) - assert torch.allclose(out_1, out_2, atol=1e-10) + torch.testing.assert_close(out_1, out_2, atol=1e-10, rtol=1e-10) @pytest.mark.parametrize("conv", CONV_LAYERS) @@ -132,7 +132,7 @@ def test_attention(conv): ) out_1, out_2 = apply_layer_rotation(layer) - assert torch.allclose(out_1, out_2, atol=1e-10) + torch.testing.assert_close(out_1, out_2, atol=1e-10, rtol=1e-10) @pytest.mark.parametrize("conv", [Conv, SeparableConv]) @@ -158,7 +158,7 @@ def test_multihead_attention(conv): ) out_1, out_2 = apply_layer_rotation(layer) - assert torch.allclose(out_1, out_2, atol=1e-10) + torch.testing.assert_close(out_1, out_2, atol=1e-10, rtol=1e-10) def test_layer_norm(): @@ -174,7 +174,7 @@ def test_layer_norm(): out_1 = layer(x @ D.T) out_2 = layer(x) @ D.T - assert torch.allclose(out_1, out_2, atol=1e-10) + torch.testing.assert_close(out_1, out_2, atol=1e-10, rtol=1e-10) def test_equivariant_mlp(): @@ -194,7 +194,7 @@ def test_equivariant_mlp(): out_1 = layer(x @ D.T) out_2 = layer(x) @ D.T - assert torch.allclose(out_1, out_2, atol=1e-10) + torch.testing.assert_close(out_1, out_2, atol=1e-10, rtol=1e-10) def test_transformer(): @@ -214,4 +214,4 @@ def test_transformer(): ) out_1, out_2 = apply_layer_rotation(layer) - assert torch.allclose(out_1, out_2, atol=1e-10) + torch.testing.assert_close(out_1, out_2, atol=1e-10, rtol=1e-10) diff --git a/tests/test_extract_irreps.py b/tests/test_extract_irreps.py index 6e57602..ec1705e 100644 --- a/tests/test_extract_irreps.py +++ b/tests/test_extract_irreps.py @@ -32,15 +32,15 @@ def test_extract_irreps_simple(): layer = ExtractIrreps(irreps_in, "0e") output = layer(input) - assert torch.allclose(output, torch.as_tensor([1.0])) + torch.testing.assert_close(output, torch.as_tensor([1.0])) layer = ExtractIrreps(irreps_in, "1o") output = layer(input) - assert torch.allclose(output, torch.as_tensor([2.0, 3.0, 4.0])) + torch.testing.assert_close(output, torch.as_tensor([2.0, 3.0, 4.0])) layer = ExtractIrreps(irreps_in, "2e") output = layer(input) - assert torch.allclose(output, torch.as_tensor([5.0, 6.0, 7.0, 8.0, 9.0])) + torch.testing.assert_close(output, torch.as_tensor([5.0, 6.0, 7.0, 8.0, 9.0])) def test_extract_irreps_multiplicity(): @@ -50,8 +50,8 @@ def test_extract_irreps_multiplicity(): layer = ExtractIrreps(irreps_in, "0e") output = layer(input) - assert torch.allclose(output, torch.as_tensor([1.0, 5.0, 6.0])) + torch.testing.assert_close(output, torch.as_tensor([1.0, 5.0, 6.0])) layer = ExtractIrreps(irreps_in, "1o") output = layer(input) - assert torch.allclose(output, torch.as_tensor([2.0, 3.0, 4.0, 7.0, 8.0, 9.0])) + torch.testing.assert_close(output, torch.as_tensor([2.0, 3.0, 4.0, 7.0, 8.0, 9.0])) diff --git a/tests/test_fused_conv.py b/tests/test_fused_conv.py index bfbe68e..9edf712 100644 --- a/tests/test_fused_conv.py +++ b/tests/test_fused_conv.py @@ -82,4 +82,4 @@ def radial_nn(edge_attr_dim: int, num_elements: int) -> nn.Module: out = layer(node_attr, edge_index, edge_attr, edge_sh) out_fused = fused_layer(node_attr, edge_index, edge_attr, edge_sh) - assert torch.allclose(out, out_fused, rtol=1e-3) + torch.testing.assert_close(out, out_fused, rtol=1e-3, atol=1e-5) diff --git a/tests/test_layer_norm.py b/tests/test_layer_norm.py index 0d47ed3..1edccac 100644 --- a/tests/test_layer_norm.py +++ b/tests/test_layer_norm.py @@ -41,7 +41,7 @@ def test_layer_norm_compiled(irreps_in: str, seed: int, batch_size: int = 8): output = layer(input) output_compiled = layer_compiled(input) - assert torch.allclose(output, output_compiled) + torch.testing.assert_close(output, output_compiled) @pytest.mark.parametrize( @@ -57,8 +57,13 @@ def test_layer_norm(irreps_in: str): output = layer(input) for mul, ir, field in unpack_irreps(output, irreps_in): - sq_norms = field.norm(dim=-1, keepdim=True).pow(2).sum(dim=-1).mean(dim=-1) + sq_norms = ( + field.norm(dim=-1, keepdim=True) + .pow(2) + .sum(dim=-1) + .mean(dim=-1, keepdim=True) + ) if ir.l == 0 and ir.p == 1 and mul == 1: - assert torch.allclose(sq_norms, torch.as_tensor([0.0])) + torch.testing.assert_close(sq_norms, torch.as_tensor([0.0])) else: - assert torch.allclose(sq_norms, torch.as_tensor([1.0])) + torch.testing.assert_close(sq_norms, torch.as_tensor([1.0])) diff --git a/tests/test_pack_unpack.py b/tests/test_pack_unpack.py index 527e147..a771bc2 100644 --- a/tests/test_pack_unpack.py +++ b/tests/test_pack_unpack.py @@ -100,7 +100,7 @@ def test_inverse(irreps_in: str, factor: int, batch_size: int = 5): output = layer(input) recovered = inv_layer(output) - assert torch.allclose(input, recovered) + torch.testing.assert_close(input, recovered) @pytest.mark.parametrize( @@ -116,7 +116,7 @@ def test_axis_to_mul_compiled(irreps_in: str, factor: int, batch_size: int = 5): layer = AxisToMul(irreps_in, factor) layer_compiled = torch.compile(layer, fullgraph=True) - assert torch.allclose(layer(input), layer_compiled(input)) + torch.testing.assert_close(layer(input), layer_compiled(input)) @pytest.mark.parametrize( @@ -132,4 +132,4 @@ def test_mul_to_axis_compiled(irreps_in: str, factor: int, batch_size: int = 5): layer = MulToAxis(irreps_in, factor) layer_compiled = torch.compile(layer, fullgraph=True) - assert torch.allclose(layer(input), layer_compiled(input)) + torch.testing.assert_close(layer(input), layer_compiled(input)) diff --git a/tests/test_scaling.py b/tests/test_scaling.py index e0aba6a..e3ed914 100644 --- a/tests/test_scaling.py +++ b/tests/test_scaling.py @@ -30,7 +30,7 @@ def test_scale_irreps_by_one(irreps_in: str): weight = torch.ones(irreps_in.num_irreps) output = layer(input, weight) - assert torch.allclose(input, output) + torch.testing.assert_close(input, output) @pytest.mark.parametrize("irreps_in", ["0e + 1o", "0e + 1o + 2e", "3x1o + 2x2o"]) @@ -46,4 +46,4 @@ def test_scale_irreps_random(irreps_in: str): norm = e3nn.o3.Norm(irreps_in) factor = norm(output) / norm(input) - assert torch.allclose(factor, torch.abs(weight)) + torch.testing.assert_close(factor, torch.abs(weight))