From c24344e41a0e60c115aee782b0bb2b9543dee941 Mon Sep 17 00:00:00 2001 From: smallred-ops <62491096+smallred-ops@users.noreply.github.com> Date: Wed, 1 Apr 2020 21:38:41 +0800 Subject: [PATCH] Add files via upload --- main.py | 176 +++++++++++++++++++++ model.py | 60 ++++++++ optimizer.py | 402 ++++++++++++++++++++++++++++++++++++++++++++++++ search_space.py | 55 +++++++ utils.py | 185 ++++++++++++++++++++++ 5 files changed, 878 insertions(+) create mode 100644 main.py create mode 100644 model.py create mode 100644 optimizer.py create mode 100644 search_space.py create mode 100644 utils.py diff --git a/main.py b/main.py new file mode 100644 index 0000000..2c69bf9 --- /dev/null +++ b/main.py @@ -0,0 +1,176 @@ +from __future__ import print_function +import argparse +import torch +import torch.nn.functional as F +from optimizer import PruneAdam +from model import LeNet, AlexNet +from utils import regularized_nll_loss, admm_loss, \ + initialize_Z_and_U, update_X, update_Z, update_Z_l1, update_U, \ + print_convergence, print_prune, apply_l1_prune, apply_pattern_prune +from torchvision import datasets, transforms +from tqdm import tqdm + + +def train(args, model, device, train_loader, test_loader, optimizer): + for epoch in range(args.num_pre_epochs):#迭代次数 + print('Pre epoch: {}'.format(epoch + 1)) + # for pa in model.named_parameters(): + # print(pa[0]) + model.train()#训练模型 + for batch_idx, (data, target) in enumerate(tqdm(train_loader)): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = regularized_nll_loss(args, model, output, target) + loss.backward() + optimizer.step() + test(args, model, device, test_loader) + + Z, U = initialize_Z_and_U(model) + for epoch in range(args.num_epochs): + model.train() + print('Epoch: {}'.format(epoch + 1)) + for batch_idx, (data, target) in enumerate(tqdm(train_loader)): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = admm_loss(args, device, model, Z, U, output, target) + loss.backward() + optimizer.step() + X = update_X(model) + Z = update_Z_l1(X, U, args) if args.l1 else update_Z(X, U, args) + U = update_U(U, X, Z) + print_convergence(model, X, Z) + test(args, model, device, test_loader) + + +def test(args, model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + correct += pred.eq(target.view_as(pred)).sum().item() + + test_loss /= len(test_loader.dataset) + + print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, len(test_loader.dataset), + 100. * correct / len(test_loader.dataset))) + + +def retrain(args, model, mask, device, train_loader, test_loader, optimizer): + for epoch in range(args.num_re_epochs): + print('Re epoch: {}'.format(epoch + 1)) + model.train() + for batch_idx, (data, target) in enumerate(tqdm(train_loader)): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.prune_step(mask) + + test(args, model, device, test_loader) + + +def main(): + # Training settings + parser = argparse.ArgumentParser(description='PyTorch MNIST Example') + parser.add_argument('--dataset', type=str, default="mnist", choices=["mnist", "cifar10"],#数据集 + metavar='D', help='training dataset (mnist or cifar10)') + parser.add_argument('--batch-size', type=int, default=64, metavar='N',#训练批次数,64 + help='input batch size for training (default: 64)') + parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',#测试批次,1000 + help='input batch size for testing (default: 1000)') + parser.add_argument('--percent', type=list, default=[0.8, 0.92, 0.991, 0.93],#剪枝率,默认0.8 + metavar='P', help='pruning percentage (default: 0.8)') + parser.add_argument('--alpha', type=float, default=5e-4, metavar='L',#? + help='l2 norm weight (default: 5e-4)') + parser.add_argument('--rho', type=float, default=1e-2, metavar='R',#基数权重,默认权重? + help='cardinality weight (default: 1e-2)') + parser.add_argument('--l1', default=False, action='store_true',#l1正则化,放大特征 + help='prune weights with l1 regularization instead of cardinality') + parser.add_argument('--l2', default=False, action='store_true', + help='apply l2 regularization')#l2正则化,防止过拟合 + parser.add_argument('--num_pre_epochs', type=int, default=3, metavar='P',#预训练的迭代次数,3 + help='number of epochs to pretrain (default: 3)') + parser.add_argument('--num_epochs', type=int, default=10, metavar='N',#正式训练的迭代次数,3 + help='number of epochs to train (default: 10)') + parser.add_argument('--num_re_epochs', type=int, default=10, metavar='R',#重新训练的迭代次数,3 + help='number of epochs to retrain (default: 3)') + parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',#学习率 + help='learning rate (default: 1e-2)') + parser.add_argument('--adam_epsilon', type=float, default=1e-8, metavar='E',#? + help='adam epsilon (default: 1e-8)') + parser.add_argument('--no-cuda', action='store_true', default=False,#设备,有无GPU + help='disables CUDA training') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--save-model', action='store_true', default=False,#是否保存model + help='For Saving the current Model') + args = parser.parse_args() + + use_cuda = not args.no_cuda and torch.cuda.is_available()#是否使用cuda + + torch.manual_seed(args.seed)#为CPU设置种子,用于生成随机数 + + device = torch.device("cuda" if use_cuda else "cpu")#设备是GPU还是CPU + + kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}#如果使用多GPU + + if args.dataset == "mnist":#mnist数据集的加载 + train_loader = torch.utils.data.DataLoader(#训练集加载 + datasets.MNIST('data', train=True, download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,))#单通道均值和标准差,input[channel] = (input[channel] - mean[channel]) / std[channel] + ])), + batch_size=args.batch_size, shuffle=True, **kwargs) + + test_loader = torch.utils.data.DataLoader(#测试集加载 + datasets.MNIST('data', train=False, transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ])), + batch_size=args.test_batch_size, shuffle=True, **kwargs) + + else:#cifar10数据集的加载 + args.percent = [0.8, 0.92, 0.93, 0.94, 0.95, 0.99, 0.99, 0.93]#剪枝率 + #迭代次数,三个阶段 + args.num_pre_epochs = 5 + args.num_epochs = 20 + args.num_re_epochs = 5 + train_loader = torch.utils.data.DataLoader(#训练集加载 + datasets.CIFAR10('data', train=True, download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.49139968, 0.48215827, 0.44653124), + (0.24703233, 0.24348505, 0.26158768))#三通道均值和标准差 + ])), shuffle=True, batch_size=args.batch_size, **kwargs) + + test_loader = torch.utils.data.DataLoader(#测试集加载 + datasets.CIFAR10('data', train=False, download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.49139968, 0.48215827, 0.44653124), + (0.24703233, 0.24348505, 0.26158768)) + ])), shuffle=True, batch_size=args.test_batch_size, **kwargs) + + model = LeNet().to(device) if args.dataset == "mnist" else AlexNet().to(device)#mnist-LeNet,cifar10-AlexNet + optimizer = PruneAdam(model.named_parameters(), lr=args.lr, eps=args.adam_epsilon) + #每一次迭代元素的名字及其值,学习率,? + # train(args, model, device, train_loader, test_loader, optimizer)#调用优化器进行优化 + mask = apply_pattern_prune(model, device,'18')#选择剪枝标准,可以输入0或者其他1-9的组合 + # print(mask) + print_prune(model)#打印剪枝后model,如何改变的模型 + test(args, model, device, test_loader)#使用剪枝后的模型进行测试 + retrain(args, model, mask, device, train_loader, test_loader, optimizer)#使用剪枝后的模型进行重新训练 + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/model.py b/model.py new file mode 100644 index 0000000..d57e20e --- /dev/null +++ b/model.py @@ -0,0 +1,60 @@ +import torch.nn as nn +import torch.nn.functional as F +#网络模型LeNet和AlexNet + +class LeNet(nn.Module): + def __init__(self): + super(LeNet, self).__init__()#28*28 + self.conv1 = nn.Conv2d(1, 20, 3, 1)#3*3的卷积核 + self.conv2 = nn.Conv2d(20, 50, 3, 1) + self.fc1 = nn.Linear(5*5*50, 500) + self.fc2 = nn.Linear(500, 10) + + def forward(self, x): + x = F.relu(self.conv1(x))#output:26*26*20 + x = F.max_pool2d(x, 2, 2)#output:13*13*20 + x = F.relu(self.conv2(x))#output:11*11*50 + x = F.max_pool2d(x, 2, 2)#output:5*5*50 + x = x.view(-1, 5*5*50) + x = F.relu(self.fc1(x))#output:500 + x = self.fc2(x)#output:10 + return F.log_softmax(x, dim=1) + + +class AlexNet(nn.Module): + def __init__(self): + super(AlexNet, self).__init__() + self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1) + # print(self.conv1.weight[0][2],len(self.conv1.weight)) + + self.features = nn.Sequential( + self.conv1, + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2), + nn.Conv2d(64, 192, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2), + nn.Conv2d(192, 384, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(384, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2), + ) + + self.classifier = nn.Sequential( + nn.Dropout(), + nn.Linear(256 * 7 * 7, 4096), + nn.ReLU(inplace=True), + nn.Dropout(), + nn.Linear(4096, 4096), + nn.ReLU(inplace=True), + nn.Linear(4096, 10), + ) + + def forward(self, x): + x = self.features(x) + x = x.view(x.shape[0], -1) + x = self.classifier(x) + return F.log_softmax(x, dim=1) diff --git a/optimizer.py b/optimizer.py new file mode 100644 index 0000000..464d22a --- /dev/null +++ b/optimizer.py @@ -0,0 +1,402 @@ +""" +This code is from official pytorch document (https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html) +I modified optimizer to use name of the parameter for preventing prunned weights from updated by gradients +""" + +import math +from collections import defaultdict +from torch._six import container_abcs +import torch +from copy import deepcopy +from itertools import chain + + +class _RequiredParameter(object): + """Singleton class representing a required parameter for an Optimizer.""" + def __repr__(self): + return "" + +required = _RequiredParameter() + + +class NameOptimizer(object): + r"""Base class for all optimizers. + + .. warning:: + Parameters need to be specified as collections that have a deterministic + ordering that is consistent between runs. Examples of objects that don't + satisfy those properties are sets and iterators over values of dictionaries. + + Arguments: + params (iterable): an iterable of :class:`torch.Tensor` s or + :class:`dict` s. Specifies what Tensors should be optimized. + defaults: (dict): a dict containing default values of optimization + options (used when a parameter group doesn't specify them). + """ + + def __init__(self, named_params, defaults): + self.defaults = defaults + + if isinstance(named_params, torch.Tensor):#判断一个对象是否是一个已知的类型 + raise TypeError("params argument given to the optimizer should be " + "an iterable of Tensors or dicts, but got " + + torch.typename(named_params)) + + self.state = defaultdict(dict) + self.param_groups = [] + + param_groups = list(named_params) + if len(param_groups) == 0: + raise ValueError("optimizer got an empty parameter list") + if not isinstance(param_groups[0], dict): + param_groups = [{'params': param_groups}] + + for param_group in param_groups: + self.add_param_group(param_group) + + def __getstate__(self): + return { + 'defaults': self.defaults, + 'state': self.state, + 'param_groups': self.param_groups, + } + + def __setstate__(self, state): + self.__dict__.update(state) + + def __repr__(self): + format_string = self.__class__.__name__ + ' (' + for i, group in enumerate(self.param_groups): + format_string += '\n' + format_string += 'Parameter Group {0}\n'.format(i) + for key in sorted(group.keys()): + if key != 'params': + format_string += ' {0}: {1}\n'.format(key, group[key]) + format_string += ')' + return format_string + + def state_dict(self): + r"""Returns the state of the optimizer as a :class:`dict`. + + It contains two entries: + + * state - a dict holding current optimization state. Its content + differs between optimizer classes. + * param_groups - a dict containing all parameter groups + """ + # Save ids instead of Tensors + def pack_group(group): + packed = {k: v for k, v in group.items() if k != 'params'} + packed['params'] = [id(p) for p in group['params']] + return packed + param_groups = [pack_group(g) for g in self.param_groups] + # Remap state to use ids as keys + packed_state = {(id(k) if isinstance(k, torch.Tensor) else k): v + for k, v in self.state.items()} + return { + 'state': packed_state, + 'param_groups': param_groups, + } + + def load_state_dict(self, state_dict): + r"""Loads the optimizer state. + + Arguments: + state_dict (dict): optimizer state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # deepcopy, to be consistent with module API + state_dict = deepcopy(state_dict) + # Validate the state_dict + groups = self.param_groups + saved_groups = state_dict['param_groups'] + + if len(groups) != len(saved_groups): + raise ValueError("loaded state dict has a different number of " + "parameter groups") + param_lens = (len(g['params']) for g in groups) + saved_lens = (len(g['params']) for g in saved_groups) + if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): + raise ValueError("loaded state dict contains a parameter group " + "that doesn't match the size of optimizer's group") + + # Update the state + id_map = {old_id: p for old_id, p in + zip(chain(*(g['params'] for g in saved_groups)), + chain(*(g['params'] for g in groups)))} + + def cast(param, value): + r"""Make a deep copy of value, casting all tensors to device of param.""" + if isinstance(value, torch.Tensor): + # Floating-point types are a bit special here. They are the only ones + # that are assumed to always match the type of params. + if param.is_floating_point(): + value = value.to(param.dtype) + value = value.to(param.device) + return value + elif isinstance(value, dict): + return {k: cast(param, v) for k, v in value.items()} + elif isinstance(value, container_abcs.Iterable): + return type(value)(cast(param, v) for v in value) + else: + return value + + # Copy state assigned to params (and cast tensors to appropriate types). + # State that is not assigned to params is copied as is (needed for + # backward compatibility). + state = defaultdict(dict) + for k, v in state_dict['state'].items(): + if k in id_map: + param = id_map[k] + state[param] = cast(param, v) + else: + state[k] = v + + # Update parameter groups, setting their 'params' value + def update_group(group, new_group): + new_group['params'] = group['params'] + return new_group + param_groups = [ + update_group(g, ng) for g, ng in zip(groups, saved_groups)] + self.__setstate__({'state': state, 'param_groups': param_groups}) + + def zero_grad(self): + r"""Clears the gradients of all optimized :class:`torch.Tensor` s.""" + for group in self.param_groups: + for name, p in group['params']: + if p.grad is not None: + p.grad.detach_() + p.grad.zero_() + + def step(self, closure): + r"""Performs a single optimization step (parameter update). + + Arguments: + closure (callable): A closure that reevaluates the model and + returns the loss. Optional for most optimizers. + """ + raise NotImplementedError + + def add_param_group(self, param_group): + r"""Add a param group to the :class:`Optimizer` s `param_groups`. + + This can be useful when fine tuning a pre-trained network as frozen layers can be made + trainable and added to the :class:`Optimizer` as training progresses. + + Arguments: + param_group (dict): Specifies what Tensors should be optimized along with group + specific optimization options. + """ + assert isinstance(param_group, dict), "param group must be a dict" + + params = param_group['params'] + if isinstance(params, torch.Tensor): + param_group['params'] = [params] + elif isinstance(params, set): + raise TypeError('optimizer parameters need to be organized in ordered collections, but ' + 'the ordering of tensors in sets will change between runs. Please use a list instead.') + else: + param_group['params'] = list(params) + + for name, param in param_group['params']: + if not isinstance(param, torch.Tensor): + raise TypeError("optimizer can only optimize Tensors, " + "but one of the params is " + torch.typename(param)) + if not param.is_leaf: + raise ValueError("can't optimize a non-leaf Tensor") + + for name, default in self.defaults.items(): + if default is required and name not in param_group: + raise ValueError("parameter group didn't specify a value of required optimization parameter " + + name) + else: + param_group.setdefault(name, default) + + param_set = set() + for group in self.param_groups: + param_set.update(set(group['params'])) + + if not param_set.isdisjoint(set(param_group['params'])): + raise ValueError("some parameters appear in more than one parameter group") + + self.param_groups.append(param_group) + + +class PruneAdam(NameOptimizer): + r"""Implements Adam algorithm. + #使用了一个已经发表的Adam算法 + It has been proposed in `Adam: A Method for Stochastic Optimization`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3)#学习率 + betas (Tuple[float, float], optional): coefficients used for computing#系数,计算梯度均值和平方 + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator分母 to improve #为了增加数值计算的稳定性而加到分母里的项 + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0)#权重衰减,使用l2惩罚 + amsgrad (boolean, optional): whether to use the AMSGrad variant of this#是否使用AMSGrad变体 + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + + .. _Adam\: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, + weight_decay=0, amsgrad=False): + if not 0.0 <= lr:#学习率 + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps:#分母加上的值,增加数据稳定性 + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0:#一组系数,计算梯度均值和平方 + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay, amsgrad=amsgrad)#其他默认参数 + super(PruneAdam, self).__init__(params, defaults) + + def __setstate__(self, state): + super(PruneAdam, self).__setstate__(state) + for group in self.param_groups:#参数组 + group.setdefault('amsgrad', False) + + def step(self, closure=None):#进行单次优化 (参数更新). + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model一个重新评价模型并返回loss的闭包 + and returns the loss.返回loss值 + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups:#param_groups 是个list,每个元素是dict形式 + for name, p in group['params']: + if p.grad is None: + continue + grad = p.grad.data#梯度 gradient + if grad.is_sparse:#如果梯度是稀疏的 + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + amsgrad = group['amsgrad'] + + state = self.state[p]#一个保存了当前优化状态的dict,optimizer的类别不同,state的内容也会不同。 + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values梯度值的指数移动平均 + state['exp_avg'] = torch.zeros_like(p.data)#创建维度和p.data一致的零矩阵 + # Exponential moving average of squared gradient values梯度平方值的指数移动平均 + state['exp_avg_sq'] = torch.zeros_like(p.data) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values上面两个参数的最大值 + state['max_exp_avg_sq'] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] #为指数赋值 + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] + beta1, beta2 = group['betas']#梯度均值和平方值,一二阶距估计的衰减率,(k可以用来调整一二阶距的学习率) + + state['step'] += 1 + + if group['weight_decay'] != 0:#权重衰减 + grad.add_(group['weight_decay'], p.data)#更新梯度 + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(1 - beta1, grad)#一阶距估计 公式中的Mt + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)#二阶距估计 公式中的Nt + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)#求序列最大值, + # Use the max. for normalizing running avg. of gradient + denom = max_exp_avg_sq.sqrt().add_(group['eps'])#分母上加上eps这个参数 + else: + denom = exp_avg_sq.sqrt().add_(group['eps']) + + bias_correction1 = 1 - beta1 ** state['step']#偏差校正 Mt' + bias_correction2 = 1 - beta2 ** state['step']#偏差校正 Nt' + step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 + + p.data.addcdiv_(-step_size, exp_avg, denom)#更新参数,p.data = p.data - step_size + exp_avg/denom + + return loss + + def prune_step(self, mask, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + mask: prunning mask to prevent weight update.剪枝mask,防止权重更新 + """ + loss = None + if closure is not None: + loss = closure() + for group in self.param_groups: + for name, p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') + amsgrad = group['amsgrad'] + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_sq'] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + if amsgrad: + max_exp_avg_sq = state['max_exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + if group['weight_decay'] != 0: + grad.add_(group['weight_decay'], p.data) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(1 - beta1, grad)#一阶距估计 + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)#二阶距估计 + + # if name.split('.')[-1] == "weight":#若name根据小数点划分后,最后一项是weight + a = name.split('.')[-1] + b = name.split('.')[-2] + if (a == "weight" and b == "conv1") or (a == "weight" and b == "conv2"): + exp_avg_sq.mul_(mask[name])#exp_avg_sq *= mask[name] + + if amsgrad: + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) + # Use the max. for normalizing running avg. of gradient + denom = max_exp_avg_sq.sqrt().add_(group['eps']) + else: + denom = exp_avg_sq.sqrt().add_(group['eps']) + + bias_correction1 = 1 - beta1 ** state['step'] #一二阶距估计的校正 + bias_correction2 = 1 - beta2 ** state['step'] + step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 + + # if name.split('.')[-1] == "weight": + a = name.split('.')[-1] + b = name.split('.')[-2] + if (a == "weight" and b == "conv1") or (a == "weight" and b == "conv2"): + exp_avg.mul_(mask[name])#exp_avg *= mask[name] + p.data.addcdiv_(-step_size, exp_avg, denom)#更新参数,p.data = p.data - step_size + exp_avg/denom + + return loss diff --git a/search_space.py b/search_space.py new file mode 100644 index 0000000..0ce74ee --- /dev/null +++ b/search_space.py @@ -0,0 +1,55 @@ +import copy + +def combination(length):#寻找1-9的排列组合,length代表想剪枝的weight数 + data = [i for i in range(1,10)]#代表3*3kernel的位置 + result = [] + temp = [0] * length + l = len(data) + def next_num(li = 0,ni = 0): + if ni == length: + result.append(copy.copy(temp)) + return + for lj in range(li, l): + temp[ni] = data[lj] + next_num(lj+1,ni+1) + next_num() + return result + +def search_space():#搜索剪枝空间 + space = {}#以剪枝位置和kernel形式存储 + result = []#所有可能的剪枝情况,[0]表示不做剪枝 + for i in range(10):#每种排列组合均加入搜索空间,即任意剪枝位置 + combi_list = combination(i) + for e in combi_list: + result.append(e) + for e in result: + key = ''#作为搜索空间的key + temp = [[1 for i in range(3)]for j in range(3)]#3*3的列表,每个位置代表剪枝或保留 + if not e: + key = '0' + else: + for e1 in e: + key = key + str(e1) + if e1 == 1:temp[0][0] = 0 + if e1 == 2:temp[0][1] = 0 + if e1 == 3:temp[0][2] = 0 + if e1 == 4:temp[1][0] = 0 + if e1 == 5:temp[1][1] = 0 + if e1 == 6:temp[1][2] = 0 + if e1 == 7:temp[2][0] = 0 + if e1 == 8:temp[2][1] = 0 + if e1 == 9:temp[2][2] = 0 + space[key] = temp + return space + +def get_pattern(str):#可以输入0或1-9的组合,例如:‘0’表示不剪枝,‘12345’表示剪枝12345位置的weight,顺序需要从大到小 + space = search_space() + # print(space) + return space[str] + +# print(get_pattern('4567')) + + + + + \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..c690263 --- /dev/null +++ b/utils.py @@ -0,0 +1,185 @@ +import torch +import torch.nn.functional as F +import numpy as np +from search_space import get_pattern + + +def regularized_nll_loss(args, model, output, target): + index = 0 + loss = F.nll_loss(output, target) + if args.l2: + for name, param in model.named_parameters():#分离model参数 + if name.split('.')[-1] == "weight": + loss += args.alpha * param.norm() + index += 1 + return loss + +def admm_loss(args, device, model, Z, U, output, target):#使用正则化计算loss? + idx = 0 + loss = F.nll_loss(output, target) + for name, param in model.named_parameters(): + if name.split('.')[-1] == "weight": + u = U[idx].to(device) + z = Z[idx].to(device) + loss += args.rho / 2 * (param - z + u).norm() + if args.l2: + loss += args.alpha * param.norm() + idx += 1 + return loss + + +def initialize_Z_and_U(model): + Z = () + U = () + for name, param in model.named_parameters(): + if name.split('.')[-1] == "weight": + Z += (param.detach().cpu().clone(),)#切断一些分支的反向传播,使得只对其中一部分参数进行调整,detach()后param并不会收到影响,对detach()后的矩阵进行clone(),会影响原矩阵 + U += (torch.zeros_like(param).cpu(),)#构造一个和param一样的矩阵,初始化为0 + return Z, U + + +def update_X(model): + X = () + for name, param in model.named_parameters(): + if name.split('.')[-1] == "weight": + X += (param.detach().cpu().clone(),) + return X + + +def update_Z(X, U, args): + new_Z = () + idx = 0 + for x, u in zip(X, U): + z = x + u + pcen = np.percentile(abs(z), 100*args.percent[idx]) + under_threshold = abs(z) < pcen + z.data[under_threshold] = 0 + new_Z += (z,) + idx += 1 + return new_Z + + +def update_Z_l1(X, U, args): + new_Z = () + delta = args.alpha / args.rho + for x, u in zip(X, U): + z = x + u + new_z = z.clone() + if (z > delta).sum() != 0: + new_z[z > delta] = z[z > delta] - delta + if (z < -delta).sum() != 0: + new_z[z < -delta] = z[z < -delta] + delta + if (abs(z) <= delta).sum() != 0: + new_z[abs(z) <= delta] = 0 + new_Z += (new_z,) + return new_Z + + +def update_U(U, X, Z): + new_U = () + for u, x, z in zip(U, X, Z): + new_u = u + x - z + new_U += (new_u,) + return new_U + + +def prune_weight(weight, device, percent):#使用阈值剪枝 + # to work with admm, we calculate percentile based on all elements instead of nonzero elements. + weight_numpy = weight.detach().cpu().numpy() + pcen = np.percentile(abs(weight_numpy), 100*percent) + under_threshold = abs(weight_numpy) < pcen + weight_numpy[under_threshold] = 0 + mask = torch.Tensor(abs(weight_numpy) >= pcen).to(device) + return mask + + +def prune_l1_weight(weight, device, delta):#使用l1正则化剪枝 + weight_numpy = weight.detach().cpu().numpy() + under_threshold = abs(weight_numpy) < delta + weight_numpy[under_threshold] = 0 + mask = torch.Tensor(abs(weight_numpy) >= delta).to(device) + return mask + + +def prune_pattern_weight(weight,device,str,layer):#使用pattern剪枝 + weight_numpy = weight.detach().cpu().numpy() + if layer == 1:#如果是卷积层1,1 input,20 output + for i in range(20): + for j in range(1): + weight_numpy[i][j] = get_pattern(str)#根据pattern对应位置赋0或1 + elif layer == 2:#如果是卷积层2,20 input,50 output + for i in range(50): + for j in range(20): + weight_numpy[i][j] = get_pattern(str) + mask = torch.Tensor(weight_numpy).to(device) + return mask + + +def apply_prune(model, device, args):#根据设定的剪枝率剪枝 + # returns dictionary of non_zero_values' indices + print("Apply Pruning based on percentile") + dict_mask = {} + idx = 0 + for name, param in model.named_parameters(): + if name.split('.')[-1] == "weight": + mask = prune_weight(param, device, args.percent[idx]) + param.data.mul_(mask) + # param.data = torch.Tensor(weight_pruned).to(device) + dict_mask[name] = mask + idx += 1 + return dict_mask + + +def apply_l1_prune(model, device, args):#根据l1正则化剪枝 + delta = args.alpha / args.rho + print("Apply Pruning based on l1") + dict_mask = {} + idx = 0 + for name, param in model.named_parameters(): + if name.split('.')[-1] == "weight": + mask = prune_l1_weight(param, device, delta) + param.data.mul_(mask) + dict_mask[name] = mask + idx += 1 + return dict_mask + + +def apply_pattern_prune(model,device,str):#根据pattern剪枝 + print("Apply Pruning based on pattern") + dict_mask = {} + for name,param in model.named_parameters(): + a = name.split('.')[-1] + b = name.split('.')[-2] + if a == "weight" and b == "conv1": + mask = prune_pattern_weight(param,device,str,1) + param.data.mul_(mask)#model参数数据的对应位置变为0 + dict_mask[name] = mask + elif a == "weight" and b == "conv2": + mask = prune_pattern_weight(param,device,str,2) + param.data.mul_(mask) + dict_mask[name] = mask + print(name,dict_mask[name]) + return dict_mask + +def print_convergence(model, X, Z):#打印收敛情况 + idx = 0 + print("normalized norm of (weight - projection)") + for name, _ in model.named_parameters(): + if name.split('.')[-1] == "weight": + x, z = X[idx], Z[idx] + print("({}): {:.4f}".format(name, (x-z).norm().item() / x.norm().item())) + idx += 1 + + +def print_prune(model):#打印剪枝情况 + prune_param, total_param = 0, 0 + for name, param in model.named_parameters(): + if name.split('.')[-1] == "weight": + print("[at weight {}]".format(name)) + print("percentage of pruned: {:.4f}%".format(100 * (abs(param) == 0).sum().item() / param.numel())) + print("nonzero parameters after pruning: {} / {}\n".format((param != 0).sum().item(), param.numel())) + total_param += param.numel() + prune_param += (param != 0).sum().item() + print("total nonzero parameters after pruning: {} / {} ({:.4f}%)". + format(prune_param, total_param, + 100 * (total_param - prune_param) / total_param))