Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ dev = [
"ruff",
"ipykernel>=6.30.1",
]
openequivariance = [
"openequivariance>=0.4.1",
]

[tool.ruff.lint]
ignore = ["F722"]
Expand Down
16 changes: 12 additions & 4 deletions src/e3tools/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from ._conv import Conv, ConvBlock, ExperimentalConv, SeparableConv, SeparableConvBlock
from ._conv import (
Conv,
ConvBlock,
SeparableConv,
SeparableConvBlock,
FusedConv,
FusedDepthwiseConv,
)
from ._linear import Linear
from ._gate import Gate, Gated, GateWrapper
from ._interaction import LinearSelfInteraction
from ._layer_norm import LayerNorm
from ._mlp import EquivariantMLP, ScalarMLP
from ._axis_to_mul import AxisToMul, MulToAxis
from ._tensor_product import ExperimentalTensorProduct, SeparableTensorProduct
from ._tensor_product import SeparableTensorProduct, DepthwiseTensorProduct
from ._transformer import Attention, MultiheadAttention, TransformerBlock
from ._extract_irreps import ExtractIrreps
from ._scaling import ScaleIrreps
Expand All @@ -15,10 +22,11 @@
"AxisToMul",
"Conv",
"ConvBlock",
"DepthwiseTensorProduct",
"EquivariantMLP",
"ExperimentalConv",
"ExperimentalTensorProduct",
"ExtractIrreps",
"FusedConv",
"FusedDepthwiseConv",
"Gate",
"GateWrapper",
"Gated",
Expand Down
167 changes: 162 additions & 5 deletions src/e3tools/nn/_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,146 @@
from ._gate import Gated
from ._interaction import LinearSelfInteraction
from ._mlp import ScalarMLP
from ._tensor_product import ExperimentalTensorProduct, SeparableTensorProduct
from ._tensor_product import SeparableTensorProduct, DepthwiseTensorProduct

try:
import openequivariance as oeq

openequivariance_available = True
except ImportError as e:
error_msg = str(e)
openequivariance_available = False


class FusedConv(nn.Module):
"""
Fused version of equivariant convolution layer with OpenEquivariance kernels.

ref: https://arxiv.org/abs/1802.08219
ref: https://arxiv.org/abs/2501.13986
"""

def __init__(
self,
irreps_in: Union[str, e3nn.o3.Irreps],
irreps_out: Union[str, e3nn.o3.Irreps],
irreps_sh: Union[str, e3nn.o3.Irreps],
edge_attr_dim: int,
radial_nn: Optional[Callable[..., nn.Module]] = None,
tensor_product: Optional[Callable[..., nn.Module]] = None,
):
"""
Parameters
----------
irreps_in: e3nn.o3.Irreps
Input node feature irreps
irreps_out: e3nn.o3.Irreps
Ouput node feature irreps
irreps_sh: e3nn.o3.Irreps
Edge spherical harmonic irreps
edge_attr_dim: int
Dimension of scalar edge attributes to be passed to radial_nn
radial_nn: Optional[Callable[..., nn.Module]]
Factory function for radial nn used to generate tensor product weights.
Should be callable as radial_nn(in_features, out_features)
if `None` then
```
functools.partial(
e3tools.nn.ScalarMLP,
hidden_features=[edge_attr_dim],
activation_layer=nn.SiLU,
)
```
is used.
tensor_product: Optional[Callable[..., nn.Module]]
Factory function for tensor product used to mix input node
representations with edge spherical harmonics.
Should be callable as `tensor_product(irreps_in, irreps_sh, irreps_out)`
and return an object with `weight_numel` property defined
If `None` then
```
functools.partial(
e3nn.o3.FullyConnectedTensorProduct
shared_weights=False,
internal_weights=False,
)
```
is used.
"""

super().__init__()

self.irreps_in = e3nn.o3.Irreps(irreps_in)
self.irreps_out = e3nn.o3.Irreps(irreps_out)
self.irreps_sh = e3nn.o3.Irreps(irreps_sh)

