forked from mit-han-lab/smoothquant
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathppl_eval.py
More file actions
85 lines (69 loc) · 2.66 KB
/
ppl_eval.py
File metadata and controls
85 lines (69 loc) · 2.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM
from smoothquant.smooth import smooth_lm
from smoothquant.fake_quant import quantize_model
import tqdm
from datasets import load_dataset
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--alpha", type=float, default=0.5)
parser.add_argument("--model_path", type=str, default="meta-llama/Llama-2-7b-hf")
parser.add_argument(
"--act_scales_path",
type=str,
default="act_scales/llama-2-7b.pt",
)
parser.add_argument("--n_samples", type=int, default=None)
parser.add_argument("--smooth", action="store_true")
parser.add_argument("--quantize", action="store_true")
args = parser.parse_args()
alpha = args.alpha
model_path = args.model_path
act_scales_path = args.act_scales_path
n_samples = args.n_samples
class Evaluator:
def __init__(self, dataset, tokenizer, device, n_samples=40):
self.dataset = dataset
self.tokenizer = tokenizer
self.device = device
self.dataset = tokenizer(
"\n\n".join(dataset["text"]), return_tensors="pt"
).input_ids.to(device)
self.n_samples = n_samples
@torch.no_grad()
def evaluate(self, model):
model.eval()
nlls = []
n_samples = self.n_samples if self.n_samples else self.dataset.size(1) // 2048
for i in tqdm.tqdm(range(n_samples), desc="Evaluating..."):
batch = self.dataset[:, (i * 2048) : ((i + 1) * 2048)].to(model.device)
with torch.no_grad():
lm_logits = model(batch).logits
shift_logits = lm_logits[:, :-1, :].contiguous().float()
shift_labels = self.dataset[:, (i * 2048) : ((i + 1) * 2048)][:, 1:]
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)
neg_log_likelihood = loss.float() * 2048
nlls.append(neg_log_likelihood)
return torch.exp(torch.stack(nlls).sum() / (n_samples * 2048))
tokenizer = AutoTokenizer.from_pretrained(model_path)
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
evaluator = Evaluator(dataset, tokenizer, "cuda", n_samples=n_samples)
model = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype=torch.bfloat16, device_map="auto"
)
if args.smooth:
act_scales = torch.load(act_scales_path)
smooth_lm(model, act_scales, alpha)
if args.quantize:
model = quantize_model(
model,
weight_quant="per_channel",
act_quant="per_token",
quantize_bmm_input=True,
)
ppl = evaluator.evaluate(model)
print(f"Perplexity: {ppl}")