-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfilter_classes.py
More file actions
210 lines (164 loc) · 7.48 KB
/
filter_classes.py
File metadata and controls
210 lines (164 loc) · 7.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
"""
Preprocessing steps for YOLO-format datasets:
remap_classes:
- Renames classes in-place (ORIGINAL_NAME -> NEW_NAME).
- The target name must already exist in the dataset (supports merging).
- Source classes are removed from the class list; their annotations are
re-labelled as the target class.
- data.yaml is updated to reflect the new class list.
filter_classes:
- Removes annotations whose class is not in `keep_classes`.
- Remaining class IDs are remapped to be consecutive starting from 0.
- If a label file becomes empty after filtering, both the label file and its
corresponding image are deleted.
- data.yaml is updated to reflect the new class list.
Run remap_classes before filter_classes so renamed classes are correctly
recognised by the filter step.
"""
import os
import yaml
from tqdm import tqdm
def remap_classes(dataset_path: str, remap: dict[str, str]) -> None:
"""
Rename/merge classes in a YOLO-format dataset in-place.
Args:
dataset_path: Path to the dataset root containing data.yaml.
remap: Mapping of {original_class_name: target_class_name}.
The target class must already exist in the dataset.
After remapping the original class is removed from the
class list and all its annotations are relabelled as the
target class.
"""
yaml_path = os.path.join(dataset_path, "data.yaml")
if not os.path.exists(yaml_path):
raise FileNotFoundError(f"data.yaml not found at {yaml_path}")
with open(yaml_path, "r") as f:
data = yaml.safe_load(f)
original_names: list[str] = data["names"]
# Validate sources and targets
for src, tgt in remap.items():
if src not in original_names:
raise ValueError(f"Remap source '{src}' not found in dataset classes: {original_names}")
if tgt not in original_names:
raise ValueError(
f"Remap target '{tgt}' not found in dataset classes: {original_names}. "
"The target class must already exist in the dataset."
)
print(f"Remapping classes: {remap}")
# New names list: remove all source classes
new_names = [n for n in original_names if n not in remap]
# Build old_id -> new_id mapping
old_id_to_new_id: dict[int, int] = {}
for old_id, name in enumerate(original_names):
target_name = remap.get(name, name) # resolve through remap if present
old_id_to_new_id[old_id] = new_names.index(target_name)
splits = ["train", "test", "val"]
for split in splits:
labels_dir = os.path.join(dataset_path, split, "labels")
if not os.path.isdir(labels_dir):
continue
label_files = [f for f in os.listdir(labels_dir) if f.endswith(".txt")]
remapped_count = 0
for label_file in tqdm(label_files, desc=f"Remapping {split}"):
label_path = os.path.join(labels_dir, label_file)
with open(label_path, "r") as f:
lines = f.readlines()
new_lines = []
for line in lines:
line = line.strip()
if not line:
continue
parts = line.split()
old_class_id = int(parts[0])
new_class_id = old_id_to_new_id[old_class_id]
if new_class_id != old_class_id:
remapped_count += 1
parts[0] = str(new_class_id)
new_lines.append(" ".join(parts))
with open(label_path, "w") as f:
f.write("\n".join(new_lines) + "\n")
print(f" {split}: remapped {remapped_count} annotation(s)")
# Update data.yaml
data["names"] = new_names
data["nc"] = len(new_names)
with open(yaml_path, "w") as f:
yaml.dump(data, f, default_flow_style=False, allow_unicode=True)
print(f"Updated data.yaml after remap: nc={len(new_names)}, names={new_names}")
def filter_classes(dataset_path: str, keep_classes: list[str]) -> None:
"""
Filter a YOLO-format dataset in-place, keeping only the listed classes.
Args:
dataset_path: Absolute or relative path to the dataset root (the folder
that contains data.yaml and train/test/val sub-directories).
keep_classes: List of class name strings to retain.
"""
yaml_path = os.path.join(dataset_path, "data.yaml")
if not os.path.exists(yaml_path):
raise FileNotFoundError(f"data.yaml not found at {yaml_path}")
with open(yaml_path, "r") as f:
data = yaml.safe_load(f)
original_names: list[str] = data["names"]
print(f"Original classes ({len(original_names)}): {original_names}")
# Validate that all requested classes actually exist in the dataset
missing = [c for c in keep_classes if c not in original_names]
if missing:
raise ValueError(f"The following classes were not found in the dataset: {missing}")
# Build old_id -> new_id mapping (only for kept classes)
old_id_to_new_id: dict[int, int] = {}
new_id = 0
for old_id, name in enumerate(original_names):
if name in keep_classes:
old_id_to_new_id[old_id] = new_id
new_id += 1
# Preserve the order defined by keep_classes
new_names = [name for name in original_names if name in keep_classes]
print(f"Keeping classes ({len(new_names)}): {new_names}")
splits = ["train", "test", "val"]
image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}
for split in splits:
labels_dir = os.path.join(dataset_path, split, "labels")
images_dir = os.path.join(dataset_path, split, "images")
if not os.path.isdir(labels_dir):
continue
label_files = [f for f in os.listdir(labels_dir) if f.endswith(".txt")]
removed_images = 0
filtered_annotations = 0
for label_file in tqdm(label_files, desc=f"Filtering {split}"):
label_path = os.path.join(labels_dir, label_file)
with open(label_path, "r") as f:
lines = f.readlines()
kept_lines = []
for line in lines:
line = line.strip()
if not line:
continue
parts = line.split()
old_class_id = int(parts[0])
if old_class_id in old_id_to_new_id:
parts[0] = str(old_id_to_new_id[old_class_id])
kept_lines.append(" ".join(parts))
else:
filtered_annotations += 1
if not kept_lines:
# Remove empty label file and its corresponding image
os.remove(label_path)
stem = os.path.splitext(label_file)[0]
for ext in image_extensions:
img_path = os.path.join(images_dir, stem + ext)
if os.path.exists(img_path):
os.remove(img_path)
break
removed_images += 1
else:
with open(label_path, "w") as f:
f.write("\n".join(kept_lines) + "\n")
print(
f" {split}: removed {filtered_annotations} annotations, "
f"deleted {removed_images} image(s) with no remaining labels"
)
# Update data.yaml
data["names"] = new_names
data["nc"] = len(new_names)
with open(yaml_path, "w") as f:
yaml.dump(data, f, default_flow_style=False, allow_unicode=True)
print(f"Updated data.yaml: nc={len(new_names)}, names={new_names}")