Skip to content

Commit ad53898

Browse files
committed
reaction subpackage transformed to module.
dropped support of pytorch 1.x. added llama mlp.
1 parent d10859e commit ad53898

File tree

13 files changed

+143
-139
lines changed

13 files changed

+143
-139
lines changed
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22
#
3-
# Copyright 2021-2023 Ramil Nugmanov <nougmanoff@protonmail.com>
3+
# Copyright 2021-2024 Ramil Nugmanov <nougmanoff@protonmail.com>
44
#
55
# Permission is hereby granted, free of charge, to any person obtaining a copy
66
# of this software and associated documentation files (the “Software”), to deal
@@ -24,9 +24,9 @@
2424
from torch import zeros_like, float as t_float
2525
from torch.nn import Embedding, GELU, Module
2626
from torchtyping import TensorType
27-
from ..molecule import MoleculeEncoder
28-
from ..transformer import EncoderLayer
29-
from ...utils.data import ReactionEncoderDataBatch
27+
from .molecule import MoleculeEncoder
28+
from .transformer import EncoderLayer
29+
from ..utils.data import ReactionEncoderDataBatch
3030

3131

3232
class ReactionEncoder(Module):

chytorch/nn/reaction/__init__.py

Lines changed: 0 additions & 26 deletions
This file was deleted.

chytorch/nn/transformer/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -*- coding: utf-8 -*-
22
#
3-
# Copyright 2023 Ramil Nugmanov <nougmanoff@protonmail.com>
3+
# Copyright 2023, 2024 Ramil Nugmanov <nougmanoff@protonmail.com>
44
#
55
# Permission is hereby granted, free of charge, to any person obtaining a copy
66
# of this software and associated documentation files (the “Software”), to deal
@@ -25,4 +25,6 @@
2525

2626

2727
__all__ = ['EncoderLayer',
28+
'MLP',
29+
'LLaMAMLP',
2830
'GraphormerAttention']

chytorch/nn/transformer/encoder.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
# SOFTWARE.
2222
#
2323
from torch import Tensor, nn
24-
from torch.nn import Dropout, GELU, LayerNorm, Module
24+
from torch.nn import Dropout, GELU, LayerNorm, Module, SiLU
2525
from typing import Tuple, Optional, Type
2626
from warnings import warn
2727
from .attention import GraphormerAttention
@@ -54,6 +54,23 @@ def forward(self, x):
5454
return self.linear2(self.dropout(self.activation(self.linear1(x))))
5555

5656

