diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2ccd5f9..8726d86 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,37 +1,41 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 + rev: v6.0.0 hooks: - id: end-of-file-fixer - - id: fix-encoding-pragma - args: [ --remove ] - id: mixed-line-ending - id: trailing-whitespace - id: check-yaml + - repo: https://github.com/asottile/pyupgrade + rev: v3.21.0 + hooks: + - id: pyupgrade + args: [ --py310-plus ] + - repo: https://github.com/ikamensh/flynt/ - rev: '1.0.1' + rev: '1.0.6' hooks: - id: flynt - repo: https://github.com/psf/black - rev: 25.1.0 + rev: 25.9.0 hooks: - id: black exclude: (.*)/migrations - repo: https://github.com/pycqa/flake8 - rev: 7.2.0 + rev: 7.3.0 hooks: - id: flake8 - repo: https://github.com/pycqa/isort - rev: '6.0.1' + rev: '7.0.0' hooks: - id: isort - repo: https://github.com/PyCQA/bandit - rev: 1.8.3 + rev: 1.8.6 hooks: - id: bandit args: [ "-c", "pyproject.toml" ] @@ -40,7 +44,7 @@ repos: - repo: https://github.com/PyCQA/pylint # Configuration help can be found here: # https://pylint.pycqa.org/en/latest/user_guide/installation/pre-commit-integration.html - rev: v3.3.6 + rev: v4.0.2 hooks: - id: pylint alias: pylint-with-spelling @@ -55,7 +59,7 @@ repos: )$ - repo: https://github.com/commitizen-tools/commitizen - rev: v4.6.0 + rev: v4.9.1 hooks: - id: commitizen stages: [ commit-msg ] @@ -66,6 +70,6 @@ repos: - id: nb-clean - repo: https://github.com/pycqa/doc8 - rev: v1.1.2 + rev: v2.0.0 hooks: - id: doc8 diff --git a/pyproject.toml b/pyproject.toml index 80a402d..69ceca8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dependencies = [ 'orbax-checkpoint', 'e3nn-jax', "equinox", - "reax>=0.2.0", + "reax>=0.2,<0.6", "tensorial>=0.4.2", "pymatgen", ] diff --git a/src/e3response/_logging.py b/src/e3response/_logging.py index d3d4d9a..d8df348 100644 --- a/src/e3response/_logging.py +++ b/src/e3response/_logging.py @@ -2,7 +2,7 @@ import os import shutil import tempfile -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING from typing_extensions import override @@ -12,7 +12,7 @@ class MlflowHandler(logging.Handler): def __init__( - self, client: "mlflow.tracking.MlflowClient", run_id: str, level: Union[int, str] = 0 + self, client: "mlflow.tracking.MlflowClient", run_id: str, level: int | str = 0 ) -> None: super().__init__(level) self._tempfile = None diff --git a/src/e3response/data/barium_titanate.py b/src/e3response/data/barium_titanate.py index 8db1059..beb9c36 100644 --- a/src/e3response/data/barium_titanate.py +++ b/src/e3response/data/barium_titanate.py @@ -1,9 +1,10 @@ +from collections.abc import Callable, Iterable, Sequence import functools import logging import pathlib import re import tarfile -from typing import Any, Callable, Final, Iterable, Optional, Sequence, Union +from typing import Any, Final import ase import ase.io @@ -33,13 +34,13 @@ class BtoDataModule(reax.DataModule): def __init__( self, r_max: float, - data_dir: Union[str, pathlib.Path] = "data/bto/", + data_dir: str | pathlib.Path = "data/bto/", archives: Sequence[str] = ( "BTO_Pm-3m_5atoms_400K_3x3x3_ensemble.tar.gz", "BTO_Pm-3m_5atoms_800K_3x3x3.tar.gz", ), tensors: tuple[str] = ("raman_tensors", "born_charges", "dielectric"), - train_val_test_split: Sequence[Union[int, float]] = (0.8, 0.1, 0.1), + train_val_test_split: Sequence[int | float] = (0.8, 0.1, 0.1), batch_size: int = 64, ) -> None: """Initialize a `SiliconDataModule`. @@ -55,14 +56,14 @@ def __init__( self._data_dir: Final[str] = str(data_dir) self._archives: Final[tuple[str, ...]] = tuple(archives) self._tensors = tensors - self._train_val_test_split: Final[Sequence[Union[int, float]]] = train_val_test_split + self._train_val_test_split: Final[Sequence[int | float]] = train_val_test_split self._batch_size: Final[int] = batch_size # State self.batch_size_per_device = batch_size - self.data_train: Optional[reax.data.Dataset] = None - self.data_val: Optional[reax.data.Dataset] = None - self.data_test: Optional[reax.data.Dataset] = None + self.data_train: reax.data.Dataset | None = None + self.data_val: reax.data.Dataset | None = None + self.data_test: reax.data.Dataset | None = None @override def setup(self, stage: "reax.Stage", /) -> None: @@ -235,7 +236,7 @@ def get_structures(root_dir: pathlib.Path, tensors: Iterable[str]) -> list[ase.A def read_scf(filename) -> ase.Atoms: - with open(filename, "r", encoding="utf-8") as fileobj: + with open(filename, encoding="utf-8") as fileobj: _data, card_lines = espresso.read_fortran_namelist(fileobj) cell, _ = espresso.get_cell_parameters(card_lines) diff --git a/src/e3response/data/qm9_nmr.py b/src/e3response/data/qm9_nmr.py index 519f578..7906356 100644 --- a/src/e3response/data/qm9_nmr.py +++ b/src/e3response/data/qm9_nmr.py @@ -1,4 +1,5 @@ import collections +from collections.abc import Callable, Sequence import functools from functools import lru_cache import logging @@ -6,7 +7,7 @@ import pathlib import re import tempfile -from typing import Any, Callable, Final, Optional, Sequence, Union +from typing import Any, Final import urllib.error import urllib.request import zipfile @@ -14,8 +15,8 @@ import ase import jraph import numpy as np -from pymatgen.io import gaussian # type: ignore -import pymatgen.io.ase # type: ignore +from pymatgen.io import gaussian +import pymatgen.io.ase import reax from tensorial import gcnn import tqdm @@ -58,9 +59,9 @@ def __init__( self, r_max: float = 5, data_dir: str = "data/qm9_nmr/", - dataset: Union[str, Sequence[str]] = "gasphase", - atom_keys: Optional[Union[str, Sequence[str]]] = None, - limit: Optional[int] = None, + dataset: str | Sequence[str] = "gasphase", + atom_keys: str | Sequence[str] | None = None, + limit: int | None = None, ) -> None: """ Initialize the QM9-NMR dataset. @@ -131,7 +132,7 @@ def __init__( try: with zipfile.ZipFile(archive_path, "r") as zip_ref: zip_ref.testzip() - except (zipfile.BadZipFile, zipfile.LargeZipFile, IOError) as e: + except (zipfile.BadZipFile, zipfile.LargeZipFile, OSError) as e: _LOGGER.warning( "%s is corrupted or unreadable: %s, removing corrupted archive ...", archive_name, @@ -189,7 +190,7 @@ def reporthook(_block_num, block_size, total_size): except OSError as e: _LOGGER.error("Filesystem error while writing %s: %s", path, e) - def _extract_archive_zip(self, zip_path: str, limit: Optional[int] = None) -> list: + def _extract_archive_zip(self, zip_path: str, limit: int | None = None) -> list: structures = [] @@ -229,7 +230,7 @@ def _create_molecule_data(log_file): structure = gaussian_output.final_structure # extraction of data from .log file - with open(log_file, "r", encoding="utf-8") as file: + with open(log_file, encoding="utf-8") as file: log_data = file.read() shielding_pattern = ( @@ -294,12 +295,12 @@ def _create_molecule_data(log_file): _LOGGER.error("Error in file %s: %s", log_file, e) raise - except (IOError, OSError) as e: + except OSError as e: _LOGGER.error("File system error while processing %s: %s", log_file, e) raise -def get_structure_and_data_from_log(log_path: pathlib.Path) -> Optional[ase.Atoms]: +def get_structure_and_data_from_log(log_path: pathlib.Path) -> ase.Atoms | None: # _LOGGER.info("Parsing Gaussian .log file: %s", log_path) try: @@ -333,7 +334,7 @@ def get_structure_and_data_from_log(log_path: pathlib.Path) -> Optional[ase.Atom return atoms - except (ValueError, IOError) as e: + except (ValueError, OSError) as e: _LOGGER.error("Parsing error for %s: %s", log_path, e) return None @@ -350,10 +351,10 @@ def __init__( self, r_max: float = 5, data_dir: str = "data/qm9_nmr/", - dataset: Union[str, Sequence[str]] = "gasphase", - atom_keys: Optional[Sequence[str]] = None, - limit: Optional[int] = None, - train_val_test_split: Sequence[Union[int, float]] = (0.85, 0.05, 0.1), + dataset: str | Sequence[str] = "gasphase", + atom_keys: Sequence[str] | None = None, + limit: int | None = None, + train_val_test_split: Sequence[int | float] = (0.85, 0.05, 0.1), batch_size: int = 64, ) -> None: """Initialize a QM9-NMR data module. @@ -371,19 +372,19 @@ def __init__( # Params self._data_dir: Final[str] = data_dir - self._dataset: Union[str, Sequence[str]] = dataset - self.dataset: Optional[Qm9NmrDataset] = None + self._dataset: str | Sequence[str] = dataset + self.dataset: Qm9NmrDataset | None = None self._rmax = r_max self._atom_keys = atom_keys self._limit = limit - self._train_val_test_split: Final[Sequence[Union[int, float]]] = train_val_test_split + self._train_val_test_split: Final[Sequence[int | float]] = train_val_test_split self._batch_size: Final[int] = batch_size # State self.batch_size_per_device = batch_size - self.data_train: Optional[reax.data.Dataset] = None - self.data_val: Optional[reax.data.Dataset] = None - self.data_test: Optional[reax.data.Dataset] = None + self.data_train: reax.data.Dataset | None = None + self.data_val: reax.data.Dataset | None = None + self.data_test: reax.data.Dataset | None = None @override def setup(self, stage: "reax.Stage", /) -> None: diff --git a/src/e3response/train.py b/src/e3response/train.py index df19a5c..3510573 100644 --- a/src/e3response/train.py +++ b/src/e3response/train.py @@ -1,6 +1,5 @@ import logging import pathlib -from typing import Optional import hydra from hydra.core import hydra_config @@ -102,7 +101,7 @@ def train(cfg: omegaconf.DictConfig): return metric_dict, object_dict -def main(cfg: omegaconf.DictConfig) -> Optional[float]: +def main(cfg: omegaconf.DictConfig) -> float | None: """Main entry point for training. :param cfg: DictConfig configuration composed by Hydra. diff --git a/src/e3response/utils/logging_utils.py b/src/e3response/utils/logging_utils.py index 7bc24dd..ae10ed2 100644 --- a/src/e3response/utils/logging_utils.py +++ b/src/e3response/utils/logging_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict +from typing import Any import jax from lightning_utilities.core.rank_zero import rank_zero_only @@ -12,7 +12,7 @@ @rank_zero_only -def log_hyperparameters(object_dict: Dict[str, Any]) -> None: +def log_hyperparameters(object_dict: dict[str, Any]) -> None: """Controls which config parts are saved by Lightning loggers. Additionally, it saves: diff --git a/src/e3response/utils/pylogger.py b/src/e3response/utils/pylogger.py index 3e0147c..84e4b85 100644 --- a/src/e3response/utils/pylogger.py +++ b/src/e3response/utils/pylogger.py @@ -1,5 +1,5 @@ +from collections.abc import Mapping import logging -from typing import Mapping, Optional from lightning_utilities.core import rank_zero @@ -13,7 +13,7 @@ def __init__( self, name: str = __name__, rank_zero_only: bool = False, - extra: Optional[Mapping[str, object]] = None, + extra: Mapping[str, object] | None = None, ) -> None: """Initializes a multi-GPU-friendly python command line logger that logs on all processes with their rank prefixed in the log message. @@ -28,7 +28,7 @@ def __init__( super().__init__(logger=logger, extra=extra) self.rank_zero_only = rank_zero_only - def log(self, level: int, msg: str, *args, rank: Optional[int] = None, **kwargs) -> None: + def log(self, level: int, msg: str, *args, rank: int | None = None, **kwargs) -> None: """Delegate a log call to the underlying logger, after prefixing its message with the rank of the process it's being logged from. If `'rank'` is provided, then the log will only occur on that rank/process. diff --git a/src/e3response/utils/rich_utils.py b/src/e3response/utils/rich_utils.py index 6fa89fb..528a647 100644 --- a/src/e3response/utils/rich_utils.py +++ b/src/e3response/utils/rich_utils.py @@ -1,5 +1,5 @@ +from collections.abc import Sequence from pathlib import Path -from typing import Sequence from hydra.core.hydra_config import HydraConfig from lightning_utilities.core.rank_zero import rank_zero_only diff --git a/src/e3response/utils/utils.py b/src/e3response/utils/utils.py index 82e9fdc..707f192 100644 --- a/src/e3response/utils/utils.py +++ b/src/e3response/utils/utils.py @@ -1,5 +1,6 @@ +from collections.abc import Callable from importlib.util import find_spec -from typing import Any, Callable, Optional +from typing import Any import warnings from omegaconf import DictConfig @@ -99,7 +100,7 @@ def wrap(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: return wrap -def get_metric_value(metric_dict: dict[str, Any], metric_name: Optional[str]) -> Optional[float]: +def get_metric_value(metric_dict: dict[str, Any], metric_name: str | None) -> float | None: """Safely retrieves value of the metric logged in reax.Module. :param metric_dict: A dict containing metric values.