-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathaugmentation.py
More file actions
53 lines (37 loc) · 1.89 KB
/
augmentation.py
File metadata and controls
53 lines (37 loc) · 1.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os
from torchvision.transforms import v2
from utils import load_images
AVAILABLE_AUGMENTATIONS = {
"HorizontalFlip": v2.RandomHorizontalFlip(p=1), # Always apply horizontal flip
"RandomRotation": v2.RandomRotation(degrees=(-10, 10)),
"RandomCrop": v2.RandomAffine(degrees=0, translate=(0, 0), scale=(1.0, 1.3), shear=0),
"TrivialAugmentWide": v2.TrivialAugmentWide(),
"RandAugment": v2.RandAugment(magnitude=4)
}
def select_augmentations(augmentations_str: str) -> list[v2.Compose]:
if not augmentations_str:
return [] # Return an empty Compose object
augmentation_strings = [aug.strip() for aug in augmentations_str.split(",")]
augmentation_list = []
for aug in augmentation_strings:
if aug not in AVAILABLE_AUGMENTATIONS:
supported = ", ".join(AVAILABLE_AUGMENTATIONS.keys())
raise ValueError(f"Unsupported augmentation: {aug}. Supported augmentations are: {supported}")
augmentation_list.append(AVAILABLE_AUGMENTATIONS[aug])
return augmentation_list
def tensor_to_image(tensor):
return v2.ToPILImage()(tensor)
def augment_images(input_dir, output_dir, augmentations):
input_dir = os.path.join(os.getcwd(), input_dir)
images = load_images([input_dir])
for img_path, tensor in images:
filename = os.path.basename(img_path)
filename_without_extension, extension = os.path.splitext(filename)
for i, augmentation in enumerate(augmentations):
augmented_tensor = augmentation(tensor)
augmented_image = tensor_to_image(augmented_tensor)
augmented_image.save(os.path.join(output_dir, f"{filename_without_extension}_aug{i}.jpg"))
if __name__ == "__main__":
augmentations = select_augmentations("TrivialAugmentWide")
augment_images("data/augmentations", "data/augmentations", augmentations)
print("Augmented images saved to augmented_data.")