Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ mace = ["mace-torch>=0.3.14"]
mattersim = ["mattersim>=0.1.2"]
metatomic = ["metatomic-torch>=0.1.3", "metatrain[pet]>=2025.12"]
orb = ["orb-models>=0.5.2"]
sevenn = ["sevenn>=0.11.0"]
sevenn = ["sevenn[torchsim] @ git+https://github.com/MDIL-SNU/SevenNet.git"]
graphpes = ["graph-pes>=0.1", "mace-torch>=0.3.12"]
nequip = ["nequip>=0.16.2"]
fairchem = ["fairchem-core>=2.7", "scipy<1.17.0"]
Expand Down
283 changes: 16 additions & 267 deletions torch_sim/models/sevennet.py
Original file line number Diff line number Diff line change
@@ -1,286 +1,35 @@
"""TorchSim wrapper for SevenNet models."""
"""Wrapper for SevenNet models in TorchSim.

from __future__ import annotations
This module re-exports the SevenNet package's torch-sim integration for convenient
importing. The actual implementation is maintained in the `sevenn` package.

References:
- SevenNet Models Package: https://github.com/MDIL-SNU/SevenNet
"""

import traceback
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any

import torch

import torch_sim as ts
from torch_sim.elastic import voigt_6_to_full_3x3_stress
from torch_sim.models.interface import ModelInterface
from torch_sim.neighbors import torchsim_nl


if TYPE_CHECKING:
from collections.abc import Callable

from sevenn.nn.sequential import AtomGraphSequential

from torch_sim.typing import StateDict
from typing import Any


try:
import sevenn._keys as key
import torch
from sevenn.atom_graph_data import AtomGraphData
from sevenn.calculator import torch_script_type
from sevenn.util import load_checkpoint
from torch_geometric.loader.dataloader import Collater
from sevenn.torchsim import SevenNetModel

except ImportError as exc:
warnings.warn(f"SevenNet import failed: {traceback.format_exc()}", stacklevel=2)

class SevenNetModel(ModelInterface):
"""SevenNet model wrapper for torch-sim.
from torch_sim.models.interface import ModelInterface

This class is a placeholder for the SevenNetModel class.
It raises an ImportError if sevenn is not installed.
class SevenNetModel(ModelInterface): # type: ignore[no-redef]
"""Dummy SevenNet model wrapper for torch-sim to enable safe imports.

NOTE: This class is a placeholder when `sevenn` is not installed.
It raises an ImportError if accessed.
"""

def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None:
"""Dummy init for type checking."""
raise err


def _validate(model: AtomGraphSequential, modal: str) -> None:
if not model.type_map:
raise ValueError("type_map is missing")

if model.cutoff == 0.0:
raise ValueError("Model cutoff seems not initialized")

modal_map = model.modal_map
if modal_map:
modal_ava = list(modal_map)
if not modal:
raise ValueError(f"modal argument missing (avail: {modal_ava})")
if modal not in modal_ava:
raise ValueError(f"unknown modal {modal} (not in {modal_ava})")
elif not model.modal_map and modal:
warnings.warn(
f"modal={modal} is ignored as model has no modal_map",
stacklevel=2,
)


class SevenNetModel(ModelInterface):
"""Computes atomistic energies, forces and stresses using an SevenNet model.

This class wraps an SevenNet model to compute energies, forces, and stresses for
atomistic systems. It handles model initialization, configuration, and
provides a forward pass that accepts a SimState object and returns model
predictions.

Examples:
>>> model = SevenNetModel(model=loaded_sevenn_model)
>>> results = model(state)
"""

def __init__(
self,
model: AtomGraphSequential | str | Path,
*, # force remaining arguments to be keyword-only
modal: str | None = None,
neighbor_list_fn: Callable = torchsim_nl,
device: torch.device | str | None = None,
dtype: torch.dtype = torch.float32,
) -> None:
"""Initialize the SevenNetModel with specified configuration.

Loads an SevenNet model from either a model object or a model path.
Sets up the model parameters for subsequent use in energy and force calculations.

Args:
model (str | Path | AtomGraphSequential): The SevenNet model to wrap.
Accepts either 1) a path to a checkpoint file, 2) a model instance,
or 3) a pretrained model name.
modal (str | None): modal (fidelity) if given model is multi-modal model.
for 7net-mf-ompa, it should be one of 'mpa' (MPtrj + sAlex) or 'omat24'
(OMat24).
neighbor_list_fn (Callable): Neighbor list function to use.
Default is torch_nl_linked_cell.
device (torch.device | str | None): Device to run the model on
dtype (torch.dtype): Data type for computation

