Skip to content

Commit 4fb899c

Browse files
committed
fixed python 3.11 compatibility
fixed torch dependency resolving
1 parent 7d06739 commit 4fb899c

File tree

5 files changed

+39
-42
lines changed

5 files changed

+39
-42
lines changed

chytorch/utils/data/_abc.py

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -20,40 +20,10 @@
2020
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
2121
# SOFTWARE.
2222
#
23-
from typing import NamedTuple, NamedTupleMeta
24-
2523
try:
2624
from torch.utils.data._utils.collate import default_collate_fn_map
2725
except ImportError: # ad-hoc for pytorch<1.13
2826
default_collate_fn_map = {}
2927

3028

31-
# https://stackoverflow.com/a/50369521
32-
if hasattr(NamedTuple, '__mro_entries__'):
33-
# Python 3.9 fixed and broke multiple inheritance in a different way
34-
# see https://github.com/python/cpython/issues/88089
35-
from typing import _NamedTuple
36-
37-
NamedTuple = _NamedTuple
38-
39-
40-
class MultipleInheritanceNamedTupleMeta(NamedTupleMeta):
41-
def __new__(mcls, typename, bases, ns):
42-
if NamedTuple in bases:
43-
base = super().__new__(mcls, '_base_' + typename, bases, ns)
44-
bases = (base, *(b for b in bases if not isinstance(b, NamedTuple)))
45-
return super(NamedTupleMeta, mcls).__new__(mcls, typename, bases, ns)
46-
47-
48-
class DataTypeMixin(metaclass=MultipleInheritanceNamedTupleMeta):
49-
def to(self, *args, **kwargs):
50-
return type(self)(*(x.to(*args, **kwargs) for x in self))
51-
52-
def cpu(self, *args, **kwargs):
53-
return type(self)(*(x.cpu(*args, **kwargs) for x in self))
54-
55-
def cuda(self, *args, **kwargs):
56-
return type(self)(*(x.cuda(*args, **kwargs) for x in self))
57-
58-
59-
__all__ = ['DataTypeMixin', 'NamedTuple', 'default_collate_fn_map']
29+
__all__ = ['default_collate_fn_map']

chytorch/utils/data/molecule/conformer.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@
2828
from torch.nn.utils.rnn import pad_sequence
2929
from torch.utils.data import Dataset
3030
from torchtyping import TensorType
31-
from typing import Sequence, Tuple, Union
32-
from .._abc import DataTypeMixin, NamedTuple, default_collate_fn_map
31+
from typing import Sequence, Tuple, Union, NamedTuple
32+
from .._abc import default_collate_fn_map
3333

3434

3535
class ConformerDataPoint(NamedTuple):
@@ -38,11 +38,20 @@ class ConformerDataPoint(NamedTuple):
3838
distances: TensorType['atoms', 'atoms', int]
3939

4040

41-
class ConformerDataBatch(NamedTuple, DataTypeMixin):
41+
class ConformerDataBatch(NamedTuple):
4242
atoms: TensorType['batch', 'atoms', int]
4343
hydrogens: TensorType['batch', 'atoms', int]
4444
distances: TensorType['batch', 'atoms', 'atoms', int]
4545

46+
def to(self, *args, **kwargs):
47+
return ConformerDataBatch(*(x.to(*args, **kwargs) for x in self))
48+
49+
def cpu(self, *args, **kwargs):
50+
return ConformerDataBatch(*(x.cpu(*args, **kwargs) for x in self))
51+
52+
def cuda(self, *args, **kwargs):
53+
return ConformerDataBatch(*(x.cuda(*args, **kwargs) for x in self))
54+
4655

