-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
91 lines (68 loc) · 3.15 KB
/
train.py
File metadata and controls
91 lines (68 loc) · 3.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import numpy as np
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
import tqdm
import os
import argparse
from utils.evaluate import get_acc
from utils.loader import DatasetLoader
from utils.modelutils import load_empty_model, save_target_model, model_manager
from lira.prepare import inference_and_score
def train_model(args, key_dict=None):
args = argparse.Namespace(**args)
train_loader, test_loader, keep_bool = get_loaders(args)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# TODO: more model architectures such as resnet50 and densenet
model = load_empty_model(args.legend)
model.to(DEVICE)
optim = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=args.epochs)
for epoch in range(1, args.epochs + 1):
model.train()
loss_total = 0
pbar = tqdm.tqdm(train_loader,
total=len(train_loader),
desc=f"Epoch {epoch}/{args.epochs}",
position=1,
leave=False)
for itr, (x, y) in enumerate(pbar):
x, y = x.to(DEVICE), y.to(DEVICE)
loss = F.cross_entropy(model(x), y)
loss_total += loss
optim.zero_grad()
loss.backward()
optim.step()
avg_loss = loss_total / (itr + 1)
pbar.set_postfix({"loss": f"{avg_loss:.4f}",})
sched.step()
model.eval()
tqdm.tqdm.write(f"Accuracy: {get_acc(model, test_loader)}%")
save_target_model(model, args, keep_bool, key_dict=key_dict)
train_loader = DatasetLoader(args.seed, args.legend, args.root).load_with_keepfile(True)
if args.n_shadows is not None:
save = os.path.join(os.path.dirname(__file__), f"./experiments/shadows-{args.legend}", str(args.shadow_id), f'{args.shadow_id}.pt')
inference_and_score(args.legend, args.root, save, args.n_queries, train_loader, model)
else:
inference_and_score(args.legend, args.root, model_manager.get_target()[0], args.n_queries, train_loader, model)
return get_acc(model, test_loader)
def get_loaders(args):
trainset = DatasetLoader(args.seed, args.legend, args.root)._get_dataset(True)
size = len(trainset)
np.random.seed(args.seed.get_seed())
if args.n_shadows is not None:
# np.random.seed(0)
keep = np.random.uniform(0, 1, size=(args.n_shadows, size))
order = keep.argsort(0)
keep = order < int(args.pkeep * args.n_shadows)
keep = np.array(keep[args.shadow_id], dtype=bool)
keep = keep.nonzero()[0]
else:
keep = np.random.choice(size, size=int(args.pkeep * size), replace=False)
keep.sort()
keep_bool = np.full((size), False)
keep_bool[keep] = True
trainset = torch.utils.data.Subset(trainset, keep)
train_loader = DataLoader(trainset, batch_size=args.batch_size, num_workers=4, shuffle=True)
test_loader = DatasetLoader(args.seed, args.legend, args.root).load_with_keepfile(False)
return train_loader, test_loader, keep_bool