diff --git a/configs/test_shakespeare.json b/configs/test_shakespeare.json index 67f20d3..71bd6fd 100644 --- a/configs/test_shakespeare.json +++ b/configs/test_shakespeare.json @@ -1,10 +1,10 @@ { - "batch_size": 16, + "batch_size": 64, "dataset": "shakespeare", "dataset_kwargs": {}, "max_epoch": 10, "model": "llama", - "model_kwargs": {"vocab_size": 92, "dim": 384, "expand": 4, "n_layers": 3, "n_heads": 2, "mlp": "mlp", "seq_len": 512}, + "model_kwargs": {"vocab_size": 92, "dim": 384, "expand": 4, "n_layers": 6, "n_heads": 6, "mlp": "mlp", "seq_len": 256}, "opt": [{"name": "adam", "lr": [1e-3], "lr_schedule": "constant", "warmup_steps": 100, "stepwise_schedule": true}], "loss_func": "sequence_cross_entropy", "score_func": "sequence_cross_entropy_accuracy", diff --git a/output/test_shakespeare.json b/output/test_shakespeare.json index f4d629b..dbe0c35 100644 --- a/output/test_shakespeare.json +++ b/output/test_shakespeare.json @@ -1,7 +1,7 @@ [ { "config": { - "batch_size": 16, + "batch_size": 64, "dataset": "shakespeare", "dataset_kwargs": {}, "loss_func": "sequence_cross_entropy", @@ -11,9 +11,9 @@ "dim": 384, "expand": 4, "mlp": "mlp", - "n_heads": 2, - "n_layers": 3, - "seq_len": 512, + "n_heads": 6, + "n_layers": 6, + "seq_len": 256, "vocab_size": 92 }, "opt": { @@ -29,129 +29,129 @@ "history": [ { "epoch": 0, - "grad_norm": 1.992753267288208, + "grad_norm": 1.121427059173584, "learning_rate": 1e-13, - "model_norm": 64.88945770263672, - "train_epoch_time": 20.36537790298462, - "train_loss": 2.165691924560494, - "train_score": 0.36806988216568765, - "val_loss": 2.218102411840154, - "val_score": 0.3520025144363272 + "model_norm": 87.65153503417969, + "train_epoch_time": 36.848896980285645, + "train_loss": 2.461630094821005, + "train_score": 0.29407953720865837, + "val_loss": 2.4997918203421765, + "val_score": 0.28291834143218164 }, { "epoch": 1, - "grad_norm": 1.334716558456421, - "learning_rate": 0.001, - "model_norm": 66.11741638183594, - "train_epoch_time": 20.214939832687378, - "train_loss": 1.8825722357983157, - "train_score": 0.4417044885807278, - "val_loss": 2.0125111694993643, - "val_score": 0.4103133981255279 + "grad_norm": 1.3179094791412354, + "learning_rate": 0.000540000000046, + "model_norm": 88.65203094482422, + "train_epoch_time": 36.27439522743225, + "train_loss": 2.0504614697296275, + "train_score": 0.40136858857687097, + "val_loss": 2.11395750226438, + "val_score": 0.3813056111609484 }, { "epoch": 2, - "grad_norm": 1.0500348806381226, + "grad_norm": 1.0813289880752563, "learning_rate": 0.001, - "model_norm": 67.5, - "train_epoch_time": 19.981853008270264, - "train_loss": 1.7142505491094666, - "train_score": 0.48880484908921845, - "val_loss": 1.9234138680600572, - "val_score": 0.4437095905172414 + "model_norm": 89.66089630126953, + "train_epoch_time": 36.44406986236572, + "train_loss": 1.7927071080488317, + "train_score": 0.46938217360969936, + "val_loss": 1.929213602014579, + "val_score": 0.43094324097293935 }, { "epoch": 3, - "grad_norm": 1.0010490417480469, + "grad_norm": 1.2653794288635254, "learning_rate": 0.001, - "model_norm": 68.94657897949219, - "train_epoch_time": 19.904484033584595, - "train_loss": 1.6100540534773868, - "train_score": 0.5147796176741782, - "val_loss": 1.8390625857758796, - "val_score": 0.4670977011494253 + "model_norm": 90.63284301757812, + "train_epoch_time": 36.52592372894287, + "train_loss": 1.6544971603232101, + "train_score": 0.5043736549497848, + "val_loss": 1.8547480991869378, + "val_score": 0.4552821832806317 }, { "epoch": 4, - "grad_norm": 0.867782711982727, + "grad_norm": 0.9828493595123291, "learning_rate": 0.001, - "model_norm": 70.42530822753906, - "train_epoch_time": 19.90021586418152, - "train_loss": 1.5380675991148134, - "train_score": 0.5328275243806236, - "val_loss": 1.792324341576675, - "val_score": 0.4788793104818498 + "model_norm": 91.56520080566406, + "train_epoch_time": 36.71347689628601, + "train_loss": 1.5387006788034863, + "train_score": 0.5344926917056956, + "val_loss": 1.7642201408316156, + "val_score": 0.4774101251644327 }, { "epoch": 5, - "grad_norm": 0.7821542620658875, + "grad_norm": 1.0504655838012695, "learning_rate": 0.001, - "model_norm": 71.892822265625, - "train_epoch_time": 19.903041124343872, - "train_loss": 1.4823639410645215, - "train_score": 0.54531541538567, - "val_loss": 1.7534021857141078, - "val_score": 0.49129849144782145 + "model_norm": 92.474609375, + "train_epoch_time": 36.35963201522827, + "train_loss": 1.4742244492986452, + "train_score": 0.5499394725350772, + "val_loss": 1.7320461796017108, + "val_score": 0.49073442892009983 }, { "epoch": 6, - "grad_norm": 0.888006865978241, + "grad_norm": 0.8896018862724304, "learning_rate": 0.001, - "model_norm": 73.36498260498047, - "train_epoch_time": 19.90959620475769, - "train_loss": 1.4459222527787552, - "train_score": 0.5567617144705097, - "val_loss": 1.7188306109658602, - "val_score": 0.5006106322524191 + "model_norm": 93.3919906616211, + "train_epoch_time": 36.567251205444336, + "train_loss": 1.4220848283945573, + "train_score": 0.5644335094791232, + "val_loss": 1.6994645117892189, + "val_score": 0.5016414314405789 }, { "epoch": 7, - "grad_norm": 0.7781890034675598, + "grad_norm": 0.7851909399032593, "learning_rate": 0.001, - "model_norm": 74.81156921386719, - "train_epoch_time": 20.277348041534424, - "train_loss": 1.4094120469350628, - "train_score": 0.564420610495194, - "val_loss": 1.7047440726181557, - "val_score": 0.5073141165163325 + "model_norm": 94.29660034179688, + "train_epoch_time": 36.6595242023468, + "train_loss": 1.3814418076784745, + "train_score": 0.5747007264639418, + "val_loss": 1.6752228601381234, + "val_score": 0.5100907716904387 }, { "epoch": 8, - "grad_norm": 0.7673189043998718, + "grad_norm": 0.7439262866973877, "learning_rate": 0.001, - "model_norm": 76.27680969238281, - "train_epoch_time": 20.30754780769348, - "train_loss": 1.3734883747473103, - "train_score": 0.5729114344828986, - "val_loss": 1.693546010159898, - "val_score": 0.5119118175287356 + "model_norm": 95.21795654296875, + "train_epoch_time": 36.596500873565674, + "train_loss": 1.359556995163347, + "train_score": 0.5797659610573155, + "val_loss": 1.6677938104360166, + "val_score": 0.511516934831709 }, { "epoch": 9, - "grad_norm": 0.6948201060295105, + "grad_norm": 0.7252312898635864, "learning_rate": 0.001, - "model_norm": 77.71929931640625, - "train_epoch_time": 20.247668027877808, - "train_loss": 1.3485437734255699, - "train_score": 0.5800568226694924, - "val_loss": 1.672089501358997, - "val_score": 0.5159931754243785 + "model_norm": 96.16193389892578, + "train_epoch_time": 36.32176995277405, + "train_loss": 1.3168328549290662, + "train_score": 0.5910767124578292, + "val_loss": 1.6540300926850118, + "val_score": 0.5194550074474409 } ], "summary": { "data_parallel": "false", - "end_time": "2025-07-23 14:30:37.932485", - "final_model_norm": 77.71929931640625, - "init_model_norm": 64.04080963134766, + "end_time": "2025-12-01 15:12:04.720226", + "final_model_norm": 96.16193389892578, + "init_model_norm": 87.41546630859375, "input_dim": [ - 512 + 256 ], - "num_batches_per_epoch": 108, + "num_batches_per_epoch": 54, "num_workers": 0, "output_dim": [ - 512 + 256 ], - "start_time": "2025-07-23 14:25:44.459311", + "start_time": "2025-12-01 15:03:08.470563", "step_scheduler_on_epoch": false } } diff --git a/run.py b/run.py index 2424dba..8ecc6fb 100644 --- a/run.py +++ b/run.py @@ -25,17 +25,18 @@ parser.add_argument('--verbose', action="store_true", help="Verbose mode.") parser.add_argument('--force-deterministic', action="store_true", help="Use deterministic mode in Pytorch. Might require setting environment variables.") -def run_one(exp_id: str, - config_dir: str=DEFAULTS.config_dir, - output_dir: str=DEFAULTS.output_dir, - data_dir: str=DEFAULTS.data_dir, - device: str=DEFAULTS.device, - num_workers: int=DEFAULTS.num_workers, - data_parallel: Union[list, None]=DEFAULTS.data_parallel, - log_every_k_steps: Union[int, None]=DEFAULTS.log_every_k_steps, - verbose: bool=DEFAULTS.verbose, - force_deterministic: bool=DEFAULTS.force_deterministic - ): +def run_one( + exp_id: str, + config_dir: str=DEFAULTS.config_dir, + output_dir: str=DEFAULTS.output_dir, + data_dir: str=DEFAULTS.data_dir, + device: str=DEFAULTS.device, + num_workers: int=DEFAULTS.num_workers, + data_parallel: Union[list, None]=DEFAULTS.data_parallel, + log_every_k_steps: Union[int, None]=DEFAULTS.log_every_k_steps, + verbose: bool=DEFAULTS.verbose, + force_deterministic: bool=DEFAULTS.force_deterministic + ): """Function for running all runs from one config file. Default values for all arguments can be found in ``stepback/defaults.py``. @@ -83,14 +84,16 @@ def run_one(exp_id: str, for j, config in enumerate(exp_list): # each run gets id, by position in the list - B = Base(name=exp_id + f'_{j}', - config=config, - device=device, - data_dir=data_dir, - num_workers=num_workers, - data_parallel=data_parallel, - log_every_k_steps=log_every_k_steps, - verbose=verbose) + B = Base( + name=exp_id + f'_{j}', + config=config, + device=device, + data_dir=data_dir, + num_workers=num_workers, + data_parallel=data_parallel, + log_every_k_steps=log_every_k_steps, + verbose=verbose + ) B.setup() B.run() # train and validate @@ -106,15 +109,17 @@ def run_one(exp_id: str, print(args) - run_one(args.id, - config_dir=args.config_dir, - output_dir=args.output_dir, - data_dir=args.data_dir, - device=args.device, - num_workers=args.num_workers, - data_parallel=args.data_parallel, - log_every_k_steps=args.log_every_k_steps, - verbose=args.verbose, - force_deterministic=args.force_deterministic) + run_one( + args.id, + config_dir=args.config_dir, + output_dir=args.output_dir, + data_dir=args.data_dir, + device=args.device, + num_workers=args.num_workers, + data_parallel=args.data_parallel, + log_every_k_steps=args.log_every_k_steps, + verbose=args.verbose, + force_deterministic=args.force_deterministic + ) diff --git a/stepback/base.py b/stepback/base.py index 81b1a0a..9b92ebd 100644 --- a/stepback/base.py +++ b/stepback/base.py @@ -18,14 +18,17 @@ from .defaults import DEFAULTS class Base: - def __init__(self, name: str, - config: dict, - device: str=DEFAULTS.device, - data_dir: str=DEFAULTS.data_dir, - num_workers: int=DEFAULTS.num_workers, - data_parallel: Union[list, None]=DEFAULTS.data_parallel, - log_every_k_steps: Union[int, None]=DEFAULTS.log_every_k_steps, - verbose: bool=DEFAULTS.verbose): + def __init__( + self, + name: str, + config: dict, + device: str=DEFAULTS.device, + data_dir: str=DEFAULTS.data_dir, + num_workers: int=DEFAULTS.num_workers, + data_parallel: Union[list, None]=DEFAULTS.data_parallel, + log_every_k_steps: Union[int, None]=DEFAULTS.log_every_k_steps, + verbose: bool=DEFAULTS.verbose + ): """The main class. Performs one single training run plus evaluation. Parameters @@ -87,9 +90,11 @@ def __init__(self, name: str, self.check_config() # Create ditionary for results - self.results = {'config': self.config, - 'history': {}, - 'summary': {}} + self.results = { + 'config': self.config, + 'history': {}, + 'summary': {} + } self.results['summary']['num_workers'] = self.num_workers self.results['summary']['data_parallel'] = 'true' if self.data_parallel else 'false' @@ -122,11 +127,12 @@ def _setup_data(self): self.results['summary']['input_dim'], self.results['summary']['output_dim'] = infer_shapes(self.train_set) # construct train loader - self.train_loader = get_loader(ds=self.train_set, - seed=self.run_seed, - batch_size=self.config['batch_size'], - num_workers=self.num_workers, - drop_last=True + self.train_loader = get_loader( + ds=self.train_set, + seed=self.run_seed, + batch_size=self.config['batch_size'], + num_workers=self.num_workers, + drop_last=True ) return @@ -136,9 +142,10 @@ def _setup_model(self): torch.manual_seed(self.seed) # Reseed to have same initialization torch.cuda.manual_seed_all(self.seed) - self.model = get_model(config=self.config, - input_dim=self.results['summary'].get('input_dim',[]), - output_dim=self.results['summary'].get('output_dim',[]) + self.model = get_model( + config=self.config, + input_dim=self.results['summary'].get('input_dim',[]), + output_dim=self.results['summary'].get('output_dim',[]) ) self.model.to(self.device) @@ -168,8 +175,13 @@ def setup(self): opt_obj, hyperp = get_optimizer(self.config['opt']) self._init_opt(opt_obj, hyperp) - - self.sched, self._step_scheduler_on_epoch = get_scheduler(self.config['opt'], self.opt) + + # total number of iters (either in steps or epochs) for LR schedule + if self.config['opt'].get('stepwise_schedule', False): + num_iter = self.config['max_epoch'] * len(self.train_loader) + else: + num_iter = self.config['max_epoch'] + self.sched, self._step_scheduler_on_epoch = get_scheduler(self.config['opt'], num_iter, self.opt) #============ Results ============== opt_val = self._compute_opt_value() @@ -224,15 +236,19 @@ def run(self): # Validation with torch.no_grad(): - metric_dict = {'loss': Loss(self.config['loss_func'], backwards=False), - 'score': Loss(self.config['score_func'], backwards=False)} + metric_dict = { + 'loss': Loss(self.config['loss_func'], backwards=False), + 'score': Loss(self.config['score_func'], backwards=False) + } - train_dict = self.evaluate(self.train_set, - metric_dict = metric_dict, + train_dict = self.evaluate( + self.train_set, + metric_dict = metric_dict, ) - val_dict = self.evaluate(self.val_set, - metric_dict = metric_dict, + val_dict = self.evaluate( + self.val_set, + metric_dict = metric_dict, ) # Record metrics @@ -241,8 +257,11 @@ def run(self): # Record metrics specific to MoMo methods if self.opt.state.get('step_size_list'): - score_dict['step_size_list'] = [float(np.format_float_scientific(t,5)) for t in self.opt.state['step_size_list']] + score_dict['step_size_list'] = [ + float(np.format_float_scientific(t,5)) for t in self.opt.state['step_size_list'] + ] self.opt.state['step_size_list'] = list() + print(score_dict['step_size_list']) # fstar estimator (could be zero) if self.opt.state.get('fstar', None) is not None: score_dict['fstar'] = self.opt.state['fstar'] @@ -314,15 +333,14 @@ def train_epoch(self): pbar.set_description(f'Training - loss={loss_val:.3f} - time data: last={timings_dataloader[-1]:.3f},(mean={np.mean(timings_dataloader):.3f}) - time model+step: last={timings_model[-1]:.3f}(mean={np.mean(timings_model):.3f})') # Log loss_val and grad_norm every k steps + total_step_counter = len(self.train_loader) * self._epochs_trained + step_counter if self.log_every_k_steps is not None: - total_step_counter = len(self.train_loader) * self._epochs_trained + step_counter if step_counter % self.log_every_k_steps == 0: self._log_stepwise["loss"][total_step_counter] = loss_val.item() self._log_stepwise["grad_norm"][total_step_counter] = grad_norm(self.model) if not self._step_scheduler_on_epoch: self._log_stepwise["lr"][total_step_counter] = self.sched.get_last_lr()[0] - if not self._step_scheduler_on_epoch: self.sched.step() @@ -374,7 +392,9 @@ def evaluate(self, dataset, metric_dict): timings_model.append(t0-t1) pbar.set_description(f'Validating {dataset.split}') - pbar.set_description(f'Validating {dataset.split} - time data: last={timings_dataloader[-1]:.3f}(mean={np.mean(timings_dataloader):.3f}) - time model: last={timings_model[-1]:.3f}(mean={np.mean(timings_model):.3f})') + pbar.set_description( + f'Validating {dataset.split} - time data: last={timings_dataloader[-1]:.3f}(mean={np.mean(timings_dataloader):.3f}) - time model: last={timings_model[-1]:.3f}(mean={np.mean(timings_model):.3f})' + ) for _met in metric_dict.keys(): @@ -388,11 +408,13 @@ def evaluate(self, dataset, metric_dict): def save_checkpoint(self, path): """See https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html""" - torch.save({'epoch': self._epochs_trained, - 'model_state_dict': self.model.state_dict(), - 'opt_state_dict': self.opt.state_dict(), - }, - path + self.name + '.mt') + torch.save({ + 'epoch': self._epochs_trained, + 'model_state_dict': self.model.state_dict(), + 'opt_state_dict': self.opt.state_dict(), + }, + path + self.name + '.mt' + ) return @@ -415,17 +437,19 @@ def _compute_opt_value(self): warnings.warn("Using bias and weight decay. Note that the implementation her will also penalize the bias.") if self.config['loss_func'] == 'squared': - opt_val = ridge_opt_value(X=self.train_set.dataset.tensors[0].detach().numpy(), - y=self.train_set.dataset.tensors[1].detach().numpy(), - lmbda = self.config['opt'].get('weight_decay', 0), - fit_intercept = fit_intercept - ) + opt_val = ridge_opt_value( + X=self.train_set.dataset.tensors[0].detach().numpy(), + y=self.train_set.dataset.tensors[1].detach().numpy(), + lmbda = self.config['opt'].get('weight_decay', 0), + fit_intercept = fit_intercept + ) elif self.config['loss_func'] == 'logistic': - opt_val = logreg_opt_value(X=self.train_set.dataset.tensors[0].detach().numpy(), - y=self.train_set.dataset.tensors[1].detach().numpy().astype(int).reshape(-1), - lmbda = self.config['opt'].get('weight_decay', 0), - fit_intercept = fit_intercept - ) + opt_val = logreg_opt_value( + X=self.train_set.dataset.tensors[0].detach().numpy(), + y=self.train_set.dataset.tensors[1].detach().numpy().astype(int).reshape(-1), + lmbda = self.config['opt'].get('weight_decay', 0), + fit_intercept = fit_intercept + ) else: opt_val = None else: diff --git a/stepback/models/llama.py b/stepback/models/llama.py index 27fbff6..ada5d39 100644 --- a/stepback/models/llama.py +++ b/stepback/models/llama.py @@ -1,4 +1,4 @@ -""" Adapted from Niccolo Ajroldi: /github.com/Niccolo-Ajroldi/plainLM +""" Adapted from Niccolo Ajroldi: github.com/Niccolo-Ajroldi/plainLM Changes: diff --git a/stepback/optim/main.py b/stepback/optim/main.py index 4d8defe..e039642 100644 --- a/stepback/optim/main.py +++ b/stepback/optim/main.py @@ -9,6 +9,7 @@ from .adabound import AdaBoundW from .adabelief import AdaBelief from .lion import Lion +from .ngn import NGN # only applicable to linear regression from .spp import SPP @@ -152,14 +153,22 @@ def get_optimizer(opt_config: dict) -> Tuple[torch.optim.Optimizer, dict]: hyperp = {'lr': opt_config.get('lr', 1e-3), 'weight_decay': opt_config.get('weight_decay', 0) } + + elif name == 'ngn': + opt_obj = NGN + hyperp = {'lr': opt_config.get('lr', 1e-3), + } + else: raise KeyError(f"Unknown optimizer name {name}.") return opt_obj, hyperp -def get_scheduler(config: dict, opt: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler: +def get_scheduler(config: dict, num_iter: int, opt: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler: """ Main function mapping to a learning rate scheduler. + + num_iter is either number of epochs or steps. """ # if not specified, use constant step sizes name = config.get('lr_schedule', 'constant') @@ -180,7 +189,26 @@ def get_scheduler(config: dict, opt: torch.optim.Optimizer) -> torch.optim.lr_sc #lr_fun = lambda t: warmup_lr + (1-warmup_lr)*t/warmup_steps if t < warmup_steps else (t-warmup_steps+1)**(-1/2) lr_fun = lambda t: (t+1)**(-1/2) scheduler = LambdaLR(opt, lr_lambda=lr_fun) + + elif name[:3] == 'wsd': + # default cooldown is 20%, otherwise specify e.g wsd_0.1 for 10% + if name == 'wsd': + cd = 0.2 + else: + cd = float(name.split('_')[1]) + cd_start = int((1 - cd) * num_iter) + + # this map is called with t = iter - warmup_steps + # but we want to fix the cooldown start independent of warmup + # so it reads a bit hacky + lr_fun = lambda t: ( + 1 - (t+warmup_steps-cd_start) / (num_iter-cd_start) + if t + warmup_steps >= cd_start + else 1.0 + ) + scheduler = LambdaLR(opt, lr_lambda=lr_fun) + elif 'exponential' in name: # use sth like 'exponential_60_0.5': decay by factor 0.5 every 60 epochs/steps step_size = int(name.split('_')[1]) diff --git a/stepback/optim/ngn.py b/stepback/optim/ngn.py new file mode 100644 index 0000000..5c026ab --- /dev/null +++ b/stepback/optim/ngn.py @@ -0,0 +1,97 @@ +""" +Implements the NGN algorithm by Orvieto and Xiao. + +Reference: https://arxiv.org/pdf/2407.04358 +""" +import torch +import warnings +from math import sqrt + +from ..types import Params, LossClosure, OptFloat + +class NGN(torch.optim.Optimizer): + def __init__(self, + params: Params, + lr: float=1e-1, + ) -> None: + """ + NGN optimizer + + Parameters + ---------- + params : Params + Model parameters. + lr : float, optional + Learning rate, by default 1e-1. + """ + + if lr < 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + + defaults = dict(lr=lr) + + super(NGN, self).__init__(params, defaults) + + # Initialization + self._number_steps = 0 + self.state['step_size_list'] = list() # for storing the adaptive step size term + + return + + def step(self, closure: LossClosure=None, loss: torch.Tensor=None) -> OptFloat: + """ + Performs a single optimization step. + + Parameters + ---------- + closure : LossClosure, optional + A callable that evaluates the model (possibly with backprop) and returns the loss, by default None. + + loss : torch.tensor, optional + The loss tensor. Use this when the backward step has already been performed. By default None. + + + Returns + ------- + (Stochastic) Loss function value. + """ + assert (closure is not None) or (loss is not None), "Either loss tensor or closure must be passed." + assert (closure is None) or (loss is None), "Pass either the loss tensor or the closure, not both." + + if closure is not None: + with torch.enable_grad(): + loss = closure() + + if len(self.param_groups) > 1: + warnings.warn("More than one param group. step_size_list contains adaptive term of last group.") + warnings.warn("More than one param group. This might cause issues for the step method.") + + self._number_steps += 1 + + # Update + grad_norm = self.compute_grad_norm() + for group in self.param_groups: + lr = group['lr'] + denom = 1 + lr / (2*loss) * (grad_norm**2) + gamma = (lr / denom).item() + + ### Update params + for p in group['params']: + p.data.add_(other=p.grad.data, alpha=-gamma) + + self.state['step_size_list'].append(gamma) + + return loss + + @torch.no_grad() + def compute_grad_norm(self): + grad_norm = 0. + for group in self.param_groups: + for p in group['params']: + assert p.grad is not None + + g = p.grad.data + grad_norm += torch.sum(torch.mul(g, g)) + + grad_norm = torch.sqrt(grad_norm) + return grad_norm