-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathdataset.py
More file actions
116 lines (102 loc) · 4.26 KB
/
dataset.py
File metadata and controls
116 lines (102 loc) · 4.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision.transforms import Compose
import utils.transforms as T
import utils.utilities as utils
from tqdm import tqdm
class FWIDataset(Dataset):
''' FWI dataset
For convenience, in this class, a batch refers to a npy file
instead of the batch used during training.
Args:
anno: path to annotation file
preload: whether to load the whole dataset into memory
sample_ratio: downsample ratio for seismic data
file_size: # of samples in each npy file
transform_data|label: transformation applied to data or label
'''
def __init__(
self,
anno,
preload=True,
sample_ratio=1,
file_size=500,
transform_data=None,
transform_label=None,
mask_factor=0.0
):
if not os.path.exists(anno):
print(f'Annotation file {anno} does not exists')
self.preload = preload
self.sample_ratio = sample_ratio
self.file_size = file_size
self.transform_data = transform_data
self.transform_label = transform_label
with open(anno, 'r') as f:
self.batches = f.readlines()
"""
THIS FUNCTION ONLY WORKS WITH PRELOAD. MASK_RATIO not implemented for preload=False
"""
if preload:
self.data_list, self.label_list= (), ()
for batch in tqdm(self.batches):
data, label = self.load_every(batch)
self.data_list = self.data_list + (data,)
self.label_list = self.label_list + (label,)
self.data_list = np.concatenate(self.data_list, 0)
self.label_list = np.concatenate(self.label_list, 0)
mask_indices = np.random.choice(len(self.data_list),
int(mask_factor*len(self.data_list)),
replace=False)
self.mask_list = np.ones(len(self.data_list), dtype=np.int8)
self.mask_list[mask_indices] = 0
print("Data concatenation complete.")
if self.transform_label is not None:
self.label_list = self.transform_label(self.label_list)
if self.transform_data is not None:
self.data_list = self.transform_data(self.data_list)
# Load from one line
def load_every(self, batch):
batch = batch.split('\t')
data_path = batch[0] if len(batch) > 1 else batch[0][:-1]
data = np.load(data_path)[:, :, ::self.sample_ratio, :]
data = data.astype('float32')
if len(batch) > 1:
label_path = batch[1][:-1]
label = np.load(label_path)
label = label.astype('float32')
else:
label = None
return data, label
def __getitem__(self, idx):
batch_idx, sample_idx = idx // self.file_size, idx % self.file_size
if self.preload:
mask = self.mask_list[idx]
data = self.data_list[idx]
label = self.label_list[idx] if len(self.label_list) != 0 else None
else:
data, label = self.load_every(self.batches[batch_idx])
data = data[sample_idx]
label = label[sample_idx] if label is not None else None
if self.transform_data is not None:
data = self.transform_data(data)
if self.transform_label is not None:
label = self.transform_data(label)
mask=None #NOT-IMPLEMENTED
return mask, data, label if label is not None else np.array([])
def __len__(self):
return len(self.batches) * self.file_size
if __name__ == '__main__':
transform_data = Compose([
T.LogTransform(k=1),
T.MinMaxNormalize(T.log_transform(-61, k=1), T.log_transform(120, k=1))
])
transform_label = Compose([
T.MinMaxNormalize(2000, 6000)
])
dataset = FWIDataset(f'relevant_files/temp.txt', transform_data=transform_data, transform_label=transform_label, file_size=1)
data, label = dataset[0]
print(data.shape)
print(label is None)