-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataloader.py
More file actions
124 lines (104 loc) · 4.73 KB
/
dataloader.py
File metadata and controls
124 lines (104 loc) · 4.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import torchaudio.transforms as T
from torch.utils.data import Dataset, DataLoader
from torchvision.models import ViT_B_16_Weights
from IEMOCAP import *
import numpy as np
import random
import audiomentations
import torch
def augment_audio(waveform, sample_rate, augment_audio_flag=True, augment_spectro_flag=True):
if not augment_audio_flag and not augment_spectro_flag:
return waveform, mel_spectrogram(torch.from_numpy(waveform).unsqueeze(0))
augmented_waveform = waveform.detach().numpy()
if augment_audio_flag:
# Augmentation Audio (avec audiomentations pour plus d'options et de robustesse)
augment = audiomentations.Compose([
audiomentations.TimeStretch(min_rate=0.9, max_rate=1.1, p=0.1), # Élongation temporelle
audiomentations.PitchShift(min_semitones=-2, max_semitones=2, p=0.1), # Décalage de la hauteur tonale
audiomentations.AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.1),# Ajout de bruit gaussien
audiomentations.Gain(min_gain_in_db=-5, max_gain_in_db=5, p=0.1) # Gain
# audiomentations.AddBackgroundNoise(sounds_path='/path/to/background/noises', min_snr_in_db=10, max_snr_in_db=30, p=0.3)
])
augmented_waveform = augment(samples=augmented_waveform, sample_rate=sample_rate)
if augment_spectro_flag:
# Augmentation Spectrogramme (directement sur le spectrogramme)
mel_spec = mel_spectrogram(torch.from_numpy(augmented_waveform).unsqueeze(0)).numpy().squeeze()
# Masquage temporel et fréquentiel
augment_melspec = audiomentations.SpecCompose(
[
# audiomentations.SpecChannelShuffle(p=0.1),
audiomentations.SpecFrequencyMask(p=0.1),
]
)
mel_spec= augment_melspec(mel_spec)
mel_spec = torch.from_numpy(mel_spec)
else:
mel_spec = mel_spectrogram(torch.from_numpy(augmented_waveform).unsqueeze(0)).numpy().squeeze()
mel_spec = torch.from_numpy(mel_spec)
return augmented_waveform, mel_spec
class IEMOCAP_MFCC(Dataset):
def __init__(self, root, sessions, speakers=None,
utterance_type=None, transform2=None, target_transform=None,augmentation= True):
self.dataset = IEMOCAP(
root=root,
sessions=sessions,
utterance_type=utterance_type,
speakers=speakers,
)
self.speakers = speakers
self.transform2 = transform2
self.target_transform = target_transform
self.augmentation = augmentation
self.label_mapping = {'ang': 0, 'hap': 1, 'sad': 2, 'neu': 3, 'exc': 1}
self.L = 6 # Nombre de découpes de la waveform
self.H = 10000 # Longueur de chaque segment
self.weights = ViT_B_16_Weights.IMAGENET1K_V1
self.preprocess = self.weights.transforms()
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
waveform, sample_rate, ser, label, utterance_id = self.dataset[idx]
# Augmentation audio (appeler la fonction augment_audio ici)
if self.augmentation == True:
augmented_waveform, mel_spec = augment_audio(waveform, sample_rate,
augment_audio_flag=True,
augment_spectro_flag=True)
else:
augmented_waveform=waveform
mel_spec = mel_spec = self.transform2(augmented_waveform)
mel_spec = mel_spec[0,:,:]
if self.transform2:
# mel_spec = self.transform2(augmented_waveform)
# mel_spec = mel_spec[0,:,:]
img = torch.stack([mel_spec,mel_spec,mel_spec],0)
img = self.preprocess(img)
if self.target_transform:
label = self.target_transform(label)
# Découpage de la waveform
augmented_waveform = torch.Tensor(augmented_waveform).squeeze()
total_length = self.L * self.H
if augmented_waveform.shape[-1] < total_length:
augmented_waveform = torch.nn.functional.pad(augmented_waveform,
(0, total_length - augmented_waveform.shape[-1] ),
"constant", 0)
else:
augmented_waveform = augmented_waveform[:total_length]
return augmented_waveform, img, self.label_mapping[label]
n_fft = 400
win_length = 400
hop_length = 200
n_mels = 128
mel_spectrogram = T.MelSpectrogram(
sample_rate=16000,
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
center=True,
pad_mode="reflect",
normalized=True,
power=2.0,
norm="slaney",
n_mels=n_mels,
mel_scale="htk",
)