-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathafs.py
More file actions
executable file
·139 lines (131 loc) · 6.07 KB
/
afs.py
File metadata and controls
executable file
·139 lines (131 loc) · 6.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import argparse
import utils.Audit as Audit
import utils.Purification as Forget
import sys
def parser():
"""
:return: args
"""
parser = argparse.ArgumentParser(prog='AFS')
subparsers = parser.add_subparsers(help='sub-command help')
parser_audit = subparsers.add_parser('audit')
parser_audit.add_argument('--root',
default='../template/MNIST',
help='root dir to the project')
parser_audit.add_argument('--query_label',
default='EXP1',
help='label of the query data, defined in Dataset.py/Config')
parser_audit.add_argument('--cal_label',
default='CAL1',
help='label of the calibration data, defined in Dataset.py/Config')
parser_audit.add_argument('--cal_test_label',
default='CALTEST1',
help='label of the calibration test data, defined in Dataset.py/Config')
parser_audit.add_argument('--test_label',
default='TEST1',
help='label of the test data, defined in Dataset.py/Config')
parser_audit.add_argument('--model2audit',
default='./models/base/best_model.pth',
help='relative path of model to be auditted to the root')
parser_audit.add_argument('--model2cal',
default='./models/cal/best_model.pth',
help='relative path of the calibration model to the root')
parser_audit.add_argument('--device',
default='cuda:0')
parser_audit.add_argument('--KP_infer_batch_size',
type=int,
default=1024,
help='batch size for inference during membership attack')
parser_audit.add_argument('--nclass',
type=int,
default=10,
help='number of classes')
parser_audit.add_argument('--num_workers',
type=int,
default=5,
help='number of num_workers')
parser_audit.add_argument('--command_class',
default=0,
type=int,
help='for internal use only, no change')
parser_forget = subparsers.add_parser('forget')
parser_forget.add_argument('--root',
default='../template/MNIST',
help='root dir to the project')
parser_forget.add_argument('--expname',
default='EXP1',
help='name of exp, will affect the path and dataset splitting')
parser_forget.add_argument('--teacher_model',
default='./models/EXP1/base/best_model.pth',
help='relative path of model to be distilled to the root')
parser_forget.add_argument('--KD_label',
default='KD0.25',
help='the name of base dataset used for KD, should be defined in CONFIG')
parser_forget.add_argument('--test_label',
default='TEST1',
help='label of the test data, defined in Dataset.py/Config')
parser_forget.add_argument('--cal_label',
default='CAL1',
help='label of the calibration data, defined in Dataset.py/Config')
parser_forget.add_argument('--cal_test_label',
default='CALTEST1',
help='label of the calibration test data, defined in Dataset.py/Config')
parser_forget.add_argument('--query_label',
default='QO1',
help='label of the query data, defined in Dataset.py/Config, here the query dataset should overlap with training dataset')
parser_forget.add_argument('--add_risk_loss',
type=int,
default=1,
help='1: will add risk loss when running KP, 0: same as pure KD')
parser_forget.add_argument('--nclass',
type=int,
default=10,
help='number of classes')
parser_forget.add_argument('--train_batch_size',
type=int,
default=32)
parser_forget.add_argument('--KP_infer_batch_size',
type=int,
default=128,
help='batch size for inference during membership attack')
parser_forget.add_argument('--device',
default='cuda:0')
parser_forget.add_argument('--epochs',
type=int,
default=20,
help='number of epochs')
parser_forget.add_argument('--T',
type=float,
default=4.0,
help='temperature for ST')
parser_forget.add_argument('--lr',
type=float,
default=0.1,
help='initial learning rate')
parser_forget.add_argument('--lambda_kd',
type=float,
default=1,
help='trade-off parameter for kd loss')
parser_forget.add_argument('--lambda_risk',
type=float,
default=10,
help='trade-off parameter for risk loss')
parser_forget.add_argument('--num_workers',
type=int,
default=5,
help='number of num_workers')
parser_forget.add_argument('--command_class',
default=1,
type=int,
help='for internal use only, no change')
args = parser.parse_args()
return args
def main(args):
if args.command_class == 0:
Audit.one_command_api(args)
elif args.command_class == 1:
Forget.one_command_api(args)
if __name__ == '__main__':
args = parser()
print(args)
main(args)