Skip to content

Grayscale Image Support in Real-ESRGAN (1 Channel) Needs Manual Fixes #744

@mhach06

Description

@mhach06

Problem

The RealESRGANPairedDataset in Real-ESRGAN currently does not support grayscale images (1 input/output channel) out of the box.

Even after modifying the num_in_ch and num_out_ch values in the .yml config file (network_g and network_d), the code fails when working with grayscale datasets.

The root cause is the img2tensor() function in BasicSR/basicsr/utils/img_util.py.


Note

This only works if both num_in_ch and num_out_ch are set to 1.
If either is not 1, the code will crash (which is acceptable for most grayscale workflows).


Proposed Fix

The following changes (3 files) allow Real-ESRGAN to train/validate on grayscale datasets without errors.


Change 1: BasicSR/basicsr/utils/img_util.py

-def img2tensor(imgs, bgr2rgb=True, float32=True):
+def img2tensor(imgs, bgr2rgb=True, float32=True, grayscale=False):

-    def _totensor(img, bgr2rgb, float32):
+    def _totensor(img, bgr2rgb, float32, grayscale):
         if img.ndim == 3 and img.shape[2] == 3:
             if img.dtype == 'float64':
                 img = img.astype('float32')
             if bgr2rgb:
                 img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+            if grayscale:
+                img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
+
+        if img.ndim == 2:
+            img = img[:, :, None]
 
-    if isinstance(imgs, list):
-        return [_totensor(img, bgr2rgb, float32) for img in imgs]
-    else:
-        return _totensor(imgs, bgr2rgb, float32)
+    if isinstance(imgs, list):
+        return [_totensor(img, bgr2rgb, float32, grayscale) for img in imgs]
+    else:
+        return _totensor(imgs, bgr2rgb, float32, grayscale)

Change 2: BasicSR/basicsr/data/realesrgan_paired_dataset.py

-        # BGR to RGB, HWC to CHW, numpy to tensor
-        img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=False, float32=True)
+        # BGR to RGB, HWC to CHW, numpy to tensor (grayscale if in_chans=1)
+        grayscale = True if self.opt['network_g']['num_in_ch'] == 1 else False
+        img_gt, img_lq = img2tensor([img_gt, img_lq],
+                                    bgr2rgb=False,
+                                    float32=True,
+                                    grayscale=grayscale)

Change 3: BasicSR/basicsr/train.py

-    for phase, dataset_opt in opt['datasets'].items():
-        if phase == 'train':
+    for phase, dataset_opt in opt['datasets'].items():
+        dataset_opt['network_g'] = opt['network_g']
+        if phase == 'train':

Result

After applying these changes:

  • Grayscale datasets (1-channel input/output) work without errors.
  • Training and validation proceed as expected.

I can open a PR if this seems useful.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions