Skip to content

Commit 898f0d5

Browse files
authored
Add support for SEED-IV dataset (#28)
1 parent 098c23a commit 898f0d5

File tree

5 files changed

+347
-0
lines changed

5 files changed

+347
-0
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
data:
2+
type: SEED_IV
3+
window_size: 1
4+
# the subject index of the dataset
5+
subject_index: 1
6+
# the upper directory of the dataset
7+
prefix: ./data
8+
9+
network:
10+
type: MMM_Encoder
11+
depth: 6
12+
num_heads: 8
13+
encoder_dim: 64
14+
channel_num: 79
15+
in_chans: 5
16+
pe_type: 2d
17+
18+
decoder_network: # used only during pre-training. Can be omitted if only finetuning.
19+
type: MMM_Encoder
20+
depth: 6
21+
encoder_dim: 64
22+
channel_num: 79
23+
in_chans: 16
24+
25+
model:
26+
type: MMM_Finetune
27+
task: multiclassification
28+
# set up pre-trained model path, leave blank for training from scratch
29+
# E.g.
30+
# uncomment the following line to use the pre-trained model
31+
# model_path: /path/to/tuh_pretrained_encoder_base.pt
32+
optimizer: Adam
33+
lr: 0.00005
34+
weight_decay: 0.005
35+
loss_fn: cross_entropy
36+
metrics: [accuracy]
37+
observe: accuracy
38+
lower_is_better: False
39+
max_epochs: 100
40+
early_stop: 70
41+
batch_size: 32
42+
out_size: 4
43+
mask_ratio: 0.
44+
45+
runtime:
46+
seed: 51
47+
use_cuda: true
48+
output_dir: outputs/MMM_SEED_IV/1/

physiopro/dataset/SEED_IV.py

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
from typing import Optional
2+
3+
import torch
4+
from torch.utils.data import Dataset
5+
import numpy as np
6+
import scipy.io as sio
7+
from scipy.stats import rankdata
8+
9+
from .base import DATASETS
10+
11+
12+
@DATASETS.register_module()
13+
class SEED_IV(Dataset):
14+
def __init__(
15+
self,
16+
prefix: str = "./data/",
17+
name: str = "DE",
18+
window_size: int = 1,
19+
addtime: bool = False,
20+
subject_index: Optional[int] = -1,
21+
dataset_name: Optional[str] = 'train',
22+
channel: int = 62,
23+
local: bool = False,
24+
normalize: str = 'gaussian',
25+
):
26+
super().__init__()
27+
self.window_size = window_size
28+
self.addtime = addtime
29+
self.dataset_name = dataset_name
30+
self.channel = channel
31+
self.local = local
32+
self.out_size = 4
33+
file = prefix + "/SEED-IV/" + name + "/"
34+
data_file_path = file + "DE_{}.mat"
35+
de_label={}
36+
for i in range(3):
37+
de_label[i] = np.array(sio.loadmat(file + f"DE_{i+1}_labels.mat")['de_labels']).squeeze(0)
38+
self.candidate_list = (
39+
[subject_index] if subject_index != -1 else list(range(45))
40+
)
41+
self.label = [de_label[i%3] for i in self.candidate_list]
42+
self.data =[
43+
np.array(sio.loadmat(data_file_path.format(i + 1))["DE_feature"]).transpose(1,0,2)
44+
for i in self.candidate_list
45+
]
46+
self._normalize(normalize)
47+
self._split(self.dataset_name)
48+
49+
if addtime:
50+
self._addtimewindow(window_size)
51+
# N,T,C,F -> N,C,T,F
52+
self.data = self.data.transpose(0,2,1,3)
53+
else:
54+
self.data = np.concatenate(self.data, axis=0)
55+
self.label = np.concatenate(self.label, axis=0)
56+
idx = [0,1,2,3,4,5,6,7,8,9,
57+
10,17,18,19,11,12,13,
58+
14,15,16,20,21,22,23,
59+
24,25,26,27,28,29,30,
60+
31,32,33,34,35,36,37,
61+
44,45,46,38,39,40,41,
62+
42,43,47,48,49,50,51,
63+
57,52,53,54,58,59,60,
64+
55,56,61]
65+
idx=torch.tensor(idx)
66+
self.data = torch.tensor(self.data)
67+
self.data = torch.index_select(self.data,dim=1,index=idx)
68+
self.data = self.data.numpy()
69+
70+
self.get_coordination()
71+
72+
73+
def _normalize(self,method='minmax'):
74+
train_size = [610, 558, 567]
75+
# min-max normalization
76+
if method == 'minmax':
77+
for i, candidate in enumerate(self.candidate_list):
78+
for j in range(5):
79+
# 0~train_size[candidate%3] is the training set, train_size[candidate%3]:: is the valid set
80+
minn = np.min(self.data[i][ :train_size[candidate%3], :, j])
81+
maxx = np.max(self.data[i][ :train_size[candidate%3], :, j])
82+
self.data[i][:,:,j] = (self.data[i][:,:,j] - minn) / (maxx-minn)
83+
84+
# gaussian standardization
85+
if method == 'gaussian':
86+
for i, candidate in enumerate(self.candidate_list):
87+
for j in range(5):
88+
# 0~train_size[candidate%3] is the training set, train_size[candidate%3]:: is the valid set
89+
mean = np.mean(self.data[i][ :train_size[candidate%3], :, j])
90+
std = np.std(self.data[i][ :train_size[candidate%3], :, j])
91+
self.data[i][:, :, j] = (self.data[i][:, :, j] - mean) / std
92+
93+
def _addtimewindow(self, window):
94+
S = len(self.data)
95+
data_results = []
96+
label_results = []
97+
for i in range(S):
98+
# padding from the last sample, to make sure the sample number is the same after addtimewindow operation
99+
data = self.data[i]
100+
label = self.label[i]
101+
print(data.shape)
102+
print(label.shape)
103+
N, C, F = data.shape
104+
data = np.concatenate([data, data[-(window):, :, :]], 0)
105+
label = np.concatenate([label, label[-(window):]], 0)
106+
data_res = np.zeros(shape=(N, window, C, F))
107+
label_res = np.zeros(shape=N)
108+
for j in range(N):
109+
# met the corner case
110+
if (
111+
label[j] == label[j + window - 1]
112+
and label[j] != label[j + window]
113+
):
114+
nearest = j + window
115+
#
116+
elif label[j] == label[j + window - 1]:
117+
nearest = -1
118+
if nearest != -1:
119+
data_res[j, :, :, :] = np.concatenate(
120+
[
121+
data[j:nearest, :, :],
122+
np.zeros(shape=(window - nearest + j, C, F)),
123+
],
124+
0,
125+
)
126+
else:
127+
data_res[j, :, :, :] = data[j : j + window, :, :]
128+
label_res[j] = label[j]
129+
data_results.append(data_res)
130+
label_results.extend(label_res)
131+
132+
self.data = np.concatenate(data_results, 0)
133+
self.label = np.array(label_results)
134+
135+
def _split(self, dataset_name):
136+
train_size = [610, 558, 567]
137+
138+
if dataset_name == "train":
139+
for idx, candidate in enumerate(self.candidate_list):
140+
print(self.data[idx].shape)
141+
self.data[idx] = self.data[idx][:train_size[candidate%3]]
142+
self.label[idx] = self.label[idx][:train_size[candidate%3]]
143+
elif dataset_name == "valid":
144+
for idx, candidate in enumerate(self.candidate_list):
145+
self.data[idx] = self.data[idx][train_size[candidate%3]:]
146+
self.label[idx] = self.label[idx][train_size[candidate%3]:]
147+
else:
148+
raise ValueError("dataset_name should be train or valid")
149+
150+
def get_coordination(self):
151+
func_areas = [[0,1,2,3,4],[5,6,7],[8,9,10,11,12,13],[14,15,16],[17,18,19],[20,21,22],[23,24,25],
152+
[26,27,28],[29,30,31],[32,33,34],[35,36,37,38,39,40],[41,42,43],[44,45,46],[47,48,49],
153+
[50,51,52],[53,54,55,56,57,58],[59,60,61]]
154+
coordination = np.array([[
155+
-27, 0, 27, -36, 36, -71, -64, -48, -25, 0, 25,
156+
-33, 0, 33, 48, 64, 71, -83, -78, -59, 59, 78, 83,
157+
-87, -82, -63, -34, 0, 34, 63, 82, 87, -83, -78,
158+
-59, -33, 0, 33, -25, 0, 25, 59, 78, 83, -71,
159+
-64, -48, 48, 64, 71, -51, -40, -36, -36, 0,
160+
36, -27, 0, 27, 40, 51, 36
161+
],
162+
[
163+
83, 87, 83, 76, 76, 51, 55, 59, 62, 63, 62, 33, 34, 33, 59, 55, 51, 27,
164+
30, 31, 31, 30, 27, 0, 0, 0, 0, 0, 0, 0, 0, 0, -27, -30, -31, -33, -34,
165+
-33, -62, -63, -62, -31, -30, -27, -51, -55, -59, -59, -55, -51,
166+
-71, -76, -83, -76, -82, -76, -83, -87, -83, -76, -71, -83
167+
],
168+
[
169+
-3, -3, -3, 24, 24, -3, 23, 44, 56, 61, 56, 74, 81, 74, 44, 23, -3,
170+
-3, 27, 56, 56, 27, -3, -3, 31, 61, 81, 88, 81, 61, 31, -3, -3, 27,
171+
56, 74, 81, 74, 56, 61, 56, 56, 27, -3, -3, 23, 44, 44, 23, -3, -3, 24,
172+
-3, 24, 31, 24, -3, -3, -3, 24, -3, -3
173+
]])
174+
for i in range(coordination.shape[0]):
175+
arr = coordination[i]
176+
rank = rankdata(arr, method="dense") - 1
177+
coordination[i] = rank
178+
sph_coordination = np.array([[
179+
18,0,-18,25,-25,54,49,39,22,0,-22,45,0,-45,-39,-49,-54,72,69,62,-62,
180+
-69,-72,90,90,90,90,-90,-90,-90,-90,-90,108,111,118,135,-180,-135,158,
181+
-180,-158,-118,-111,-108,126,131,141,-141,-131,-126,144,155,162,155,-180,
182+
-155,162,-180,-162,-155,-144,-162
183+
],
184+
[
185+
-2,-2,-2,16,16,-2,15,30,40,44,40,58,67,58,30,15,-2,-2,
186+
18,40,40,18,-2,-2,21,44,67,90,67,44,21,-2,-2,18,40,58,
187+
67,58,40,44,40,40,18,-2,-2,15,30,30,15,-2,-2,-2,-2,16,
188+
21,16,-2,-2,-2,-2,-2,-2,
189+
]])
190+
191+
# process attention mask
192+
attn_mask = torch.zeros((80,80),dtype=torch.int)
193+
if self.local:
194+
for func_area in func_areas:
195+
for i in func_area:
196+
for j in func_area:
197+
attn_mask[i,j] = 1
198+
else:
199+
for i in range(62):
200+
for j in range(62):
201+
attn_mask[i,j] = 1
202+
for i, _ in enumerate(func_areas):
203+
for j in func_areas[i]:
204+
attn_mask[62+i,j] = 1
205+
attn_mask[j,62+i] = 1
206+
for i in range(17):
207+
for j in range(17):
208+
attn_mask[62+i,62+j] = 1
209+
for i in range(62+17+1):
210+
attn_mask[62+17, i] = 1
211+
attn_mask[i, 62+17] = 1
212+
self.attn_mask = (1-attn_mask).bool()
213+
#process supernode coordination
214+
self.coordination = area_gather(coordination, func_areas)
215+
self.sph_coordination = area_gather(sph_coordination, func_areas)
216+
217+
def get_index(self):
218+
return self.label.index
219+
220+
def __len__(self):
221+
return self.label.shape[0]
222+
223+
def __getitem__(self, idx):
224+
return (
225+
torch.tensor(
226+
self.data[idx],
227+
dtype=torch.float,
228+
),
229+
torch.tensor(
230+
self.label[idx],
231+
dtype=torch.long,
232+
).squeeze(),
233+
)
234+
235+
def freeup(self):
236+
pass
237+
238+
def load(self):
239+
pass
240+
241+
def area_gather(coordination, areas):
242+
supernode_coordination = np.zeros([coordination.shape[0], len(areas)])
243+
for idx,area in enumerate(areas):
244+
for i in area:
245+
for j in range(coordination.shape[0]):
246+
supernode_coordination[j][idx] += coordination[j][i]/len(area)
247+
248+
res = np.concatenate((coordination,supernode_coordination), axis=1)
249+
return res
250+
251+
if __name__=='__main__':
252+
dataset = SEED_IV(subject_index=0,window_size=10,addtime=True)
253+
dl = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
254+
cnt=0
255+
for data,label in dl:
256+
print(data.shape)
257+
cnt+=data.shape[0]
258+
print(cnt)

physiopro/dataset/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@
77
from .df import DfDataset
88
from .cinc2020 import CinC2020
99
from .SEED import SEED
10+
from .SEED_IV import SEED_IV
1011
from .tpp import EventDataset

physiopro/model/mmm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from ..common.utils import AverageMeter, GlobalTracker, to_torch,printt
1616
from .base import MODELS,BaseModel
1717

18+
1819
@MODELS.register_module()
1920
class MMM_Finetune(BaseModel):
2021
def __init__(

scripts/extract.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import os
2+
import glob
3+
import numpy as np
4+
import scipy.io as sio
5+
6+
7+
data_path = '/home/yansenwang/data/SEED-IV/SEED-IV/eeg_feature_smooth/' # path to the raw SEED dataset
8+
save_path = '/home/yansenwang/data/New_SEED_IV/DE/'
9+
os.makedirs(save_path, exist_ok=True)
10+
labels = [
11+
[1,2,3,0,2,0,0,1,0,1,2,1,1,1,2,3,2,2,3,3,0,3,0,3],
12+
[2,1,3,0,0,2,0,2,3,3,2,3,2,0,1,1,2,1,0,3,0,1,3,1],
13+
[1,2,2,1,3,3,3,1,1,2,1,0,2,3,3,0,2,3,0,0,2,0,1,0]
14+
]
15+
final_labels = {}
16+
17+
for exp in range(3):
18+
filenames = glob.glob(os.path.join(data_path, f'{exp+1}/*.mat'))
19+
filenames.sort()
20+
for sub in range(15):
21+
session_label = []
22+
data = []
23+
mat_path = filenames[sub]
24+
print(mat_path)
25+
T = sio.loadmat(mat_path)
26+
27+
for trial in range(24):
28+
temp = T['de_LDS' + str(trial + 1)]
29+
data.append(temp)
30+
31+
if sub == 0:
32+
temp_label = np.tile(labels[exp][trial], temp.shape[1])
33+
session_label.extend(temp_label)
34+
if sub == 0:
35+
final_labels[exp] = session_label
36+
data = np.concatenate(data, axis=1)
37+
sio.savemat(os.path.join(save_path, 'DE_' + str(sub * 3 + exp+1) + '.mat'), {'DE_feature': np.array(data)}) # save the features
38+
for exp in range(3):
39+
sio.savemat(os.path.join(save_path, 'DE_' + str(exp+1) + '_labels.mat'), {'de_labels': np.array(final_labels[exp])})

0 commit comments

Comments
 (0)