From 613a95ffd47802a3f07f14eac0ae793617032cbf Mon Sep 17 00:00:00 2001 From: carolinehaoud Date: Tue, 29 Nov 2022 18:01:23 -0500 Subject: [PATCH] fixed batch size incrementation in training; still need to test --- psychrnn/backend/rnn.py | 163 +++++++++++++++++++++++----------------- 1 file changed, 95 insertions(+), 68 deletions(-) diff --git a/psychrnn/backend/rnn.py b/psychrnn/backend/rnn.py index ab18136..20fcac0 100644 --- a/psychrnn/backend/rnn.py +++ b/psychrnn/backend/rnn.py @@ -1,4 +1,13 @@ from __future__ import division +from psychrnn.backend.initializations import WeightInitializer, GaussianSpectralRadius +from psychrnn.backend.loss_functions import LossFunction +from psychrnn.backend.regularizations import Regularizer +from inspect import isgenerator +from os import makedirs, path +from time import time +import sys +import numpy as np +import tensorflow as tf from __future__ import print_function from abc import ABCMeta, abstractmethod @@ -6,20 +15,10 @@ # abstract class python 2 & 3 compatible ABC = ABCMeta('ABC', (object,), {}) -import tensorflow as tf -import numpy as np - -import sys -from time import time -from os import makedirs, path -from inspect import isgenerator - -from psychrnn.backend.regularizations import Regularizer -from psychrnn.backend.loss_functions import LossFunction -from psychrnn.backend.initializations import WeightInitializer, GaussianSpectralRadius tf.compat.v1.disable_eager_execution() + class RNN(ABC): """ The base recurrent neural network class. @@ -65,11 +64,12 @@ class RNN(ABC): * **autapses** (*bool, optional*) -- If False, self connections are not allowed in N_rec, and diagonal of :data:`rec_connectivity` will be set to 0. Default: True. * **dale_ratio** (float, optional) -- Dale's ratio, used to construct Dale_rec and Dale_out. 0 <= dale_ratio <=1 if dale_ratio should be used. ``dale_ratio * N_rec`` recurrent units will be excitatory, the rest will be inhibitory. Default: None * **transfer_function** (*function, optional*) -- Transfer function to use for the network. Default: `tf.nn.relu `_. - + Inferred Parameters: * **alpha** (*float*) -- The number of unit time constants per simulation timestep. """ + def __init__(self, params): self.params = params @@ -114,13 +114,13 @@ def __init__(self, params): except KeyError: print("You must pass 'dt' to RNN") raise - + try: self.tau = params['tau'] except KeyError: print("You must pass 'tau' to RNN") raise - try: + try: self.tau = self.tau.astype('float32') except AttributeError: pass @@ -130,11 +130,10 @@ def __init__(self, params): except KeyError: print("You must pass 'N_batch' to RNN") raise - + self.alpha = (1.0 * self.dt) / self.tau self.rec_noise = params.get('rec_noise', 0.0) - # ---------------------------------- # Load weights path # ---------------------------------- @@ -145,7 +144,8 @@ def __init__(self, params): # ------------------------------------------------ if self.load_weights_path is not None: # transfer function is passed in here only for backwards compatibility -- if you load weights saved before transfer_function was added to saved weights, the model will use the custom transfer function passed in. - self.initializer = WeightInitializer(load_weights_path=self.load_weights_path, transfer_function=params.get('transfer_function', tf.nn.relu)) + self.initializer = WeightInitializer( + load_weights_path=self.load_weights_path, transfer_function=params.get('transfer_function', tf.nn.relu)) elif params.get('W_rec', None) is not None: self.initializer = params.get('initializer', WeightInitializer(**params)) @@ -171,7 +171,8 @@ def __init__(self, params): # --------------------------------------------------- self.x = tf.compat.v1.placeholder("float", [None, N_steps, N_in]) self.y = tf.compat.v1.placeholder("float", [None, N_steps, N_out]) - self.output_mask = tf.compat.v1.placeholder("float", [None, N_steps, N_out]) + self.output_mask = tf.compat.v1.placeholder( + "float", [None, N_steps, N_out]) # -------------------------------------------------- # Initialize variables in proper scope @@ -183,19 +184,21 @@ def __init__(self, params): # ------------------------------------------------ try: self.init_state = tf.compat.v1.get_variable('init_state', [1, N_rec], - initializer=self.initializer.get('init_state'), - trainable=self.init_state_train) + initializer=self.initializer.get( + 'init_state'), + trainable=self.init_state_train) except ValueError as error: - raise UserWarning("Try calling model.destruct() or changing params['name'].") - + raise UserWarning( + "Try calling model.destruct() or changing params['name'].") self.init_state = tf.tile(self.init_state, [self.N_batch, 1]) # Input weight matrix: self.W_in = \ tf.compat.v1.get_variable('W_in', [N_rec, N_in], - initializer=self.initializer.get('W_in'), - trainable=self.W_in_train) + initializer=self.initializer.get( + 'W_in'), + trainable=self.W_in_train) # Recurrent weight matrix: self.W_rec = \ @@ -207,15 +210,16 @@ def __init__(self, params): # Output weight matrix: self.W_out = tf.compat.v1.get_variable('W_out', [N_out, N_rec], - initializer=self.initializer.get('W_out'), - trainable=self.W_out_train) + initializer=self.initializer.get( + 'W_out'), + trainable=self.W_out_train) # Recurrent bias: self.b_rec = tf.compat.v1.get_variable('b_rec', [N_rec], initializer=self.initializer.get('b_rec'), - trainable=self.b_rec_train) + trainable=self.b_rec_train) # Output bias: self.b_out = tf.compat.v1.get_variable('b_out', [N_out], initializer=self.initializer.get('b_out'), - trainable=self.b_out_train) + trainable=self.b_out_train) # ------------------------------------------------ # Non-trainable variables: @@ -224,24 +228,29 @@ def __init__(self, params): # Recurrent Dale's law weight matrix: self.Dale_rec = tf.compat.v1.get_variable('Dale_rec', [N_rec, N_rec], - initializer=self.initializer.get('Dale_rec'), - trainable=False) + initializer=self.initializer.get( + 'Dale_rec'), + trainable=False) # Output Dale's law weight matrix: self.Dale_out = tf.compat.v1.get_variable('Dale_out', [N_rec, N_rec], - initializer=self.initializer.get('Dale_out'), - trainable=False) + initializer=self.initializer.get( + 'Dale_out'), + trainable=False) # Connectivity weight matrices: self.input_connectivity = tf.compat.v1.get_variable('input_connectivity', [N_rec, N_in], - initializer=self.initializer.get('input_connectivity'), - trainable=False) + initializer=self.initializer.get( + 'input_connectivity'), + trainable=False) self.rec_connectivity = tf.compat.v1.get_variable('rec_connectivity', [N_rec, N_rec], - initializer=self.initializer.get('rec_connectivity'), - trainable=False) + initializer=self.initializer.get( + 'rec_connectivity'), + trainable=False) self.output_connectivity = tf.compat.v1.get_variable('output_connectivity', [N_out, N_rec], - initializer=self.initializer.get('output_connectivity'), - trainable=False) + initializer=self.initializer.get( + 'output_connectivity'), + trainable=False) # -------------------------------------------------- # Flag to check if variables initialized, model built @@ -332,21 +341,22 @@ def get_effective_W_out(self): if self.dale_ratio: W_out = tf.matmul(tf.abs(W_out), self.Dale_out, name="in_2") return W_out - + @abstractmethod def forward_pass(self): """ Run the RNN on a batch of task inputs. Note: This is an abstract function that must be defined in a child class. - + Returns: tuple: * **predictions** (*ndarray(dtype=float, shape=(*:attr:`N_batch`, :attr:`N_steps`, :attr:`N_out` *))*) -- Network output on inputs found in self.x within the tf network. * **states** (*ndarray(dtype=float, shape=(*:attr:`N_batch`, :attr:`N_steps`, :attr:`N_rec` *))*) -- State variable values over the course of the trials found in self.x within the tf network. """ - raise UserWarning("forward_pass must be implemented in child class. See Basic for example.") + raise UserWarning( + "forward_pass must be implemented in child class. See Basic for example.") def get_weights(self): """ Get weights used in the network. @@ -380,17 +390,20 @@ def get_weights(self): if not self.is_initialized: self.sess.run(tf.compat.v1.global_variables_initializer()) self.is_initialized = True - + weights_dict = dict() - + for var in tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, scope=self.name): # avoid saving duplicates if var.name.endswith(':0') and var.name.startswith(self.name): name = var.name[len(self.name)+1:-2] weights_dict.update({name: var.eval(session=self.sess)}) - weights_dict.update({'W_rec': self.get_effective_W_rec().eval(session=self.sess)}) - weights_dict.update({'W_in': self.get_effective_W_in().eval(session=self.sess)}) - weights_dict.update({'W_out': self.get_effective_W_out().eval(session=self.sess)}) + weights_dict.update( + {'W_rec': self.get_effective_W_rec().eval(session=self.sess)}) + weights_dict.update( + {'W_in': self.get_effective_W_in().eval(session=self.sess)}) + weights_dict.update( + {'W_out': self.get_effective_W_out().eval(session=self.sess)}) weights_dict['dale_ratio'] = self.dale_ratio weights_dict['transfer_function'] = self.transfer_function return weights_dict @@ -472,25 +485,31 @@ def train(self, trial_batch_generator, train_params={}): loss_epoch = train_params.get('loss_epoch', 10) verbosity = train_params.get('verbosity', True) save_weights_path = train_params.get('save_weights_path', None) - save_training_weights_epoch = train_params.get('save_training_weights_epoch', 100) + save_training_weights_epoch = train_params.get( + 'save_training_weights_epoch', 100) training_weights_path = train_params.get('training_weights_path', None) curriculum = train_params.get('curriculum', None) optimizer = train_params.get('optimizer', tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)) clip_grads = train_params.get('clip_grads', True) - fixed_weights = train_params.get('fixed_weights', None) # array of zeroes and ones. One indicates to pin and not train that weight. + # array of zeroes and ones. One indicates to pin and not train that weight. + fixed_weights = train_params.get('fixed_weights', None) performance_cutoff = train_params.get('performance_cutoff', None) performance_measure = train_params.get('performance_measure', None) if (performance_cutoff is not None and performance_measure is None) or (performance_cutoff is None and performance_measure is not None): - raise UserWarning("training will not be cutoff based on performance. Make sure both performance_measure and performance_cutoff are defined") + raise UserWarning( + "training will not be cutoff based on performance. Make sure both performance_measure and performance_cutoff are defined") if curriculum is not None: - trial_batch_generator = curriculum.get_generator_function() + trial_batch_generator_batch_size = curriculum.get_generator_function() + trial_batch_generator_training = curriculum.get_generator_function() + trial_batch_generator_curriculum = curriculum.get_generator_function() if not isgenerator(trial_batch_generator): trial_batch_generator = trial_batch_generator.batch_generator() - + trial_batch_generator_training = curriculum.batch_generator() + trial_batch_generator_curriculum = curriculum.batch_generator() # -------------------------------------------------- # Make weights folder if it doesn't already exist. # -------------------------------------------------- @@ -521,7 +540,6 @@ def train(self, trial_batch_generator, train_params={}): grad = tf.multiply(grad, (1-fixed_weights[name])) grads[i] = (grad, var) - # -------------------------------------------------- # Clip gradients # -------------------------------------------------- @@ -546,35 +564,40 @@ def train(self, trial_batch_generator, train_params={}): # Training loop # -------------------------------------------------- epoch = 1 - batch_size = next(trial_batch_generator)[0].shape[0] + batch_size = next(trial_batch_generator_batch_size)[0].shape[0] losses = [] if performance_cutoff is not None: performance = performance_cutoff - 1 while (epoch - 1) * batch_size < training_iters and (performance_cutoff is None or performance < performance_cutoff): - batch_x, batch_y, output_mask, _ = next(trial_batch_generator) - self.sess.run(optimize, feed_dict={self.x: batch_x, self.y: batch_y, self.output_mask: output_mask}) + batch_x, batch_y, output_mask, _ = next( + trial_batch_generator_training) + self.sess.run(optimize, feed_dict={ + self.x: batch_x, self.y: batch_y, self.output_mask: output_mask}) # -------------------------------------------------- # Output batch loss # -------------------------------------------------- if epoch % loss_epoch == 0: reg_loss = self.sess.run(self.reg_loss, - feed_dict={self.x: batch_x, self.y: batch_y, self.output_mask: output_mask}) + feed_dict={self.x: batch_x, self.y: batch_y, self.output_mask: output_mask}) losses.append(reg_loss) if verbosity: - print("Iter " + str(epoch * batch_size) + ", Minibatch Loss= " + \ + print("Iter " + str(epoch * batch_size) + ", Minibatch Loss= " + "{:.6f}".format(reg_loss)) # -------------------------------------------------- # Allow for curriculum learning # -------------------------------------------------- if curriculum is not None and epoch % curriculum.metric_epoch == 0: - trial_batch, trial_y, output_mask, _ = next(trial_batch_generator) + trial_batch, trial_y, output_mask, _ = next( + trial_batch_generator_curriculum) output, _ = self.test(trial_batch) if curriculum.metric_test(trial_batch, trial_y, output_mask, output, epoch, losses, verbosity): if curriculum.stop_training: break - trial_batch_generator = curriculum.get_generator_function() + trial_batch_generator_batch_size = curriculum.get_generator_function() + trial_batch_generator_training = curriculum.get_generator_function() + trial_batch_generator_curriculum = curriculum.get_generator_function() # -------------------------------------------------- # Save intermediary weights @@ -583,15 +606,18 @@ def train(self, trial_batch_generator, train_params={}): if training_weights_path is not None: self.save(training_weights_path + str(epoch)) if verbosity: - print("Training weights saved in file: %s" % training_weights_path + str(epoch)) - + print("Training weights saved in file: %s" % + training_weights_path + str(epoch)) + # --------------------------------------------------- # Update performance value if necessary # --------------------------------------------------- if performance_measure is not None: - trial_batch, trial_y, output_mask, _ = next(trial_batch_generator) + trial_batch, trial_y, output_mask, _ = next( + trial_batch_generator) output, _ = self.test(trial_batch) - performance = performance_measure(trial_batch, trial_y, output_mask, output, epoch, losses, verbosity) + performance = performance_measure( + trial_batch, trial_y, output_mask, output, epoch, losses, verbosity) if verbosity: print("performance: " + str(performance)) epoch += 1 @@ -613,7 +639,6 @@ def train(self, trial_batch_generator, train_params={}): # -------------------------------------------------- return losses, (t2 - t1), (t1 - t0) - def train_curric(self, train_params): """Wrapper function for training with curriculum to streamline curriculum learning. @@ -630,9 +655,11 @@ def train_curric(self, train_params): curriculum = train_params.get('curriculum', None) if curriculum is None: - raise UserWarning("train_curric requires a curriculum. Please pass in a curriculum or use train instead.") - - losses, training_time, initialization_time = self.train(curriculum.get_generator_function(), train_params) + raise UserWarning( + "train_curric requires a curriculum. Please pass in a curriculum or use train instead.") + + losses, training_time, initialization_time = self.train( + curriculum.get_generator_function(), train_params) return losses, training_time, initialization_time @@ -641,7 +668,7 @@ def test(self, trial_batch): Arguments: trial_batch ((*ndarray(dtype=float, shape =(*:attr:`N_batch`, :attr:`N_steps`, :attr:`N_out` *))*): Task stimulus to run the network on. Stimulus from :func:`psychrnn.tasks.task.Task.get_trial_batch`, or from next(:func:`psychrnn.tasks.task.Task.batch_generator` ). - + Returns: tuple: * **outputs** (*ndarray(dtype=float, shape =(*:attr:`N_batch`, :attr:`N_steps`, :attr:`N_out` *))*) -- Output time series of the network for each trial in the batch.