From 1b06dcd1e67585903d0190b6716666dcf02c0b95 Mon Sep 17 00:00:00 2001
From: PabloNA97
Date: Sat, 23 Aug 2025 17:38:58 +0200
Subject: [PATCH] feat: extended BaseCV class to support learning rate
schedulers that monitor a metric
To use ReduceLROnPlateau or OneCycleLR we need to provide as output of configure_optimizers a dictionary (lr_scheduler_config) which contains the scheduler and its associated configuration
---
mlcolvar/cvs/cv.py | 36 +++++++++++++++++++++++++++++-------
1 file changed, 29 insertions(+), 7 deletions(-)
diff --git a/mlcolvar/cvs/cv.py b/mlcolvar/cvs/cv.py
index 7bfd95c5..17c989d5 100644
--- a/mlcolvar/cvs/cv.py
+++ b/mlcolvar/cvs/cv.py
@@ -47,6 +47,7 @@ def __init__(
self._optimizer_name = "Adam"
self.optimizer_kwargs = {}
self.lr_scheduler_kwargs = {}
+ self.lr_scheduler_config = {}
# PRE/POST
self.preprocessing = preprocessing
@@ -88,6 +89,8 @@ def parse_options(self, options: dict = None):
self.optimizer_kwargs.update(options[o])
elif o == "lr_scheduler":
self.lr_scheduler_kwargs.update(options[o])
+ elif o == "lr_scheduler_config":
+ self.lr_scheduler_config.update(options[o])
else:
raise ValueError(
f'The key {o} is not available in this class. The available keys are: {", ".join(self.BLOCKS)}, optimizer and lr_scheduler.'
@@ -192,24 +195,43 @@ def optimizer_name(self, optimizer_name: str):
def configure_optimizers(self):
"""
Initialize the optimizer based on self._optimizer_name and self.optimizer_kwargs.
+ It also adds the learning rate scheduler if self.lr_scheduler_kwargs is not empty.
+ The scheduler is given as a dictionary with the key 'scheduler' containing the scheduler class
+ and the rest of the keys are config options for the scheduler.
Returns
-------
torch.optim
Torch optimizer
+
+ dict, optional
+ Learning rate scheduler configuration (if any)
"""
+ # Create the optimizer from the optimizer name and kwargs
optimizer = getattr(torch.optim, self._optimizer_name)(
self.parameters(), **self.optimizer_kwargs
)
-
- if self.lr_scheduler_kwargs:
- scheduler_cls = self.lr_scheduler_kwargs['scheduler']
- scheduler_kwargs = {k: v for k, v in self.lr_scheduler_kwargs.items() if k != 'scheduler'}
- lr_scheduler = scheduler_cls(optimizer, **scheduler_kwargs)
- return [optimizer] , [lr_scheduler]
- else:
+
+ # Return just the optimizer if no scheduler is defined
+ if not self.lr_scheduler_kwargs:
return optimizer
+
+ # Create the scheduler from the lr_scheduler_kwargs if any
+ scheduler_cls = self.lr_scheduler_kwargs['scheduler']
+ scheduler_kwargs = {k: v for k, v in self.lr_scheduler_kwargs.items() if k != 'scheduler'}
+ print("So far so good")
+ print(scheduler_cls)
+ print(scheduler_kwargs)
+ lr_scheduler = scheduler_cls(optimizer, **scheduler_kwargs)
+ lr_scheduler_config = {
+ "scheduler": lr_scheduler
+ }
+
+ # Add possible additional config options
+ if self.lr_scheduler_config:
+ lr_scheduler_config.update(self.lr_scheduler_config)
+ return [optimizer], [lr_scheduler_config]
def __setattr__(self, key, value):
# PyTorch overrides __setattr__ to raise a TypeError when you try to assign