forked from abhijitmishra/Thought2Text
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_eeg_classifier.py
More file actions
83 lines (69 loc) · 2.29 KB
/
test_eeg_classifier.py
File metadata and controls
83 lines (69 loc) · 2.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import random
import torch
import numpy as np
from tqdm import tqdm
from datautils import EEGDataset, Splitter
from channelnet.model import ChannelNetModel
from channelnet.config import EEGModelConfig
from args import get_args_for_encoder_training
from torch.utils.data import DataLoader, Dataset
import evaluate
def set_seed(seed):
"""Set seed for reproducibility"""
# Set seed for Python's built-in random module
random.seed(seed)
# Set seed for numpy
np.random.seed(seed)
# Set seed for PyTorch
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False # disable to ensure reproducibility
def set_gradients(module, requires_grad):
for param in module.parameters():
param.requires_grad = requires_grad
def main():
args = get_args_for_encoder_training()
set_seed(42)
dataset = EEGDataset(args=args)
loaders = {
split: DataLoader(
Splitter(
dataset,
split_path=args.splits_path,
split_num=args.split_num,
split_name=split,
),
batch_size=args.batch_size,
drop_last=True,
shuffle=True,
)
for split in ["train", "val", "test"]
}
test_loader = loaders["test"]
config = EEGModelConfig()
model = ChannelNetModel.from_pretrained(
pretrained_model_name_or_path=args.output, config=config
)
model.to(args.device)
model.eval()
metric = evaluate.load("accuracy")
softmax = torch.nn.Softmax(dim=1)
all_labels = []
all_preds = []
for batch in tqdm(test_loader):
image_raw, eeg_data, labels = batch
image_raw = image_raw.to(args.device)
eeg_data = eeg_data.to(args.device)
labels = labels.to(args.device)
emb_output, cls_output = model(eeg_data)
preds = softmax(cls_output).argmax(dim=1)
for l in labels:
all_labels.append(l.item())
for o in preds:
all_preds.append(o.item())
test_metric = metric.compute(predictions=all_preds, references=all_labels)
print({"acc": test_metric["accuracy"]})
if __name__ == "__main__":
main()