-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
97 lines (73 loc) · 3.2 KB
/
main.py
File metadata and controls
97 lines (73 loc) · 3.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import argparse
import os
from os import path
from network.inception_resnet_v1 import InceptionResnetV1
from network.fc_layers import Identity
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from network.TorchUtils import TorchModel
from utils.callbacks import DefaultModelCallback, TensorBoardCallback
from utils.utils import register_logger, get_torch_device
from dataloader import get_dataloader
def get_args():
parser = argparse.ArgumentParser(description="PyTorch CIV6 Face Parser")
# io
parser.add_argument('--inputs_path', default='features',
help="path to inputs")
parser.add_argument('--log_file', type=str, default="log.log",
help="set logging file.")
parser.add_argument('--exps_dir', type=str, default="exps",
help="path to the directory where models and tensorboard would be saved.")
parser.add_argument('--checkpoint', type=str,
help="load a model for resume training")
# optimization
# parser.add_argument('--batch_size', type=int, default=60,
# help="batch size")
parser.add_argument('--save_every', type=int, default=1,
help="epochs interval for saving the model checkpoints")
parser.add_argument('--lr_base', type=float, default=0.001,
help="learning rate")
parser.add_argument('--epochs', type=int, default=10,
help="number of training epochs")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
register_logger()
os.makedirs(args.exps_dir, exist_ok=True)
models_dir = path.join(args.exps_dir, 'models')
tb_dir = path.join(args.exps_dir, 'tensorboard')
os.makedirs(models_dir, exist_ok=True)
os.makedirs(tb_dir, exist_ok=True)
device = get_torch_device()
train_loader, test_loader = get_dataloader()
if args.checkpoint is not None and path.exists(args.checkpoint):
model = TorchModel.load_model(args.checkpoint)
else:
model = InceptionResnetV1()
param_dict = torch.load("pretrained/pretrained.pth")
model.load_state_dict(param_dict)
for param in model.parameters():
param.requires_grad = False
model.last_linear = Identity()
# model.last_linear = nn.Linear(512, 28)
model = TorchModel(model)
# print(model)
# for param in model.model.logits.parameters():
# print(param.requires_grad)
# assert False
tb_writer = SummaryWriter(log_dir=tb_dir)
model.register_callback(DefaultModelCallback(visualization_dir=args.exps_dir))
model.register_callback(TensorBoardCallback(tb_writer=tb_writer))
model = model.to(device).train()
# print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr_base)
criterion = torch.nn.CrossEntropyLoss(reduction='mean')
model.fit(train_iter=train_loader,
eval_iter=test_loader,
criterion=criterion,
optimizer=optimizer,
epochs=args.epochs,
network_model_path_base=models_dir,
save_every=args.save_every,
evaluate_every=True)