diff --git a/src/dataset/__init__.py b/src/dataset/__init__.py index 1ef1229..88895ac 100644 --- a/src/dataset/__init__.py +++ b/src/dataset/__init__.py @@ -86,7 +86,7 @@ def get_dataset( List[BaseNormalsDataset], ]: if "mixed" == cfg_data_split.name: - assert DatasetMode.TRAIN == mode, "Only training mode supports mixed datasets." + assert DatasetMode.TRAIN.value == mode.value, "Only training mode supports mixed datasets." dataset_ls = [ get_dataset(_cfg, base_data_dir, mode, **kwargs) for _cfg in cfg_data_split.dataset_list diff --git a/src/dataset/base_depth_dataset.py b/src/dataset/base_depth_dataset.py index 5b7a997..de1c39f 100644 --- a/src/dataset/base_depth_dataset.py +++ b/src/dataset/base_depth_dataset.py @@ -116,7 +116,7 @@ def __len__(self): def __getitem__(self, index): rasters, other = self._get_data_item(index) - if DatasetMode.TRAIN == self.mode: + if DatasetMode.TRAIN.value == self.mode.value: rasters = self._training_preprocess(rasters) # merge outputs = rasters diff --git a/src/dataset/base_iid_dataset.py b/src/dataset/base_iid_dataset.py index a7c302f..4eca8f3 100644 --- a/src/dataset/base_iid_dataset.py +++ b/src/dataset/base_iid_dataset.py @@ -99,7 +99,7 @@ def __len__(self): def __getitem__(self, index): rasters, other = self._get_data_item(index) - if DatasetMode.TRAIN == self.mode: + if DatasetMode.TRAIN.value == self.mode.value: rasters = self._training_preprocess(rasters) # merge outputs = rasters diff --git a/src/dataset/base_normals_dataset.py b/src/dataset/base_normals_dataset.py index 6b5cc36..dbda493 100644 --- a/src/dataset/base_normals_dataset.py +++ b/src/dataset/base_normals_dataset.py @@ -92,7 +92,7 @@ def __len__(self): def __getitem__(self, index): rasters, other = self._get_data_item(index) - if DatasetMode.TRAIN == self.mode: + if DatasetMode.TRAIN.value == self.mode.value: rasters = self._training_preprocess(rasters) # merge outputs = rasters