|
| 1 | +from typing import Optional |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch.utils.data import Dataset |
| 5 | +import numpy as np |
| 6 | +import scipy.io as sio |
| 7 | +from scipy.stats import rankdata |
| 8 | + |
| 9 | +from .base import DATASETS |
| 10 | + |
| 11 | + |
| 12 | +@DATASETS.register_module() |
| 13 | +class SEED_IV(Dataset): |
| 14 | + def __init__( |
| 15 | + self, |
| 16 | + prefix: str = "./data/", |
| 17 | + name: str = "DE", |
| 18 | + window_size: int = 1, |
| 19 | + addtime: bool = False, |
| 20 | + subject_index: Optional[int] = -1, |
| 21 | + dataset_name: Optional[str] = 'train', |
| 22 | + channel: int = 62, |
| 23 | + local: bool = False, |
| 24 | + normalize: str = 'gaussian', |
| 25 | + ): |
| 26 | + super().__init__() |
| 27 | + self.window_size = window_size |
| 28 | + self.addtime = addtime |
| 29 | + self.dataset_name = dataset_name |
| 30 | + self.channel = channel |
| 31 | + self.local = local |
| 32 | + self.out_size = 4 |
| 33 | + file = prefix + "/SEED-IV/" + name + "/" |
| 34 | + data_file_path = file + "DE_{}.mat" |
| 35 | + de_label={} |
| 36 | + for i in range(3): |
| 37 | + de_label[i] = np.array(sio.loadmat(file + f"DE_{i+1}_labels.mat")['de_labels']).squeeze(0) |
| 38 | + self.candidate_list = ( |
| 39 | + [subject_index] if subject_index != -1 else list(range(45)) |
| 40 | + ) |
| 41 | + self.label = [de_label[i%3] for i in self.candidate_list] |
| 42 | + self.data =[ |
| 43 | + np.array(sio.loadmat(data_file_path.format(i + 1))["DE_feature"]).transpose(1,0,2) |
| 44 | + for i in self.candidate_list |
| 45 | + ] |
| 46 | + self._normalize(normalize) |
| 47 | + self._split(self.dataset_name) |
| 48 | + |
| 49 | + if addtime: |
| 50 | + self._addtimewindow(window_size) |
| 51 | + # N,T,C,F -> N,C,T,F |
| 52 | + self.data = self.data.transpose(0,2,1,3) |
| 53 | + else: |
| 54 | + self.data = np.concatenate(self.data, axis=0) |
| 55 | + self.label = np.concatenate(self.label, axis=0) |
| 56 | + idx = [0,1,2,3,4,5,6,7,8,9, |
| 57 | + 10,17,18,19,11,12,13, |
| 58 | + 14,15,16,20,21,22,23, |
| 59 | + 24,25,26,27,28,29,30, |
| 60 | + 31,32,33,34,35,36,37, |
| 61 | + 44,45,46,38,39,40,41, |
| 62 | + 42,43,47,48,49,50,51, |
| 63 | + 57,52,53,54,58,59,60, |
| 64 | + 55,56,61] |
| 65 | + idx=torch.tensor(idx) |
| 66 | + self.data = torch.tensor(self.data) |
| 67 | + self.data = torch.index_select(self.data,dim=1,index=idx) |
| 68 | + self.data = self.data.numpy() |
| 69 | + |
| 70 | + self.get_coordination() |
| 71 | + |
| 72 | + |
| 73 | + def _normalize(self,method='minmax'): |
| 74 | + train_size = [610, 558, 567] |
| 75 | + # min-max normalization |
| 76 | + if method == 'minmax': |
| 77 | + for i, candidate in enumerate(self.candidate_list): |
| 78 | + for j in range(5): |
| 79 | + # 0~train_size[candidate%3] is the training set, train_size[candidate%3]:: is the valid set |
| 80 | + minn = np.min(self.data[i][ :train_size[candidate%3], :, j]) |
| 81 | + maxx = np.max(self.data[i][ :train_size[candidate%3], :, j]) |
| 82 | + self.data[i][:,:,j] = (self.data[i][:,:,j] - minn) / (maxx-minn) |
| 83 | + |
| 84 | + # gaussian standardization |
| 85 | + if method == 'gaussian': |
| 86 | + for i, candidate in enumerate(self.candidate_list): |
| 87 | + for j in range(5): |
| 88 | + # 0~train_size[candidate%3] is the training set, train_size[candidate%3]:: is the valid set |
| 89 | + mean = np.mean(self.data[i][ :train_size[candidate%3], :, j]) |
| 90 | + std = np.std(self.data[i][ :train_size[candidate%3], :, j]) |
| 91 | + self.data[i][:, :, j] = (self.data[i][:, :, j] - mean) / std |
| 92 | + |
| 93 | + def _addtimewindow(self, window): |
| 94 | + S = len(self.data) |
| 95 | + data_results = [] |
| 96 | + label_results = [] |
| 97 | + for i in range(S): |
| 98 | + # padding from the last sample, to make sure the sample number is the same after addtimewindow operation |
| 99 | + data = self.data[i] |
| 100 | + label = self.label[i] |
| 101 | + print(data.shape) |
| 102 | + print(label.shape) |
| 103 | + N, C, F = data.shape |
| 104 | + data = np.concatenate([data, data[-(window):, :, :]], 0) |
| 105 | + label = np.concatenate([label, label[-(window):]], 0) |
| 106 | + data_res = np.zeros(shape=(N, window, C, F)) |
| 107 | + label_res = np.zeros(shape=N) |
| 108 | + for j in range(N): |
| 109 | + # met the corner case |
| 110 | + if ( |
| 111 | + label[j] == label[j + window - 1] |
| 112 | + and label[j] != label[j + window] |
| 113 | + ): |
| 114 | + nearest = j + window |
| 115 | + # |
| 116 | + elif label[j] == label[j + window - 1]: |
| 117 | + nearest = -1 |
| 118 | + if nearest != -1: |
| 119 | + data_res[j, :, :, :] = np.concatenate( |
| 120 | + [ |
| 121 | + data[j:nearest, :, :], |
| 122 | + np.zeros(shape=(window - nearest + j, C, F)), |
| 123 | + ], |
| 124 | + 0, |
| 125 | + ) |
| 126 | + else: |
| 127 | + data_res[j, :, :, :] = data[j : j + window, :, :] |
| 128 | + label_res[j] = label[j] |
| 129 | + data_results.append(data_res) |
| 130 | + label_results.extend(label_res) |
| 131 | + |
| 132 | + self.data = np.concatenate(data_results, 0) |
| 133 | + self.label = np.array(label_results) |
| 134 | + |
| 135 | + def _split(self, dataset_name): |
| 136 | + train_size = [610, 558, 567] |
| 137 | + |
| 138 | + if dataset_name == "train": |
| 139 | + for idx, candidate in enumerate(self.candidate_list): |
| 140 | + print(self.data[idx].shape) |
| 141 | + self.data[idx] = self.data[idx][:train_size[candidate%3]] |
| 142 | + self.label[idx] = self.label[idx][:train_size[candidate%3]] |
| 143 | + elif dataset_name == "valid": |
| 144 | + for idx, candidate in enumerate(self.candidate_list): |
| 145 | + self.data[idx] = self.data[idx][train_size[candidate%3]:] |
| 146 | + self.label[idx] = self.label[idx][train_size[candidate%3]:] |
| 147 | + else: |
| 148 | + raise ValueError("dataset_name should be train or valid") |
| 149 | + |
| 150 | + def get_coordination(self): |
| 151 | + func_areas = [[0,1,2,3,4],[5,6,7],[8,9,10,11,12,13],[14,15,16],[17,18,19],[20,21,22],[23,24,25], |
| 152 | + [26,27,28],[29,30,31],[32,33,34],[35,36,37,38,39,40],[41,42,43],[44,45,46],[47,48,49], |
| 153 | + [50,51,52],[53,54,55,56,57,58],[59,60,61]] |
| 154 | + coordination = np.array([[ |
| 155 | + -27, 0, 27, -36, 36, -71, -64, -48, -25, 0, 25, |
| 156 | + -33, 0, 33, 48, 64, 71, -83, -78, -59, 59, 78, 83, |
| 157 | + -87, -82, -63, -34, 0, 34, 63, 82, 87, -83, -78, |
| 158 | + -59, -33, 0, 33, -25, 0, 25, 59, 78, 83, -71, |
| 159 | + -64, -48, 48, 64, 71, -51, -40, -36, -36, 0, |
| 160 | + 36, -27, 0, 27, 40, 51, 36 |
| 161 | + ], |
| 162 | + [ |
| 163 | + 83, 87, 83, 76, 76, 51, 55, 59, 62, 63, 62, 33, 34, 33, 59, 55, 51, 27, |
| 164 | + 30, 31, 31, 30, 27, 0, 0, 0, 0, 0, 0, 0, 0, 0, -27, -30, -31, -33, -34, |
| 165 | + -33, -62, -63, -62, -31, -30, -27, -51, -55, -59, -59, -55, -51, |
| 166 | + -71, -76, -83, -76, -82, -76, -83, -87, -83, -76, -71, -83 |
| 167 | + ], |
| 168 | + [ |
| 169 | + -3, -3, -3, 24, 24, -3, 23, 44, 56, 61, 56, 74, 81, 74, 44, 23, -3, |
| 170 | + -3, 27, 56, 56, 27, -3, -3, 31, 61, 81, 88, 81, 61, 31, -3, -3, 27, |
| 171 | + 56, 74, 81, 74, 56, 61, 56, 56, 27, -3, -3, 23, 44, 44, 23, -3, -3, 24, |
| 172 | + -3, 24, 31, 24, -3, -3, -3, 24, -3, -3 |
| 173 | + ]]) |
| 174 | + for i in range(coordination.shape[0]): |
| 175 | + arr = coordination[i] |
| 176 | + rank = rankdata(arr, method="dense") - 1 |
| 177 | + coordination[i] = rank |
| 178 | + sph_coordination = np.array([[ |
| 179 | + 18,0,-18,25,-25,54,49,39,22,0,-22,45,0,-45,-39,-49,-54,72,69,62,-62, |
| 180 | + -69,-72,90,90,90,90,-90,-90,-90,-90,-90,108,111,118,135,-180,-135,158, |
| 181 | + -180,-158,-118,-111,-108,126,131,141,-141,-131,-126,144,155,162,155,-180, |
| 182 | + -155,162,-180,-162,-155,-144,-162 |
| 183 | + ], |
| 184 | + [ |
| 185 | + -2,-2,-2,16,16,-2,15,30,40,44,40,58,67,58,30,15,-2,-2, |
| 186 | + 18,40,40,18,-2,-2,21,44,67,90,67,44,21,-2,-2,18,40,58, |
| 187 | + 67,58,40,44,40,40,18,-2,-2,15,30,30,15,-2,-2,-2,-2,16, |
| 188 | + 21,16,-2,-2,-2,-2,-2,-2, |
| 189 | + ]]) |
| 190 | + |
| 191 | + # process attention mask |
| 192 | + attn_mask = torch.zeros((80,80),dtype=torch.int) |
| 193 | + if self.local: |
| 194 | + for func_area in func_areas: |
| 195 | + for i in func_area: |
| 196 | + for j in func_area: |
| 197 | + attn_mask[i,j] = 1 |
| 198 | + else: |
| 199 | + for i in range(62): |
| 200 | + for j in range(62): |
| 201 | + attn_mask[i,j] = 1 |
| 202 | + for i, _ in enumerate(func_areas): |
| 203 | + for j in func_areas[i]: |
| 204 | + attn_mask[62+i,j] = 1 |
| 205 | + attn_mask[j,62+i] = 1 |
| 206 | + for i in range(17): |
| 207 | + for j in range(17): |
| 208 | + attn_mask[62+i,62+j] = 1 |
| 209 | + for i in range(62+17+1): |
| 210 | + attn_mask[62+17, i] = 1 |
| 211 | + attn_mask[i, 62+17] = 1 |
| 212 | + self.attn_mask = (1-attn_mask).bool() |
| 213 | + #process supernode coordination |
| 214 | + self.coordination = area_gather(coordination, func_areas) |
| 215 | + self.sph_coordination = area_gather(sph_coordination, func_areas) |
| 216 | + |
| 217 | + def get_index(self): |
| 218 | + return self.label.index |
| 219 | + |
| 220 | + def __len__(self): |
| 221 | + return self.label.shape[0] |
| 222 | + |
| 223 | + def __getitem__(self, idx): |
| 224 | + return ( |
| 225 | + torch.tensor( |
| 226 | + self.data[idx], |
| 227 | + dtype=torch.float, |
| 228 | + ), |
| 229 | + torch.tensor( |
| 230 | + self.label[idx], |
| 231 | + dtype=torch.long, |
| 232 | + ).squeeze(), |
| 233 | + ) |
| 234 | + |
| 235 | + def freeup(self): |
| 236 | + pass |
| 237 | + |
| 238 | + def load(self): |
| 239 | + pass |
| 240 | + |
| 241 | +def area_gather(coordination, areas): |
| 242 | + supernode_coordination = np.zeros([coordination.shape[0], len(areas)]) |
| 243 | + for idx,area in enumerate(areas): |
| 244 | + for i in area: |
| 245 | + for j in range(coordination.shape[0]): |
| 246 | + supernode_coordination[j][idx] += coordination[j][i]/len(area) |
| 247 | + |
| 248 | + res = np.concatenate((coordination,supernode_coordination), axis=1) |
| 249 | + return res |
| 250 | + |
| 251 | +if __name__=='__main__': |
| 252 | + dataset = SEED_IV(subject_index=0,window_size=10,addtime=True) |
| 253 | + dl = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True) |
| 254 | + cnt=0 |
| 255 | + for data,label in dl: |
| 256 | + print(data.shape) |
| 257 | + cnt+=data.shape[0] |
| 258 | + print(cnt) |
0 commit comments