diff --git a/pyproject.toml b/pyproject.toml index 1280e577..1f474cf9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,7 @@ mace = ["mace-torch>=0.3.14"] mattersim = ["mattersim>=0.1.2"] metatomic = ["metatomic-torch>=0.1.3", "metatrain[pet]>=2025.12"] orb = ["orb-models>=0.5.2"] -sevenn = ["sevenn>=0.11.0"] +sevenn = ["sevenn[torchsim] @ git+https://github.com/MDIL-SNU/SevenNet.git"] graphpes = ["graph-pes>=0.1", "mace-torch>=0.3.12"] nequip = ["nequip>=0.16.2"] fairchem = ["fairchem-core>=2.7", "scipy<1.17.0"] diff --git a/torch_sim/models/sevennet.py b/torch_sim/models/sevennet.py index 81b76820..0d62fe60 100644 --- a/torch_sim/models/sevennet.py +++ b/torch_sim/models/sevennet.py @@ -1,44 +1,30 @@ -"""TorchSim wrapper for SevenNet models.""" +"""Wrapper for SevenNet models in TorchSim. -from __future__ import annotations +This module re-exports the SevenNet package's torch-sim integration for convenient +importing. The actual implementation is maintained in the `sevenn` package. + +References: + - SevenNet Models Package: https://github.com/MDIL-SNU/SevenNet +""" import traceback import warnings -from pathlib import Path -from typing import TYPE_CHECKING, Any - -import torch - -import torch_sim as ts -from torch_sim.elastic import voigt_6_to_full_3x3_stress -from torch_sim.models.interface import ModelInterface -from torch_sim.neighbors import torchsim_nl - - -if TYPE_CHECKING: - from collections.abc import Callable - - from sevenn.nn.sequential import AtomGraphSequential - - from torch_sim.typing import StateDict +from typing import Any try: - import sevenn._keys as key - import torch - from sevenn.atom_graph_data import AtomGraphData - from sevenn.calculator import torch_script_type - from sevenn.util import load_checkpoint - from torch_geometric.loader.dataloader import Collater + from sevenn.torchsim import SevenNetModel except ImportError as exc: warnings.warn(f"SevenNet import failed: {traceback.format_exc()}", stacklevel=2) - class SevenNetModel(ModelInterface): - """SevenNet model wrapper for torch-sim. + from torch_sim.models.interface import ModelInterface - This class is a placeholder for the SevenNetModel class. - It raises an ImportError if sevenn is not installed. + class SevenNetModel(ModelInterface): # type: ignore[no-redef] + """Dummy SevenNet model wrapper for torch-sim to enable safe imports. + + NOTE: This class is a placeholder when `sevenn` is not installed. + It raises an ImportError if accessed. """ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: @@ -46,241 +32,4 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None: raise err -def _validate(model: AtomGraphSequential, modal: str) -> None: - if not model.type_map: - raise ValueError("type_map is missing") - - if model.cutoff == 0.0: - raise ValueError("Model cutoff seems not initialized") - - modal_map = model.modal_map - if modal_map: - modal_ava = list(modal_map) - if not modal: - raise ValueError(f"modal argument missing (avail: {modal_ava})") - if modal not in modal_ava: - raise ValueError(f"unknown modal {modal} (not in {modal_ava})") - elif not model.modal_map and modal: - warnings.warn( - f"modal={modal} is ignored as model has no modal_map", - stacklevel=2, - ) - - -class SevenNetModel(ModelInterface): - """Computes atomistic energies, forces and stresses using an SevenNet model. - - This class wraps an SevenNet model to compute energies, forces, and stresses for - atomistic systems. It handles model initialization, configuration, and - provides a forward pass that accepts a SimState object and returns model - predictions. - - Examples: - >>> model = SevenNetModel(model=loaded_sevenn_model) - >>> results = model(state) - """ - - def __init__( - self, - model: AtomGraphSequential | str | Path, - *, # force remaining arguments to be keyword-only - modal: str | None = None, - neighbor_list_fn: Callable = torchsim_nl, - device: torch.device | str | None = None, - dtype: torch.dtype = torch.float32, - ) -> None: - """Initialize the SevenNetModel with specified configuration. - - Loads an SevenNet model from either a model object or a model path. - Sets up the model parameters for subsequent use in energy and force calculations. - - Args: - model (str | Path | AtomGraphSequential): The SevenNet model to wrap. - Accepts either 1) a path to a checkpoint file, 2) a model instance, - or 3) a pretrained model name. - modal (str | None): modal (fidelity) if given model is multi-modal model. - for 7net-mf-ompa, it should be one of 'mpa' (MPtrj + sAlex) or 'omat24' - (OMat24). - neighbor_list_fn (Callable): Neighbor list function to use. - Default is torch_nl_linked_cell. - device (torch.device | str | None): Device to run the model on - dtype (torch.dtype): Data type for computation - - Raises: - ValueError: the model doesn't have a cutoff - ValueError: the model has a modal_map but modal is not given - ValueError: the modal given is not in the modal_map - ValueError: the model doesn't have a type_map - """ - super().__init__() - - self._device = device or torch.device( - "cuda" if torch.cuda.is_available() else "cpu" - ) - if isinstance(self._device, str): - self._device = torch.device(self._device) - - if dtype is not torch.float32: - warnings.warn( - "SevenNetModel currently only supports" - "float32, but received different dtype", - UserWarning, - stacklevel=2, - ) - - if isinstance(model, (str, Path)): - cp = load_checkpoint(model) - model = cp.build_model() - - _validate(model, modal) - - model.eval_type_map = torch.tensor(data=True) - - self._dtype = dtype - self._memory_scales_with = "n_atoms_x_density" - self._compute_stress = True - self._compute_forces = True - - model.set_is_batch_data(True) - model_loaded = model - self.cutoff = torch.tensor(model.cutoff) - self.neighbor_list_fn = neighbor_list_fn - - self.model = model_loaded - self.modal = modal - - self.model = model.to(self._device) - self.model = self.model.eval() - - if self.dtype is not None: - self.model = self.model.to(dtype=self.dtype) - - self.implemented_properties = ["energy", "forces", "stress"] - - def forward(self, state: ts.SimState | StateDict) -> dict[str, torch.Tensor]: - """Perform forward pass to compute energies, forces, and other properties. - - Takes a simulation state and computes the properties implemented by the model, - such as energy, forces, and stresses. - - Args: - state (SimState | StateDict): State object containing positions, cells, - atomic numbers, and other system information. If a dictionary is provided, - it will be converted to a SimState. - - Returns: - dict: Model predictions, which may include: - - energy (torch.Tensor): Energy with shape [batch_size] - - forces (torch.Tensor): Forces with shape [n_atoms, 3] - - stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3], - if compute_stress is True - - Notes: - The state is automatically transferred to the model's device if needed. - All output tensors are detached from the computation graph. - """ - sim_state = ( - state - if isinstance(state, ts.SimState) - else ts.SimState(**state, masses=torch.ones_like(state["positions"])) - ) - - if sim_state.device != self._device: - sim_state = sim_state.to(self._device) - - # TODO: is this clone necessary? - sim_state = sim_state.clone() - - # Batched neighbor list using linked-cell algorithm with row-vector cell - n_systems = sim_state.system_idx.max().item() + 1 - edge_index, mapping_system, unit_shifts = self.neighbor_list_fn( - sim_state.positions, - sim_state.row_vector_cell, - sim_state.pbc, - self.cutoff, - sim_state.system_idx, - ) - - # Build per-system SevenNet AtomGraphData by slicing the global NL - n_atoms_per_system = sim_state.system_idx.bincount() - stride = torch.cat( - ( - torch.tensor([0], device=self.device, dtype=torch.long), - n_atoms_per_system.cumsum(0), - ) - ) - - data_list = [] - for sys_idx in range(n_systems): - sys_start = stride[sys_idx].item() - sys_end = stride[sys_idx + 1].item() - - pos = sim_state.positions[sys_start:sys_end] - row_vector_cell = sim_state.row_vector_cell[sys_idx] - atomic_nums = sim_state.atomic_numbers[sys_start:sys_end] - - mask = mapping_system == sys_idx - edge_idx_sys_global = edge_index[:, mask] - unit_shifts_sys = unit_shifts[mask] - - # Convert global indices to local indices - edge_idx = edge_idx_sys_global - sys_start - shifts = torch.mm(unit_shifts_sys, row_vector_cell) - edge_vec = pos[edge_idx[1]] - pos[edge_idx[0]] + shifts - vol = torch.det(row_vector_cell) - - data = { - key.NODE_FEATURE: atomic_nums, - key.ATOMIC_NUMBERS: atomic_nums.to(dtype=torch.int64, device=self.device), - key.POS: pos, - key.EDGE_IDX: edge_idx, - key.EDGE_VEC: edge_vec, - key.CELL: row_vector_cell, - key.CELL_SHIFT: unit_shifts_sys, - key.CELL_VOLUME: vol, - key.NUM_ATOMS: torch.tensor(len(atomic_nums), device=self.device), - key.DATA_MODALITY: self.modal, - } - data[key.INFO] = {} - - data = AtomGraphData(**data) - data_list.append(data) - - batched_data = Collater([], follow_batch=None, exclude_keys=None)(data_list) - batched_data.to(self.device) - - if isinstance(self.model, torch_script_type): - batched_data[key.NODE_FEATURE] = torch.tensor( - [self.type_map[z.item()] for z in data[key.NODE_FEATURE]], - dtype=torch.int64, - device=self.device, - ) - batched_data[key.POS].requires_grad_( - requires_grad=True - ) # backward compatibility - batched_data[key.EDGE_VEC].requires_grad_(requires_grad=True) - batched_data = batched_data.to_dict() - del batched_data["data_info"] - - output = self.model(batched_data) - - results: dict[str, torch.Tensor] = {} - energy = output[key.PRED_TOTAL_ENERGY] - if energy is not None: - results["energy"] = energy.detach() - else: - results["energy"] = torch.zeros( - sim_state.system_idx.max().item() + 1, device=self.device - ) - - forces = output[key.PRED_FORCE] - if forces is not None: - results["forces"] = forces.detach() - - stress = output[key.PRED_STRESS] - if stress is not None: - results["stress"] = -voigt_6_to_full_3x3_stress( - stress.detach()[..., [0, 1, 2, 4, 5, 3]] - ) - - return results +__all__ = ["SevenNetModel"]