Skip to content

【数据处理BUG反馈】iSAID官方数据集转换脚本生成的标签掩码(label mask)只有0/1像素值,类别信息严重缺失 #25

@caixiaoshun

Description

@caixiaoshun

详细描述

问题描述

在使用本仓库官方提供的 iSAID 数据集处理脚本(tools/dataset_converters/isaid.py )时,发现生成的标签掩码(mask)PNG 文件只有像素值 [0, 1],而不是应该有的全部类别(例如 iSAID 应该是 0-15)。这样生成的数据标签无法用于语义分割的训练和评估。


复现步骤

  1. 按照 README 指南下载 iSAID 数据集。

  2. 运行官方数据转换脚本,例如:

    python tools/dataset_converters/isaid.py /path/to/iSAID
  3. 用 numpy 检查任意输出的 label mask(PNG):

    import numpy as np
    from PIL import Image
    mask = np.array(Image.open('你的label_mask.png'))
    print(np.unique(mask))
  4. 实际输出只有 [0, 1],没有所有类别ID。


问题分析

  • 官方脚本保存 label mask 时,使用了 PIL 的 'P'(调色板)模式,但没有显式设置调色板
  • PIL 会自动重新映射/压缩调色板索引,导致绝大多数类别信息丢失,只剩下 [0, 1] 或极少数值。
  • 这种标签图像根本无法用于分割任务,类别信息严重缺失。

期望行为

  • 输出的标签掩码应保留全部类别像素值(如 iSAID 应为 0-15,UAVid 应为 0-7)。
  • 用 numpy 读取 PNG 时,可以看到所有类别索引。

建议修复

  • 用于训练和评测的 label mask 推荐使用 PIL 'L'(灰度)模式保存,例如:

    Image.fromarray(mask.astype(np.uint8), mode='L').save(...)
  • 如果确实需要彩色可视化,再用 'P' 模式,并在保存前用 putpalette() 明确指定调色板。


补充说明:多线程高效处理与论文结果复现

在发现并修复了标签保存方式的问题后,我改写了处理流程,采用 Python 的 ProcessPoolExecutor 实现了图像和标签的多进程高效切片。修复后采用 'L' 模式保存标签,类别信息完全保留。
使用该流程后,我能够完全复现原论文中 iSAID 各类别 IoU、mIoU、Acc 等分割结果。

主要修正代码片段举例

import argparse
import glob
import os
import os.path as osp
import shutil
import tempfile
import zipfile

import mmcv
import numpy as np
from mmengine.utils import ProgressBar, mkdir_or_exist
from PIL import Image
from concurrent.futures import ProcessPoolExecutor, as_completed

iSAID_palette = {
    0: (0, 0, 0),
    1: (0, 0, 63),
    2: (0, 63, 63),
    3: (0, 63, 0),
    4: (0, 63, 127),
    5: (0, 63, 191),
    6: (0, 63, 255),
    7: (0, 127, 63),
    8: (0, 127, 127),
    9: (0, 0, 127),
    10: (0, 0, 191),
    11: (0, 0, 255),
    12: (0, 191, 127),
    13: (0, 127, 191),
    14: (0, 127, 255),
    15: (0, 100, 155)
}

iSAID_invert_palette = {v: k for k, v in iSAID_palette.items()}

def iSAID_convert_from_color(arr_3d, palette=iSAID_invert_palette):
    """RGB-color encoding to grayscale labels."""
    arr_2d = np.zeros((arr_3d.shape[0], arr_3d.shape[1]), dtype=np.uint8)
    for c, i in palette.items():
        m = np.all(arr_3d == np.array(c).reshape(1, 1, 3), axis=2)
        arr_2d[m] = i
    return arr_2d

