diff --git a/.gitignore b/.gitignore index 3be3867dd..01cde4042 100644 --- a/.gitignore +++ b/.gitignore @@ -40,3 +40,6 @@ uv.lock # IDE .vscode/ + +# duecredit +.duecredit.p diff --git a/README.md b/README.md index 1fa32906d..88ef33a87 100644 --- a/README.md +++ b/README.md @@ -140,3 +140,27 @@ TorchSim is released under an [MIT license](LICENSE). ## Citation If you use TorchSim in your research, please cite our [publication](https://iopscience.iop.org/article/10.1088/3050-287X/ae1799). + +```bibtex +@article{cohen2025torchsim, + title={TorchSim: An efficient atomistic simulation engine in PyTorch}, + author={Cohen, Orion and Riebesell, Janosh and Goodall, Rhys and Kolluru, Adeesh and Falletta, Stefano and Krause, Joseph and Colindres, Jorge and Ceder, Gerbrand and Gangan, Abhijeet S}, + journal={AI for Science}, + volume={1}, + number={2}, + pages={025003}, + year={2025}, + publisher={IOP Publishing}, + doi={10.1088/3050-287X/ae1799} +} +``` + +## Due Credit + +We aim to recognize all [duecredit](https://github.com/duecredit/duecredit) for the decades of work that TorchSim builds on top of, an automated list of references can be obtained for the package by running `DUECREDIT_ENABLE=yes uv run --with . --extra docs --extra test python -m duecredit <(printf 'import pytest\nraise SystemExit(pytest.main(["-q"]))\n')`. This list is incomplete and we welcome PRs to help improve our citation coverage. + +To collect citations for a specific tutorial run, for example autobatching, use: + +```sh +DUECREDIT_ENABLE=yes uv run --with . --extra docs --extra test python -m duecredit examples/tutorials/autobatching_tutorial.py +``` diff --git a/pyproject.toml b/pyproject.toml index 5fa283a89..1280e577a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,18 +37,14 @@ dependencies = [ [project.optional-dependencies] test = [ - "ase>=3.26", - "phonopy>=2.37.0", + "torch-sim-atomistic[io,symmetry]", "platformdirs>=4.0.0", "psutil>=7.0.0", - "pymatgen>=2025.6.14", "pytest-cov>=6", "pytest>=8", - "moyopy>=0.3", - "spglib>=2.6", ] io = ["ase>=3.26", "phonopy>=2.37.0", "pymatgen>=2025.6.14"] -symmetry = ["moyopy>=0.3"] +symmetry = ["moyopy>=0.3", "spglib>=2.6"] mace = ["mace-torch>=0.3.14"] mattersim = ["mattersim>=0.1.2"] metatomic = ["metatomic-torch>=0.1.3", "metatrain[pet]>=2025.12"] @@ -59,6 +55,7 @@ nequip = ["nequip>=0.16.2"] fairchem = ["fairchem-core>=2.7", "scipy<1.17.0"] docs = [ "autodoc_pydantic==2.2.0", + "duecredit>=0.11", "furo==2024.8.6", "ipython==8.34.0", "ipykernel==6.30.1", diff --git a/tests/models/test_fairchem.py b/tests/models/test_fairchem.py index cd93d05f0..b992f8172 100644 --- a/tests/models/test_fairchem.py +++ b/tests/models/test_fairchem.py @@ -20,7 +20,7 @@ import torch_sim as ts from torch_sim.models.fairchem import FairChemModel -except ImportError: +except (ImportError, OSError, RuntimeError, AttributeError, ValueError): pytest.skip( f"FairChem not installed: {traceback.format_exc()}", allow_module_level=True ) diff --git a/tests/models/test_fairchem_legacy.py b/tests/models/test_fairchem_legacy.py index 02b642ab7..aa13a5d43 100644 --- a/tests/models/test_fairchem_legacy.py +++ b/tests/models/test_fairchem_legacy.py @@ -20,7 +20,7 @@ from torch_sim.models.fairchem_legacy import FairChemV1Model -except ImportError: +except (ImportError, OSError, RuntimeError, AttributeError, ValueError): pytest.skip( f"FairChem not installed: {traceback.format_exc()}", allow_module_level=True ) diff --git a/tests/models/test_graphpes_framework.py b/tests/models/test_graphpes_framework.py index 4c4617afe..5e0379389 100644 --- a/tests/models/test_graphpes_framework.py +++ b/tests/models/test_graphpes_framework.py @@ -11,14 +11,15 @@ make_model_calculator_consistency_test, make_validate_model_outputs_test, ) -from torch_sim.models.graphpes import GraphPESWrapper from torch_sim.testing import CONSISTENCY_SIMSTATES try: from graph_pes.atomic_graph import AtomicGraph, to_batch from graph_pes.models import LennardJones, SchNet, TensorNet -except ImportError: + + from torch_sim.models.graphpes_framework import GraphPESWrapper +except (ImportError, OSError, RuntimeError, AttributeError, ValueError): pytest.skip( f"graph-pes not installed: {traceback.format_exc()}", allow_module_level=True ) diff --git a/tests/models/test_mace.py b/tests/models/test_mace.py index ef7dc60d0..9754a0d46 100644 --- a/tests/models/test_mace.py +++ b/tests/models/test_mace.py @@ -10,7 +10,6 @@ make_model_calculator_consistency_test, make_validate_model_outputs_test, ) -from torch_sim.models.mace import MaceUrls from torch_sim.testing import SIMSTATE_BULK_GENERATORS, SIMSTATE_MOLECULE_GENERATORS @@ -18,8 +17,9 @@ from mace.calculators import MACECalculator from mace.calculators.foundations_models import mace_mp, mace_off - from torch_sim.models.mace import MaceModel -except (ImportError, ValueError): + from torch_sim.models.mace import MaceModel, MaceUrls + +except (ImportError, OSError, RuntimeError, AttributeError, ValueError): pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True) # mace_omol is optional (added in newer MACE versions) @@ -28,7 +28,7 @@ raw_mace_omol = mace_omol(model="extra_large", return_raw_model=True) HAS_MACE_OMOL = True -except ImportError: +except (ImportError, OSError, RuntimeError, AttributeError, ValueError): raw_mace_omol = None HAS_MACE_OMOL = False diff --git a/tests/models/test_mattersim.py b/tests/models/test_mattersim.py index e53a5cdc0..92a2bcbc9 100644 --- a/tests/models/test_mattersim.py +++ b/tests/models/test_mattersim.py @@ -19,7 +19,7 @@ from torch_sim.models.mattersim import MatterSimModel -except ImportError: +except (ImportError, OSError, RuntimeError, AttributeError, ValueError): pytest.skip( f"mattersim not installed: {traceback.format_exc()}", allow_module_level=True ) diff --git a/tests/models/test_orb.py b/tests/models/test_orb.py index 118f6d307..33415da7e 100644 --- a/tests/models/test_orb.py +++ b/tests/models/test_orb.py @@ -10,7 +10,6 @@ make_validate_model_outputs_test, ) from torch_sim import SimState -from torch_sim.models.orb import cell_to_cellpar from torch_sim.testing import SIMSTATE_GENERATORS @@ -18,7 +17,8 @@ from orb_models.forcefield import pretrained from orb_models.forcefield.calculator import ORBCalculator - from torch_sim.models.orb import OrbModel + from torch_sim.models.orb import OrbModel, cell_to_cellpar + except ImportError: pytest.skip(f"ORB not installed: {traceback.format_exc()}", allow_module_level=True) diff --git a/tests/test_elastic.py b/tests/test_elastic.py index d729175dc..6ad6af76f 100644 --- a/tests/test_elastic.py +++ b/tests/test_elastic.py @@ -21,7 +21,7 @@ from mace.calculators.foundations_models import mace_mp from torch_sim.models.mace import MaceModel -except ImportError: +except (ImportError, OSError, RuntimeError, AttributeError, ValueError): pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True) diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index 6d5deb5d7..17da2fda4 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -467,24 +467,31 @@ def test_vesin_nl_edge_cases() -> None: assert len(mapping[0]) > 0 # Should find neighbors -def test_torchsim_nl_availability() -> None: +def test_vesin_nl_availability() -> None: """Test that availability flags are correctly set.""" assert isinstance(neighbors.VESIN_AVAILABLE, bool) + + assert callable(neighbors.vesin_nl) + assert callable(neighbors.vesin_nl_ts) + + if not neighbors.VESIN_AVAILABLE: + with pytest.raises(ImportError, match="Vesin is not installed"): + neighbors.vesin_nl() + with pytest.raises(ImportError, match="Vesin is not installed"): + neighbors.vesin_nl_ts() + + +def test_alchemiops_nl_availability() -> None: assert isinstance(neighbors.ALCHEMIOPS_AVAILABLE, bool) - if neighbors.VESIN_AVAILABLE: - assert neighbors.VesinNeighborList is not None - assert neighbors.VesinNeighborListTorch is not None - else: - assert neighbors.VesinNeighborList is None - assert neighbors.VesinNeighborListTorch is None + assert callable(neighbors.alchemiops_nl_n2) + assert callable(neighbors.alchemiops_nl_cell_list) - if neighbors.ALCHEMIOPS_AVAILABLE: - assert neighbors.alchemiops_nl_n2 is not None - assert neighbors.alchemiops_nl_cell_list is not None - else: - assert neighbors.alchemiops_nl_n2 is None - assert neighbors.alchemiops_nl_cell_list is None + if not neighbors.ALCHEMIOPS_AVAILABLE: + with pytest.raises(ImportError, match="nvalchemiops is not installed"): + neighbors.alchemiops_nl_n2() + with pytest.raises(ImportError, match="nvalchemiops is not installed"): + neighbors.alchemiops_nl_cell_list() @pytest.mark.skipif( diff --git a/tests/test_optimizers_vs_ase.py b/tests/test_optimizers_vs_ase.py index aa509138a..2c510a4a9 100644 --- a/tests/test_optimizers_vs_ase.py +++ b/tests/test_optimizers_vs_ase.py @@ -11,7 +11,14 @@ import torch_sim as ts from tests.conftest import DTYPE -from torch_sim.models.mace import MaceModel, MaceUrls + + +try: + from mace.calculators.foundations_models import mace_mp + + from torch_sim.models.mace import MaceModel, MaceUrls +except (ImportError, OSError, RuntimeError, AttributeError, ValueError): + pytest.skip(f"MACE not installed: {traceback.format_exc()}", allow_module_level=True) if TYPE_CHECKING: @@ -21,13 +28,6 @@ @pytest.fixture def ts_mace_mpa() -> MaceModel: """Provides a MACE MP model instance for the optimizer tests.""" - try: - from mace.calculators.foundations_models import mace_mp - except ImportError: - pytest.skip( - f"MACE not installed: {traceback.format_exc()}", allow_module_level=True - ) - # Use float64 for potentially higher precision needed in optimization dtype = getattr(torch, dtype_str := "float64") raw_mace = mace_mp( @@ -45,13 +45,6 @@ def ts_mace_mpa() -> MaceModel: @pytest.fixture def ase_mace_mpa() -> "MACECalculator": """Provides an ASE MACECalculator instance using mace_mp.""" - try: - from mace.calculators.foundations_models import mace_mp - except ImportError: - pytest.skip( - f"MACE not installed: {traceback.format_exc()}", allow_module_level=True - ) - # Ensure dtype matches the one used in the torch-sim fixture (float64) return mace_mp(model=MaceUrls.mace_mp_small, default_dtype="float64") diff --git a/torch_sim/__init__.py b/torch_sim/__init__.py index b56973170..fc7e52666 100644 --- a/torch_sim/__init__.py +++ b/torch_sim/__init__.py @@ -5,7 +5,7 @@ from datetime import datetime from importlib.metadata import version -import torch_sim as ts +import torch_sim._duecredit from torch_sim import ( autobatching, constraints, @@ -100,3 +100,5 @@ SCRIPTS_DIR = f"{ROOT}/examples" __version__ = version("torch-sim-atomistic") + +import torch_sim._citations # noqa: E402 diff --git a/torch_sim/_citations.py b/torch_sim/_citations.py new file mode 100644 index 000000000..3cfc15c0a --- /dev/null +++ b/torch_sim/_citations.py @@ -0,0 +1,47 @@ +"""Package-level duecredit citations for TorchSim and its core dependencies. + +This module must be imported at the end of torch_sim.__init__ so that all +packages are fully loaded before citations are registered. +""" + +from torch_sim._duecredit import BibTeX, due + + +if due is not None: + due.cite( + BibTeX( + """@article{cohen2025torchsim, + title={TorchSim: An efficient atomistic simulation engine in PyTorch}, + author={Cohen, Orion and Riebesell, Janosh and Goodall, Rhys and + Kolluru, Adeesh and Falletta, Stefano and Krause, Joseph and + Colindres, Jorge and Ceder, Gerbrand and Gangan, Abhijeet S}, + journal={AI for Science}, + volume={1}, + number={2}, + pages={025003}, + year={2025}, + publisher={IOP Publishing}, + doi={10.1088/3050-287X/ae1799} +}""" + ), + description="TorchSim simulation engine", + path="torch_sim", + cite_module=True, + ) + due.cite( + BibTeX( + """@inproceedings{paszke2019pytorch, + title={PyTorch: An Imperative Style, High-Performance Deep Learning Library}, + author={Paszke, Adam and Gross, Sam and Massa, Francisco and + Lerer, Adam and Bradbury, James and Chanan, Gregory and + Killeen, Trevor and Lin, Zeming and Gimelshein, Natalia and + Antiga, Luca and others}, + booktitle={Advances in Neural Information Processing Systems}, + volume={32}, + year={2019} +}""" + ), + description="PyTorch deep learning framework", + path="torch", + cite_module=True, + ) diff --git a/torch_sim/_duecredit.py b/torch_sim/_duecredit.py new file mode 100644 index 000000000..b82da53ec --- /dev/null +++ b/torch_sim/_duecredit.py @@ -0,0 +1,66 @@ +"""Stub file for a guaranteed safe import of duecredit constructs.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + + +if TYPE_CHECKING: + from collections.abc import Callable + + +class InactiveDueCreditCollector: + """Just a stub at the Collector which would not do anything.""" + + def _donothing(self, *_args: Any, **_kwargs: Any) -> None: + """Perform no good and no bad.""" + + def dcite( + self, *_args: Any, **_kwargs: Any + ) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """If I could cite I would.""" + + def nondecorating_decorator(func: Callable[..., Any]) -> Callable[..., Any]: + return func + + return nondecorating_decorator + + active = False + activate = add = cite = dump = load = _donothing + + def __repr__(self) -> str: + return self.__class__.__name__ + "()" + + +def _donothing_func(*_args: Any, **_kwargs: Any) -> Any: + """Perform no good and no bad.""" + return None + + +def _disable_duecredit(exc: Exception) -> None: + import logging + + logging.getLogger("duecredit").exception( + "Failed to import duecredit despite being installed: %s", exc + ) + + +try: + from duecredit import BibTeX, Doi, Text, Url, due +except Exception as e: # noqa: BLE001 + if not isinstance(e, ImportError): + _disable_duecredit(e) + due = InactiveDueCreditCollector() + BibTeX = Doi = Url = Text = _donothing_func + + +def dcite( + doi: str, description: str | None = None, *, path: str | None = None +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """Create a duecredit decorator from a DOI and description.""" + kwargs: dict[str, Any] = ( + {"description": description} if description is not None else {} + ) + if path is not None: + kwargs["path"] = path + return due.dcite(Doi(doi), **kwargs) diff --git a/torch_sim/integrators/md.py b/torch_sim/integrators/md.py index 63f32624a..3185b6cdf 100644 --- a/torch_sim/integrators/md.py +++ b/torch_sim/integrators/md.py @@ -5,6 +5,7 @@ import torch +from torch_sim._duecredit import dcite from torch_sim.models.interface import ModelInterface from torch_sim.quantities import calc_kT from torch_sim.state import SimState @@ -312,6 +313,9 @@ class NoseHooverChainFns: } +@dcite("10.1063/1.463940") +@dcite("10.2183/pjab.69.161") +@dcite("10.1016/0375-9601(90)90092-3") def construct_nose_hoover_chain( # noqa: C901 PLR0915 dt: torch.Tensor, chain_length: int, diff --git a/torch_sim/integrators/npt.py b/torch_sim/integrators/npt.py index 4d6bd752e..a709c4c2a 100644 --- a/torch_sim/integrators/npt.py +++ b/torch_sim/integrators/npt.py @@ -8,6 +8,7 @@ import torch import torch_sim as ts +from torch_sim._duecredit import dcite from torch_sim.integrators.md import ( MDState, NoseHooverChain, @@ -628,6 +629,7 @@ def npt_langevin_init( ) +@dcite("10.1063/1.4901303") def npt_langevin_step( state: NPTLangevinState, model: ModelInterface, @@ -1432,6 +1434,7 @@ def npt_nose_hoover_init( ) +@dcite("10.1080/00268979600100761") def npt_nose_hoover_step( state: NPTNoseHooverState, model: ModelInterface, @@ -1963,6 +1966,8 @@ def _crescale_isotropic_barostat_step( return state +@dcite("10.1063/5.0020514") +@dcite("10.3390/app12031139") def npt_crescale_anisotropic_step( state: NPTCRescaleState, model: ModelInterface, @@ -2031,6 +2036,8 @@ def npt_crescale_anisotropic_step( return _vrescale_update(state, tau, kT, dt / 2) +@dcite("10.1063/5.0020514") +@dcite("10.3390/app12031139") def npt_crescale_independent_lengths_step( state: NPTCRescaleState, model: ModelInterface, @@ -2099,6 +2106,8 @@ def npt_crescale_independent_lengths_step( return _vrescale_update(state, tau, kT, dt / 2) +@dcite("10.1063/5.0020514") +@dcite("10.3390/app12031139") def npt_crescale_average_anisotropic_step( state: NPTCRescaleState, model: ModelInterface, @@ -2168,6 +2177,7 @@ def npt_crescale_average_anisotropic_step( return _vrescale_update(state, tau, kT, dt / 2) +@dcite("10.1063/5.0020514") def npt_crescale_isotropic_step( state: NPTCRescaleState, model: ModelInterface, diff --git a/torch_sim/integrators/nvt.py b/torch_sim/integrators/nvt.py index 61cb84723..957ad89f9 100644 --- a/torch_sim/integrators/nvt.py +++ b/torch_sim/integrators/nvt.py @@ -6,6 +6,7 @@ import torch import torch_sim as ts +from torch_sim._duecredit import dcite from torch_sim.integrators.md import ( MDState, NoseHooverChain, @@ -126,6 +127,7 @@ def nvt_langevin_init( ) +@dcite("10.1098/rspa.2016.0138") def nvt_langevin_step( state: MDState, model: ModelInterface, @@ -321,6 +323,7 @@ def nvt_nose_hoover_init( ) +@dcite("10.1080/00268979600100761") def nvt_nose_hoover_step( state: NVTNoseHooverState, model: ModelInterface, @@ -600,6 +603,7 @@ def nvt_vrescale_init( ) +@dcite("10.1063/1.2408420") def nvt_vrescale_step( model: ModelInterface, state: NVTVRescaleState, diff --git a/torch_sim/io.py b/torch_sim/io.py index 6c4007801..835446ec0 100644 --- a/torch_sim/io.py +++ b/torch_sim/io.py @@ -18,6 +18,7 @@ import torch import torch_sim as ts +from torch_sim._duecredit import dcite if TYPE_CHECKING: @@ -26,6 +27,11 @@ from pymatgen.core import Structure +@dcite( + "10.1088/1361-648X/aa680e", + description="ASE: Atomic Simulation Environment", + path="ase", +) def state_to_atoms(state: "ts.SimState") -> list["Atoms"]: """Convert a SimState to a list of ASE Atoms objects. @@ -85,6 +91,11 @@ def state_to_atoms(state: "ts.SimState") -> list["Atoms"]: return atoms_list +@dcite( + "10.1016/j.commatsci.2012.10.028", + description="pymatgen: Python Materials Genomics", + path="pymatgen", +) def state_to_structures(state: "ts.SimState") -> list["Structure"]: """Convert a SimState to a list of Pymatgen Structure objects. @@ -140,6 +151,16 @@ def state_to_structures(state: "ts.SimState") -> list["Structure"]: return structures +@dcite( + "10.1088/1361-648X/aa680e", + description="ASE: Atomic Simulation Environment", + path="ase", +) +@dcite( + "10.1016/j.scriptamat.2015.07.021", + description="Phonopy: harmonic and quasi-harmonic phonon calculationss", + path="phonopy", +) def state_to_phonopy(state: "ts.SimState") -> list["PhonopyAtoms"]: """Convert a SimState to a list of PhonopyAtoms objects. @@ -188,6 +209,11 @@ def state_to_phonopy(state: "ts.SimState") -> list["PhonopyAtoms"]: return phonopy_atoms_list +@dcite( + "10.1088/1361-648X/aa680e", + description="ASE: Atomic Simulation Environment", + path="ase", +) def atoms_to_state( atoms: "Atoms | list[Atoms]", device: torch.device | None = None, @@ -265,6 +291,11 @@ def atoms_to_state( ) +@dcite( + "10.1016/j.commatsci.2012.10.028", + description="pymatgen: Python Materials Genomics", + path="pymatgen", +) def structures_to_state( structure: "Structure | list[Structure]", device: torch.device | None = None, @@ -337,6 +368,11 @@ def structures_to_state( ) +@dcite( + "10.1016/j.scriptamat.2015.07.021", + description="Phonopy: harmonic and quasi-harmonic phonon calculationss", + path="phonopy", +) def phonopy_to_state( phonopy_atoms: "PhonopyAtoms | list[PhonopyAtoms]", device: torch.device | None = None, diff --git a/torch_sim/math.py b/torch_sim/math.py index 6e5e94c82..5f7604f6e 100644 --- a/torch_sim/math.py +++ b/torch_sim/math.py @@ -6,6 +6,8 @@ import torch +from torch_sim._duecredit import dcite + @torch.jit.script def torch_divmod(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: @@ -38,25 +40,18 @@ def expm_frechet( # noqa: C901 Method notes: - ``SPS`` uses scaling-Pade-squaring for the matrix exponential and its - Frechet derivative. - - ``blockEnlarge`` uses the block matrix identity + Frechet derivative. See :func:`expm_frechet_sps`. + - ``BE`` uses the block matrix identity exp([[A, E], [0, A]]) = [[exp(A), L_exp(A, E)], [0, exp(A)]]. - - References: - - Awad H. Al-Mohy and Nicholas J. Higham (2009), "Computing the Frechet - Derivative of the Matrix Exponential, with an Application to Condition - Number Estimation", SIAM J. Matrix Anal. Appl. 30(4):1639-1657. - https://doi.org/10.1137/080716426 - - Nicholas J. Higham (2008), "Functions of Matrices: Theory and - Computation", SIAM. (See the Frechet derivative block-matrix identity.) + See :func:`expm_frechet_block_enlarge`. Args: A: (B, 3, 3) or (3, 3) tensor. Matrix of which to take the matrix exponential. E: (B, 3, 3) or (3, 3) tensor. Matrix direction in which to take the Frechet derivative. Must have same shape as A. method: str, optional. Choice of algorithm. Should be one of - - `SPS` (default) - - `blockEnlarge` + - `SPS` - Scaling-Pade-squaring (default) + - `BE` - Block-enlarge check_finite: bool, optional. Whether to check that the input matrix contains only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain @@ -86,7 +81,7 @@ def expm_frechet( # noqa: C901 if method is None: method = "SPS" - if method == "blockEnlarge": + if method in ["BE", "blockEnlarge"]: # "blockEnlarge" is deprecated if A.dim() != 3 or A.shape[1] != A.shape[2]: raise ValueError("expected A to be (B, N, N)") return expm_frechet_block_enlarge(A, E) @@ -108,14 +103,17 @@ def matrix_exp(A: torch.Tensor) -> torch.Tensor: return torch.matrix_exp(A) +@dcite("10.1137/080716426") def expm_frechet_sps( A: torch.Tensor, E: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: - """SPS helper for Frechet derivative of exp(A) on 3x3 matrices. + """Scaling-Pade-squaring helper for Frechet derivative of exp(A) on 3x3 matrices. - SPS = scaling-Pade-squaring. This implementation follows the approach in: - Awad H. Al-Mohy and Nicholas J. Higham (2009), SIAM J. Matrix Anal. - Appl. 30(4):1639-1657. https://doi.org/10.1137/080716426 + References: + - Awad H. Al-Mohy and Nicholas J. Higham (2009), "Computing the Fréchet + Derivative of the Matrix Exponential, with an Application to Condition + Number Estimation", SIAM J. Matrix Anal. Appl. 30(4):1639-1657. + https://doi.org/10.1137/080716426 """ # Handle unbatched 3x3 input by adding batch dimension unbatched = A.dim() == 2 @@ -191,6 +189,7 @@ def expm_frechet_sps( return R, L +@dcite("10.1137/1.9780898717778") def expm_frechet_block_enlarge( A: torch.Tensor, E: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: @@ -612,6 +611,7 @@ def _process_matrix_log_case( return result +@dcite("10.1007/s10659-008-9169-x") def _matrix_log_33(T: torch.Tensor, dtype: torch.dtype = torch.float64) -> torch.Tensor: """Compute the logarithm of 3x3 matrix T based on its eigenvalue structure. diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index 00503a7b9..49ab9f2cf 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -6,6 +6,7 @@ import torch_sim as ts import torch_sim.math as fm +from torch_sim._duecredit import dcite from torch_sim.optimizers import cell_filters from torch_sim.state import SimState from torch_sim.typing import StateDict @@ -21,6 +22,7 @@ ) +@dcite("10.1103/PhysRevLett.97.170201") def fire_init( state: SimState | StateDict, model: "ModelInterface",