-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
31 lines (25 loc) · 1.04 KB
/
utils.py
File metadata and controls
31 lines (25 loc) · 1.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from torch.utils.data import DataLoader
from dataset import BirdDataset
import torch
def load_data_set(image_paths, image_dir, segmentation_dir, transforms, batch_size=8, shuffle=True):
dataset = BirdDataset(image_paths,
image_dir,
segmentation_dir,
transform_image=transforms[0],
transform_mask=transforms[1])
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [11772, 16])
return DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=shuffle
), DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=shuffle
)
if __name__ == "__main__":
image_dir = "CUB_200_2011/CUB_200_2011/images"
segmentation_dir = "CUB_200_2011/CUB_200_2011/segmentations"
image_paths = "CUB_200_2011/CUB_200_2011/images.txt"
train_dataset, val_dataset = load_data_set(image_paths, image_dir, segmentation_dir)
print(len(train_dataset), len(val_dataset))