-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Open
Description
Problem
The RealESRGANPairedDataset in Real-ESRGAN currently does not support grayscale images (1 input/output channel) out of the box.
Even after modifying the num_in_ch and num_out_ch values in the .yml config file (network_g and network_d), the code fails when working with grayscale datasets.
The root cause is the img2tensor() function in BasicSR/basicsr/utils/img_util.py.
Note
This only works if both num_in_ch and num_out_ch are set to 1.
If either is not 1, the code will crash (which is acceptable for most grayscale workflows).
Proposed Fix
The following changes (3 files) allow Real-ESRGAN to train/validate on grayscale datasets without errors.
Change 1: BasicSR/basicsr/utils/img_util.py
-def img2tensor(imgs, bgr2rgb=True, float32=True):
+def img2tensor(imgs, bgr2rgb=True, float32=True, grayscale=False):
- def _totensor(img, bgr2rgb, float32):
+ def _totensor(img, bgr2rgb, float32, grayscale):
if img.ndim == 3 and img.shape[2] == 3:
if img.dtype == 'float64':
img = img.astype('float32')
if bgr2rgb:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ if grayscale:
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
+
+ if img.ndim == 2:
+ img = img[:, :, None]
- if isinstance(imgs, list):
- return [_totensor(img, bgr2rgb, float32) for img in imgs]
- else:
- return _totensor(imgs, bgr2rgb, float32)
+ if isinstance(imgs, list):
+ return [_totensor(img, bgr2rgb, float32, grayscale) for img in imgs]
+ else:
+ return _totensor(imgs, bgr2rgb, float32, grayscale)Change 2: BasicSR/basicsr/data/realesrgan_paired_dataset.py
- # BGR to RGB, HWC to CHW, numpy to tensor
- img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=False, float32=True)
+ # BGR to RGB, HWC to CHW, numpy to tensor (grayscale if in_chans=1)
+ grayscale = True if self.opt['network_g']['num_in_ch'] == 1 else False
+ img_gt, img_lq = img2tensor([img_gt, img_lq],
+ bgr2rgb=False,
+ float32=True,
+ grayscale=grayscale)Change 3: BasicSR/basicsr/train.py
- for phase, dataset_opt in opt['datasets'].items():
- if phase == 'train':
+ for phase, dataset_opt in opt['datasets'].items():
+ dataset_opt['network_g'] = opt['network_g']
+ if phase == 'train':Result
After applying these changes:
- Grayscale datasets (1-channel input/output) work without errors.
- Training and validation proceed as expected.
I can open a PR if this seems useful.
Metadata
Metadata
Assignees
Labels
No labels