if tensor_product is None:
tensor_product = functools.partial(
e3nn.o3.FullyConnectedTensorProduct,
shared_weights=False,
internal_weights=False,
)

self.tp = tensor_product(irreps_in, irreps_sh, irreps_out)
if radial_nn is None:
radial_nn = functools.partial(
ScalarMLP,
hidden_features=[edge_attr_dim],
activation_layer=nn.SiLU,
)

self.radial_nn = radial_nn(edge_attr_dim, self.tp.weight_numel)

if not openequivariance_available:
raise ImportError(f"OpenEquivariance could not be imported:\n{error_msg}")

# Remove path weight and path shape from instructions.
oeq_instructions = [instruction[:5] for instruction in self.tp.instructions]
oeq_tpp = oeq.TPProblem(
self.tp.irreps_in1,
self.tp.irreps_in2,
self.tp.irreps_out,
oeq_instructions,
shared_weights=False,
internal_weights=False,
)
self.fused_tp = oeq.TensorProductConv(
oeq_tpp, torch_op=True, deterministic=False, use_opaque=False
)

def forward(
self,
node_attr: torch.Tensor,
edge_index: torch.Tensor,
edge_attr: torch.Tensor,
edge_sh: torch.Tensor,
) -> torch.Tensor:
"""
Computes the forward pass of the equivariant convolution.

Let N be the number of nodes, and E be the number of edges

Parameters
----------
node_attr: [N, irreps_in.dim]
edge_index: [2, E]
edge_attr: [E, edge_attr_dim]
edge_sh: [E, irreps_sh.dim]

Returns
-------
out: [N, irreps_out.dim]
"""
N = node_attr.shape[0]

src, dst = edge_index
radial_attr = self.radial_nn(edge_attr)
messages_agg = self.fused_tp(node_attr, edge_sh, radial_attr, dst, src)
num_neighbors = scatter(
torch.ones_like(src), src, dim=0, dim_size=N, reduce="sum"
)
out = messages_agg / num_neighbors.clamp_min(1).unsqueeze(1)
return out