def slide_crop_image(src_path, out_dir, mode, patch_H, patch_W, overlap):
    img = np.asarray(Image.open(src_path).convert('RGB'))
    img_H, img_W, _ = img.shape
    if img_H < patch_H and img_W > patch_W:
        img = mmcv.impad(img, shape=(patch_H, img_W), pad_val=0)
        img_H, img_W, _ = img.shape
    elif img_H > patch_H and img_W < patch_W:
        img = mmcv.impad(img, shape=(img_H, patch_W), pad_val=0)
        img_H, img_W, _ = img.shape
    elif img_H < patch_H and img_W < patch_W:
        img = mmcv.impad(img, shape=(patch_H, patch_W), pad_val=0)
        img_H, img_W, _ = img.shape
    for x in range(0, img_W, patch_W - overlap):
        for y in range(0, img_H, patch_H - overlap):
            x_str = x
            x_end = x + patch_W
            if x_end > img_W:
                diff_x = x_end - img_W
                x_str -= diff_x
                x_end = img_W
            y_str = y
            y_end = y + patch_H
            if y_end > img_H:
                diff_y = y_end - img_H
                y_str -= diff_y
                y_end = img_H
            img_patch = img[y_str:y_end, x_str:x_end, :]
            img_patch = Image.fromarray(img_patch.astype(np.uint8))
            image = osp.basename(src_path).split('.')[0] + '_' + str(
                y_str) + '_' + str(y_end) + '_' + str(x_str) + '_' + str(
                x_end) + '.png'
            save_path_image = osp.join(out_dir, 'img_dir', mode, str(image))
            img_patch.save(save_path_image, format='BMP')

def slide_crop_label(src_path, out_dir, mode, patch_H, patch_W, overlap):
    label = mmcv.imread(src_path, channel_order='rgb')
    label = iSAID_convert_from_color(label)
    img_H, img_W = label.shape
    if img_H < patch_H and img_W > patch_W:
        label = mmcv.impad(label, shape=(patch_H, img_W), pad_val=255)
        img_H = patch_H
    elif img_H > patch_H and img_W < patch_W:
        label = mmcv.impad(label, shape=(img_H, patch_W), pad_val=255)
        img_W = patch_W
    elif img_H < patch_H and img_W < patch_W:
        label = mmcv.impad(label, shape=(patch_H, patch_W), pad_val=255)
        img_H = patch_H
        img_W = patch_W
    for x in range(0, img_W, patch_W - overlap):
        for y in range(0, img_H, patch_H - overlap):
            x_str = x
            x_end = x + patch_W
            if x_end > img_W:
                diff_x = x_end - img_W
                x_str -= diff_x
                x_end = img_W
            y_str = y
            y_end = y + patch_H
            if y_end > img_H:
                diff_y = y_end - img_H
                y_str -= diff_y
                y_end = img_H
            lab_patch = label[y_str:y_end, x_str:x_end]
            lab_patch = Image.fromarray(lab_patch.astype(np.uint8), mode='L')
            image = osp.basename(src_path).split('.')[0].split(
                '_')[0] + '_' + str(y_str) + '_' + str(y_end) + '_' + str(
                x_str) + '_' + str(x_end) + '_instance_color_RGB' + '.png'
            lab_patch.save(osp.join(out_dir, 'ann_dir', mode, str(image)))

def process_image(args):
    img_path, out_dir, mode, patch_H, patch_W, overlap = args
    slide_crop_image(img_path, out_dir, mode, patch_H, patch_W, overlap)
    return img_path

def process_label(args):
    lab_path, out_dir, mode, patch_H, patch_W, overlap = args
    slide_crop_label(lab_path, out_dir, mode, patch_H, patch_W, overlap)
    return lab_path

def parse_args():
    parser = argparse.ArgumentParser(
        description='Convert iSAID dataset to mmsegmentation format')
    parser.add_argument('dataset_path', help='iSAID folder path')
    parser.add_argument('--tmp_dir', help='path of the temporary directory')
    parser.add_argument('-o', '--out_dir', help='output path')
    parser.add_argument('--patch_width', default=896, type=int, help='Width of the cropped image patch')
    parser.add_argument('--patch_height', default=896, type=int, help='Height of the cropped image patch')
    parser.add_argument('--overlap_area', default=384, type=int, help='Overlap area')
    parser.add_argument('--mode', nargs='+', default=['val'], help='Which mode to process (default: val)')
    parser.add_argument('--max_workers', default=16, type=int, help='Max worker processes')
    return parser.parse_args()

