Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions rate/eg_optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""Exponentiated Gradient Optimizer"""

import torch

class ExponentiatedGradient(torch.optim.Optimizer):
def __init__(self, params, lr=1e-2, weight_decay=0.0,
eps=1e-12, max_norm=None, norm_axis=None):
defaults = dict(lr=lr, weight_decay=weight_decay,
eps=eps, max_norm=max_norm, norm_axis=norm_axis)
super().__init__(params, defaults)

@torch.no_grad()
def step(self, closure=None):
for group in self.param_groups:
lr, wd, eps = group['lr'], group['weight_decay'], group['eps']
max_norm, norm_axis = group['max_norm'], group['norm_axis']
for p in group['params']:
if p.grad is None: continue
g = p.grad + wd * p if wd else p.grad
update = torch.exp(-lr * g).clamp_min(torch.finfo(p.dtype).tiny)
p.mul_(update).clamp_min_(eps)
# optional per‑column norm control
if max_norm is not None:
n = p.norm(dim=norm_axis, keepdim=True)
scale = (max_norm / (n + 1e-12)).clamp_max(1.0)
p.mul_(scale)
77 changes: 69 additions & 8 deletions rate/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@

from model import loss_op

# Exponentiated Gradient optimizer
from eg_optimizer import ExponentiatedGradient

# Parse input arguments
parser = argparse.ArgumentParser(description='Training rate RNNs')
parser.add_argument('--gpu', required=False,
Expand Down Expand Up @@ -60,6 +63,12 @@
type=str, default='sigmoid', help="Activation function (sigmoid, clipped_relu)")
parser.add_argument("--loss_fn", required=True,
type=str, default='l2', help="Loss function (either L1 or L2)")
parser.add_argument("--optimizer", required=False,
type=str, default='adam', help="Optimizer to use (adam or eg)")
parser.add_argument("--momentum", required=False,
type=float, default = 0.0, help="Momentum for the EG optimizer")
parser.add_argument("--weight_decay", required=False,
type=float, default = 0.0, help="Weight decay for the EG optimizer")
parser.add_argument("--apply_dale", required=True,
type=str2bool, default='True', help="Apply Dale's principle?")
parser.add_argument("--decay_taus", required=True,
Expand Down Expand Up @@ -165,14 +174,48 @@
'eval_amp_threh': 0.3, # amplitude threshold during response window
'activation': args.act.lower(), # activation function
'loss_fn': args.loss_fn.lower(), # loss function ('L1' or 'L2')
'P_rec': 0.20
'P_rec': 0.20, # initial connectivity probability
'momentum': args.momentum, # momentum (alpha)
'weight_decay': args.weight_decay, # weight decay (gamma)
'optimizer': args.optimizer.lower(), # optimizer to use (`adam` or `eg`)
}

'''
Set up optimizer
'''
if args.mode.lower() == 'train':
optimizer = optim.Adam(net.parameters(), lr=training_params['learning_rate'])
if args.optimizer.lower() == 'adam':
print('Using Adam optimizer for all parameters...')
optimizer = optim.Adam(net.parameters(), lr=training_params['learning_rate'])
training_params['optimizer'] = 'adam'

elif args.optimizer.lower() == 'eg':
print('Using Exponentiated Gradient (EG) for recurrent weights (w) and Adam for others...')
eg_params = []
adam_params = []

for name, param in net.named_parameters():
if name == 'w':
# Apply EG only to the recurrent weight magnitudes 'w'
eg_params.append(param)
else:
# Apply Adam to all other parameters (w_in, w_out, b_out)
adam_params.append(param)

# Create two optimizers
optimizer_eg = ExponentiatedGradient(eg_params,
lr=training_params['learning_rate'],
weight_decay=training_params.get('weight_decay', 0.0))

optimizer_adam = optim.Adam(adam_params, lr=training_params['learning_rate'])

# use a list to hold both and iterate during train step
optimizer = [optimizer_eg, optimizer_adam]
training_params['optimizer'] = 'eg'

else:
raise ValueError(f"Unknown optimizer: {args.optimizer}")

print('Set up optimizer...')

'''
Expand Down Expand Up @@ -200,7 +243,11 @@
start_time = time.time()

# Zero gradients
optimizer.zero_grad()
if isinstance(optimizer, list):
for opt in optimizer:
opt.zero_grad()
else:
optimizer.zero_grad()

# Generate a task-specific input signal
u, target, label = task.simulate_trial()
Expand All @@ -213,13 +260,19 @@
# Forward pass
stim, x, r, o, w, w_in, m, som_m, w_out, b_out, taus_gaus = \
net.forward(u_tensor, settings['taus'], training_params, settings)

# Compute loss
t_loss = loss_op(o, target, training_params)

# Backward pass
t_loss.backward()
optimizer.step()

# Optimizer step
if isinstance(optimizer, list):
for opt in optimizer:
opt.step()
else:
optimizer.step()

print('Loss: ', t_loss.item())
losses[tr] = t_loss.item()
Expand Down Expand Up @@ -404,15 +457,23 @@
scipy.io.savemat(os.path.join(out_dir, fname), var)

# Also save the PyTorch model
torch.save({
if isinstance(optimizer, list):
optimizer_state = {f'optimizer_{i}_state_dict': opt.state_dict()
for i, opt in enumerate(optimizer)}
else:
optimizer_state = {'optimizer_state_dict': optimizer.state_dict()}

save_dict = {
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'settings': settings,
'training_params': training_params,
'losses': losses,
'final_loss': t_loss.item(),
'trial': tr,
}, os.path.join(out_dir, fname.replace('.mat', '.pth')))
}
save_dict.update(optimizer_state)

torch.save(save_dict, os.path.join(out_dir, fname.replace('.mat', '.pth')))



Expand Down
31 changes: 23 additions & 8 deletions rate/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,12 @@ def initialize_W(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:

if self.apply_dale == True:
w = np.abs(w)


# add a small positive value to the weights to avoid division by zero
# nonzero_mask = w > 0
# w[nonzero_mask] = np.maximum(w[nonzero_mask], 1e-8)


# Mask matrix
mask = np.eye(self.N, dtype=np.float32)
mask[np.where(self.inh==True)[0], np.where(self.inh==True)[0]] = -1
Expand Down Expand Up @@ -198,6 +203,7 @@ def display(self) -> None:
print('\t Zero Weights: %2.2f %%' % (zero_w/(self.N*self.N)*100))
print('\t Positive Weights: %2.2f %%' % (pos_w/(self.N*self.N)*100))
print('\t Negative Weights: %2.2f %%' % (neg_w/(self.N*self.N)*100))


def forward(self, stim: torch.Tensor, taus: List[float], training_params: Dict[str, Any],
settings: Dict[str, Any]) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor],
Expand Down Expand Up @@ -254,14 +260,16 @@ def forward(self, stim: torch.Tensor, taus: List[float], training_params: Dict[s

# Forward pass through time
for t in range(1, T):
if self.apply_dale == True:
# When using EG optimizer, weights are non-negative by construction
# so no need for ReLU. For Adam optimizer with Dale's principle, use ReLU.
if self.apply_dale == True and training_params.get('optimizer', 'adam') != 'eg':
# Parametrize the weight matrix to enforce exc/inh synaptic currents
w_pos = F.relu(self.w)
w = F.relu(self.w)
else:
w_pos = self.w
w = self.w

# Compute effective weight matrix
ww = torch.matmul(w_pos, self.mask)
ww = torch.matmul(w, self.mask)
ww = ww * self.som_mask

# Compute time constants
Expand Down Expand Up @@ -412,7 +420,7 @@ def loss_op(o: List[torch.Tensor], z: Union[np.ndarray, torch.Tensor], training_
return loss

def eval_rnn(net: FR_RNN_dale, settings: Dict[str, Any], u: np.ndarray,
device: torch.device) -> Tuple[List[float], List[np.ndarray], List[np.ndarray]]:
device: torch.device, training_params: Dict[str, Any] = None) -> Tuple[List[float], List[np.ndarray], List[np.ndarray]]:
"""
Evaluate a trained PyTorch RNN.

