Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can add citation and link to original code here, like punya did in his PR src/safetunebed/whitebox/attacks/gcg/init.py

Empty file.
161 changes: 161 additions & 0 deletions src/safetunebed/whitebox/attacks/wanda_pruning/ablate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import math
import time

import torch
import torch.nn as nn
import transformers

torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

class AblateGPT:

def __init__(self, layer):
self.layer = layer
self.dev = self.layer.weight.device
W = layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
W = W.t()
self.rows = W.shape[0]
self.columns = W.shape[1]
self.H = torch.zeros((self.columns, self.columns), device=self.dev)
self.nsamples = 0

self.scaler_row = torch.zeros((self.columns), device=self.dev)

def add_batch(self, inp, out):
if len(inp.shape) == 2:
inp = inp.unsqueeze(0)
tmp = inp.shape[0]
if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.t()
self.H *= self.nsamples / (self.nsamples + tmp)

self.scaler_row *= self.nsamples / (self.nsamples+tmp)

self.nsamples += tmp
inp = math.sqrt(2 / self.nsamples) * inp.float()
self.H += inp.matmul(inp.t())
self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2 / self.nsamples

def get_wanda_mask(self, sparsity, prunen, prunem):
W_metric = torch.abs(self.layer.weight.data) * torch.sqrt(self.scaler_row.reshape((1,-1)))
W_mask = (torch.zeros_like(W_metric) == 1) ## initialize a mask to be all False
if prunen != 0:
for ii in range(W_metric.shape[1]):
if ii % prunem == 0:
tmp = W_metric[:,ii:(ii+prunem)].float()
W_mask.scatter_(1,ii+torch.topk(tmp, prunen,dim=1, largest=False)[1], True)
else:
sort_res = torch.sort(W_metric, dim=-1, stable=True)
indices = sort_res[1][:,:int(W_metric.shape[1]*sparsity)]
W_mask.scatter_(1, indices, True)

return W_mask

def get_mag_mask(self, sparsity, prunen, prunem):
W = self.layer.weight.data
W_metric = torch.abs(W)
if prunen != 0:
W_mask = (torch.zeros_like(W)==1)
for ii in range(W_metric.shape[1]):
if ii % prunem == 0:
tmp = W_metric[:,ii:(ii+prunem)].float()
W_mask.scatter_(1,ii+torch.topk(tmp, prunen,dim=1, largest=False)[1], True)
else:
thresh = torch.sort(W_metric.flatten().cuda())[0][int(W.numel()*sparsity)].cpu()
W_mask = (W_metric<=thresh)

return W_mask

def fasterprune(
self, args, sparsity, mask=None, prune_n=0, prune_m=0, blocksize=128, percdamp=.01
):
W = self.layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
W = W.t()
W = W.float()

tick = time.time()

H = self.H
del self.H
dead = torch.diag(H) == 0
H[dead, dead] = 1
W[:, dead] = 0

Losses = torch.zeros(self.rows, device=self.dev)

damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(self.columns, device=self.dev)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H

for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
count = i2 - i1

W1 = W[:, i1:i2].clone()
Q1 = torch.zeros_like(W1)
Err1 = torch.zeros_like(W1)
Losses1 = torch.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2]

if prune_n == 0 or mask is not None:
if mask is not None:
mask1 = mask[:, i1:i2]
else:
# tmp = W1 ** 2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2
if "wanda" in args.prune_method:
tmp = torch.abs(W1) * torch.sqrt(self.scaler_row[i1:i2].reshape((1,-1)))
elif "mag" in args.prune_method:
tmp = torch.abs(W1)
thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)]
mask1 = tmp <= thresh
else:
mask1 = torch.zeros_like(W1) == 1

for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]

if prune_n != 0 and i % prune_m == 0 and mask is None:
# tmp = W1[:, i:(i + prune_m)] ** 2 / (torch.diag(Hinv1)[i:(i + prune_m)].reshape((1, -1))) ** 2
if "wanda" in args.prune_method:
tmp = torch.abs(W1[:, i:(i+prune_m)]) * torch.sqrt(self.scaler_row[(i+i1):(i+i1+prune_m)].reshape((1,-1)))
elif "mag" in args.prune_method:
tmp = torch.abs(W1[:, i:(i+prune_m)])
mask1.scatter_(1, i + torch.topk(tmp, prune_n, dim=1, largest=False)[1], True)

q = w.clone()
q[mask1[:, i]] = 0

Q1[:, i] = q
Losses1[:, i] = (w - q) ** 2 / d ** 2

err1 = (w - q) / d
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
Err1[:, i] = err1

W[:, i1:i2] = Q1
Losses += torch.sum(Losses1, 1) / 2

W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])

torch.cuda.synchronize()
if isinstance(self.layer, transformers.Conv1D):
W = W.t()
self.layer.weight.data = W.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)

def free(self):
self.H = None
torch.cuda.empty_cache()
73 changes: 73 additions & 0 deletions src/safetunebed/whitebox/attacks/wanda_pruning/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Code adapted from https://github.com/IST-DASLab/sparsegpt/blob/master/datautils.py

import numpy as np
import random
import torch
from datasets import load_dataset

# Set seed for reproducibility
def set_seed(seed):
np.random.seed(seed)
torch.random.manual_seed(seed)

