-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata_processing.py
More file actions
executable file
·110 lines (87 loc) · 3.27 KB
/
data_processing.py
File metadata and controls
executable file
·110 lines (87 loc) · 3.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
import numpy as np
import pandas as pd
from PIL import Image
import random
from typing import List
import json
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
class CustomImageDataset(Dataset):
def __init__(self, data_dir, data_df, if_pretrained: bool, transform=None):
self.df = data_df.reset_index(drop=True)
self.transform = transform
self.if_pretrained = if_pretrained
self.image_paths = []
self.gmom_units = []
for _, row in data_df.iterrows():
image_path = os.path.join(data_dir, row["which_instrument"], "processed_data", row["GMoM_Unit_acronym"], row["Filename"])
self.image_paths.append(image_path)
self.gmom_units.append(row["GMoM_Unit_acronym"])
def __len__(self):
return len(self.image_paths)
def __getitem__(self, index):
image_path, gmom_unit = self.image_paths[index], self.gmom_units[index]
# Load image
arr = np.array(Image.open(image_path)).astype(np.float32)
arr_normalized = 255 * (arr - arr.min()) / (arr.max() - arr.min())
image = Image.fromarray(arr_normalized.astype(np.uint8)).convert("RGB")
# Apply transforms
if self.transform is not None:
image = self.transform(image)
# Print min max of image
return image, gmom_unit
def prepare_dataloaders(
data_dir: str,
data_df: str,
if_pretrained: bool,
which_instrument: List[str],
batch_size: int,
num_workers: int,
pin_mem: bool
):
### Define transforms
instrument_key = ", ".join(which_instrument)
with open("utils/statistics.json", "r") as f:
INSTRUMENT_STATS = json.load(f)
if if_pretrained:
mean = INSTRUMENT_STATS["ImageNet"]["mean"]
std = INSTRUMENT_STATS["ImageNet"]["std"]
else:
mean = INSTRUMENT_STATS[instrument_key]["mean"]
std = INSTRUMENT_STATS[instrument_key]["std"]
train_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
val_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
### Read data
data_df = pd.read_csv(data_df)
train_df = data_df[data_df["Split"]=="train"]
val_df = data_df[data_df["Split"]=="val"]
### Create datasets
dataset_train = CustomImageDataset(data_dir=data_dir, data_df=train_df, if_pretrained=if_pretrained, transform=train_transforms)
dataset_val = CustomImageDataset(data_dir=data_dir, data_df=val_df, if_pretrained=if_pretrained, transform=val_transforms)
### Create dataloaders
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.RandomSampler(dataset_val)
train_dataloader = DataLoader(
dataset_train, sampler=sampler_train,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_mem,
drop_last=True,
)
val_dataloader = DataLoader(
dataset_val, sampler=sampler_val,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_mem,
drop_last=True,
)
return train_dataloader, val_dataloader