diff --git a/README.md b/README.md index 2e46203..0a79688 100644 --- a/README.md +++ b/README.md @@ -1,39 +1,126 @@ -# AutoClip: Adaptive Gradient Clipping - -This repository accompanies the [paper](https://arxiv.org/abs/2007.14469): +# AutoClip +Pytorch and tensorflow implementations (and variations) of the AutoClip gradient smoothing procedure from [Seetharaman et al](https://arxiv.org/abs/2007.14469). > Prem Seetharaman, Gordon Wichern, Bryan Pardo, Jonathan Le Roux. "AutoClip: Adaptive Gradient Clipping for Source Separation Networks." 2020 IEEE 30th International Workshop on Machine Learning for Signal Processing (MLSP). IEEE, 2020. -At the moment it contains a [sample implementation of AutoClip](autoclip.py) that can be integrated into an ML project based on PyTorch easily. -Soon it will come as a Python package that can be installed and attached to a training script more easily. +## About -## Abstract -> Clipping the gradient is a known approach to improving gradient descent, but requires hand selection of a clipping threshold hyperparameter. We present AutoClip, a simple method for automatically and adaptively choosing a gradient clipping threshold, based on the history of gradient norms observed during training. Experimental results show that applying AutoClip results in improved generalization performance for audio source separation networks. Observation of the training dynamics of a separation network trained with and without AutoClip show that AutoClip guides optimization into smoother parts of the loss landscape. AutoClip is very simple to implement and can be integrated readily into a variety of applications across multiple domains. +While training your model, AutoClip keeps a running history of all of your model's gradient magnitudes. Using these, the gradient clipper can adaptively clamp outlier gradient values before they reach the optimizer of your choice. -## Presentation +While AutoClip is great as a preventative measure against exploding gradients, it also speeds up training time, and encourages the optimizer to find more optimal models. At an intuitive level, AutoClip compensates for the stochastic nature of training over batches, regularizing training effects. -This work was presented at MLSP2020 in a special session. If you missed my talk, no worries, there's a pandemic happening so it's recorded! [Here it is](https://share.descript.com/view/18725e02-95fe-4fb0-b32d-26c63617d482). +## Installation -## Citation +AutoClip is listed on pypi. To install AutoClip simply run the following command ``` -@inproceedings{seetharaman2020autoclip, - title={AutoClip: Adaptive Gradient Clipping for Source Separation Networks}, - author={Seetharaman, Prem, and Wichern, Gordon, and Pardo, Bryan, and Le Roux, Jonathan}, - booktitle={2020 IEEE 30th International Workshop on Machine Learning for Signal Processing (MLSP)}, - year={2020}, - organization={IEEE} -} +pip install autoclip ``` +and the `autoclip` package will be installed in your currently active environment. + +## Torch API + +Below are some examples how to use `autoclip`'s torch API. + +### Clippers as Optimizer Wrappers +Using the optimizer wrapping pattern is the recommended way to use AutoClip, and `autoclip`'s torch API supports wrapping arbitrary pytorch optimizers. The wrapping pattern allows you to avoid changing your training code when you want to use an AutoClip clipper. This is especially useful if you do not own the training code for whatever reason. (Say for example you are using someone else's Trainer class, as is often the case with frameworks like `huggingface`.) +The following is an example of how to integrate AutoClip into your model training using this pattern: +```python +import torch +from autoclip.torch import QuantileClip -## Training dynamics +model = torch.nn.Sequential( + torch.nn.Linear(100, 50), + torch.nn.ReLU(), + torch.nn.Linear(50, 2) +) -### Mask-inference loss +optimizer = torch.optim.AdamW(model.parameters()) +optimizer = QuantileClip.as_optimizer(optimizer=optimizer, quantile=0.9, history_length=1000) +``` +Now you can use the optimizer just like you would have before adding the clipper, and the clipping will be applied automatically. -![](images/mi.gif) +### Raw AutoClip Clippers +You can still use the clipper manually if you would like. If this is the case, then you would create your clipper like this: +```python +import torch +from autoclip.torch import QuantileClip -### Whitened K-Means loss +model = torch.nn.Sequential( + torch.nn.Linear(100, 50), + torch.nn.ReLU(), + torch.nn.Linear(50, 2) +) -![](images/wkm.gif) +clipper = QuantileClip(model.parameters(), quantile=0.9, history_length=1000) +``` +Then, to clip the model's gradients, simply run the clipper's `.step()` function during your training loop. Note that you should call the clipper's `step` before you call your optimizer's `step`. Calling it after would mean that your clipping will have no effect, since the model will have already been updated using the unclipped gradients. For example: +```python +for batch_num, batch in enumerate(training_dataset): + model_prediction = model(batch['data']) + loss = loss_function(model_prediction, batch['targets']) + loss.backward() + clipper.step() # clipper comes before optimizer + optimizer.step() +``` -Training dynamics of a smaller mask inference network (2 BLSTM layers with 300 hidden units) with mask-inference loss and whitened k-means loss, with and without AutoClip. The top left figure shows the norm of the step size taken on the model parameters. The top right figure shows the training loss over time, showing that AutoClip leads to better optimization. The bottom figures show the relationship between gradient norm and a measure of smoothness along the training trajectory. Statistics were recorded every 20 iterations during training. With AutoClip, we observe a stronger correlation (r-value of .86), compared to without (r-value of .62). All gradients to the right of the dashed black line in the bottom right plot are clipped. We show the location of the AutoClip threshold at the end of training. The threshold changes during training. +### Global vs Local Clipping +`autoclip`'s torch clippers support two clipping modes. The first is `global_clipping`, which is the original AutoClip as described in Seetherman et al. The second is local or parameter-wise clipping. In this mode a history is kept for every parameter, and each is clipped according to its own history. By default, the `autoclip` clippers will use the parameter-wise clipping. +To use the global mode, simply pass the appropriate flag: +```python +clipper = QuantileClip( + model.parameters(), + quantile=0.9, + history_length=1000, + global_clipping=True +) +``` + +### Checkpointing +The torch clippers also support checkpointing through `state_dict()` and `load_state_dict()`, just like torch models and optimizers. For example, if you want to checkpoint a clipper to `clipper.pth`: +```python +clipper = QuantileClip(model.parameters()) +torch.save(clipper.state_dict(), 'clipper.pth') + +# Then later +clipper = QuantileClip(model.parameters()) +clipper.load_state_dict(torch.load('clipper.pth')) +``` +Keep in mind that just like a torch optimizer this will error if you give the clipper differently sized model parameters. + +While it is generally recommended to use `state_dict`s instead (see the [pytorch documentation](https://pytorch.org/tutorials/beginner/saving_loading_models.html#save-load-entire-model) on this subject for more info), you may also use `torch.save` and `torch.load` directly to pickle the entire clipper object. + +## Tensorflow +`autoclip`'s tensorflow API does not currently have feature parity with the torch API (If you want to change this, feel free to [contribute](https://github.com/HesitantlyHuman/autoclip/issues/2)). +As it is, the tensorflow API currently only supports the original AutoClip algorithm, and does not support checkpointing. Below is a short example: +```python +import tensorflow as tf +from autoclip.tf import QuantileClip + +model = tf.keras.models.Sequential( + [ + tf.keras.layers.Dense(50), + tf.keras.layers.ReLU(), + tf.keras.layers.Dense(10), + tf.keras.layers.ReLU(), + tf.keras.layers.Dense( + 2, + activation=tf.keras.activations.tanh + ), + ] +) +model.compile( + optimizer=tf.keras.optimizers.Adam( + learning_rate=0.001, + gradient_transformers=[ + QuantileClip( + quantile=0.9, + history_length=1000 + ) + ] + ), + loss="mean_absolute_error", + metrics=["accuracy"], +) +model.fit(train_data, train_targets) +``` diff --git a/autoclip.pdf b/autoclip.pdf deleted file mode 100644 index 8ae8651..0000000 Binary files a/autoclip.pdf and /dev/null differ diff --git a/autoclip.py b/autoclip.py deleted file mode 100644 index dd65df2..0000000 --- a/autoclip.py +++ /dev/null @@ -1,29 +0,0 @@ -import numpy as np -import torch -from ignite.engine import EventEnum - -def _get_grad_norm(model): - total_norm = 0 - for p in model.parameters(): - if p.grad is not None: - param_norm = p.grad.data.norm(2) - total_norm += param_norm.item() ** 2 - total_norm = total_norm ** (1. / 2) - return total_norm - -# written for pytorch ignite -# fire this on backwards pass -class BackwardsEvents(EventEnum): - BACKWARDS_COMPLETED = 'backwards_completed' - -def add_autoclip_gradient_handler(engine, model, clip_percentile): - # Keep track of the history of gradients and select a cutoff - # to clip values to based on percentile. - grad_history = [] - - @engine.on(BackwardsEvents.BACKWARDS_COMPLETED) - def autoclip_gradient(engine): - obs_grad_norm = _get_grad_norm(model) - grad_history.append(obs_grad_norm) - clip_value = np.percentile(grad_history, clip_percentile) - torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value) diff --git a/autoclip/__init__.py b/autoclip/__init__.py new file mode 100644 index 0000000..9666cc2 --- /dev/null +++ b/autoclip/__init__.py @@ -0,0 +1,3 @@ +import os + +__location__ = os.path.dirname(__file__) diff --git a/autoclip/tf/__init__.py b/autoclip/tf/__init__.py new file mode 100644 index 0000000..d5ce209 --- /dev/null +++ b/autoclip/tf/__init__.py @@ -0,0 +1 @@ +from autoclip.tf.quantile import QuantileClip diff --git a/autoclip/tf/quantile.py b/autoclip/tf/quantile.py new file mode 100644 index 0000000..9cdada3 --- /dev/null +++ b/autoclip/tf/quantile.py @@ -0,0 +1,31 @@ +import tensorflow as tf +import tensorflow_probability as tfp + + +class QuantileClip: + def __init__(self, quantile: float = 0.9, history_length: int = 1000): + self.quantile = quantile * 100 + self.grad_history = tf.Variable(tf.zeros(history_length), trainable=False) + self.i = tf.Variable(0, trainable=False) + self.history_size = history_length + + def __call__(self, grads_and_vars): + grad_norms = [self._get_grad_norm(g) for g, _ in grads_and_vars] + total_norm = tf.norm(grad_norms) + assign_idx = tf.math.mod(self.i, self.history_size) + self.grad_history = self.grad_history[assign_idx].assign(total_norm) + self.i = self.i.assign_add(1) + clip_value = tfp.stats.percentile(self.grad_history[: self.i], q=self.quantile) + return [(tf.clip_by_norm(g, clip_value), v) for g, v in grads_and_vars] + + def _get_grad_norm(self, t, axes=None, name=None): + values = tf.convert_to_tensor( + t.values if isinstance(t, tf.IndexedSlices) else t, name="t" + ) + + # Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm + l2sum = tf.math.reduce_sum(values * values, axes, keepdims=True) + pred = l2sum > 0 + # Two-tap tf.where trick to bypass NaN gradients + l2sum_safe = tf.where(pred, l2sum, tf.ones_like(l2sum)) + return tf.squeeze(tf.where(pred, tf.math.sqrt(l2sum_safe), l2sum)) diff --git a/autoclip/torch/__init__.py b/autoclip/torch/__init__.py new file mode 100644 index 0000000..37dea18 --- /dev/null +++ b/autoclip/torch/__init__.py @@ -0,0 +1,2 @@ +from autoclip.torch.quantile import QuantileClip +from autoclip.torch.std import StandardClip diff --git a/autoclip/torch/clipper.py b/autoclip/torch/clipper.py new file mode 100644 index 0000000..a97c561 --- /dev/null +++ b/autoclip/torch/clipper.py @@ -0,0 +1,256 @@ +from typing import Callable, Iterator, Iterable, Dict, Mapping, Union, Any, List + +import torch +from copy import deepcopy +from itertools import chain +from collections import defaultdict +from autoclip.torch.utils import deep_tensor_move + + +class Clipper: + """ + Modeled after torch.optim.Optimizer + """ + + def __init__( + self, + parameters: Iterator[torch.nn.parameter.Parameter], + defaults: Dict[str, Any], + ) -> None: + self.parameter_groups: List[Dict[str, Any]] = [] + self.verify_parameter_settings(settings=defaults) + self.defaults = defaults + self.state = defaultdict(dict) + + if not isinstance(parameters, (Iterator, Iterable)): + raise TypeError( + "parameters argument given to the clipper should be " + "an iterable of Tensors or dicts, but instead got " + + torch.typename(parameters) + ) + + parameter_groups = list(parameters) + if len(parameter_groups) == 0: + raise ValueError( + f"Clipper {type(self).__name__} got an empty parameter list" + ) + if not isinstance(parameter_groups[0], dict): + parameter_groups = [{"params": parameter_groups}] + + for parameter_group in parameter_groups: + self.add_param_group(parameter_group=parameter_group) + + @classmethod + def as_optimizer( + cls: "Clipper", + optimizer: torch.optim.Optimizer, + **kwargs, + ) -> "OptimizerWithClipping": + parameters = chain.from_iterable( + [parameter_group["params"] for parameter_group in optimizer.param_groups] + ) + clipper = cls(parameters=parameters, **kwargs) + return OptimizerWithClipping(optimizer=optimizer, clipper=clipper) + + def verify_parameter_settings(self, settings: Dict[str, Any]) -> None: + raise NotImplementedError + + def step(self) -> None: + raise NotImplementedError + + def state_dict(self) -> Dict[str, Any]: + packed_parameter_groups = [] + for parameter_group in self.parameter_groups: + packed_parameter_group = { + key: value for key, value in parameter_group.items() if key != "params" + } + packed_parameter_group["params"] = [ + id(parameter) for parameter in parameter_group["params"] + ] + packed_parameter_groups.append(packed_parameter_group) + + packed_state = { + (id(k) if isinstance(k, torch.Tensor) else k): v + for k, v in self.state.items() + } + + return { + "state": packed_state, + "param_groups": packed_parameter_groups, + } + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + loaded_state_dict = deepcopy(state_dict) + local_groups, saved_groups = ( + self.parameter_groups, + loaded_state_dict["param_groups"], + ) + + if len(local_groups) != len(saved_groups): + raise ValueError( + f"Loaded state dict has {len(saved_groups)} parameter " + f"groups, Clipper {type(self).__name__} has " + f"{len(local_groups)} parameter groups" + ) + local_lens = (len(g["params"]) for g in local_groups) + saved_lens = (len(g["params"]) for g in saved_groups) + if any(p_len != s_len for p_len, s_len in zip(local_lens, saved_lens)): + raise ValueError( + "Loaded state dict contains a parameter group " + "that doesn't match the size of Clipper " + f"{type(self).__name__}'s group" + ) + + saved_id_to_parameter = { + saved_id: parameter + for saved_id, parameter in zip( + chain.from_iterable([group["params"] for group in saved_groups]), + chain.from_iterable([group["params"] for group in local_groups]), + ) + } + + state = defaultdict(dict) + for key, value in loaded_state_dict["state"].items(): + if key in saved_id_to_parameter: + parameter = saved_id_to_parameter[key] + state[parameter] = deep_tensor_move(value, parameter.device) + else: + state[key] = value + + new_parameter_groups = [] + for local_group, saved_group in zip(local_groups, saved_groups): + saved_group["params"] = local_group["params"] + new_parameter_groups.append(saved_group) + + self.state = state + self.parameter_groups = new_parameter_groups + + def add_param_group( + self, + parameter_group: Dict[str, Union[torch.Tensor, List[torch.Tensor]]], + **kwargs, + ) -> None: + """Add a param_group to the :class:`Optimizer` s `param_groups`. + + This can be useful when fine tuning a pre-trained network as frozen layers can be made + trainable and added to the :class:`Optimizer` as training progresses. + + Args: + param_group (dict): Specifies what Tensors should be optimized along with group + specific optimization options. + """ + if not isinstance(parameter_group, Mapping): + parameter_group = {"params": parameter_group} + + parameters = parameter_group["params"] + if isinstance(parameters, torch.Tensor): + parameter_group["params"] = [parameters] + elif isinstance(parameters, set): + raise TypeError( + "Clipping parameters must be ordered collections. " + "The ordering of tensors in sets will change between runs." + "Please use a list instead." + ) + else: + parameter_group["params"] = list(parameters) + + for parameter in parameter_group["params"]: + if not isinstance(parameter, torch.Tensor): + raise TypeError( + f"Clipper {type(self).__name__} can only clip Tensors, " + f"but one of the params is {torch.typename(parameter)}" + ) + if not parameter.is_leaf: + raise ValueError( + "Gradients to clip will only accumulate on leaf Tensors. " + f"{type(self).__name__} recieved non-leaf Tensor." + ) + + for name, default in self.defaults.items(): + parameter_group.setdefault(name, default) + parameter_group.update(kwargs) + self.verify_parameter_settings(parameter_group) + + parameters = parameter_group["params"] + if len(parameters) != len(set(parameters)): + raise ValueError( + "Clipper contains a parameter group with duplicate parameters." + ) + + parameter_set = set() + for group in self.parameter_groups: + parameter_set.update(set(group["params"])) + + if not parameter_set.isdisjoint(set(parameter_group["params"])): + raise ValueError( + "Some clipping parameters appear in more than one parameter group" + ) + + self.parameter_groups.append(parameter_group) + + def __repr__(self): + format_string = self.__class__.__name__ + " (" + for i, group in enumerate(self.parameter_groups): + format_string += "\n" + format_string += "Parameter Group {0}\n".format(i) + for key in sorted(group.keys()): + if key != "params": + format_string += " {0}: {1}\n".format(key, group[key]) + format_string += ")" + return format_string + + +class OptimizerWithClipping(torch.optim.Optimizer): + def __init__(self, optimizer: torch.optim.Optimizer, clipper: Clipper) -> None: + self.optimizer = optimizer + self.clipper = clipper + + def step( + self, closure: Union[Callable[[], float], None] = None + ) -> Union[float, None]: + self.clipper.step() + return self.optimizer.step(closure=closure) + + def add_param_group( + self, param_group: Dict[str, Union[torch.Tensor, List[torch.Tensor]]], **kwargs + ) -> None: + self.optimizer.add_param_group(param_group=param_group) + self.clipper.add_param_group(parameter_group=param_group, **kwargs) + + def zero_grad(self, set_to_none: bool = False) -> None: + self.optimizer.zero_grad(set_to_none=set_to_none) + + def __getstate__(self): + return { + "optimizer": self.optimizer, + "clipper": self.clipper, + } + + def __setstate__(self, state): + self.__dict__.update(state) + self.optimizer._hook_for_profile() + + def state_dict(self) -> Dict[str, Any]: + """Returns the state dict of the optimizer and clipper as a :class:`dict`. + + It contains two entries: + + * optimizer - a dict holding the optimizer state dict. It will conain both + state and param_groups, as described in the :class:`torch.optim.Optimizer` docs. + * clipper - a dict holding the clipper state dict. It will contain its ow + state and param_groups, as described in the :class:`autoclip.torch.Clipper` docs. + """ + return { + "optimizer": self.optimizer.state_dict(), + "clipper": self.clipper.state_dict(), + } + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + self.optimizer.load_state_dict(state_dict=state_dict["optimizer"]) + self.clipper.load_state_dict(state_dict=state_dict["clipper"]) + + def __repr__(self) -> str: + return f"OptimizerWithClipping (\n{self.optimizer}\n{self.clipper})" + + def __getattr__(self, attr): + return getattr(self.optimizer, attr) diff --git a/autoclip/torch/quantile.py b/autoclip/torch/quantile.py new file mode 100644 index 0000000..0063d7e --- /dev/null +++ b/autoclip/torch/quantile.py @@ -0,0 +1,111 @@ +from typing import Iterator, List, Dict, Union, Any +import torch + +from autoclip.torch.clipper import Clipper, OptimizerWithClipping + + +class QuantileClip(Clipper): + def __init__( + self, + parameters: Iterator[torch.nn.parameter.Parameter], + quantile: float = 0.9, + history_length: int = 1000, + global_threshold: bool = False, + ) -> None: + self.global_threshold = global_threshold + self.global_quantile = None + self.global_history_length = None + if self.global_threshold: + self.global_quantile = quantile + self.global_history_length = history_length + + super().__init__( + parameters, + {"quantile": quantile, "history_length": history_length}, + ) + + @classmethod + def as_optimizer( + cls: "QuantileClip", + optimizer: torch.optim.Optimizer, + quantile: float = 0.9, + history_length: int = 1000, + global_threshold: bool = False, + ) -> "OptimizerWithClipping": + return super().as_optimizer( + optimizer, + quantile=quantile, + history_length=history_length, + global_threshold=global_threshold, + ) + + def verify_parameter_settings(self, settings: Dict[str, Any]) -> None: + quantile = settings["quantile"] + history_length = settings["history_length"] + if not isinstance(quantile, (float, torch.Tensor)): + raise TypeError("QuantileClip quantile value must be a float or a tensor.") + if not isinstance(history_length, int): + raise TypeError("QuantileClip history_length must be an int.") + if quantile < 0.0 or quantile > 1.0: + raise ValueError("QuantileClip quantile value must be between 0.0 and 1.0.") + if history_length <= 0: + raise ValueError("QuantileClip history length must be greater than zero.") + + def step(self) -> None: + if self.global_threshold: + self._clip_global() + else: + self._clip_local() + + def _clip_local(self): + for parameter_group in self.parameter_groups: + group_quantile = parameter_group["quantile"] + group_history_length = parameter_group["history_length"] + + for parameter in parameter_group["params"]: + if parameter.grad is None: + continue + + state = self.state[parameter] + if len(state) == 0: + state["history"] = torch.Tensor([]).to(parameter.device) + threshold = torch.inf + else: + threshold = torch.quantile(state["history"], group_quantile) + new_grad_norm = torch.nn.utils.clip_grad_norm_( + parameter, max_norm=threshold + ) + state["history"] = torch.hstack((state["history"], new_grad_norm))[ + -group_history_length: + ] + + def _clip_global(self): + parameters = [] + for parameter_group in self.parameter_groups: + parameters = parameters + parameter_group["params"] + + if len(self.state["global_history"]) == 0: + # Assumes all parameters are on the same device + self.state["global_history"] = torch.Tensor([]).to(parameters[0].device) + threshold = torch.inf + else: + threshold = torch.quantile( + self.state["global_history"], self.global_quantile + ) + new_grad_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm=threshold) + self.state["global_history"] = torch.hstack( + (self.state["global_history"], new_grad_norm) + )[-self.global_history_length :] + + def add_param_group( + self, + parameter_group: Dict[str, Union[torch.Tensor, List[torch.Tensor]]], + quantile: float = None, + history_length: int = None, + ) -> None: + parameter_group_args = {} + if quantile is not None: + parameter_group_args["quantile"] = quantile + if history_length is not None: + parameter_group_args["history_length"] = history_length + return super().add_param_group(parameter_group, **parameter_group_args) diff --git a/autoclip/torch/std.py b/autoclip/torch/std.py new file mode 100644 index 0000000..3e04c53 --- /dev/null +++ b/autoclip/torch/std.py @@ -0,0 +1,115 @@ +from typing import Iterator, List, Dict, Union, Any +import torch + +from autoclip.torch.clipper import Clipper, OptimizerWithClipping + + +class StandardClip(Clipper): + def __init__( + self, + parameters: Iterator[torch.nn.parameter.Parameter], + deviations: float = 2.0, + history_length: int = 1000, + global_threshold: bool = False, + ) -> None: + self.global_threshold = global_threshold + self.global_deviations = None + self.global_history_length = None + if self.global_threshold: + self.global_deviations = deviations + self.global_history_length = history_length + + super().__init__( + parameters, + {"deviations": deviations, "history_length": history_length}, + ) + + @classmethod + def as_optimizer( + cls: "StandardClip", + optimizer: torch.optim.Optimizer, + deviations: float = 2.0, + history_length: int = 1000, + global_threshold: bool = False, + ) -> "OptimizerWithClipping": + return super().as_optimizer( + optimizer, + deviations=deviations, + history_length=history_length, + global_threshold=global_threshold, + ) + + def verify_parameter_settings(self, settings: Dict[str, Any]) -> None: + quantile = settings["deviations"] + history_length = settings["history_length"] + if not isinstance(quantile, (int, float, torch.Tensor)): + raise TypeError( + "StandardClip deviations value must be an int, float or a tensor." + ) + if not isinstance(history_length, int): + raise TypeError("StandardClip history_length must be an int.") + if quantile < 0.0: + raise ValueError( + "StandardClip deviations value must be greater than or equal to 0." + ) + if history_length <= 0: + raise ValueError("StandardClip history length must be greater than zero.") + + def step(self) -> None: + if self.global_threshold: + self._clip_global() + else: + self._clip_local() + + def _clip_local(self): + for parameter_group in self.parameter_groups: + group_deviations = parameter_group["deviations"] + group_history_length = parameter_group["history_length"] + + for parameter in parameter_group["params"]: + if parameter.grad is None: + continue + + state = self.state[parameter] + if len(state) == 0: + state["history"] = torch.Tensor([]).to(parameter.device) + threshold = torch.inf + else: + std = torch.std(state["history"]) + threshold = std * group_deviations + new_grad_norm = torch.nn.utils.clip_grad_norm_( + parameter, max_norm=threshold + ) + state["history"] = torch.hstack((state["history"], new_grad_norm))[ + -group_history_length: + ] + + def _clip_global(self): + parameters = [] + for parameter_group in self.parameter_groups: + parameters = parameters + parameter_group["params"] + + if len(self.state["global_history"]) == 0: + # Assumes all parameters are on the same device + self.state["global_history"] = torch.Tensor([]).to(parameters[0].device) + threshold = torch.inf + else: + std = torch.std(self.state["global_history"]) + threshold = std * self.global_deviations + new_grad_norm = torch.nn.utils.clip_grad_norm_(parameters, max_norm=threshold) + self.state["global_history"] = torch.hstack( + (self.state["global_history"], new_grad_norm) + )[-self.global_history_length :] + + def add_param_group( + self, + parameter_group: Dict[str, Union[torch.Tensor, List[torch.Tensor]]], + deviations: float = None, + history_length: int = None, + ) -> None: + parameter_group_args = {} + if deviations is not None: + parameter_group_args["deviations"] = deviations + if history_length is not None: + parameter_group_args["history_length"] = history_length + return super().add_param_group(parameter_group, **parameter_group_args) diff --git a/autoclip/torch/utils.py b/autoclip/torch/utils.py new file mode 100644 index 0000000..79c1e3c --- /dev/null +++ b/autoclip/torch/utils.py @@ -0,0 +1,26 @@ +from typing import Union, Mapping, Collection +import torch + + +def deep_tensor_move( + tensors: Union[Mapping, Collection, torch.Tensor], device: Union[torch.device, str] +) -> Union[Mapping, Collection, torch.Tensor]: + """ + Extracted from torch.optim.Optimizer's load_state_dict method. + """ + if isinstance(tensors, torch.Tensor): + tensors = tensors.to(device=device) + return tensors + elif isinstance(tensors, Mapping): + return type(tensors)( + { + key: deep_tensor_move(tensors=value, device=device) + for key, value in tensors.items() + } + ) + elif isinstance(tensors, Collection) and not isinstance(tensors, str): + return type(tensors)( + [deep_tensor_move(tensors=value, device=device) for value in tensors] + ) + else: + return tensors diff --git a/autoclip_tf.py b/autoclip_tf.py deleted file mode 100644 index 264ff8e..0000000 --- a/autoclip_tf.py +++ /dev/null @@ -1,55 +0,0 @@ -import tensorflow as tf -import tensorflow_probability as tfp - - -class AutoClipper: - def __init__(self, clip_percentile, history_size=10000): - self.clip_percentile = clip_percentile - self.grad_history = tf.Variable(tf.zeros(history_size), trainable=False) - self.i = tf.Variable(0, trainable=False) - self.history_size = history_size - - def __call__(self, grads_and_vars): - grad_norms = [self._get_grad_norm(g) for g, _ in grads_and_vars] - total_norm = tf.norm(grad_norms) - assign_idx = tf.math.mod(self.i, self.history_size) - self.grad_history = self.grad_history[assign_idx].assign(total_norm) - self.i = self.i.assign_add(1) - clip_value = tfp.stats.percentile(self.grad_history[: self.i], q=self.clip_percentile) - return [(tf.clip_by_norm(g, clip_value), v) for g, v in grads_and_vars] - - def _get_grad_norm(self, t, axes=None, name=None): - values = tf.convert_to_tensor(t.values if isinstance(t, tf.IndexedSlices) else t, name="t") - - # Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm - l2sum = tf.math.reduce_sum(values * values, axes, keepdims=True) - pred = l2sum > 0 - # Two-tap tf.where trick to bypass NaN gradients - l2sum_safe = tf.where(pred, l2sum, tf.ones_like(l2sum)) - return tf.squeeze(tf.where(pred, tf.math.sqrt(l2sum_safe), l2sum)) - - -if __name__ == "__main__": - # Example usage - model = tf.keras.models.Sequential( - [ - tf.keras.layers.Flatten(input_shape=(28, 28)), - tf.keras.layers.Dense(128, activation="relu"), - tf.keras.layers.Dropout(0.2), - tf.keras.layers.Dense(10), - ] - ) - - model.compile( - optimizer=tf.keras.optimizers.Adam( - learning_rate=0.001, gradient_transformers=[AutoClipper(10)] - ), - loss="mean_absolute_error", - metrics=["accuracy"], - ) - - (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() - x_train, x_test = x_train / 255.0, x_test / 255.0 - - model.fit(x_train, y_train) - diff --git a/examples/mnist.py b/examples/mnist.py new file mode 100644 index 0000000..998a949 --- /dev/null +++ b/examples/mnist.py @@ -0,0 +1,157 @@ +""" +Adapted from the pytorch mnist example found at https://github.com/pytorch/examples/blob/main/mnist/main.py +""" + +from __future__ import print_function +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +import autoclip +from autoclip.torch import QuantileClip + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + + +def train( + model: nn.Module, + train_loader: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler, + epoch: int, + device: torch.device = torch.device("cuda"), + log_interval: int = 10, +): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + lr_scheduler.step() + if batch_idx % log_interval == 0: + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(data), + len(train_loader.dataset), + 100.0 * batch_idx / len(train_loader), + loss.item(), + ) + ) + + +def test( + model: nn.Module, + test_loader: torch.utils.data.DataLoader, + device: torch.device = torch.device("cuda"), +): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss( + output, target, reduction="sum" + ).item() # sum up batch loss + pred = output.argmax( + dim=1, keepdim=True + ) # get the index of the max log-probability + correct += pred.eq(target.view_as(pred)).sum().item() + + test_loss /= len(test_loader.dataset) + + print( + "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( + test_loss, + correct, + len(test_loader.dataset), + 100.0 * correct / len(test_loader.dataset), + ) + ) + + +def main(): + config = {"num_epochs": 14, "max_learning_rate": 1e-3, "weight_decay": 0.05} + train_kwargs = {"batch_size": 64} + test_kwargs = {"batch_size": 1000} + if torch.cuda.is_available(): + device = torch.device("cuda") + cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True} + train_kwargs.update(cuda_kwargs) + test_kwargs.update(cuda_kwargs) + else: + device = torch.device("cpu") + + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] + ) + dataset1 = datasets.MNIST("../data", train=True, download=True, transform=transform) + dataset2 = datasets.MNIST("../data", train=False, transform=transform) + train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs) + test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs) + + model = Net().to(device) + optimizer = optim.AdamW( + model.parameters(), + lr=config["max_learning_rate"], + weight_decay=config["weight_decay"], + ) + optimizer = QuantileClip.as_optimizer( + optimizer=optimizer, + quantile=0.8, + history_length=1000, + ) + scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer, + max_lr=config["max_learning_rate"], + steps_per_epoch=len(train_loader), + epochs=config["num_epochs"], + ) + + for epoch in range(1, config["num_epochs"] + 1): + train( + model=model, + train_loader=train_loader, + optimizer=optimizer, + lr_scheduler=scheduler, + epoch=epoch, + device=device, + log_interval=10, + ) + test(model=model, test_loader=test_loader, device=device) + + torch.save(model.state_dict(), "mnist_cnn.pth") + + +if __name__ == "__main__": + main() diff --git a/images/mi.gif b/images/mi.gif deleted file mode 100644 index d2dc896..0000000 Binary files a/images/mi.gif and /dev/null differ diff --git a/images/wkm.gif b/images/wkm.gif deleted file mode 100644 index 5e24039..0000000 Binary files a/images/wkm.gif and /dev/null differ diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..85d03f5 --- /dev/null +++ b/setup.py @@ -0,0 +1,19 @@ +import os +from setuptools import setup, find_packages + +PACKAGE_ROOT = os.path.dirname(os.path.realpath(__file__)) +README_FILE = open(os.path.join(PACKAGE_ROOT, "README.md"), "r").read() + +if __name__ == "__main__": + setup( + name="autoclip", + version="0.2.1", + description="Smart gradient clippers", + long_description=README_FILE, + long_description_content_type="text/markdown", + url="https://github.com/HesitantlyHuman/autoclip", + author="HesitantlyHuman", + author_email="tannersims@hesitantlyhuman.com", + license="MIT", + packages=find_packages(), + ) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..007fa40 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,4 @@ +import sys +import os + +sys.path.append(os.path.join(os.path.dirname(__file__), "utilities")) diff --git a/tests/torch/test_quantile.py b/tests/torch/test_quantile.py new file mode 100644 index 0000000..d6d5748 --- /dev/null +++ b/tests/torch/test_quantile.py @@ -0,0 +1,203 @@ +import torch +import pytest +from autoclip.torch.quantile import QuantileClip +from utilities.torch import ( + run_clipper_initialization_error_tests, + run_clipping_test, + run_clipping_test_wrapper, + run_add_param_group, +) + + +def test_create_clipper(): + example_model = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + clipper = QuantileClip(example_model.parameters()) + clipper = QuantileClip(example_model.parameters(), quantile=0.5, history_length=500) + clipper = QuantileClip(example_model.parameters(), quantile=0.0) + clipper = QuantileClip(example_model.parameters(), quantile=1.0) + + +def test_clipper_parameters_not_parameters(): + with pytest.raises(TypeError): + clipper = QuantileClip(1.0) + + with pytest.raises(TypeError): + clipper = QuantileClip(torch.nn.Linear(10, 10)) + + +def test_bad_history_values(): + run_clipper_initialization_error_tests( + QuantileClip, + "history_length", + [-1.0, 5.0, 0, -100], + [TypeError, TypeError, ValueError, ValueError], + ) + + +def test_bad_quantile_values(): + run_clipper_initialization_error_tests( + QuantileClip, "quantile", [-1.0, 5.0], [ValueError, ValueError] + ) + + +def test_create_optimizer_wrapper(): + example_model = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + optimizer = torch.optim.AdamW(example_model.parameters()) + clipper = QuantileClip.as_optimizer(optimizer=optimizer) + clipper = QuantileClip.as_optimizer( + optimizer=optimizer, quantile=0.5, history_length=500 + ) + clipper = QuantileClip.as_optimizer(optimizer=optimizer, quantile=0.0) + clipper = QuantileClip.as_optimizer(optimizer=optimizer, quantile=1.0) + + optimizer = torch.optim.LBFGS(example_model.parameters()) + clipper = QuantileClip.as_optimizer(optimizer=optimizer) + clipper = QuantileClip.as_optimizer( + optimizer=optimizer, quantile=0.5, history_length=500 + ) + clipper = QuantileClip.as_optimizer(optimizer=optimizer, quantile=0.0) + clipper = QuantileClip.as_optimizer(optimizer=optimizer, quantile=1.0) + + +def test_clip_local(): + run_clipping_test(QuantileClip, {}) + run_clipping_test(QuantileClip, {"quantile": 0.5}) + run_clipping_test(QuantileClip, {"history_length": 500}) + run_clipping_test(QuantileClip, {"quantile": 0.5, "history_length": 500}) + + +def test_clip_global(): + run_clipping_test(QuantileClip, {"global_threshold": True}) + run_clipping_test(QuantileClip, {"quantile": 0.5, "global_threshold": True}) + run_clipping_test(QuantileClip, {"history_length": 500, "global_threshold": True}) + run_clipping_test( + QuantileClip, {"quantile": 0.5, "history_length": 500, "global_threshold": True} + ) + + +def test_clip_wrapper(): + run_clipping_test_wrapper(QuantileClip, {}) + run_clipping_test_wrapper(QuantileClip, {"quantile": 0.5}) + run_clipping_test_wrapper(QuantileClip, {"history_length": 500}) + run_clipping_test_wrapper(QuantileClip, {"quantile": 0.5, "history_length": 500}) + + +def test_add_param_group(): + run_add_param_group(QuantileClip, {}) + run_add_param_group(QuantileClip, {"quantile": 0.5}) + run_add_param_group(QuantileClip, {"history_length": 500}) + run_add_param_group(QuantileClip, {"quantile": 0.5, "history_length": 500}) + + +def test_save_state_dict(): + example_model = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + clipper = QuantileClip(example_model.parameters()) + state_dict = clipper.state_dict() + clipper = QuantileClip(example_model.parameters(), global_threshold=True) + state_dict = clipper.state_dict() + + +def test_save_state_dict_wrapper(): + example_model = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + optimizer = torch.optim.AdamW(example_model.parameters()) + clipper = QuantileClip.as_optimizer(optimizer=optimizer) + state_dict = clipper.state_dict() + + +def test_load_state_dict(): + example_model = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + clipper = QuantileClip(example_model.parameters()) + state_dict = clipper.state_dict() + example_model = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + clipper = QuantileClip(example_model.parameters()) + clipper.load_state_dict(state_dict=state_dict) + + +def test_load_state_dict_wrapper(): + example_model = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + optimizer = torch.optim.AdamW(example_model.parameters()) + clipper = QuantileClip.as_optimizer(optimizer=optimizer) + state_dict = clipper.state_dict() + example_model = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + optimizer = torch.optim.AdamW(example_model.parameters()) + clipper = QuantileClip.as_optimizer(optimizer=optimizer) + clipper.load_state_dict(state_dict=state_dict) + + +def test_clip_after_state_dict_load(): + example_model = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + clipper = QuantileClip(example_model.parameters()) + state_dict = clipper.state_dict() + example_model = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + clipper = QuantileClip(example_model.parameters()) + clipper.load_state_dict(state_dict=state_dict) + loss_fn = torch.nn.MSELoss() + prediction = example_model(torch.rand((10))) + target = torch.Tensor([10.0]) + loss = loss_fn(prediction, target) + loss.backward() + clipper.step() + + +def test_pickle_optimizer_wrapper(): + import io + + example_model = torch.nn.Linear(10, 1) + optimizer = torch.optim.AdamW(example_model.parameters()) + clipper = QuantileClip.as_optimizer(optimizer=optimizer) + buffer = io.BytesIO() + torch.save(clipper, buffer) + buffer.seek(0) + clipper = torch.load(buffer) + clipper.optimizer + + +def test_pickle_clipper(): + import io + + example_model = torch.nn.Linear(10, 1) + clipper = QuantileClip(example_model.parameters()) + buffer = io.BytesIO() + torch.save(clipper, buffer) + buffer.seek(0) + clipper = torch.load(buffer) diff --git a/tests/torch/test_std.py b/tests/torch/test_std.py new file mode 100644 index 0000000..d53491a --- /dev/null +++ b/tests/torch/test_std.py @@ -0,0 +1,206 @@ +import torch +import pytest +from autoclip.torch import StandardClip +from utilities.torch import ( + run_clipper_initialization_error_tests, + run_clipping_test, + run_clipping_test_wrapper, + run_add_param_group, +) + + +def test_create_clipper(): + example_model = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + clipper = StandardClip(example_model.parameters()) + clipper = StandardClip( + example_model.parameters(), deviations=0.5, history_length=500 + ) + clipper = StandardClip(example_model.parameters(), deviations=0.0) + clipper = StandardClip(example_model.parameters(), deviations=1.0) + + +def test_clipper_parameters_not_parameters(): + with pytest.raises(TypeError): + clipper = StandardClip(1.0) + + with pytest.raises(TypeError): + clipper = StandardClip(torch.nn.Linear(10, 10)) + + +def test_bad_history_values(): + run_clipper_initialization_error_tests( + StandardClip, + "history_length", + [-1.0, 5.0, 0, -100], + [TypeError, TypeError, ValueError, ValueError], + ) + + +def test_bad_deviations_values(): + run_clipper_initialization_error_tests( + StandardClip, "deviations", [-1.0, -0.5], [ValueError, ValueError] + ) + + +def test_create_optimizer_wrapper(): + example_model = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + optimizer = torch.optim.AdamW(example_model.parameters()) + clipper = StandardClip.as_optimizer(optimizer=optimizer) + clipper = StandardClip.as_optimizer( + optimizer=optimizer, deviations=0.5, history_length=500 + ) + clipper = StandardClip.as_optimizer(optimizer=optimizer, deviations=0.0) + clipper = StandardClip.as_optimizer(optimizer=optimizer, deviations=1.0) + + optimizer = torch.optim.LBFGS(example_model.parameters()) + clipper = StandardClip.as_optimizer(optimizer=optimizer) + clipper = StandardClip.as_optimizer( + optimizer=optimizer, deviations=0.5, history_length=500 + ) + clipper = StandardClip.as_optimizer(optimizer=optimizer, deviations=0.0) + clipper = StandardClip.as_optimizer(optimizer=optimizer, deviations=1.0) + + +def test_clip_local(): + run_clipping_test(StandardClip, {}) + run_clipping_test(StandardClip, {"deviations": 0.5}) + run_clipping_test(StandardClip, {"history_length": 500}) + run_clipping_test(StandardClip, {"deviations": 0.5, "history_length": 500}) + + +def test_clip_global(): + run_clipping_test(StandardClip, {"global_threshold": True}) + run_clipping_test(StandardClip, {"deviations": 0.5, "global_threshold": True}) + run_clipping_test(StandardClip, {"history_length": 500, "global_threshold": True}) + run_clipping_test( + StandardClip, + {"deviations": 0.5, "history_length": 500, "global_threshold": True}, + ) + + +def test_clip_wrapper(): + run_clipping_test_wrapper(StandardClip, {}) + run_clipping_test_wrapper(StandardClip, {"deviations": 0.5}) + run_clipping_test_wrapper(StandardClip, {"history_length": 500}) + run_clipping_test_wrapper(StandardClip, {"deviations": 0.5, "history_length": 500}) + + +def test_add_param_group(): + run_add_param_group(StandardClip, {}) + run_add_param_group(StandardClip, {"deviations": 0.5}) + run_add_param_group(StandardClip, {"history_length": 500}) + run_add_param_group(StandardClip, {"deviations": 0.5, "history_length": 500}) + + +def test_save_state_dict(): + example_model = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + clipper = StandardClip(example_model.parameters()) + state_dict = clipper.state_dict() + clipper = StandardClip(example_model.parameters(), global_threshold=True) + state_dict = clipper.state_dict() + + +def test_save_state_dict_wrapper(): + example_model = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + optimizer = torch.optim.AdamW(example_model.parameters()) + clipper = StandardClip.as_optimizer(optimizer=optimizer) + state_dict = clipper.state_dict() + + +def test_load_state_dict(): + example_model = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + clipper = StandardClip(example_model.parameters()) + state_dict = clipper.state_dict() + example_model = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + clipper = StandardClip(example_model.parameters()) + clipper.load_state_dict(state_dict=state_dict) + + +def test_load_state_dict_wrapper(): + example_model = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + optimizer = torch.optim.AdamW(example_model.parameters()) + clipper = StandardClip.as_optimizer(optimizer=optimizer) + state_dict = clipper.state_dict() + example_model = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + optimizer = torch.optim.AdamW(example_model.parameters()) + clipper = StandardClip.as_optimizer(optimizer=optimizer) + clipper.load_state_dict(state_dict=state_dict) + + +def test_clip_after_state_dict_load(): + example_model = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + clipper = StandardClip(example_model.parameters()) + state_dict = clipper.state_dict() + example_model = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + clipper = StandardClip(example_model.parameters()) + clipper.load_state_dict(state_dict=state_dict) + loss_fn = torch.nn.MSELoss() + prediction = example_model(torch.rand((10))) + target = torch.Tensor([10.0]) + loss = loss_fn(prediction, target) + loss.backward() + clipper.step() + + +def test_pickle_optimizer_wrapper(): + import io + + example_model = torch.nn.Linear(10, 1) + optimizer = torch.optim.AdamW(example_model.parameters()) + clipper = StandardClip.as_optimizer(optimizer=optimizer) + buffer = io.BytesIO() + torch.save(clipper, buffer) + buffer.seek(0) + clipper = torch.load(buffer) + clipper.optimizer + + +def test_pickle_clipper(): + import io + + example_model = torch.nn.Linear(10, 1) + clipper = StandardClip(example_model.parameters()) + buffer = io.BytesIO() + torch.save(clipper, buffer) + buffer.seek(0) + clipper = torch.load(buffer) diff --git a/tests/torch/test_utils.py b/tests/torch/test_utils.py new file mode 100644 index 0000000..abff837 --- /dev/null +++ b/tests/torch/test_utils.py @@ -0,0 +1,44 @@ +import torch + +from autoclip.torch.utils import deep_tensor_move + + +def test_deep_tensor_move_dicts(): + structure = { + "some_value": torch.rand((10, 10)), + "another_value": {"some_nested_thing": torch.rand(5)}, + } + deep_tensor_move(structure, "cpu") + deep_tensor_move(structure, torch.device("cpu")) + + +def test_deep_tensor_move_lists(): + structure = [torch.rand((6, 12, 2)), [torch.rand(5), torch.rand(10, 4)]] + deep_tensor_move(structure, "cpu") + deep_tensor_move(structure, torch.device("cpu")) + + +def test_deep_tensor_move_tuples(): + structure = (torch.rand((6, 12, 2)), (torch.rand(5), torch.rand(10, 4))) + deep_tensor_move(structure, "cpu") + deep_tensor_move(structure, torch.device("cpu")) + + +def test_deep_tensor_move_tensors(): + structure = torch.rand((1, 2, 1, 6, 3)) + deep_tensor_move(structure, "cpu") + deep_tensor_move(structure, torch.device("cpu")) + + +def test_deep_tensor_move_non_tensors(): + structure = { + "value": 1.0, + "list": [ + "string", + { + "some-value": torch.rand((6, 7)), + }, + ], + } + deep_tensor_move(structure, "cpu") + deep_tensor_move(structure, torch.device("cpu")) diff --git a/tests/utilities/torch.py b/tests/utilities/torch.py new file mode 100644 index 0000000..14b9c72 --- /dev/null +++ b/tests/utilities/torch.py @@ -0,0 +1,73 @@ +import torch +import autoclip +import pytest +from autoclip.torch.clipper import Clipper +from autoclip.torch import QuantileClip, StandardClip + + +def run_clipper_initialization_error_tests( + clipper_type: Clipper, value_name: str, bad_values: list, error_types: list +): + example_model = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + for bad_value, error_type in zip(bad_values, error_types): + with pytest.raises(error_type) as _: + clipper = clipper_type( + example_model.parameters(), **{value_name: bad_value} + ) + + with pytest.raises(error_type): + additional_parameters = torch.nn.Sequential(torch.nn.Linear(10, 10)) + clipper = clipper_type(example_model.parameters()) + clipper.add_param_group( + additional_parameters.parameters(), **{value_name: bad_value} + ) + + +def run_clipping_test(clipper: Clipper, clipper_args: dict): + example_model = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + clipper = clipper(example_model.parameters(), **clipper_args) + loss_fn = torch.nn.MSELoss() + prediction = example_model(torch.rand((10))) + target = torch.Tensor([10.0]) + loss = loss_fn(prediction, target) + loss.backward() + clipper.step() + + +def run_clipping_test_wrapper(clipper: Clipper, clipper_args: dict): + example_model = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + loss_fn = torch.nn.MSELoss() + optimizer = torch.optim.AdamW(example_model.parameters()) + optimizer = clipper.as_optimizer(optimizer, **clipper_args) + prediction = example_model(torch.rand((10))) + target = torch.Tensor([10.0]) + loss = loss_fn(prediction, target) + loss.backward() + optimizer.step() + + +def run_add_param_group(clipper: Clipper, clipper_args: dict): + example_model = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + additional_model = torch.nn.Sequential( + torch.nn.Linear(10, 10), + torch.nn.ReLU(), + torch.nn.Linear(10, 1), + ) + clipper = clipper(example_model.parameters()) + clipper.add_param_group(additional_model.parameters(), **clipper_args)