# Wrapper for tokenized input IDs
class TokenizerWrapper:
def __init__(self, input_ids):
self.input_ids = input_ids

# Load and process wikitext2 dataset
def get_wikitext2(nsamples, seed, seqlen, tokenizer):
# Load train and test datasets
traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')

# Encode datasets
trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt')
testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')

# Generate samples from training set
random.seed(seed)
trainloader = []
for _ in range(nsamples):
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))
return trainloader, testenc

# Load and process c4 dataset
def get_c4(nsamples, seed, seqlen, tokenizer):
# Load train and validation datasets
traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train')
valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')

# Generate samples from training set
random.seed(seed)
trainloader = []
for _ in range(nsamples):
while True:
i = random.randint(0, len(traindata) - 1)
trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
if trainenc.input_ids.shape[1] > seqlen:
break
i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
j = i + seqlen
inp = trainenc.input_ids[:, i:j]
tar = inp.clone()
tar[:, :-1] = -100
trainloader.append((inp, tar))

# Prepare validation dataset
valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
valenc = valenc.input_ids[:, :(256 * seqlen)]
valenc = TokenizerWrapper(valenc)
return trainloader, valenc

# Function to select the appropriate loader based on dataset name
def get_loaders(name, nsamples=128, seed=0, seqlen=2048, tokenizer=None):
if 'wikitext2' in name:
return get_wikitext2(nsamples, seed, seqlen, tokenizer)
if "c4" in name:
return get_c4(nsamples, seed, seqlen, tokenizer)
35 changes: 35 additions & 0 deletions src/safetunebed/whitebox/attacks/wanda_pruning/layerwrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
import torch.nn as nn

# Define WrappedGPT class
class WrappedGPT:
"""
This class wraps a GPT layer for specific operations.
"""

def __init__(self, layer, layer_id=0, layer_name="none"):
self.layer = layer
self.dev = self.layer.weight.device
self.rows = layer.weight.data.shape[0]
self.columns = layer.weight.data.shape[1]

self.scaler_row = torch.zeros((self.columns), device=self.dev)
self.nsamples = 0

self.layer_id = layer_id
self.layer_name = layer_name

def add_batch(self, inp, out):
if len(inp.shape) == 2:
inp = inp.unsqueeze(0)
tmp = inp.shape[0]
if isinstance(self.layer, nn.Linear):
if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.t()

self.scaler_row *= self.nsamples / (self.nsamples+tmp)
self.nsamples += tmp

inp = inp.type(torch.float32)
self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2 / self.nsamples
121 changes: 121 additions & 0 deletions src/safetunebed/whitebox/attacks/wanda_pruning/sparsegpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@

import math
import time

import torch
import torch.nn as nn
import transformers

torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

## SparseGPT: https://github.com/IST-DASLab/sparsegpt/tree/f5c25005a61f96a0933ca2f95705a963585aafaa
class SparseGPT:

def __init__(self, layer):
self.layer = layer
self.dev = self.layer.weight.device
W = layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
W = W.t()
self.rows = W.shape[0]
self.columns = W.shape[1]
self.H = torch.zeros((self.columns, self.columns), device=self.dev)
self.nsamples = 0

def add_batch(self, inp, out):
if len(inp.shape) == 2:
inp = inp.unsqueeze(0)
tmp = inp.shape[0]
if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.t()
self.H *= self.nsamples / (self.nsamples + tmp)
self.nsamples += tmp
inp = math.sqrt(2 / self.nsamples) * inp.float()
self.H += inp.matmul(inp.t())

def fasterprune(
self, sparsity, prune_n=0, prune_m=0, blocksize=128, percdamp=.01
):
W = self.layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
W = W.t()
W = W.float()

tick = time.time()

H = self.H
del self.H
dead = torch.diag(H) == 0
H[dead, dead] = 1
W[:, dead] = 0

Losses = torch.zeros(self.rows, device=self.dev)

damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(self.columns, device=self.dev)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H

mask = None

for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
count = i2 - i1

W1 = W[:, i1:i2].clone()
Q1 = torch.zeros_like(W1)
Err1 = torch.zeros_like(W1)
Losses1 = torch.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2]

if prune_n == 0:
if mask is not None:
mask1 = mask[:, i1:i2]
else:
tmp = W1 ** 2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2
thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)]
mask1 = tmp <= thresh
else:
mask1 = torch.zeros_like(W1) == 1

for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]

if prune_n != 0 and i % prune_m == 0:
tmp = W1[:, i:(i + prune_m)] ** 2 / (torch.diag(Hinv1)[i:(i + prune_m)].reshape((1, -1))) ** 2
mask1.scatter_(1, i + torch.topk(tmp, prune_n, dim=1, largest=False)[1], True)

q = w.clone()
q[mask1[:, i]] = 0

Q1[:, i] = q
Losses1[:, i] = (w - q) ** 2 / d ** 2

err1 = (w - q) / d
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
Err1[:, i] = err1

W[:, i1:i2] = Q1
Losses += torch.sum(Losses1, 1) / 2

W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])

torch.cuda.synchronize()
if isinstance(self.layer, transformers.Conv1D):
W = W.t()
self.layer.weight.data = W.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)

def free(self):
self.H = None
torch.cuda.empty_cache()
Loading