class Conv(nn.Module):
Expand Down Expand Up @@ -92,10 +231,21 @@ def __init__(

self.radial_nn = radial_nn(edge_attr_dim, self.tp.weight_numel)

def apply_per_edge(self, node_attr_src, edge_attr, edge_sh):
def apply_per_edge(
self,
node_attr_src: torch.Tensor,
edge_attr: torch.Tensor,
edge_sh: torch.Tensor,
) -> torch.Tensor:
return self.tp(node_attr_src, edge_sh, self.radial_nn(edge_attr))

def forward(self, node_attr, edge_index, edge_attr, edge_sh):
def forward(
self,
node_attr: torch.Tensor,
edge_index: torch.Tensor,
edge_attr: torch.Tensor,
edge_sh: torch.Tensor,
) -> torch.Tensor:
"""
Computes the forward pass of the equivariant convolution.

Expand Down Expand Up @@ -137,12 +287,19 @@ def __init__(self, *args, **kwargs):
)


class ExperimentalConv(Conv):
class FusedDepthwiseConv(FusedConv):
"""
Equivariant convolution layer using separable tensor product

ref: https://arxiv.org/abs/1802.08219
ref: https://arxiv.org/abs/2206.11990
"""

def __init__(self, *args, **kwargs):
super().__init__(
*args,
**kwargs,
tensor_product=ExperimentalTensorProduct,
tensor_product=DepthwiseTensorProduct,
)


Expand Down
55 changes: 28 additions & 27 deletions src/e3tools/nn/_tensor_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

import e3nn
import e3nn.o3
import torch
from torch import nn

from ._linear import Linear


class SeparableTensorProduct(nn.Module):
class DepthwiseTensorProduct(nn.Module):
"""
Tensor product factored into depthwise and pointwise components
Depthwise tensor product

ref: https://arxiv.org/abs/2206.11990
ref: https://github.com/atomicarchitects/equiformer/blob/a4360ada2d213ba7b4d884335d3dc54a92b7a371/nets/graph_attention_transformer.py#L157
Expand All @@ -24,45 +25,46 @@ def __init__(
super().__init__()
self.irreps_in1 = e3nn.o3.Irreps(irreps_in1)
self.irreps_in2 = e3nn.o3.Irreps(irreps_in2)
self.irreps_out = e3nn.o3.Irreps(irreps_out)
irreps_out = e3nn.o3.Irreps(irreps_out)

irreps_out_dtp = []
instructions_dtp = []

for i, (mul, ir_in1) in enumerate(self.irreps_in1):
for j, (_, ir_in2) in enumerate(self.irreps_in2):
for ir_out in ir_in1 * ir_in2:
if ir_out in self.irreps_out or ir_out == e3nn.o3.Irrep(0, 1):
if ir_out in irreps_out or ir_out == e3nn.o3.Irrep(0, 1):
k = len(irreps_out_dtp)
irreps_out_dtp.append((mul, ir_out))
instructions_dtp.append((i, j, k, "uvu", True))

irreps_out_dtp = e3nn.o3.Irreps(irreps_out_dtp)

# depth wise
self.dtp = e3nn.o3.TensorProduct(
self.tp = e3nn.o3.TensorProduct(
irreps_in1,
irreps_in2,
irreps_out_dtp,
instructions_dtp,
internal_weights=False,
shared_weights=False,
)

# point wise
self.lin = Linear(irreps_out_dtp, self.irreps_out)

self.weight_numel = self.dtp.weight_numel

def forward(self, x, y, weight):
out = self.dtp(x, y, weight)
out = self.lin(out)
self.irreps_out = self.tp.irreps_out
self.weight_numel = self.tp.weight_numel
self.instructions = self.tp.instructions

def forward(
self, x: torch.Tensor, y: torch.Tensor, weight: torch.Tensor
) -> torch.Tensor:
out = self.tp(x, y, weight)
return out


class ExperimentalTensorProduct(nn.Module):
class SeparableTensorProduct(nn.Module):
"""
Compileable tensor product
Tensor product factored into depthwise and pointwise components

ref: https://arxiv.org/abs/2206.11990
ref: https://github.com/atomicarchitects/equiformer/blob/a4360ada2d213ba7b4d884335d3dc54a92b7a371/nets/graph_attention_transformer.py#L157
"""

def __init__(
Expand All @@ -76,18 +78,17 @@ def __init__(
self.irreps_in2 = e3nn.o3.Irreps(irreps_in2)
self.irreps_out = e3nn.o3.Irreps(irreps_out)

self.tp = e3nn.o3.FullTensorProductv2(self.irreps_in1, self.irreps_in2)

self.lin = Linear(
self.tp.irreps_out,
self.irreps_out,
internal_weights=False,
shared_weights=False,
# Depthwise and pointwise
self.dtp = DepthwiseTensorProduct(
self.irreps_in1, self.irreps_in2, self.irreps_out
)
self.lin = Linear(self.dtp.irreps_out, self.irreps_out)

self.weight_numel = self.lin.weight_numel
# For book-keeping.
self.instructions = self.dtp.instructions
self.weight_numel = self.dtp.weight_numel

def forward(self, x, y, weight):
out = self.tp(x, y)
out = self.lin(out, weight)
out = self.dtp(x, y, weight)
out = self.lin(out)
return out
3 changes: 3 additions & 0 deletions src/e3tools/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._default_dtype_manager import default_dtype_manager

__all__ = ["default_dtype_manager"]
13 changes: 13 additions & 0 deletions src/e3tools/utils/_default_dtype_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from contextlib import contextmanager

import torch


@contextmanager
def default_dtype_manager(dtype):
original_dtype = torch.get_default_dtype()
try:
torch.set_default_dtype(dtype)
yield
finally:
torch.set_default_dtype(original_dtype)
Loading