diff --git a/src/safetunebed/whitebox/attacks/wanda_pruning/__init__.py b/src/safetunebed/whitebox/attacks/wanda_pruning/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/safetunebed/whitebox/attacks/wanda_pruning/ablate.py b/src/safetunebed/whitebox/attacks/wanda_pruning/ablate.py new file mode 100644 index 00000000..91f8f159 --- /dev/null +++ b/src/safetunebed/whitebox/attacks/wanda_pruning/ablate.py @@ -0,0 +1,161 @@ +import math +import time + +import torch +import torch.nn as nn +import transformers + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + +class AblateGPT: + + def __init__(self, layer): + self.layer = layer + self.dev = self.layer.weight.device + W = layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + self.rows = W.shape[0] + self.columns = W.shape[1] + self.H = torch.zeros((self.columns, self.columns), device=self.dev) + self.nsamples = 0 + + self.scaler_row = torch.zeros((self.columns), device=self.dev) + + def add_batch(self, inp, out): + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + tmp = inp.shape[0] + if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + self.H *= self.nsamples / (self.nsamples + tmp) + + self.scaler_row *= self.nsamples / (self.nsamples+tmp) + + self.nsamples += tmp + inp = math.sqrt(2 / self.nsamples) * inp.float() + self.H += inp.matmul(inp.t()) + self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2 / self.nsamples + + def get_wanda_mask(self, sparsity, prunen, prunem): + W_metric = torch.abs(self.layer.weight.data) * torch.sqrt(self.scaler_row.reshape((1,-1))) + W_mask = (torch.zeros_like(W_metric) == 1) ## initialize a mask to be all False + if prunen != 0: + for ii in range(W_metric.shape[1]): + if ii % prunem == 0: + tmp = W_metric[:,ii:(ii+prunem)].float() + W_mask.scatter_(1,ii+torch.topk(tmp, prunen,dim=1, largest=False)[1], True) + else: + sort_res = torch.sort(W_metric, dim=-1, stable=True) + indices = sort_res[1][:,:int(W_metric.shape[1]*sparsity)] + W_mask.scatter_(1, indices, True) + + return W_mask + + def get_mag_mask(self, sparsity, prunen, prunem): + W = self.layer.weight.data + W_metric = torch.abs(W) + if prunen != 0: + W_mask = (torch.zeros_like(W)==1) + for ii in range(W_metric.shape[1]): + if ii % prunem == 0: + tmp = W_metric[:,ii:(ii+prunem)].float() + W_mask.scatter_(1,ii+torch.topk(tmp, prunen,dim=1, largest=False)[1], True) + else: + thresh = torch.sort(W_metric.flatten().cuda())[0][int(W.numel()*sparsity)].cpu() + W_mask = (W_metric<=thresh) + + return W_mask + + def fasterprune( + self, args, sparsity, mask=None, prune_n=0, prune_m=0, blocksize=128, percdamp=.01 + ): + W = self.layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + W = W.float() + + tick = time.time() + + H = self.H + del self.H + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + Losses = torch.zeros(self.rows, device=self.dev) + + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(self.columns, device=self.dev) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + Hinv = H + + for i1 in range(0, self.columns, blocksize): + i2 = min(i1 + blocksize, self.columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + if prune_n == 0 or mask is not None: + if mask is not None: + mask1 = mask[:, i1:i2] + else: + # tmp = W1 ** 2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2 + if "wanda" in args.prune_method: + tmp = torch.abs(W1) * torch.sqrt(self.scaler_row[i1:i2].reshape((1,-1))) + elif "mag" in args.prune_method: + tmp = torch.abs(W1) + thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)] + mask1 = tmp <= thresh + else: + mask1 = torch.zeros_like(W1) == 1 + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + + if prune_n != 0 and i % prune_m == 0 and mask is None: + # tmp = W1[:, i:(i + prune_m)] ** 2 / (torch.diag(Hinv1)[i:(i + prune_m)].reshape((1, -1))) ** 2 + if "wanda" in args.prune_method: + tmp = torch.abs(W1[:, i:(i+prune_m)]) * torch.sqrt(self.scaler_row[(i+i1):(i+i1+prune_m)].reshape((1,-1))) + elif "mag" in args.prune_method: + tmp = torch.abs(W1[:, i:(i+prune_m)]) + mask1.scatter_(1, i + torch.topk(tmp, prune_n, dim=1, largest=False)[1], True) + + q = w.clone() + q[mask1[:, i]] = 0 + + Q1[:, i] = q + Losses1[:, i] = (w - q) ** 2 / d ** 2 + + err1 = (w - q) / d + W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + Err1[:, i] = err1 + + W[:, i1:i2] = Q1 + Losses += torch.sum(Losses1, 1) / 2 + + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + + torch.cuda.synchronize() + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + self.layer.weight.data = W.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) + + def free(self): + self.H = None + torch.cuda.empty_cache() \ No newline at end of file diff --git a/src/safetunebed/whitebox/attacks/wanda_pruning/data.py b/src/safetunebed/whitebox/attacks/wanda_pruning/data.py new file mode 100644 index 00000000..b6842c40 --- /dev/null +++ b/src/safetunebed/whitebox/attacks/wanda_pruning/data.py @@ -0,0 +1,73 @@ +# Code adapted from https://github.com/IST-DASLab/sparsegpt/blob/master/datautils.py + +import numpy as np +import random +import torch +from datasets import load_dataset + +# Set seed for reproducibility +def set_seed(seed): + np.random.seed(seed) + torch.random.manual_seed(seed) + +# Wrapper for tokenized input IDs +class TokenizerWrapper: + def __init__(self, input_ids): + self.input_ids = input_ids + +# Load and process wikitext2 dataset +def get_wikitext2(nsamples, seed, seqlen, tokenizer): + # Load train and test datasets + traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train') + testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test') + + # Encode datasets + trainenc = tokenizer(" ".join(traindata['text']), return_tensors='pt') + testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt') + + # Generate samples from training set + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + +# Load and process c4 dataset +def get_c4(nsamples, seed, seqlen, tokenizer): + # Load train and validation datasets + traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train') + valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation') + + # Generate samples from training set + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + while True: + i = random.randint(0, len(traindata) - 1) + trainenc = tokenizer(traindata[i]['text'], return_tensors='pt') + if trainenc.input_ids.shape[1] > seqlen: + break + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + + # Prepare validation dataset + valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt') + valenc = valenc.input_ids[:, :(256 * seqlen)] + valenc = TokenizerWrapper(valenc) + return trainloader, valenc + +# Function to select the appropriate loader based on dataset name +def get_loaders(name, nsamples=128, seed=0, seqlen=2048, tokenizer=None): + if 'wikitext2' in name: + return get_wikitext2(nsamples, seed, seqlen, tokenizer) + if "c4" in name: + return get_c4(nsamples, seed, seqlen, tokenizer) \ No newline at end of file diff --git a/src/safetunebed/whitebox/attacks/wanda_pruning/layerwrapper.py b/src/safetunebed/whitebox/attacks/wanda_pruning/layerwrapper.py new file mode 100644 index 00000000..1821e8f9 --- /dev/null +++ b/src/safetunebed/whitebox/attacks/wanda_pruning/layerwrapper.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn + +# Define WrappedGPT class +class WrappedGPT: + """ + This class wraps a GPT layer for specific operations. + """ + + def __init__(self, layer, layer_id=0, layer_name="none"): + self.layer = layer + self.dev = self.layer.weight.device + self.rows = layer.weight.data.shape[0] + self.columns = layer.weight.data.shape[1] + + self.scaler_row = torch.zeros((self.columns), device=self.dev) + self.nsamples = 0 + + self.layer_id = layer_id + self.layer_name = layer_name + + def add_batch(self, inp, out): + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + tmp = inp.shape[0] + if isinstance(self.layer, nn.Linear): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + + self.scaler_row *= self.nsamples / (self.nsamples+tmp) + self.nsamples += tmp + + inp = inp.type(torch.float32) + self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2 / self.nsamples \ No newline at end of file diff --git a/src/safetunebed/whitebox/attacks/wanda_pruning/sparsegpt.py b/src/safetunebed/whitebox/attacks/wanda_pruning/sparsegpt.py new file mode 100644 index 00000000..7d16321a --- /dev/null +++ b/src/safetunebed/whitebox/attacks/wanda_pruning/sparsegpt.py @@ -0,0 +1,121 @@ + +import math +import time + +import torch +import torch.nn as nn +import transformers + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + +## SparseGPT: https://github.com/IST-DASLab/sparsegpt/tree/f5c25005a61f96a0933ca2f95705a963585aafaa +class SparseGPT: + + def __init__(self, layer): + self.layer = layer + self.dev = self.layer.weight.device + W = layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + self.rows = W.shape[0] + self.columns = W.shape[1] + self.H = torch.zeros((self.columns, self.columns), device=self.dev) + self.nsamples = 0 + + def add_batch(self, inp, out): + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + tmp = inp.shape[0] + if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + self.H *= self.nsamples / (self.nsamples + tmp) + self.nsamples += tmp + inp = math.sqrt(2 / self.nsamples) * inp.float() + self.H += inp.matmul(inp.t()) + + def fasterprune( + self, sparsity, prune_n=0, prune_m=0, blocksize=128, percdamp=.01 + ): + W = self.layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + W = W.float() + + tick = time.time() + + H = self.H + del self.H + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + Losses = torch.zeros(self.rows, device=self.dev) + + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(self.columns, device=self.dev) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + Hinv = H + + mask = None + + for i1 in range(0, self.columns, blocksize): + i2 = min(i1 + blocksize, self.columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + if prune_n == 0: + if mask is not None: + mask1 = mask[:, i1:i2] + else: + tmp = W1 ** 2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2 + thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)] + mask1 = tmp <= thresh + else: + mask1 = torch.zeros_like(W1) == 1 + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + + if prune_n != 0 and i % prune_m == 0: + tmp = W1[:, i:(i + prune_m)] ** 2 / (torch.diag(Hinv1)[i:(i + prune_m)].reshape((1, -1))) ** 2 + mask1.scatter_(1, i + torch.topk(tmp, prune_n, dim=1, largest=False)[1], True) + + q = w.clone() + q[mask1[:, i]] = 0 + + Q1[:, i] = q + Losses1[:, i] = (w - q) ** 2 / d ** 2 + + err1 = (w - q) / d + W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + Err1[:, i] = err1 + + W[:, i1:i2] = Q1 + Losses += torch.sum(Losses1, 1) / 2 + + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + + torch.cuda.synchronize() + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + self.layer.weight.data = W.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype) + + def free(self): + self.H = None + torch.cuda.empty_cache() diff --git a/src/safetunebed/whitebox/attacks/wanda_pruning/utils.py b/src/safetunebed/whitebox/attacks/wanda_pruning/utils.py new file mode 100644 index 00000000..fc92a135 --- /dev/null +++ b/src/safetunebed/whitebox/attacks/wanda_pruning/utils.py @@ -0,0 +1,398 @@ +import time +import heapq +import torch +import torch.nn as nn +from safetunebed.whitebox.attacks.wanda_pruning.sparsegpt import SparseGPT +from safetunebed.whitebox.attacks.wanda_pruning.layerwrapper import WrappedGPT +from safetunebed.whitebox.attacks.wanda_pruning.data import get_loaders +from safetunebed.whitebox.attacks.wanda_pruning.ablate import AblateGPT + + +def find_layers(module, layers=[nn.Linear], name=''): + """ + Recursively find the layers of a certain type in a module. + + Args: + module (nn.Module): PyTorch module. + layers (list): List of layer types to find. + name (str): Name of the module. + + Returns: + dict: Dictionary of layers of the given type(s) within the module. + """ + if type(module) in layers: + return {name: module} + res = {} + for name1, child in module.named_children(): + res.update(find_layers( + child, layers=layers, name=name + '.' + name1 if name != '' else name1 + )) + return res + +def check_sparsity(model): + use_cache = model.config.use_cache + model.config.use_cache = False + + layers = model.model.layers + count = 0 + total_params = 0 + for i in range(len(layers)): + layer = layers[i] + subset = find_layers(layer) + + sub_count = 0 + sub_params = 0 + for name in subset: + W = subset[name].weight.data + count += (W==0).sum().item() + total_params += W.numel() + + sub_count += (W==0).sum().item() + sub_params += W.numel() + + print(f"layer {i} sparsity {float(sub_count)/sub_params:.6f}") + + model.config.use_cache = use_cache + return float(count)/total_params + +def prepare_calibration_input(model, dataloader, device): + use_cache = model.config.use_cache + model.config.use_cache = False + layers = model.model.layers + + # dev = model.hf_device_map["model.embed_tokens"] + if "model.embed_tokens" in model.hf_device_map: + device = model.hf_device_map["model.embed_tokens"] + + dtype = next(iter(model.parameters())).dtype + inps = torch.zeros((128, model.seqlen, model.config.hidden_size), dtype=dtype, device=device) + inps.requires_grad = False + cache = {'i': 0, 'attention_mask': None, "position_ids": None} + + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + def forward(self, inp, **kwargs): + inps[cache['i']] = inp + cache['i'] += 1 + cache['attention_mask'] = kwargs['attention_mask'] + cache['position_ids'] = kwargs['position_ids'] + raise ValueError + layers[0] = Catcher(layers[0]) + for batch in dataloader: + try: + model(batch[0].to(device)) + except ValueError: + pass + layers[0] = layers[0].module + + outs = torch.zeros_like(inps) + attention_mask = cache['attention_mask'] + position_ids = cache['position_ids'] + model.config.use_cache = use_cache + + return inps, outs, attention_mask, position_ids + +def return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before): + thres_cumsum = sum_before * alpha + sort_mask = tmp_metric <= thres_cumsum.reshape((-1,1)) + thres = torch.gather(sort_res[0], dim=1, index=sort_mask.sum(dim=1, keepdims=True)-1) + W_mask = (W_metric <= thres) + cur_sparsity = (W_mask==True).sum() / W_mask.numel() + return W_mask, cur_sparsity + +def prune_magnitude(args, model, tokenizer, device=torch.device("cuda:0"), prune_n=0, prune_m=0): + layers = model.model.layers + + for i in range(len(layers)): + layer = layers[i] + subset = find_layers(layer) + + for name in subset: + W = subset[name].weight.data + W_metric = torch.abs(W) + if prune_n != 0: + W_mask = (torch.zeros_like(W)==1) + for ii in range(W_metric.shape[1]): + if ii % prune_m == 0: + tmp = W_metric[:,ii:(ii+prune_m)].float() + W_mask.scatter_(1,ii+torch.topk(tmp, prune_n,dim=1, largest=False)[1], True) + else: + thresh = torch.sort(W_metric.flatten().cuda())[0][int(W.numel()*args.sparsity_ratio)].cpu() + W_mask = (W_metric<=thresh) + + W[W_mask] = 0 + +def prune_wanda(args, model, tokenizer, device=torch.device("cuda:0"), prune_n=0, prune_m=0): + use_cache = model.config.use_cache + model.config.use_cache = False + + print("loading calibration data") + dataloader, _ = get_loaders("c4",nsamples=args.nsamples,seed=args.seed,seqlen=model.seqlen,tokenizer=tokenizer) + print("dataset loading complete") + with torch.no_grad(): + inps, outs, attention_mask, position_ids = prepare_calibration_input(model, dataloader, device) + + layers = model.model.layers + for i in range(len(layers)): + layer = layers[i] + subset = find_layers(layer) + + if f"model.layers.{i}" in model.hf_device_map: ## handle the case for llama-30B and llama-65B, when the device map has multiple GPUs; + dev = model.hf_device_map[f"model.layers.{i}"] + inps, outs, attention_mask, position_ids = inps.to(dev), outs.to(dev), attention_mask.to(dev), position_ids.to(dev) + + wrapped_layers = {} + for name in subset: + wrapped_layers[name] = WrappedGPT(subset[name]) + + def add_batch(name): + def tmp(_, inp, out): + wrapped_layers[name].add_batch(inp[0].data, out.data) + return tmp + + handles = [] + for name in wrapped_layers: + handles.append(subset[name].register_forward_hook(add_batch(name))) + for j in range(args.nsamples): + with torch.no_grad(): + outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] + for h in handles: + h.remove() + + for name in subset: + print(f"pruning layer {i} name {name}") + W_metric = torch.abs(subset[name].weight.data) * torch.sqrt(wrapped_layers[name].scaler_row.reshape((1,-1))) + + W_mask = (torch.zeros_like(W_metric) == 1) ## initialize a mask to be all False + if prune_n != 0: + # structured n:m sparsity + for ii in range(W_metric.shape[1]): + if ii % prune_m == 0: + tmp = W_metric[:,ii:(ii+prune_m)].float() + W_mask.scatter_(1,ii+torch.topk(tmp, prune_n,dim=1, largest=False)[1], True) + else: + sort_res = torch.sort(W_metric, dim=-1, stable=True) + + if args.use_variant: + # wanda variant + tmp_metric = torch.cumsum(sort_res[0], dim=1) + sum_before = W_metric.sum(dim=1) + + alpha = 0.4 + alpha_hist = [0., 0.8] + W_mask, cur_sparsity = return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before) + while (torch.abs(cur_sparsity - args.sparsity_ratio)>0.001) and (alpha_hist[1]-alpha_hist[0]>=0.001): + if cur_sparsity > args.sparsity_ratio: + alpha_new = (alpha + alpha_hist[0]) / 2.0 + alpha_hist[1] = alpha + else: + alpha_new = (alpha + alpha_hist[1]) / 2.0 + alpha_hist[0] = alpha + + alpha = alpha_new + W_mask, cur_sparsity = return_given_alpha(alpha, sort_res, W_metric, tmp_metric, sum_before) + print(f"alpha found {alpha} sparsity {cur_sparsity:.6f}") + else: + # unstructured pruning + indices = sort_res[1][:,:int(W_metric.shape[1]*args.sparsity_ratio)] + W_mask.scatter_(1, indices, True) + + subset[name].weight.data[W_mask] = 0 ## set weights to zero + + for j in range(args.nsamples): + with torch.no_grad(): + outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] + inps, outs = outs, inps + + model.config.use_cache = use_cache + torch.cuda.empty_cache() + + +@torch.no_grad() +def prune_sparsegpt(args, model, tokenizer, dev, prune_n=0, prune_m=0): + ## SparseGPT code available at: https://github.com/IST-DASLab/sparsegpt/tree/f5c25005a61f96a0933ca2f95705a963585aafaa + print('Starting ...') + dataloader, _ = get_loaders("c4",nsamples=args.nsamples,seed=args.seed,seqlen=model.seqlen,tokenizer=tokenizer) + + use_cache = model.config.use_cache + model.config.use_cache = False + layers = model.model.layers + + if "model.embed_tokens" in model.hf_device_map: + dev = model.hf_device_map["model.embed_tokens"] + + dtype = next(iter(model.parameters())).dtype + inps = torch.zeros( + (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev + ) + cache = {'i': 0, 'attention_mask': None, "position_ids": None} + + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + def forward(self, inp, **kwargs): + inps[cache['i']] = inp + cache['i'] += 1 + cache['attention_mask'] = kwargs['attention_mask'] + cache['position_ids'] = kwargs['position_ids'] + raise ValueError + layers[0] = Catcher(layers[0]) + for batch in dataloader: + try: + model(batch[0].to(dev)) + except ValueError: + pass + layers[0] = layers[0].module + torch.cuda.empty_cache() + + outs = torch.zeros_like(inps) + attention_mask = cache['attention_mask'] + position_ids = cache['position_ids'] + + print('Ready.') + + for i in range(len(layers)): + layer = layers[i] + if f"model.layers.{i}" in model.hf_device_map: + dev = model.hf_device_map[f"model.layers.{i}"] + print(f"layer {i} device {dev}") + inps, outs, attention_mask, position_ids = inps.to(dev), outs.to(dev), attention_mask.to(dev), position_ids.to(dev) + + subset = find_layers(layer) + + gpts = {} + for name in subset: + gpts[name] = SparseGPT(subset[name]) + + def add_batch(name): + def tmp(_, inp, out): + gpts[name].add_batch(inp[0].data, out.data) + return tmp + + handles = [] + for name in gpts: + handles.append(subset[name].register_forward_hook(add_batch(name))) + + for j in range(args.nsamples): + outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] + for h in handles: + h.remove() + + for name in gpts: + print(i, name) + print('Pruning ...') + + gpts[name].fasterprune(args.sparsity_ratio, prune_n=prune_n, prune_m=prune_m, percdamp=0.01, blocksize=128) + gpts[name].free() + + for j in range(args.nsamples): + outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] + + layers[i] = layer + torch.cuda.empty_cache() + + inps, outs = outs, inps + + model.config.use_cache = use_cache + torch.cuda.empty_cache() + + + +@torch.no_grad() +def prune_ablate(args, model, tokenizer, dev, prune_n=0, prune_m=0): + ## SparseGPT code available at: https://github.com/IST-DASLab/sparsegpt/tree/f5c25005a61f96a0933ca2f95705a963585aafaa + print('Starting ...') + dataloader, _ = get_loaders("c4",nsamples=args.nsamples,seed=args.seed,seqlen=model.seqlen,tokenizer=tokenizer) + + use_cache = model.config.use_cache + model.config.use_cache = False + layers = model.model.layers + + if "model.embed_tokens" in model.hf_device_map: + dev = model.hf_device_map["model.embed_tokens"] + + dtype = next(iter(model.parameters())).dtype + inps = torch.zeros( + (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev + ) + cache = {'i': 0, 'attention_mask': None, "position_ids": None} + + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + def forward(self, inp, **kwargs): + inps[cache['i']] = inp + cache['i'] += 1 + cache['attention_mask'] = kwargs['attention_mask'] + cache['position_ids'] = kwargs['position_ids'] + raise ValueError + layers[0] = Catcher(layers[0]) + for batch in dataloader: + try: + model(batch[0].to(dev)) + except ValueError: + pass + layers[0] = layers[0].module + torch.cuda.empty_cache() + + outs = torch.zeros_like(inps) + attention_mask = cache['attention_mask'] + position_ids = cache['position_ids'] + + print('Ready.') + + for i in range(len(layers)): + layer = layers[i] + if f"model.layers.{i}" in model.hf_device_map: + dev = model.hf_device_map[f"model.layers.{i}"] + print(f"layer {i} device {dev}") + inps, outs, attention_mask, position_ids = inps.to(dev), outs.to(dev), attention_mask.to(dev), position_ids.to(dev) + + subset = find_layers(layer) + + gpts = {} + for name in subset: + gpts[name] = AblateGPT(subset[name]) + + def add_batch(name): + def tmp(_, inp, out): + gpts[name].add_batch(inp[0].data, out.data) + return tmp + + handles = [] + for name in gpts: + handles.append(subset[name].register_forward_hook(add_batch(name))) + + for j in range(args.nsamples): + outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] + for h in handles: + h.remove() + + for name in gpts: + print(i, name) + print('Pruning ...') + + if args.prune_method == "ablate_wanda_seq": + prune_mask = gpts[name].get_wanda_mask(args.sparsity_ratio, prune_n, prune_m) + elif args.prune_method == "ablate_mag_seq": + prune_mask = gpts[name].get_mag_mask(args.sparsity_ratio, prune_n, prune_m) + elif "iter" in args.prune_method: + prune_mask = None + + gpts[name].fasterprune(args, args.sparsity_ratio, mask=prune_mask, prune_n=prune_n, prune_m=prune_m, percdamp=0.01, blocksize=128) + gpts[name].free() + + for j in range(args.nsamples): + outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask, position_ids=position_ids)[0] + + layers[i] = layer + torch.cuda.empty_cache() + + inps, outs = outs, inps + + model.config.use_cache = use_cache + torch.cuda.empty_cache() \ No newline at end of file diff --git a/src/safetunebed/whitebox/attacks/wanda_pruning/wanda_pruning.py b/src/safetunebed/whitebox/attacks/wanda_pruning/wanda_pruning.py new file mode 100644 index 00000000..feef80de --- /dev/null +++ b/src/safetunebed/whitebox/attacks/wanda_pruning/wanda_pruning.py @@ -0,0 +1,57 @@ +"""WandA pruning-based tampering attack.""" + +from pathlib import Path +import shutil +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from safetunebed.whitebox.attacks.base import TamperAttack, TamperAttackConfig +from safetunebed.whitebox.utils.names import EvalName + +from safetunebed.whitebox.attacks.wanda_pruning.utils import prune_wanda +from safetunebed.whitebox.evals.strong_reject.strong_reject import ( + StrongRejectEvaluation, + StrongRejectEvaluationConfig, +) + +class WandaPruningAttack(TamperAttack[TamperAttackConfig]): + """Implements weight-space tampering via WandA pruning.""" + + def run_attack(self) -> None: + cfg = self.attack_config + + print(f"[WandA] Loading model from: {cfg.base_input_checkpoint_path}") + model = AutoModelForCausalLM.from_pretrained(cfg.base_input_checkpoint_path, torch_dtype=torch.float16) + tokenizer = AutoTokenizer.from_pretrained(cfg.base_input_checkpoint_path) + + model.eval() + model.cuda() + + # Handling n:m sparsity (Will add to cfg) + prune_n, prune_m = 4, 8 + print(f"[WandA] Applying pruning...") + prune_wanda(cfg, model, tokenizer, device="cuda",prune_n=prune_n, prune_m=prune_m ) + + print(f"[WandA] Saving pruned model to {cfg.base_output_checkpoint_path}") + model.save_pretrained(cfg.base_output_checkpoint_path) + tokenizer.save_pretrained(cfg.base_output_checkpoint_path) + + if cfg.lora_input_folder and cfg.lora_output_folder: + shutil.copytree(cfg.lora_input_folder, cfg.lora_output_folder, dirs_exist_ok=True) + + print("[WandA] Attack complete.") + + def evaluate_strong_reject(self) -> dict[str, float]: + """Evaluate attack on the `StrongRejectEvaluation` evaluator.""" + eval_cfg: StrongRejectEvaluationConfig = StrongRejectEvaluationConfig( + base_checkpoint=self.attack_config.base_output_checkpoint_path, + lora_folder=self.attack_config.lora_output_folder, + max_generation_length=self.attack_config.max_generation_length, + batch_size=64, + small=True, + ) + evaluator: StrongRejectEvaluation[StrongRejectEvaluationConfig] = ( + StrongRejectEvaluation(eval_config=eval_cfg) + ) + + return evaluator.run_evaluation() \ No newline at end of file