-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathmain.py
More file actions
128 lines (114 loc) · 5.67 KB
/
main.py
File metadata and controls
128 lines (114 loc) · 5.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os,utils,lm_eval,logging
import quant._get_args as get_args
import quant.ost_model_utils as ost_model_utils,utils.data_utils as data_utils,utils.eval_utils as eval_utils,utils.gptq_utils as gptq_utils
import torch,torch.nn as nn,torch.nn.functional as F,transformers,sys,datetime,datasets,lm_eval,lm_eval.tasks,accelerate
from loguru import logger
from quant.ost_train import rotate_smooth_train
from lm_eval.tasks import TaskManager
from lm_eval.utils import make_table
from lm_eval.models.huggingface import HFLM,eval_logger
eval_logger.level = logging.ERROR
def main(args):
transformers.set_seed(args.seed)
lm = ost_model_utils.LM(args)
lm.model.eval()
test_loader = data_utils.get_loaders(args.eval_dataset,seed=args.seed,model=args.model,seqlen=lm.seqlen,eval_mode=True)
if args.pre_eval:
pre_ppl = eval_utils.evaluator(lm.model,test_loader,utils.DEV,args,),
logger.info(f"Float ppl:{pre_ppl}")
eval_tasks(lm,args)
if args.rotate:
lm.fuse_layer_norms()
lm.generate_rotate_parameters()
if False:
lm.rotate_smooth_model_inplace()
pre_ppl = eval_utils.evaluator(lm.model,test_loader,utils.DEV,args)
logger.info(f"Fuse Layer Norms PPL:{pre_ppl}")
if args.train_rotate:
lm.set_quant_state(use_weight_quant=args.train_enable_wquant,use_act_quant=True,use_fully_quant=args.fully_quant)
lm = rotate_smooth_train(args,lm)
elif args.resume_path is not None:
q = torch.load(args.resume_path)
r=lm.model.load_state_dict(q, strict=False)
logger.info(f"resume from {args.resume_path}")
lm.rotate_smooth_model_inplace()
if False:
dynamic_ppl = dynamic_eval(lm,test_loader,args,use_act_quant=True,use_fully_quant=args.fully_quant,use_weight_quant=args.train_enable_wquant)
logger.info(f"After Dynamic ActQuant PPL:{dynamic_ppl}")
if args.test_static:
eval_tasks(lm,args)
static_ppl = static_eval(lm,test_loader,args)
logger.info(f"After Static ActQuant PPL:{static_ppl}")
if args.w_bits < 16:
lm.set_quant_state(use_act_quant=False,use_fully_quant=False)
if args.force_clip:
args.w_clip = True
save_dict = {}
if args.load_qmodel_path:
assert args.rotate, "Model should be rotated to load a quantized model!"
assert not args.save_qmodel_path, "Cannot save a quantized model if it is already loaded!"
print("Load quantized model from ", args.load_qmodel_path)
save_dict = torch.load(args.load_qmodel_path)
lm.model.load_state_dict(save_dict["model"])
elif args.w_gptq:
trainloader = data_utils.get_loaders(
args.cal_dataset, nsamples=args.nsamples,
seed=args.seed, model=args.model,
seqlen=lm.seqlen, eval_mode=False
)
quantizers = gptq_utils.gptq_fwrd(lm.model, trainloader, utils.DEV, args)
save_dict["w_quantizers"] = quantizers
else:
quantizers = gptq_utils.rtn_fwrd(lm.model, utils.DEV, args)
save_dict["w_quantizers"] = quantizers
if args.save_qmodel_path:
save_dict["model"] = lm.model.state_dict()
torch.save(save_dict, args.save_qmodel_path)
lm.set_quant_state(use_act_quant=True,use_fully_quant=False)
if True:
dynamic_ppl = dynamic_eval(lm,test_loader,args,use_act_quant=True,use_fully_quant=args.fully_quant)
logger.info(f"After Dynamic Act&Weight Quant PPL:{dynamic_ppl}")
if args.test_static:
static_ppl = static_eval(lm,test_loader,args)
logger.info(f"After Static Act&Weight Quant PPL:{static_ppl}")
if not args.lm_eval:
return
else:
eval_tasks(lm,args)
def eval_tasks(lm,args):
utils.cleanup_memory()
if args.distribute:
utils.distribute_model(lm.model)
else:
lm.model.to(utils.DEV)
task_manager = TaskManager()
tasks = task_manager.match_tasks(args.tasks)
hflm = HFLM(pretrained=lm.model, tokenizer=lm.tokenizer,batch_size="auto",max_batch_size=256)
results = lm_eval.simple_evaluate(hflm,tasks=tasks,batch_size="auto",max_batch_size=256)
metric_vals = {task: round(result.get('acc_norm,none', result['acc,none']), 4) for task, result in results['results'].items()}
mean_acc_val = round(sum(metric_vals.values()) / len(metric_vals.values()), 4)
std_vals = {task: round(result.get('acc_norm_stderr,none', result['acc_stderr,none']), 4) for task, result in results['results'].items()}
mean_std_val =round(sum(std_vals.values()) / len(std_vals.values()), 4)
metric_vals['acc_avg'] = mean_acc_val
results['results']['AVERAGE'] = {
"acc,none":mean_acc_val,
"acc_stderr,none":mean_std_val
}
logger.info("\n" + make_table(results))
def dynamic_eval(lm:ost_model_utils.LM,dataloader,args,use_act_quant=True,use_fully_quant=False,use_weight_quant=False):
lm.set_quant_state(use_weight_quant=use_weight_quant,use_act_quant=use_act_quant,use_fully_quant=use_fully_quant)
lm.set_dynamic(True)
ppl = eval_utils.evaluator(lm.model,dataloader,utils.DEV,args)
return ppl
def static_eval(lm:ost_model_utils.LM,dataloader,args):
lm.set_quant_state(False,True,args.fully_quant)
lm.set_dynamic(False)
if not lm.calibrated:
lm.set_observer(True)
eval_utils.evaluator(lm.model,dataloader,args.dev,args,eval_samples=2)
lm.set_observer(False)
ppl = eval_utils.evaluator(lm.model,dataloader,args.dev,args)
return ppl
if __name__ == "__main__":
args = get_args.parse_args()
main(args)