-
Notifications
You must be signed in to change notification settings - Fork 1
attack: added Wanda Pruning (attack) #26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
esveee
wants to merge
2
commits into
main
Choose a base branch
from
esveee/wanda-final
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
161 changes: 161 additions & 0 deletions
161
src/safetunebed/whitebox/attacks/wanda_pruning/ablate.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,161 @@ | ||
| import math | ||
| import time | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| import transformers | ||
|
|
||
| torch.backends.cuda.matmul.allow_tf32 = False | ||
| torch.backends.cudnn.allow_tf32 = False | ||
|
|
||
| class AblateGPT: | ||
|
|
||
| def __init__(self, layer): | ||
| self.layer = layer | ||
| self.dev = self.layer.weight.device | ||
| W = layer.weight.data.clone() | ||
| if isinstance(self.layer, nn.Conv2d): | ||
| W = W.flatten(1) | ||
| if isinstance(self.layer, transformers.Conv1D): | ||
| W = W.t() | ||
| self.rows = W.shape[0] | ||
| self.columns = W.shape[1] | ||
| self.H = torch.zeros((self.columns, self.columns), device=self.dev) | ||
| self.nsamples = 0 | ||
|
|
||
| self.scaler_row = torch.zeros((self.columns), device=self.dev) | ||
|
|
||
| def add_batch(self, inp, out): | ||
| if len(inp.shape) == 2: | ||
| inp = inp.unsqueeze(0) | ||
| tmp = inp.shape[0] | ||
| if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): | ||
| if len(inp.shape) == 3: | ||
| inp = inp.reshape((-1, inp.shape[-1])) | ||
| inp = inp.t() | ||
| self.H *= self.nsamples / (self.nsamples + tmp) | ||
|
|
||
| self.scaler_row *= self.nsamples / (self.nsamples+tmp) | ||
|
|
||
| self.nsamples += tmp | ||
| inp = math.sqrt(2 / self.nsamples) * inp.float() | ||
| self.H += inp.matmul(inp.t()) | ||
| self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2 / self.nsamples | ||
|
|
||
| def get_wanda_mask(self, sparsity, prunen, prunem): | ||
| W_metric = torch.abs(self.layer.weight.data) * torch.sqrt(self.scaler_row.reshape((1,-1))) | ||
| W_mask = (torch.zeros_like(W_metric) == 1) ## initialize a mask to be all False | ||
| if prunen != 0: | ||
| for ii in range(W_metric.shape[1]): | ||
| if ii % prunem == 0: | ||
| tmp = W_metric[:,ii:(ii+prunem)].float() | ||
| W_mask.scatter_(1,ii+torch.topk(tmp, prunen,dim=1, largest=False)[1], True) | ||
| else: | ||
| sort_res = torch.sort(W_metric, dim=-1, stable=True) | ||
| indices = sort_res[1][:,:int(W_metric.shape[1]*sparsity)] | ||
| W_mask.scatter_(1, indices, True) | ||
|
|
||
| return W_mask | ||
|
|
||
| def get_mag_mask(self, sparsity, prunen, prunem): | ||
| W = self.layer.weight.data | ||
| W_metric = torch.abs(W) | ||
| if prunen != 0: | ||
| W_mask = (torch.zeros_like(W)==1) | ||
| for ii in range(W_metric.shape[1]): | ||
| if ii % prunem == 0: | ||
| tmp = W_metric[:,ii:(ii+prunem)].float() | ||
| W_mask.scatter_(1,ii+torch.topk(tmp, prunen,dim=1, largest=False)[1], True) | ||
| else: | ||
| thresh = torch.sort(W_metric.flatten().cuda())[0][int(W.numel()*sparsity)].cpu() | ||
| W_mask = (W_metric<=thresh) | ||
|
|
||
| return W_mask | ||
|
|
||
| def fasterprune( | ||
| self, args, sparsity, mask=None, prune_n=0, prune_m=0, blocksize=128, percdamp=.01 | ||
| ): | ||
| W = self.layer.weight.data.clone() | ||
| if isinstance(self.layer, nn.Conv2d): | ||
| W = W.flatten(1) | ||
| if isinstance(self.layer, transformers.Conv1D): | ||
| W = W.t() | ||
| W = W.float() | ||
|
|
||
| tick = time.time() | ||
|
|
||
| H = self.H | ||
| del self.H | ||
| dead = torch.diag(H) == 0 | ||
| H[dead, dead] = 1 | ||
| W[:, dead] = 0 | ||
|
|
||
| Losses = torch.zeros(self.rows, device=self.dev) | ||
|
|
||
| damp = percdamp * torch.mean(torch.diag(H)) | ||
| diag = torch.arange(self.columns, device=self.dev) | ||
| H[diag, diag] += damp | ||
| H = torch.linalg.cholesky(H) | ||
| H = torch.cholesky_inverse(H) | ||
| H = torch.linalg.cholesky(H, upper=True) | ||
| Hinv = H | ||
|
|
||
| for i1 in range(0, self.columns, blocksize): | ||
| i2 = min(i1 + blocksize, self.columns) | ||
| count = i2 - i1 | ||
|
|
||
| W1 = W[:, i1:i2].clone() | ||
| Q1 = torch.zeros_like(W1) | ||
| Err1 = torch.zeros_like(W1) | ||
| Losses1 = torch.zeros_like(W1) | ||
| Hinv1 = Hinv[i1:i2, i1:i2] | ||
|
|
||
| if prune_n == 0 or mask is not None: | ||
| if mask is not None: | ||
| mask1 = mask[:, i1:i2] | ||
| else: | ||
| # tmp = W1 ** 2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2 | ||
| if "wanda" in args.prune_method: | ||
| tmp = torch.abs(W1) * torch.sqrt(self.scaler_row[i1:i2].reshape((1,-1))) | ||
| elif "mag" in args.prune_method: | ||
| tmp = torch.abs(W1) | ||
| thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)] | ||
| mask1 = tmp <= thresh | ||
| else: | ||
| mask1 = torch.zeros_like(W1) == 1 | ||
|
|
||
| for i in range(count): | ||
| w = W1[:, i] | ||
| d = Hinv1[i, i] | ||
|
|
||
| if prune_n != 0 and i % prune_m == 0 and mask is None: | ||
| # tmp = W1[:, i:(i + prune_m)] ** 2 / (torch.diag(Hinv1)[i:(i + prune_m)].reshape((1, -1))) ** 2 | ||
| if "wanda" in args.prune_method: | ||
| tmp = torch.abs(W1[:, i:(i+prune_m)]) * torch.sqrt(self.scaler_row[(i+i1):(i+i1+prune_m)].reshape((1,-1))) | ||
| elif "mag" in args.prune_method: | ||
| tmp = torch.abs(W1[:, i:(i+prune_m)]) | ||
| mask1.scatter_(1, i + torch.topk(tmp, prune_n, dim=1, largest=False)[1], True) | ||
|
|
||
| q = w.clone() | ||
| q[mask1[:, i]] = 0 | ||
|
|
||
| Q1[:, i] = q | ||
| Losses1[:, i] = (w - q) ** 2 / d ** 2 | ||
|
|
||
| err1 = (w - q) / d | ||
| W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) | ||
| Err1[:, i] = err1 | ||
|
|
||
| W[:, i1:i2] = Q1 | ||
| Losses += torch.sum(Losses1, 1) / 2 | ||
|
|
||
| W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) | ||
|
|
||
| torch.cuda.synchronize() | ||
| if isinstance(self.layer, transformers.Conv1D): | ||
| W = W.t() | ||
| self.layer.weight.data = W.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) | ||
|
|
||
| def free(self): | ||
| self.H = None | ||
| torch.cuda.empty_cache() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,73 @@ | ||
| # Code adapted from https://github.com/IST-DASLab/sparsegpt/blob/master/datautils.py | ||
|
|
||
| import numpy as np | ||
| import random | ||
| import torch | ||
| from datasets import load_dataset | ||
|
|
||
| # Set seed for reproducibility | ||
| def set_seed(seed): | ||
| np.random.seed(seed) | ||
| torch.random.manual_seed(seed) | ||
|
|
||
| # Wrapper for tokenized input IDs | ||
| class TokenizerWrapper: | ||
| def __init__(self, input_ids): | ||
| self.input_ids = input_ids | ||
|
|
||
| # Load and process wikitext2 dataset | ||
| def get_wikitext2(nsamples, seed, seqlen, tokenizer): | ||
| # Load train and test datasets | ||
| traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') | ||
| testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') | ||
|
|
||
| # Encode datasets | ||
| trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt') | ||
| testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') | ||
|
|
||
| # Generate samples from training set | ||
| random.seed(seed) | ||
| trainloader = [] | ||
| for _ in range(nsamples): | ||
| i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) | ||
| j = i + seqlen | ||
| inp = trainenc.input_ids[:, i:j] | ||
| tar = inp.clone() | ||
| tar[:, :-1] = -100 | ||
| trainloader.append((inp, tar)) | ||
| return trainloader, testenc | ||
|
|
||
| # Load and process c4 dataset | ||
| def get_c4(nsamples, seed, seqlen, tokenizer): | ||
| # Load train and validation datasets | ||
| traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train') | ||
| valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation') | ||
|
|
||
| # Generate samples from training set | ||
| random.seed(seed) | ||
| trainloader = [] | ||
| for _ in range(nsamples): | ||
| while True: | ||
| i = random.randint(0, len(traindata) - 1) | ||
| trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') | ||
| if trainenc.input_ids.shape[1] > seqlen: | ||
| break | ||
| i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) | ||
| j = i + seqlen | ||
| inp = trainenc.input_ids[:, i:j] | ||
| tar = inp.clone() | ||
| tar[:, :-1] = -100 | ||
| trainloader.append((inp, tar)) | ||
|
|
||
| # Prepare validation dataset | ||
| valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') | ||
| valenc = valenc.input_ids[:, :(256 * seqlen)] | ||
| valenc = TokenizerWrapper(valenc) | ||
| return trainloader, valenc | ||
|
|
||
| # Function to select the appropriate loader based on dataset name | ||
| def get_loaders(name, nsamples=128, seed=0, seqlen=2048, tokenizer=None): | ||
| if 'wikitext2' in name: | ||
| return get_wikitext2(nsamples, seed, seqlen, tokenizer) | ||
| if "c4" in name: | ||
| return get_c4(nsamples, seed, seqlen, tokenizer) |
35 changes: 35 additions & 0 deletions
35
src/safetunebed/whitebox/attacks/wanda_pruning/layerwrapper.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| # Define WrappedGPT class | ||
| class WrappedGPT: | ||
| """ | ||
| This class wraps a GPT layer for specific operations. | ||
| """ | ||
|
|
||
| def __init__(self, layer, layer_id=0, layer_name="none"): | ||
| self.layer = layer | ||
| self.dev = self.layer.weight.device | ||
| self.rows = layer.weight.data.shape[0] | ||
| self.columns = layer.weight.data.shape[1] | ||
|
|
||
| self.scaler_row = torch.zeros((self.columns), device=self.dev) | ||
| self.nsamples = 0 | ||
|
|
||
| self.layer_id = layer_id | ||
| self.layer_name = layer_name | ||
|
|
||
| def add_batch(self, inp, out): | ||
| if len(inp.shape) == 2: | ||
| inp = inp.unsqueeze(0) | ||
| tmp = inp.shape[0] | ||
| if isinstance(self.layer, nn.Linear): | ||
| if len(inp.shape) == 3: | ||
| inp = inp.reshape((-1, inp.shape[-1])) | ||
| inp = inp.t() | ||
|
|
||
| self.scaler_row *= self.nsamples / (self.nsamples+tmp) | ||
| self.nsamples += tmp | ||
|
|
||
| inp = inp.type(torch.float32) | ||
| self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2 / self.nsamples |
121 changes: 121 additions & 0 deletions
121
src/safetunebed/whitebox/attacks/wanda_pruning/sparsegpt.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,121 @@ | ||
|
|
||
| import math | ||
| import time | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| import transformers | ||
|
|
||
| torch.backends.cuda.matmul.allow_tf32 = False | ||
| torch.backends.cudnn.allow_tf32 = False | ||
|
|
||
| ## SparseGPT: https://github.com/IST-DASLab/sparsegpt/tree/f5c25005a61f96a0933ca2f95705a963585aafaa | ||
| class SparseGPT: | ||
|
|
||
| def __init__(self, layer): | ||
| self.layer = layer | ||
| self.dev = self.layer.weight.device | ||
| W = layer.weight.data.clone() | ||
| if isinstance(self.layer, nn.Conv2d): | ||
| W = W.flatten(1) | ||
| if isinstance(self.layer, transformers.Conv1D): | ||
| W = W.t() | ||
| self.rows = W.shape[0] | ||
| self.columns = W.shape[1] | ||
| self.H = torch.zeros((self.columns, self.columns), device=self.dev) | ||
| self.nsamples = 0 | ||
|
|
||
| def add_batch(self, inp, out): | ||
| if len(inp.shape) == 2: | ||
| inp = inp.unsqueeze(0) | ||
| tmp = inp.shape[0] | ||
| if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): | ||
| if len(inp.shape) == 3: | ||
| inp = inp.reshape((-1, inp.shape[-1])) | ||
| inp = inp.t() | ||
| self.H *= self.nsamples / (self.nsamples + tmp) | ||
| self.nsamples += tmp | ||
| inp = math.sqrt(2 / self.nsamples) * inp.float() | ||
| self.H += inp.matmul(inp.t()) | ||
|
|
||
| def fasterprune( | ||
| self, sparsity, prune_n=0, prune_m=0, blocksize=128, percdamp=.01 | ||
| ): | ||
| W = self.layer.weight.data.clone() | ||
| if isinstance(self.layer, nn.Conv2d): | ||
| W = W.flatten(1) | ||
| if isinstance(self.layer, transformers.Conv1D): | ||
| W = W.t() | ||
| W = W.float() | ||
|
|
||
| tick = time.time() | ||
|
|
||
| H = self.H | ||
| del self.H | ||
| dead = torch.diag(H) == 0 | ||
| H[dead, dead] = 1 | ||
| W[:, dead] = 0 | ||
|
|
||
| Losses = torch.zeros(self.rows, device=self.dev) | ||
|
|
||
| damp = percdamp * torch.mean(torch.diag(H)) | ||
| diag = torch.arange(self.columns, device=self.dev) | ||
| H[diag, diag] += damp | ||
| H = torch.linalg.cholesky(H) | ||
| H = torch.cholesky_inverse(H) | ||
| H = torch.linalg.cholesky(H, upper=True) | ||
| Hinv = H | ||
|
|
||
| mask = None | ||
|
|
||
| for i1 in range(0, self.columns, blocksize): | ||
| i2 = min(i1 + blocksize, self.columns) | ||
| count = i2 - i1 | ||
|
|
||
| W1 = W[:, i1:i2].clone() | ||
| Q1 = torch.zeros_like(W1) | ||
| Err1 = torch.zeros_like(W1) | ||
| Losses1 = torch.zeros_like(W1) | ||
| Hinv1 = Hinv[i1:i2, i1:i2] | ||
|
|
||
| if prune_n == 0: | ||
| if mask is not None: | ||
| mask1 = mask[:, i1:i2] | ||
| else: | ||
| tmp = W1 ** 2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2 | ||
| thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)] | ||
| mask1 = tmp <= thresh | ||
| else: | ||
| mask1 = torch.zeros_like(W1) == 1 | ||
|
|
||
| for i in range(count): | ||
| w = W1[:, i] | ||
| d = Hinv1[i, i] | ||
|
|
||
| if prune_n != 0 and i % prune_m == 0: | ||
| tmp = W1[:, i:(i + prune_m)] ** 2 / (torch.diag(Hinv1)[i:(i + prune_m)].reshape((1, -1))) ** 2 | ||
| mask1.scatter_(1, i + torch.topk(tmp, prune_n, dim=1, largest=False)[1], True) | ||
|
|
||
| q = w.clone() | ||
| q[mask1[:, i]] = 0 | ||
|
|
||
| Q1[:, i] = q | ||
| Losses1[:, i] = (w - q) ** 2 / d ** 2 | ||
|
|
||
| err1 = (w - q) / d | ||
| W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) | ||
| Err1[:, i] = err1 | ||
|
|
||
| W[:, i1:i2] = Q1 | ||
| Losses += torch.sum(Losses1, 1) / 2 | ||
|
|
||
| W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) | ||
|
|
||
| torch.cuda.synchronize() | ||
| if isinstance(self.layer, transformers.Conv1D): | ||
| W = W.t() | ||
| self.layer.weight.data = W.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) | ||
|
|
||
| def free(self): | ||
| self.H = None | ||
| torch.cuda.empty_cache() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can add citation and link to original code here, like punya did in his PR src/safetunebed/whitebox/attacks/gcg/init.py