def main():
    args = parse_args()
    dataset_path = args.dataset_path
    patch_H, patch_W = args.patch_height, args.patch_width
    overlap = args.overlap_area
    modes = args.mode
    max_workers = args.max_workers

    if args.out_dir is None:
        out_dir = osp.join('data', 'iSAID')
    else:
        out_dir = args.out_dir

    print('Making directories...')
    for mode in modes:
        mkdir_or_exist(osp.join(out_dir, 'img_dir', mode))
        mkdir_or_exist(osp.join(out_dir, 'ann_dir', mode))
        assert os.path.exists(os.path.join(dataset_path, mode)), \
            f'{mode} is not in {dataset_path}'

    with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
        for dataset_mode in modes:
            print(f'Extracting {dataset_mode} zip files...')
            img_zipp_list = glob.glob(
                os.path.join(dataset_path, dataset_mode, 'images', '*.zip'))
            print('Find image zips:', img_zipp_list)
            for img_zipp in img_zipp_list:
                zip_file = zipfile.ZipFile(img_zipp)
                zip_file.extractall(os.path.join(tmp_dir, dataset_mode, 'img'))
            src_path_list = glob.glob(
                os.path.join(tmp_dir, dataset_mode, 'img', 'images', '*.png'))

            # 图片多进程切片
            print(f'Start slicing {dataset_mode} images with {max_workers} workers...')
            src_prog_bar = ProgressBar(len(src_path_list))
            src_tasks = [
                (img_path, out_dir, dataset_mode, patch_H, patch_W, overlap)
                for img_path in src_path_list
            ]
            with ProcessPoolExecutor(max_workers=max_workers) as executor:
                futures = [executor.submit(process_image, task) for task in src_tasks]
                for i, future in enumerate(as_completed(futures)):
                    src_prog_bar.update()
            print(f'{dataset_mode} image slicing finished!')

            if dataset_mode != 'test':
                label_zipp_list = glob.glob(
                    os.path.join(dataset_path, dataset_mode, 'Semantic_masks', '*.zip'))
                print('Find label zips:', label_zipp_list)
                for label_zipp in label_zipp_list:
                    zip_file = zipfile.ZipFile(label_zipp)
                    zip_file.extractall(os.path.join(tmp_dir, dataset_mode, 'lab'))
                lab_path_list = glob.glob(
                    os.path.join(tmp_dir, dataset_mode, 'lab', 'images', '*.png'))

                # 标签多进程切片
                print(f'Start slicing {dataset_mode} labels with {max_workers} workers...')
                lab_prog_bar = ProgressBar(len(lab_path_list))
                lab_tasks = [
                    (lab_path, out_dir, dataset_mode, patch_H, patch_W, overlap)
                    for lab_path in lab_path_list
                ]
                with ProcessPoolExecutor(max_workers=max_workers) as executor:
                    futures = [executor.submit(process_label, task) for task in lab_tasks]
                    for i, future in enumerate(as_completed(futures)):
                        lab_prog_bar.update()
                print(f'{dataset_mode} label slicing finished!')

        print('Removing the temporary files...')
    print('Done!')

if __name__ == '__main__':
    main()

验证结果示例(iSAID 验证集):

+--------------------+-------+-------+
|       类别         |  IoU  |  Acc  |
+--------------------+-------+-------+
|     background     | 92.46 | 94.05 |
|        ship        |  15.0 | 15.81 |
|     store_tank     |  0.71 |  0.71 |
|  baseball_diamond  | 35.91 |  70.0 |
|    tennis_court    | 46.91 | 70.09 |
|  basketball_court  | 24.37 | 30.15 |
| Ground_Track_Field |  2.81 |  5.86 |
|       Bridge       |  3.2  | 39.18 |
|   Large_Vehicle    | 19.74 | 30.27 |
|   Small_Vehicle    |  8.46 |  9.3  |
|     Helicopter     | 17.15 | 35.47 |
|   Swimming_pool    |  6.37 | 50.19 |
|     Roundabout     |  1.36 | 68.22 |
| Soccer_ball_field  | 39.91 | 43.09 |
|       plane        | 29.19 | 80.32 |
|       Harbor       |  3.07 |  25.2 |
+--------------------+-------+-------+
aAcc: 92.29  mIoU: 21.66  mAcc: 41.75

与原论文结果高度一致。

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions