Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions examples/bb_batchwise_gradient_checkpointing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import sys, os
sys.path.insert(0, os.path.join(os.path.abspath(os.pardir),'src'))
from molearn.data import PDBData
from molearn.trainers import Sinkhorn_Trainer
from molearn.models.foldingnet import AutoEncoder
import torch


import torch
from geomloss import SamplesLoss

#This script is based off of the sinkhorn script.
# We may want a larger batch size for sinkhorn but may not have the GPU memory to do this.
#We can use gradient checkpointing in the batch dimension to give us access to dramatically larger batch sizes in the same memory (We are trading compute for memory so this might be a bit slower)
from molearn.utils import CheckpointBatch

class CustomAutoEncoder(AutoEncoder):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
#Gradient checkpointing but in the batch dimension
#I have been able to use batch sizes of up to 4000 structures (my entire dataset)
#This function is ideally used with model that don't use batch_norm which foldingnet does.
#I would recommend replacing with layer norm or group norm in future.
self.decoder = CheckpointBatch(self.decoder, backward_batch_size = 16, forward_batch_size=16)




if __name__ == '__main__':

##### Load Data #####
data = PDBData()
data.import_pdb('data/MurD_closed_selection.pdb')
data.import_pdb('data/MurD_open_selection.pdb')
data.fix_terminal()
data.atomselect(atoms = ['CA', 'C', 'N', 'CB', 'O'])

##### Prepare Trainer #####
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trainer = Sinkhorn_Trainer(device=device)

trainer.set_data(data, batch_size=8, validation_split=0.1, manual_seed = 25)
trainer.prepare_physics(remove_NB = True)

trainer.set_autoencoder(CustomAutoEncoder, out_points = data.dataset.shape[-1])
trainer.prepare_optimiser()


##### Training Loop #####
#Keep training until loss does not improve for 32 consecutive epochs

runkwargs = dict(
log_filename='log_file.dat',
log_folder='xbb_sinkhorn_checkpoints',
checkpoint_folder='xbb_sinkhorn_checkpoints',
)

best = 1e24
while True:
trainer.run(max_epochs = 32+trainer.epoch,**runkwargs)
if not best>trainer.best:
break
best = trainer.best
print(f'best {trainer.best}, best_filename {trainer.best_name}')
68 changes: 68 additions & 0 deletions examples/bb_sinkhorn_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import sys, os
sys.path.insert(0, os.path.join(os.path.abspath(os.pardir),'src'))
from molearn.data import PDBData
from molearn.trainers import Sinkhorn_Trainer
from molearn.models.foldingnet import AutoEncoder
import torch


import torch
from geomloss import SamplesLoss


#This is an idea of how Sinkhorn_Trainer is implemented in molearn.trainers.sinkhorn_trainer
#We are only using the decoder of the autoencoder.
#i.e. we are training a generator but we still use the autoencoder terms because the classes are subclassed from the original trainer for training autoencoders
'''
#define loss function
self.sinkhorn = SamplesLoss(**kwargs)

#### Sample from a normal distribution
z = torch.randn(batch.shape[0], self.latent_dim, 1).to(self.device)

#### Decode those latent structures
structures = self.autoencoder.decode(z)[:,:,:batch.shape[2]]

#### Calcuate Sinkhorn distance between actual structures and generated structures
loss = self.sinkhorn(structures.reshape(structures.size(0), -1), batch.reshape(batch.size(0),-1))

#### We alse calculate a physical energy loss and add it
final_loss = results['sinkhorn']+scale*results['physics_loss']
'''

if __name__ == '__main__':

##### Load Data #####
data = PDBData()
data.import_pdb('data/MurD_closed_selection.pdb')
data.import_pdb('data/MurD_open_selection.pdb')
data.fix_terminal()
data.atomselect(atoms = ['CA', 'C', 'N', 'CB', 'O'])

##### Prepare Trainer #####
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trainer = Sinkhorn_Trainer(device=device)

trainer.set_data(data, batch_size=8, validation_split=0.1, manual_seed = 25)
trainer.prepare_physics(remove_NB = True)

trainer.set_autoencoder(AutoEncoder, out_points = data.dataset.shape[-1])
trainer.prepare_optimiser()


##### Training Loop #####
#Keep training until loss does not improve for 32 consecutive epochs

runkwargs = dict(
log_filename='log_file.dat',
log_folder='xbb_sinkhorn_checkpoints',
checkpoint_folder='xbb_sinkhorn_checkpoints',
)

best = 1e24
while True:
trainer.run(max_epochs = 32+trainer.epoch,**runkwargs)
if not best>trainer.best:
break
best = trainer.best
print(f'best {trainer.best}, best_filename {trainer.best_name}')
15 changes: 15 additions & 0 deletions src/molearn/trainers/deprecated/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2022 Samuel C. Musson
#
# Molearn is free software ;
# you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation ;
# either version 2 of the License, or (at your option) any later version.
# Molightning is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY ;
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details.
# You should have received a copy of the GNU General Public License along with molightning ;
# if not, write to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
"""
trainers holds classes for training networks
"""


Loading