-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTrainingAnomalyDetector_public.py
More file actions
99 lines (83 loc) · 3.76 KB
/
TrainingAnomalyDetector_public.py
File metadata and controls
99 lines (83 loc) · 3.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import argparse
import os
from os import path
import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
from features_loader import FeaturesLoader
from network.TorchUtils import TorchModel
from network.anomaly_detector_model import AnomalyDetector, objective, RegularizedLoss
from utils.callbacks import DefaultModelCallback, TensorBoardCallback
from utils.utils import register_logger, get_torch_device
def get_args():
parser = argparse.ArgumentParser(description="PyTorch Video Classification Parser")
# io
parser.add_argument('--features_path', default='features',
help="path to features")
parser.add_argument('--annotation_path', default="Train_Annotation.txt",
help="path to train annotation")
parser.add_argument('--log_file', type=str, default="log.log",
help="set logging file.")
parser.add_argument('--exps_dir', type=str, default="exps",
help="path to the directory where models and tensorboard would be saved.")
parser.add_argument('--checkpoint', type=str,
help="load a model for resume training")
# optimization
parser.add_argument('--batch_size', type=int, default=60,
help="batch size")
parser.add_argument('--feature_dim', type=int, default=2048,
help="feature dimension")
parser.add_argument('--save_every', type=int, default=1,
help="epochs interval for saving the model checkpoints")
parser.add_argument('--lr_base', type=float, default=0.01,
help="learning rate")
parser.add_argument('--iterations_per_epoch', type=int, default=20000,
help="number of training iterations")
parser.add_argument('--epochs', type=int, default=2,
help="number of training epochs")
parser.add_argument('--k', type=int, default=1,
help="set top-k Ranking Loss parameters")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
# Register directories
register_logger(log_file=args.log_file)
os.makedirs(args.exps_dir, exist_ok=True)
models_dir = path.join(args.exps_dir, 'models')
tb_dir = path.join(args.exps_dir, 'tensorboard')
os.makedirs(models_dir, exist_ok=True)
os.makedirs(tb_dir, exist_ok=True)
# Optimizations
device = get_torch_device()
cudnn.benchmark = True # enable cudnn tune
# Data loader
train_loader = FeaturesLoader(features_path=args.features_path,
feature_dim=args.feature_dim,
annotation_path=args.annotation_path,
iterations=args.iterations_per_epoch)
# Model
if args.checkpoint is not None and path.exists(args.checkpoint):
model = TorchModel.load_model(args.checkpoint)
else:
network = AnomalyDetector(args.feature_dim)
model = TorchModel(network)
model = model.to(device).train()
# Training parameters
"""
In the original paper:
lr = 0.01
epsilon = 1e-8
"""
optimizer = torch.optim.Adadelta(model.parameters(), lr=args.lr_base, eps=1e-8)
criterion = RegularizedLoss(network, objective, args.k).to(device)
# Callbacks
tb_writer = SummaryWriter(log_dir=tb_dir)
model.register_callback(DefaultModelCallback(visualization_dir=args.exps_dir))
model.register_callback(TensorBoardCallback(tb_writer=tb_writer))
# Training
model.fit(train_iter=train_loader,
criterion=criterion,
optimizer=optimizer,
epochs=args.epochs,
network_model_path_base=models_dir,
save_every=args.save_every)