From 47b3c23fe341ddb34acb50d860ec6ca8895eb9ab Mon Sep 17 00:00:00 2001 From: Sallyliubj Date: Sun, 8 Mar 2026 14:20:46 -0400 Subject: [PATCH 1/2] implemented Exponentiated Gradient Descent --- rate/eg_optimizer.py | 26 +++++++++++++++ rate/main.py | 77 +++++++++++++++++++++++++++++++++++++++----- rate/model.py | 46 +++++++++++++++++++++----- 3 files changed, 133 insertions(+), 16 deletions(-) create mode 100644 rate/eg_optimizer.py diff --git a/rate/eg_optimizer.py b/rate/eg_optimizer.py new file mode 100644 index 0000000..cab974e --- /dev/null +++ b/rate/eg_optimizer.py @@ -0,0 +1,26 @@ +"""Exponentiated Gradient Optimizer""" + +import torch + +class ExponentiatedGradient(torch.optim.Optimizer): + def __init__(self, params, lr=1e-2, weight_decay=0.0, + eps=1e-12, max_norm=None, norm_axis=None): + defaults = dict(lr=lr, weight_decay=weight_decay, + eps=eps, max_norm=max_norm, norm_axis=norm_axis) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + for group in self.param_groups: + lr, wd, eps = group['lr'], group['weight_decay'], group['eps'] + max_norm, norm_axis = group['max_norm'], group['norm_axis'] + for p in group['params']: + if p.grad is None: continue + g = p.grad + wd * p if wd else p.grad + update = torch.exp(-lr * g).clamp_min(torch.finfo(p.dtype).tiny) + p.mul_(update).clamp_min_(eps) + # optional per‑column norm control + if max_norm is not None: + n = p.norm(dim=norm_axis, keepdim=True) + scale = (max_norm / (n + 1e-12)).clamp_max(1.0) + p.mul_(scale) \ No newline at end of file diff --git a/rate/main.py b/rate/main.py index 722a5b0..27e410a 100644 --- a/rate/main.py +++ b/rate/main.py @@ -29,6 +29,9 @@ from model import loss_op +# Exponentiated Gradient optimizer +from eg_optimizer import ExponentiatedGradient + # Parse input arguments parser = argparse.ArgumentParser(description='Training rate RNNs') parser.add_argument('--gpu', required=False, @@ -60,6 +63,12 @@ type=str, default='sigmoid', help="Activation function (sigmoid, clipped_relu)") parser.add_argument("--loss_fn", required=True, type=str, default='l2', help="Loss function (either L1 or L2)") +parser.add_argument("--optimizer", required=False, + type=str, default='adam', help="Optimizer to use (adam or eg)") +parser.add_argument("--momentum", required=False, + type=float, default = 0.0, help="Momentum for the EG optimizer") +parser.add_argument("--weight_decay", required=False, + type=float, default = 0.0, help="Weight decay for the EG optimizer") parser.add_argument("--apply_dale", required=True, type=str2bool, default='True', help="Apply Dale's principle?") parser.add_argument("--decay_taus", required=True, @@ -165,14 +174,48 @@ 'eval_amp_threh': 0.3, # amplitude threshold during response window 'activation': args.act.lower(), # activation function 'loss_fn': args.loss_fn.lower(), # loss function ('L1' or 'L2') - 'P_rec': 0.20 + 'P_rec': 0.20, # initial connectivity probability + 'momentum': args.momentum, # momentum (alpha) + 'weight_decay': args.weight_decay, # weight decay (gamma) + 'optimizer': args.optimizer.lower(), # optimizer to use (`adam` or `eg`) } ''' Set up optimizer ''' if args.mode.lower() == 'train': - optimizer = optim.Adam(net.parameters(), lr=training_params['learning_rate']) + if args.optimizer.lower() == 'adam': + print('Using Adam optimizer for all parameters...') + optimizer = optim.Adam(net.parameters(), lr=training_params['learning_rate']) + training_params['optimizer'] = 'adam' + + elif args.optimizer.lower() == 'eg': + print('Using Exponentiated Gradient (EG) for recurrent weights (w) and Adam for others...') + eg_params = [] + adam_params = [] + + for name, param in net.named_parameters(): + if name == 'w': + # Apply EG only to the recurrent weight magnitudes 'w' + eg_params.append(param) + else: + # Apply Adam to all other parameters (w_in, w_out, b_out) + adam_params.append(param) + + # Create two optimizers + optimizer_eg = ExponentiatedGradient(eg_params, + lr=training_params['learning_rate'], + weight_decay=training_params.get('weight_decay', 0.0)) + + optimizer_adam = optim.Adam(adam_params, lr=training_params['learning_rate']) + + # use a list to hold both and iterate during train step + optimizer = [optimizer_eg, optimizer_adam] + training_params['optimizer'] = 'eg' + + else: + raise ValueError(f"Unknown optimizer: {args.optimizer}") + print('Set up optimizer...') ''' @@ -200,7 +243,11 @@ start_time = time.time() # Zero gradients - optimizer.zero_grad() + if isinstance(optimizer, list): + for opt in optimizer: + opt.zero_grad() + else: + optimizer.zero_grad() # Generate a task-specific input signal u, target, label = task.simulate_trial() @@ -213,13 +260,19 @@ # Forward pass stim, x, r, o, w, w_in, m, som_m, w_out, b_out, taus_gaus = \ net.forward(u_tensor, settings['taus'], training_params, settings) - + # Compute loss t_loss = loss_op(o, target, training_params) # Backward pass t_loss.backward() - optimizer.step() + + # Optimizer step + if isinstance(optimizer, list): + for opt in optimizer: + opt.step() + else: + optimizer.step() print('Loss: ', t_loss.item()) losses[tr] = t_loss.item() @@ -404,15 +457,23 @@ scipy.io.savemat(os.path.join(out_dir, fname), var) # Also save the PyTorch model - torch.save({ + if isinstance(optimizer, list): + optimizer_state = {f'optimizer_{i}_state_dict': opt.state_dict() + for i, opt in enumerate(optimizer)} + else: + optimizer_state = {'optimizer_state_dict': optimizer.state_dict()} + + save_dict = { 'model_state_dict': net.state_dict(), - 'optimizer_state_dict': optimizer.state_dict(), 'settings': settings, 'training_params': training_params, 'losses': losses, 'final_loss': t_loss.item(), 'trial': tr, - }, os.path.join(out_dir, fname.replace('.mat', '.pth'))) + } + save_dict.update(optimizer_state) + + torch.save(save_dict, os.path.join(out_dir, fname.replace('.mat', '.pth'))) diff --git a/rate/model.py b/rate/model.py index fe5933a..d11afa8 100644 --- a/rate/model.py +++ b/rate/model.py @@ -138,7 +138,12 @@ def initialize_W(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: if self.apply_dale == True: w = np.abs(w) - + + # add a small positive value to the weights to avoid division by zero + # nonzero_mask = w > 0 + # w[nonzero_mask] = np.maximum(w[nonzero_mask], 1e-8) + + # Mask matrix mask = np.eye(self.N, dtype=np.float32) mask[np.where(self.inh==True)[0], np.where(self.inh==True)[0]] = -1 @@ -198,6 +203,22 @@ def display(self) -> None: print('\t Zero Weights: %2.2f %%' % (zero_w/(self.N*self.N)*100)) print('\t Positive Weights: %2.2f %%' % (pos_w/(self.N*self.N)*100)) print('\t Negative Weights: %2.2f %%' % (neg_w/(self.N*self.N)*100)) + + def project_weights(self, training_params: Dict[str, Any]) -> None: + """ + Project weights to satisfy Dale's principle constraints (optional). + + Note: With F.relu() in the forward pass, this projection is technically + optional since negative weights are masked during computation anyway. + However, it keeps the stored weights cleaner. + + Args: + training_params (Dict[str, Any]): Training parameters including optimizer type. + """ + # if self.apply_dale == True and training_params.get('optimizer', 'adam') != 'eg': + # with torch.no_grad(): + # self.w.data.clamp_(min=0.0) + pass def forward(self, stim: torch.Tensor, taus: List[float], training_params: Dict[str, Any], settings: Dict[str, Any]) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor], @@ -254,14 +275,16 @@ def forward(self, stim: torch.Tensor, taus: List[float], training_params: Dict[s # Forward pass through time for t in range(1, T): - if self.apply_dale == True: + # When using EG optimizer, weights are non-negative by construction + # so no need for ReLU. For Adam optimizer with Dale's principle, use ReLU. + if self.apply_dale == True and training_params.get('optimizer', 'adam') != 'eg': # Parametrize the weight matrix to enforce exc/inh synaptic currents - w_pos = F.relu(self.w) + w = F.relu(self.w) else: - w_pos = self.w + w = self.w # Compute effective weight matrix - ww = torch.matmul(w_pos, self.mask) + ww = torch.matmul(w, self.mask) ww = ww * self.som_mask # Compute time constants @@ -412,7 +435,7 @@ def loss_op(o: List[torch.Tensor], z: Union[np.ndarray, torch.Tensor], training_ return loss def eval_rnn(net: FR_RNN_dale, settings: Dict[str, Any], u: np.ndarray, - device: torch.device) -> Tuple[List[float], List[np.ndarray], List[np.ndarray]]: + device: torch.device, training_params: Dict[str, Any] = None) -> Tuple[List[float], List[np.ndarray], List[np.ndarray]]: """ Evaluate a trained PyTorch RNN. @@ -421,6 +444,8 @@ def eval_rnn(net: FR_RNN_dale, settings: Dict[str, Any], u: np.ndarray, settings (Dict[str, Any]): Dictionary containing task settings. u (np.ndarray): Stimulus matrix. device (torch.device): PyTorch device. + training_params (Dict[str, Any], optional): Dictionary containing training parameters + including optimizer type. Defaults to None. Returns: Tuple[List[float], List[np.ndarray], List[np.ndarray]]: Tuple containing: @@ -432,6 +457,10 @@ def eval_rnn(net: FR_RNN_dale, settings: Dict[str, Any], u: np.ndarray, DeltaT = settings['DeltaT'] taus = settings['taus'] + # Default training_params if not provided + if training_params is None: + training_params = {'optimizer': 'adam'} + net.eval() with torch.no_grad(): u_tensor = torch.tensor(u, dtype=torch.float32, device=device) @@ -443,10 +472,11 @@ def eval_rnn(net: FR_RNN_dale, settings: Dict[str, Any], u: np.ndarray, o = [] for t in range(1, T): - if net.apply_dale: + # When using EG optimizer, weights are non-negative by construction + if net.apply_dale and training_params.get('optimizer', 'adam') != 'eg': w_pos = F.relu(net.w) else: - w_pos = net.w + w_pos = net.w # non-negative by construction ww = torch.matmul(w_pos, net.mask) ww = ww * net.som_mask From ef2dffea885dedd3ef0bc6bc2af8df869f2de266 Mon Sep 17 00:00:00 2001 From: Sallyliubj Date: Fri, 20 Mar 2026 00:23:38 -0400 Subject: [PATCH 2/2] remove unimplemented function --- rate/model.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/rate/model.py b/rate/model.py index d11afa8..95d6332 100644 --- a/rate/model.py +++ b/rate/model.py @@ -204,21 +204,6 @@ def display(self) -> None: print('\t Positive Weights: %2.2f %%' % (pos_w/(self.N*self.N)*100)) print('\t Negative Weights: %2.2f %%' % (neg_w/(self.N*self.N)*100)) - def project_weights(self, training_params: Dict[str, Any]) -> None: - """ - Project weights to satisfy Dale's principle constraints (optional). - - Note: With F.relu() in the forward pass, this projection is technically - optional since negative weights are masked during computation anyway. - However, it keeps the stored weights cleaner. - - Args: - training_params (Dict[str, Any]): Training parameters including optimizer type. - """ - # if self.apply_dale == True and training_params.get('optimizer', 'adam') != 'eg': - # with torch.no_grad(): - # self.w.data.clamp_(min=0.0) - pass def forward(self, stim: torch.Tensor, taus: List[float], training_params: Dict[str, Any], settings: Dict[str, Any]) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor],