diff --git a/mlcolvar/cvs/cv.py b/mlcolvar/cvs/cv.py index 7bfd95c5..17c989d5 100644 --- a/mlcolvar/cvs/cv.py +++ b/mlcolvar/cvs/cv.py @@ -47,6 +47,7 @@ def __init__( self._optimizer_name = "Adam" self.optimizer_kwargs = {} self.lr_scheduler_kwargs = {} + self.lr_scheduler_config = {} # PRE/POST self.preprocessing = preprocessing @@ -88,6 +89,8 @@ def parse_options(self, options: dict = None): self.optimizer_kwargs.update(options[o]) elif o == "lr_scheduler": self.lr_scheduler_kwargs.update(options[o]) + elif o == "lr_scheduler_config": + self.lr_scheduler_config.update(options[o]) else: raise ValueError( f'The key {o} is not available in this class. The available keys are: {", ".join(self.BLOCKS)}, optimizer and lr_scheduler.' @@ -192,24 +195,43 @@ def optimizer_name(self, optimizer_name: str): def configure_optimizers(self): """ Initialize the optimizer based on self._optimizer_name and self.optimizer_kwargs. + It also adds the learning rate scheduler if self.lr_scheduler_kwargs is not empty. + The scheduler is given as a dictionary with the key 'scheduler' containing the scheduler class + and the rest of the keys are config options for the scheduler. Returns ------- torch.optim Torch optimizer + + dict, optional + Learning rate scheduler configuration (if any) """ + # Create the optimizer from the optimizer name and kwargs optimizer = getattr(torch.optim, self._optimizer_name)( self.parameters(), **self.optimizer_kwargs ) - - if self.lr_scheduler_kwargs: - scheduler_cls = self.lr_scheduler_kwargs['scheduler'] - scheduler_kwargs = {k: v for k, v in self.lr_scheduler_kwargs.items() if k != 'scheduler'} - lr_scheduler = scheduler_cls(optimizer, **scheduler_kwargs) - return [optimizer] , [lr_scheduler] - else: + + # Return just the optimizer if no scheduler is defined + if not self.lr_scheduler_kwargs: return optimizer + + # Create the scheduler from the lr_scheduler_kwargs if any + scheduler_cls = self.lr_scheduler_kwargs['scheduler'] + scheduler_kwargs = {k: v for k, v in self.lr_scheduler_kwargs.items() if k != 'scheduler'} + print("So far so good") + print(scheduler_cls) + print(scheduler_kwargs) + lr_scheduler = scheduler_cls(optimizer, **scheduler_kwargs) + lr_scheduler_config = { + "scheduler": lr_scheduler + } + + # Add possible additional config options + if self.lr_scheduler_config: + lr_scheduler_config.update(self.lr_scheduler_config) + return [optimizer], [lr_scheduler_config] def __setattr__(self, key, value): # PyTorch overrides __setattr__ to raise a TypeError when you try to assign