-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
84 lines (71 loc) · 3.08 KB
/
test.py
File metadata and controls
84 lines (71 loc) · 3.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import torch
from torch.utils.data import DataLoader
from model.dataloader import MVA_Dataset, Normalize0, ShortSideScale
from model.dataloader import build_config
import argparse
from torchvision.transforms import (
CenterCrop,
Compose,
)
from model.MVA import TBCNet, TBCNet_test
def evaluate(label, pred):
num_all=len(label)
num_right=sum(a == b for a, b in zip(label, pred))
acc=num_right/num_all
return acc
def valid(model, device, loader):
model.eval()
y_pred_vector = []
labels_vector = []
for step, (input, y) in enumerate(loader):
input = input.transpose(0, 1).to(device)
y=(y-1).to(device)
with torch.no_grad():
y_p = model(input)
y_pred = torch.argmax(y_p, dim=1)
y_pred = y_pred.detach()
y_pred_vector.extend(y_pred)
labels_vector.extend(y)
print('---------Test---------')
print('Results:')
acc = evaluate(labels_vector, y_pred_vector)
print(f'ACC = {acc*100}%')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='test')
parser.add_argument('--checkpoint', default=None,
type=str, required=False, help='Path to the pre-trained model.')
parser.add_argument("--testtype", dest='testtype', default="CS", choices=["CS", 'CV','CSet']
, required=False, help='Set test type from CS or CV')
parser.add_argument('--dataset', type=str, default='ntu-60', required=False, help='Dataset to use.',
choices=["ntu-120", 'ntu-60',"pkummd", "n-ucla"])
parser.add_argument("--device", dest='device', default=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
, required=False, help='Set CUDA_VISIBLE_DEVICES environment variable, optional')
parser.add_argument('--batch_size', default=16, type=int)
parser.add_argument("--num_workers", default=8)
args = parser.parse_args()
batch_size,num_workers,= args.batch_size,args.num_workers
device=args.device
high_feature_dim, low_feature_dim= 128,256
if args.testtype=='CV':
view_num=2
else:
view_num=3
transform_test = Compose(
[
Normalize0([0.45, 0.45, 0.45], [0.225, 0.225, 0.225]),
ShortSideScale(size=256),
CenterCrop(224)
]
)
cfg = build_config(args.dataset,args.testtype)
data_gen = MVA_Dataset(cfg, 'test', transform=transform_test)
data_loader = DataLoader(data_gen, batch_size=args.batch_size, shuffle = True,
num_workers=args.num_workers,pin_memory=True)
model = TBCNet_test(view_num, low_feature_dim, high_feature_dim,cfg.num_actions,device,"test")
checkpoint0 = torch.load(args.checkpoint)
model.load_state_dict(checkpoint0)
model = model.to(args.device)
print("Dataset:{}".format(args.dataset))
print("Loading models...")
valid(model, args.device, data_loader, view_num)