57+
class LLaMAMLP(Module):
58+
def __init__(self, d_model, dim_feedforward, dropout=0.1, activation=SiLU, bias: bool = False):
59+
super().__init__()
60+
self.linear1 = Linear(d_model, dim_feedforward, bias=bias)
61+
self.linear2 = Linear(d_model, dim_feedforward, bias=bias)
62+
self.linear3 = Linear(dim_feedforward, d_model, bias=bias)
63+
self.dropout = Dropout(dropout)
64+
65+
# ad-hoc for resolving class from name
66+
if isinstance(activation, str):
67+
activation = getattr(nn, activation)
68+
self.activation = activation()
69+
70+
def forward(self, x):
71+
return self.linear3(self.dropout(self.activation(self.linear1(x))) * self.linear2(x))
72+
73+
5774
class EncoderLayer(Module):
5875
r"""EncoderLayer based on torch.nn.TransformerEncoderLayer, but batch always first and returns also attention.
5976
@@ -96,4 +113,4 @@ def forward(self, x: Tensor, attn_mask: Optional[Tensor], *,
96113
return None, a
97114

98115

99-
__all__ = ['EncoderLayer', 'MLP']
116+
__all__ = ['EncoderLayer', 'MLP', 'LLaMAMLP']

chytorch/utils/data/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from ._utils import *
3131

3232

33-
__all__ = ['MoleculeDataset', 'collate_molecules',
33+
__all__ = ['MoleculeDataset', 'collate_molecules', 'left_padded_collate_molecules',
3434
'ConformerDataset', 'collate_conformers',
3535
'ReactionEncoderDataset', 'collate_encoded_reactions',
3636
'RDKitConformerDataset',

chytorch/utils/data/_abc.py

Lines changed: 0 additions & 29 deletions
This file was deleted.

chytorch/utils/data/molecule/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
from .rdkit import *
2727

2828

29-
__all__ = ['MoleculeDataset', 'MoleculeDataPoint', 'MoleculeDataBatch', 'collate_molecules',
29+
__all__ = ['MoleculeDataset', 'MoleculeDataPoint', 'MoleculeDataBatch',
30+
'collate_molecules', 'left_padded_collate_molecules',
3031
'ConformerDataset', 'ConformerDataPoint', 'ConformerDataBatch', 'collate_conformers',
3132
'RDKitConformerDataset',
3233
'thiacalix_n_arene_dataset']

chytorch/utils/data/molecule/_unpack.pyx

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,12 @@ def unpack(const unsigned char[::1] data not None, unsigned short add_cls, unsig
7979

8080
cdef cnp.ndarray[DTYPE_t, ndim=1] atoms, neighbors
8181
cdef cnp.ndarray[DTYPE_t, ndim=2] distance
82-
cdef DTYPE_t d, attention
82+
cdef DTYPE_t d
8383

8484
# read header
8585
if data[0] != 2:
8686
raise ValueError('invalid pack version')
8787

88-
attention = 1 if components_attention else 0
89-
9088
a, b, c = data[1], data[2], data[3]
9189
atoms_count = (a << 4| b >> 4) + add_cls
9290
cis_trans_count = (b & 0x0f) << 8 | c
@@ -170,7 +168,7 @@ def unpack(const unsigned char[::1] data not None, unsigned short add_cls, unsig
170168
d = distance[i, j]
171169
if d == 9999:
172170
# set attention between subgraphs
173-
distance[i, j] = distance[j, i] = attention
171+
distance[i, j] = distance[j, i] = components_attention
174172
elif d > max_distance:
175173
distance[i, j] = distance[j, i] = max_distance + 2
176174
else:

chytorch/utils/data/molecule/conformer.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
from torch import IntTensor, Size, zeros, ones as t_ones, int32 as t_int32, eye
2828
from torch.nn.utils.rnn import pad_sequence
2929
from torch.utils.data import Dataset
30+
from torch.utils.data._utils.collate import default_collate_fn_map
3031
from torchtyping import TensorType
3132
from typing import Sequence, Tuple, Union, NamedTuple
32-
from .._abc import default_collate_fn_map
3333

3434

3535
class ConformerDataPoint(NamedTuple):
@@ -53,7 +53,7 @@ def cuda(self, *args, **kwargs):
5353
return ConformerDataBatch(*(x.cuda(*args, **kwargs) for x in self))
5454

5555

56-
def collate_conformers(batch, *, padding_left: bool = False, collate_fn_map=None) -> ConformerDataBatch:
56+
def collate_conformers(batch, *, collate_fn_map=None) -> ConformerDataBatch:
5757
"""
5858
Prepares batches of conformers.
5959
@@ -62,25 +62,16 @@ def collate_conformers(batch, *, padding_left: bool = False, collate_fn_map=None
6262
atoms, hydrogens, distances = [], [], []
6363

6464
for a, h, d in batch:
65-
if padding_left:
66-
atoms.append(a.flipud())
67-
hydrogens.append(h.flipud())
68-
else:
69-
atoms.append(a)
70-
hydrogens.append(h)
65+
atoms.append(a)
66+
hydrogens.append(h)
7167
distances.append(d)
7268

7369
pa = pad_sequence(atoms, True)
7470
b, s = pa.shape
7571
tmp = eye(s, dtype=t_int32).repeat(b, 1, 1) # prevent nan in MHA softmax on padding
7672
for i, d in enumerate(distances):
7773
s = d.size(0)
78-
if padding_left:
79-
tmp[i, -s:, -s:] = d
80-
else:
81-
tmp[i, :s, :s] = d
82-
if padding_left:
83-
return ConformerDataBatch(pa.fliplr(), pad_sequence(hydrogens, True).fliplr(), tmp)
74+
tmp[i, :s, :s] = d
8475
return ConformerDataBatch(pa, pad_sequence(hydrogens, True), tmp)
8576

8677

@@ -144,10 +135,9 @@ def __getitem__(self, item: int) -> ConformerDataPoint:
144135
atoms = IntTensor(len(mol))
145136
hydrogens = IntTensor(len(mol))
146137

147-
hgs = mol._hydrogens # noqa
148138
for i, (n, a) in enumerate(mol.atoms(), self.add_cls):
149139
atoms[i] = a.atomic_number + 2
150-
hydrogens[i] = (hgs[n] or 0) + 2
140+
hydrogens[i] = (a.implicit_hydrogens or 0) + 2
151141

152142
xyz = empty((len(mol), 3))
153143
conformer = mol._conformers[0] # noqa

0 commit comments

Comments
 (0)