Raises:
ValueError: the model doesn't have a cutoff
ValueError: the model has a modal_map but modal is not given
ValueError: the modal given is not in the modal_map
ValueError: the model doesn't have a type_map
"""
super().__init__()

self._device = device or torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
if isinstance(self._device, str):
self._device = torch.device(self._device)

if dtype is not torch.float32:
warnings.warn(
"SevenNetModel currently only supports"
"float32, but received different dtype",
UserWarning,
stacklevel=2,
)

if isinstance(model, (str, Path)):
cp = load_checkpoint(model)
model = cp.build_model()

_validate(model, modal)

model.eval_type_map = torch.tensor(data=True)

self._dtype = dtype
self._memory_scales_with = "n_atoms_x_density"
self._compute_stress = True
self._compute_forces = True

model.set_is_batch_data(True)
model_loaded = model
self.cutoff = torch.tensor(model.cutoff)
self.neighbor_list_fn = neighbor_list_fn

self.model = model_loaded
self.modal = modal

self.model = model.to(self._device)
self.model = self.model.eval()

if self.dtype is not None:
self.model = self.model.to(dtype=self.dtype)

self.implemented_properties = ["energy", "forces", "stress"]

def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]:
"""Perform forward pass to compute energies, forces, and other properties.

Takes a simulation state and computes the properties implemented by the model,
such as energy, forces, and stresses.

Args:
state (SimState | StateDict): State object containing positions, cells,
atomic numbers, and other system information. If a dictionary is provided,
it will be converted to a SimState.

Returns:
dict: Model predictions, which may include:
- energy (torch.Tensor): Energy with shape [batch_size]
- forces (torch.Tensor): Forces with shape [n_atoms, 3]
- stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3],
if compute_stress is True

Notes:
The state is automatically transferred to the model's device if needed.
All output tensors are detached from the computation graph.
"""
sim_state = (
state
if isinstance(state, ts.SimState)
else ts.SimState(**state, masses=torch.ones_like(state["positions"]))
)

if sim_state.device != self._device:
sim_state = sim_state.to(self._device)

# TODO: is this clone necessary?
sim_state = sim_state.clone()

# Batched neighbor list using linked-cell algorithm with row-vector cell
n_systems = sim_state.system_idx.max().item() + 1
edge_index, mapping_system, unit_shifts = self.neighbor_list_fn(
sim_state.positions,
sim_state.row_vector_cell,
sim_state.pbc,
self.cutoff,
sim_state.system_idx,
)

# Build per-system SevenNet AtomGraphData by slicing the global NL
n_atoms_per_system = sim_state.system_idx.bincount()
stride = torch.cat(
(
torch.tensor([0], device=self.device, dtype=torch.long),
n_atoms_per_system.cumsum(0),
)
)

data_list = []
for sys_idx in range(n_systems):
sys_start = stride[sys_idx].item()
sys_end = stride[sys_idx + 1].item()

pos = sim_state.positions[sys_start:sys_end]
row_vector_cell = sim_state.row_vector_cell[sys_idx]
atomic_nums = sim_state.atomic_numbers[sys_start:sys_end]

mask = mapping_system == sys_idx
edge_idx_sys_global = edge_index[:, mask]
unit_shifts_sys = unit_shifts[mask]

# Convert global indices to local indices
edge_idx = edge_idx_sys_global - sys_start
shifts = torch.mm(unit_shifts_sys, row_vector_cell)
edge_vec = pos[edge_idx[1]] - pos[edge_idx[0]] + shifts
vol = torch.det(row_vector_cell)

data = {
key.NODE_FEATURE: atomic_nums,
key.ATOMIC_NUMBERS: atomic_nums.to(dtype=torch.int64, device=self.device),
key.POS: pos,
key.EDGE_IDX: edge_idx,
key.EDGE_VEC: edge_vec,
key.CELL: row_vector_cell,
key.CELL_SHIFT: unit_shifts_sys,
key.CELL_VOLUME: vol,
key.NUM_ATOMS: torch.tensor(len(atomic_nums), device=self.device),
key.DATA_MODALITY: self.modal,
}
data[key.INFO] = {}

data = AtomGraphData(**data)
data_list.append(data)

batched_data = Collater([], follow_batch=None, exclude_keys=None)(data_list)
batched_data.to(self.device)

if isinstance(self.model, torch_script_type):
batched_data[key.NODE_FEATURE] = torch.tensor(
[self.type_map[z.item()] for z in data[key.NODE_FEATURE]],
dtype=torch.int64,
device=self.device,
)
batched_data[key.POS].requires_grad_(
requires_grad=True
) # backward compatibility
batched_data[key.EDGE_VEC].requires_grad_(requires_grad=True)
batched_data = batched_data.to_dict()
del batched_data["data_info"]

output = self.model(batched_data)

results: dict[str, torch.Tensor] = {}
energy = output[key.PRED_TOTAL_ENERGY]
if energy is not None:
results["energy"] = energy.detach()
else:
results["energy"] = torch.zeros(
sim_state.system_idx.max().item() + 1, device=self.device
)

forces = output[key.PRED_FORCE]
if forces is not None:
results["forces"] = forces.detach()

stress = output[key.PRED_STRESS]
if stress is not None:
results["stress"] = -voigt_6_to_full_3x3_stress(
stress.detach()[..., [0, 1, 2, 4, 5, 3]]
)

return results
__all__ = ["SevenNetModel"]
Loading