-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathIEMOCAP.py
More file actions
209 lines (169 loc) · 6.66 KB
/
IEMOCAP.py
File metadata and controls
209 lines (169 loc) · 6.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import os
import re
from pathlib import Path
from typing import Optional, Tuple, Union
from torch import Tensor
from torch.utils.data import Dataset
import logging
import os
import tarfile
import zipfile
from typing import Any, List, Optional
import torchaudio
_LG = logging.getLogger(__name__)
def _extract_tar(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]:
if to_path is None:
to_path = os.path.dirname(from_path)
with tarfile.open(from_path, "r") as tar:
files = []
for file_ in tar:
file_path = os.path.join(to_path, file_.name)
if file_.isfile():
files.append(file_path)
if os.path.exists(file_path):
_LG.info("%s already extracted.", file_path)
if not overwrite:
continue
tar.extract(file_, to_path)
return files
def _extract_zip(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]:
if to_path is None:
to_path = os.path.dirname(from_path)
with zipfile.ZipFile(from_path, "r") as zfile:
files = zfile.namelist()
for file_ in files:
file_path = os.path.join(to_path, file_)
if os.path.exists(file_path):
_LG.info("%s already extracted.", file_path)
if not overwrite:
continue
zfile.extract(file_, to_path)
return files
def _load_waveform(
root: str,
filename: str,
exp_sample_rate: int,
):
path = os.path.join(root, filename)
waveform, sample_rate = torchaudio.load(path)
if exp_sample_rate != sample_rate:
raise ValueError(f"sample rate should be {exp_sample_rate}, but got {sample_rate}")
return waveform
_SAMPLE_RATE = 16000
def _get_wavs_paths(data_dir):
wav_dir = data_dir / "sentences" / "wav"
wav_paths = sorted(str(p) for p in wav_dir.glob("*/*.wav"))
relative_paths = []
for wav_path in wav_paths:
start = wav_path.find("Session")
wav_path = wav_path[start:]
relative_paths.append(wav_path)
return relative_paths
class IEMOCAP(Dataset):
"""*IEMOCAP* :cite:`iemocap` dataset.
Args:
root (str or Path): Root directory where the dataset's top level directory is found
sessions (Tuple[int]): Tuple of sessions (1-5) to use. (Default: ``(1, 2, 3, 4, 5)``)
utterance_type (str or None, optional): Which type(s) of utterances to include in the dataset.
Options: ("scripted", "improvised", ``None``). If ``None``, both scripted and improvised
data are used.
"""
def __init__(
self,
root: Union[str, Path],
sessions: Tuple[str] = (1, 2, 3, 4, 5),
utterance_type: Optional[str] = None,
speakers: Tuple[str] = ('Ses01F','Ses01M','Ses02F','Ses02M','Ses03F','Ses03M','Ses04F','Ses04M','Ses05F','Ses05M'),
):
root = Path(root)
self._path = root / "IEMOCAP"
if not os.path.isdir(self._path):
raise RuntimeError("Dataset not found.")
if utterance_type not in ["scripted", "improvised", None]:
raise ValueError("utterance_type must be one of ['scripted', 'improvised', or None]")
all_data = []
self.data = []
self.mapping = {}
for session in sessions:
session_name = f"Session{session}"
session_dir = self._path / session_name
# get wav paths
wav_paths = _get_wavs_paths(session_dir)
for wav_path in wav_paths:
wav_stem = str(Path(wav_path).stem)
all_data.append(wav_stem)
# add labels
label_dir = session_dir / "dialog" / "EmoEvaluation"
query = "*.txt"
if utterance_type == "scripted":
query = "*script*.txt"
elif utterance_type == "improvised":
query = "*impro*.txt"
label_paths = label_dir.glob(query)
for label_path in label_paths:
with open(label_path, "r") as f:
for line in f:
if not line.startswith("["):
continue
line = re.split("[\t\n]", line)
wav_stem = line[1]
speaker = wav_stem.split("_")[0]
label = line[2]
if wav_stem not in all_data:
continue
if label not in ["neu", "hap", "ang", "sad", "exc"]:
continue
if speaker not in speakers:
continue
self.mapping[wav_stem] = {}
self.mapping[wav_stem]["label"] = label
for wav_path in wav_paths:
wav_stem = str(Path(wav_path).stem)
if wav_stem in self.mapping:
self.data.append(wav_stem)
self.mapping[wav_stem]["path"] = wav_path
def get_metadata(self, n: int) -> Tuple[str, int, str, str, str]:
"""Get metadata for the n-th sample from the dataset. Returns filepath instead of waveform,
but otherwise returns the same fields as :py:meth:`__getitem__`.
Args:
n (int): The index of the sample to be loaded
Returns:
Tuple of the following items;
str:
Path to audio
int:
Sample rate
str:
File name
str:
Label (one of ``"neu"``, ``"hap"``, ``"ang"``, ``"sad"``, ``"exc"``, ``"fru"``)
str:
Speaker
"""
wav_stem = self.data[n]
wav_path = self.mapping[wav_stem]["path"]
label = self.mapping[wav_stem]["label"]
speaker = wav_stem.split("_")[0]
return (wav_path, _SAMPLE_RATE, wav_stem, label, speaker)
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
Tuple of the following items;
Tensor:
Waveform
int:
Sample rate
str:
File name
str:
Label (one of ``"neu"``, ``"hap"``, ``"ang"``, ``"sad"``, ``"exc"``, ``"fru"``)
str:
Speaker
"""
metadata = self.get_metadata(n)
waveform = _load_waveform(self._path, metadata[0], metadata[1])
return (waveform,) + metadata[1:]
def __len__(self):
return len(self.data)