4756
def collate_conformers(batch, *, padding_left: bool = False, collate_fn_map=None) -> ConformerDataBatch:
4857
"""

chytorch/utils/data/molecule/encoder.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
from torch.nn.utils.rnn import pad_sequence
2828
from torch.utils.data import Dataset
2929
from torchtyping import TensorType
30-
from typing import Sequence, Union
30+
from typing import Sequence, Union, NamedTuple
3131
from zlib import decompress
32-
from .._abc import DataTypeMixin, NamedTuple, default_collate_fn_map
32+
from .._abc import default_collate_fn_map
3333

3434

3535
class MoleculeDataPoint(NamedTuple):
@@ -38,11 +38,20 @@ class MoleculeDataPoint(NamedTuple):
3838
distances: TensorType['atoms', 'atoms', int]
3939

4040

41-
class MoleculeDataBatch(NamedTuple, DataTypeMixin):
41+
class MoleculeDataBatch(NamedTuple):
4242
atoms: TensorType['batch', 'atoms', int]
4343
neighbors: TensorType['batch', 'atoms', int]
4444
distances: TensorType['batch', 'atoms', 'atoms', int]
4545

46+
def to(self, *args, **kwargs):
47+
return MoleculeDataBatch(*(x.to(*args, **kwargs) for x in self))
48+
49+
def cpu(self, *args, **kwargs):
50+
return MoleculeDataBatch(*(x.cpu(*args, **kwargs) for x in self))
51+
52+
def cuda(self, *args, **kwargs):
53+
return MoleculeDataBatch(*(x.cuda(*args, **kwargs) for x in self))
54+
4655

4756
def collate_molecules(batch, *, padding_left: bool = False, collate_fn_map=None) -> MoleculeDataBatch:
4857
"""

chytorch/utils/data/reaction/encoder.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626
from torch.nn.utils.rnn import pad_sequence
2727
from torch.utils.data import Dataset
2828
from torchtyping import TensorType
29-
from typing import Sequence, Union
29+
from typing import Sequence, Union, NamedTuple
3030
from ..molecule import MoleculeDataset
31-
from .._abc import DataTypeMixin, NamedTuple, default_collate_fn_map
31+
from .._abc import default_collate_fn_map
3232

3333

3434
class ReactionEncoderDataPoint(NamedTuple):
@@ -38,12 +38,21 @@ class ReactionEncoderDataPoint(NamedTuple):
3838
roles: TensorType['atoms', int]
3939

4040

41-
class ReactionEncoderDataBatch(NamedTuple, DataTypeMixin):
41+
class ReactionEncoderDataBatch(NamedTuple):
4242
atoms: TensorType['batch', 'atoms', int]
4343
neighbors: TensorType['batch', 'atoms', int]
4444
distances: TensorType['batch', 'atoms', 'atoms', int]
4545
roles: TensorType['batch', 'atoms', int]
4646

47+
def to(self, *args, **kwargs):
48+
return ReactionEncoderDataBatch(*(x.to(*args, **kwargs) for x in self))
49+
50+
def cpu(self, *args, **kwargs):
51+
return ReactionEncoderDataBatch(*(x.cpu(*args, **kwargs) for x in self))
52+
53+
def cuda(self, *args, **kwargs):
54+
return ReactionEncoderDataBatch(*(x.cuda(*args, **kwargs) for x in self))
55+
4756

4857
def collate_encoded_reactions(batch, *, padding_left: bool = False, collate_fn_map=None) -> ReactionEncoderDataBatch:
4958
"""

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = 'chytorch'
3-
version = '1.59'
3+
version = '1.60'
44
description = 'Library for modeling molecules and reactions in torch way'
55
authors = ['Ramil Nugmanov <nougmanoff@protonmail.com>']
66
license = 'MIT'
@@ -32,7 +32,7 @@ python = '>=3.8,<3.12'
3232
torchtyping = '^0.1.4'
3333
chython = '^1.70'
3434
scipy = '^1.10'
35-
torch = '>=2.0,>=1.8'
35+
torch = '>=1.8'
3636
lmdb = {version='^1.4.1', optional = true}
3737
psycopg2-binary = {version='^2.9', optional = true}
3838
rdkit = {version = '^2023.9.1', optional = true}

0 commit comments

Comments
 (0)