Expand All @@ -421,6 +429,8 @@ def eval_rnn(net: FR_RNN_dale, settings: Dict[str, Any], u: np.ndarray,
settings (Dict[str, Any]): Dictionary containing task settings.
u (np.ndarray): Stimulus matrix.
device (torch.device): PyTorch device.
training_params (Dict[str, Any], optional): Dictionary containing training parameters
including optimizer type. Defaults to None.

Returns:
Tuple[List[float], List[np.ndarray], List[np.ndarray]]: Tuple containing:
Expand All @@ -432,6 +442,10 @@ def eval_rnn(net: FR_RNN_dale, settings: Dict[str, Any], u: np.ndarray,
DeltaT = settings['DeltaT']
taus = settings['taus']

# Default training_params if not provided
if training_params is None:
training_params = {'optimizer': 'adam'}

net.eval()
with torch.no_grad():
u_tensor = torch.tensor(u, dtype=torch.float32, device=device)
Expand All @@ -443,10 +457,11 @@ def eval_rnn(net: FR_RNN_dale, settings: Dict[str, Any], u: np.ndarray,
o = []

for t in range(1, T):
if net.apply_dale:
# When using EG optimizer, weights are non-negative by construction
if net.apply_dale and training_params.get('optimizer', 'adam') != 'eg':
w_pos = F.relu(net.w)
else:
w_pos = net.w
w_pos = net.w # non-negative by construction

ww = torch.matmul(w_pos, net.mask)
ww = ww * net.som_mask
Expand Down