diff --git a/datasets/README.md b/datasets/README.md index c662f8e2..cb83a8db 100644 --- a/datasets/README.md +++ b/datasets/README.md @@ -9,7 +9,8 @@ . ├── TransUNet │   ├──datasets -│   │    └── dataset_*.py +│   │    ├── dataset_*.py +│ │ └── preprocess_data.py │   ├──train.py │   ├──test.py │   └──... diff --git a/datasets/preprocess_data.py b/datasets/preprocess_data.py new file mode 100644 index 00000000..7e5b2138 --- /dev/null +++ b/datasets/preprocess_data.py @@ -0,0 +1,90 @@ +import os +import time +import argparse +from glob import glob + +import h5py +import nibabel as nib +import numpy as np +from tqdm import tqdm + + +parser = argparse.ArgumentParser() +parser.add_argument('--src_path', type=str, + default='../data/Abdomen/RawData', help='download path for Synapse data') +parser.add_argument('--dst_path', type=str, + default='../data/Synapse', help='root dir for data') +parser.add_argument('--use_normalize', action='store_true', default=True, + help='use normalize') +args = parser.parse_args() + + +def preprocess_train_image(image_files: str, label_files: str) -> None: + os.makedirs(f"{args.dst_path}/train_npz", exist_ok=True) + + a_min, a_max = -125, 275 + + pbar = tqdm(zip(image_files, label_files), total=len(image_files)) + for image_file, label_file in pbar: + # **/imgXXXX.nii.gz -> parse XXXX + number = image_file.split('/')[-1][3:7] + + image_data = nib.load(image_file).get_fdata() + label_data = nib.load(label_file).get_fdata() + + image_data = np.clip(image_data, a_min, a_max) + if args.use_normalize: + assert a_max != a_min + image_data = (image_data - a_min) / (a_max - a_min) + + H, W, D = image_data.shape + + image_data = np.transpose(image_data, (2, 1, 0)) + label_data = np.transpose(label_data, (2, 1, 0)) + + for dep in range(D): + save_path = f"{args.dst_path}/train_npz/case{number}_slice{dep:03d}.npz" + np.savez(save_path, label=label_data[dep,:,:], image=image_data[dep,:,:]) + pbar.close() + + +def preprocess_valid_image(image_files: str, label_files: str) -> None: + os.makedirs(f"{args.dst_path}/test_vol_h5", exist_ok=True) + + a_min, a_max = -125, 275 + + pbar = tqdm(zip(image_files, label_files), total=len(image_files)) + for image_file, label_file in pbar: + # **/imgXXXX.nii.gz -> parse XXXX + number = image_file.split('/')[-1][3:7] + + image_data = nib.load(image_file).get_fdata() + label_data = nib.load(label_file).get_fdata() + + image_data = np.clip(image_data, a_min, a_max) + if args.use_normalize: + assert a_max != a_min + image_data = (image_data - a_min) / (a_max - a_min) + + H, W, D = image_data.shape + + image_data = np.transpose(image_data, (2, 1, 0)) + label_data = np.transpose(label_data, (2, 1, 0)) + + save_path = f"{args.dst_path}/test_vol_h5/case{number}.npy.h5" + f = h5py.File(save_path, 'w') + f['image'] = image_data + f['label'] = label_data + f.close() + pbar.close() + + +if __name__ == "__main__": + data_root = f"{args.src_path}/Training" + + # String sort + image_files = sorted(glob(f"{data_root}/img/*.nii.gz")) + label_files = sorted(glob(f"{data_root}/label/*.nii.gz")) + + preprocess_train_image(image_files, label_files) + preprocess_valid_image(image_files, label_files) diff --git a/requirements.txt b/requirements.txt index 4abfe422..13806154 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -torch==1.4.0 -torchvision==0.5.0 +torch>=1.4.0 +torchvision>=0.5.0 numpy tqdm tensorboard @@ -9,3 +9,4 @@ medpy SimpleITK scipy h5py +nibabel \ No newline at end of file diff --git a/test.py b/test.py index 35a48027..f70ff0cb 100644 --- a/test.py +++ b/test.py @@ -20,7 +20,7 @@ parser.add_argument('--dataset', type=str, default='Synapse', help='experiment_name') parser.add_argument('--num_classes', type=int, - default=4, help='output channel of network') + default=14, help='output channel of network') parser.add_argument('--list_dir', type=str, default='./lists/lists_Synapse', help='list dir') @@ -51,8 +51,7 @@ def inference(args, model, test_save_path=None): for i_batch, sampled_batch in tqdm(enumerate(testloader)): h, w = sampled_batch["image"].size()[2:] image, label, case_name = sampled_batch["image"], sampled_batch["label"], sampled_batch['case_name'][0] - metric_i = test_single_volume(image, label, model, classes=args.num_classes, patch_size=[args.img_size, args.img_size], - test_save_path=test_save_path, case=case_name, z_spacing=args.z_spacing) + metric_i = test_single_volume(image, label, model, classes=args.num_classes, patch_size=[args.img_size, args.img_size], test_save_path=test_save_path, case=case_name, z_spacing=args.z_spacing) metric_list += np.array(metric_i) logging.info('idx %d case %s mean_dice %f mean_hd95 %f' % (i_batch, case_name, np.mean(metric_i, axis=0)[0], np.mean(metric_i, axis=0)[1])) metric_list = metric_list / len(db_test) @@ -80,9 +79,9 @@ def inference(args, model, test_save_path=None): dataset_config = { 'Synapse': { 'Dataset': Synapse_dataset, - 'volume_path': '../data/Synapse/test_vol_h5', - 'list_dir': './lists/lists_Synapse', - 'num_classes': 9, + 'volume_path': args.volume_path, + 'list_dir': args.list_dir, + 'num_classes': args.num_classes, 'z_spacing': 1, }, } diff --git a/train.py b/train.py index 438dc76b..41125bad 100644 --- a/train.py +++ b/train.py @@ -17,7 +17,7 @@ parser.add_argument('--list_dir', type=str, default='./lists/lists_Synapse', help='list dir') parser.add_argument('--num_classes', type=int, - default=9, help='output channel of network') + default=14, help='output channel of network') parser.add_argument('--max_iterations', type=int, default=30000, help='maximum epoch number to train') parser.add_argument('--max_epochs', type=int, @@ -57,9 +57,9 @@ dataset_name = args.dataset dataset_config = { 'Synapse': { - 'root_path': '../data/Synapse/train_npz', - 'list_dir': './lists/lists_Synapse', - 'num_classes': 9, + 'root_path': args.root_path, + 'list_dir': args.list_dir, + 'num_classes': args.num_classes, }, } args.num_classes = dataset_config[dataset_name]['num_classes'] diff --git a/utils.py b/utils.py index 0e3a1bf9..d33dad6f 100644 --- a/utils.py +++ b/utils.py @@ -79,8 +79,7 @@ def test_single_volume(image, label, net, classes, patch_size=[256, 256], test_s pred = out prediction[ind] = pred else: - input = torch.from_numpy(image).unsqueeze( - 0).unsqueeze(0).float().cuda() + input = torch.from_numpy(image).unsqueeze(0).unsqueeze(0).float().cuda() net.eval() with torch.no_grad(): out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0)