From 17f21c9c1422d12bda5f9906d726478dfd870ae3 Mon Sep 17 00:00:00 2001 From: Prakhar Kulshreshtha Date: Wed, 3 Dec 2025 14:33:04 -0800 Subject: [PATCH] compare the value of the enum instead of enum directly --- src/dataset/__init__.py | 2 +- src/dataset/base_depth_dataset.py | 2 +- src/dataset/base_iid_dataset.py | 2 +- src/dataset/base_normals_dataset.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/dataset/__init__.py b/src/dataset/__init__.py index 1ef1229..88895ac 100644 --- a/src/dataset/__init__.py +++ b/src/dataset/__init__.py @@ -86,7 +86,7 @@ def get_dataset( List[BaseNormalsDataset], ]: if "mixed" == cfg_data_split.name: - assert DatasetMode.TRAIN == mode, "Only training mode supports mixed datasets." + assert DatasetMode.TRAIN.value == mode.value, "Only training mode supports mixed datasets." dataset_ls = [ get_dataset(_cfg, base_data_dir, mode, **kwargs) for _cfg in cfg_data_split.dataset_list diff --git a/src/dataset/base_depth_dataset.py b/src/dataset/base_depth_dataset.py index 5b7a997..de1c39f 100644 --- a/src/dataset/base_depth_dataset.py +++ b/src/dataset/base_depth_dataset.py @@ -116,7 +116,7 @@ def __len__(self): def __getitem__(self, index): rasters, other = self._get_data_item(index) - if DatasetMode.TRAIN == self.mode: + if DatasetMode.TRAIN.value == self.mode.value: rasters = self._training_preprocess(rasters) # merge outputs = rasters diff --git a/src/dataset/base_iid_dataset.py b/src/dataset/base_iid_dataset.py index a7c302f..4eca8f3 100644 --- a/src/dataset/base_iid_dataset.py +++ b/src/dataset/base_iid_dataset.py @@ -99,7 +99,7 @@ def __len__(self): def __getitem__(self, index): rasters, other = self._get_data_item(index) - if DatasetMode.TRAIN == self.mode: + if DatasetMode.TRAIN.value == self.mode.value: rasters = self._training_preprocess(rasters) # merge outputs = rasters diff --git a/src/dataset/base_normals_dataset.py b/src/dataset/base_normals_dataset.py index 6b5cc36..dbda493 100644 --- a/src/dataset/base_normals_dataset.py +++ b/src/dataset/base_normals_dataset.py @@ -92,7 +92,7 @@ def __len__(self): def __getitem__(self, index): rasters, other = self._get_data_item(index) - if DatasetMode.TRAIN == self.mode: + if DatasetMode.TRAIN.value == self.mode.value: rasters = self._training_preprocess(rasters) # merge outputs = rasters