From 45f23397c94ebf7e5ed0b26ad860387915576e3f Mon Sep 17 00:00:00 2001 From: Laure Ciernik <64899273+lciernik@users.noreply.github.com> Date: Mon, 11 Aug 2025 11:09:49 +0200 Subject: [PATCH 1/4] fixed some bug --- thingsvision/core/extraction/base.py | 82 ++++++++++++++++----------- thingsvision/core/extraction/torch.py | 76 +++++++++++++++++++++---- 2 files changed, 115 insertions(+), 43 deletions(-) diff --git a/thingsvision/core/extraction/base.py b/thingsvision/core/extraction/base.py index 6b5db03..d8ba337 100644 --- a/thingsvision/core/extraction/base.py +++ b/thingsvision/core/extraction/base.py @@ -164,6 +164,14 @@ def _module_and_output_check( output_type in self.get_output_types() ), f"\nData type of output feature matrix must be set to one of the following available data types: {self.get_output_types()}\n" + def _save_features(self, features, features_file, extension): + if extension == "npy": + np.save(features_file, features) + elif extension == "pt": + torch.save(features, features_file) + else: + raise ValueError(f"Invalid extension: {extension}") + def extract_features( self, batches: Iterator[Union[TensorType["b", "c", "h", "w"], Array]], @@ -280,31 +288,35 @@ def extract_features( features[module_name].append(modules_features[module_name]) if output_dir and (i % step_size == 0 or i == len(batches)): + curr_output_dir = os.path.join(output_dir, module_name) + if not os.path.exists(curr_output_dir): + print(f"Creating output directory: {curr_output_dir}") + os.makedirs(curr_output_dir) + if self.get_backend() == "pt": features_subset = torch.cat(features[module_name]) if output_type == "ndarray": features_subset = self._to_numpy(features_subset) - features_subset_file = os.path.join( - output_dir, - f"{module_name}/features{file_name_suffix}_{last_image_ct}-{image_ct}.npy", - ) - np.save(features_subset_file, features_subset) - else: # output_type = tensor - features_subset_file = os.path.join( - output_dir, - f"{module_name}/features{file_name_suffix}_{last_image_ct}-{image_ct}.pt", - ) - torch.save(features_subset, features_subset_file) + file_extension = "npy" + else: + file_extension = "pt" else: - features_subset_file = os.path.join( - output_dir, - f"{module_name}/features{file_name_suffix}_{last_image_ct}-{image_ct}.npy", - ) features_subset = np.vstack(features[module_name]) - np.save(features_subset_file, features_subset) - features = defaultdict(list) - last_image_ct = image_ct + file_extension = "npy" + + features_subset_file = os.path.join( + curr_output_dir, + f"features{file_name_suffix}_{last_image_ct}-{image_ct}.{file_extension}", + ) + self._save_features( + features_subset, features_subset_file, file_extension + ) + + # Note: we add full file paths to feature_file_names to be able to load the features later feature_file_names[module_name].append(features_subset_file) + features[module_name] = [] + last_image_ct = image_ct + print( f"...Features successfully extracted for all {image_ct} images in the database." ) @@ -316,29 +328,31 @@ def extract_features( features = [] for file in feature_file_names[module_name]: if self.get_backend() == "pt" and output_type != "ndarray": - if file.endswith(".pt"): - features.append( - torch.load(os.path.join(output_dir, file)) - ) + features.append(torch.load(file)) + elif file.endswith(".npy"): + features.append(np.load(file)) else: - if file.endswith(".npy"): - features.append( - np.load(os.path.join(output_dir, file)) - ) + raise ValueError( + f"Invalid or unsupported file extension: {file}" + ) + features_file = os.path.join( output_dir, f"{module_name}/features{file_name_suffix}" ) if output_type == "ndarray": - np.save(f"{features_file}.npy", np.concatenate(features)) - else: # output_type = tensor - torch.save(torch.cat(features), f"{features_file}.pt") + self._save_features( + np.concatenate(features), features_file + ".npy", "npy" + ) + else: + self._save_features( + torch.cat(features), features_file + ".pt", "pt" + ) print( f"...Features for module '{module_name}' were saved to {features_file}." ) - # remove temporary files for file in feature_file_names[module_name]: os.remove(os.path.join(output_dir, file)) - + print(f"...Features were saved to {output_dir}.") return None else: @@ -349,9 +363,11 @@ def extract_features( features[module_name] = self._to_numpy(features[module_name]) else: features[module_name] = np.vstack(features[module_name]) - print(f"...Features shape: {features[module_name].shape}") + print( + f"...Features for module '{module_name}' have shape: {features[module_name].shape}" + ) + if single_module_call: - # for backward compatibility return features[module_name] return features diff --git a/thingsvision/core/extraction/torch.py b/thingsvision/core/extraction/torch.py index c9e0bcd..15c9cc7 100644 --- a/thingsvision/core/extraction/torch.py +++ b/thingsvision/core/extraction/torch.py @@ -7,6 +7,7 @@ from torchvision import transforms as T import torch +from torch.utils.data import DataLoader from .base import BaseExtractor @@ -101,7 +102,10 @@ def batch_extraction( ) -> object: """Allows mini-batch extraction for custom data pipeline using a with-statement.""" return BatchExtraction( - extractor=self, module_name=module_name, module_names=module_names, output_type=output_type + extractor=self, + module_name=module_name, + module_names=module_names, + output_type=output_type, ) def extract_batch( @@ -206,6 +210,8 @@ def extract_features( output_type: str = "ndarray", output_dir: Optional[str] = None, step_size: Optional[int] = None, + file_name_suffix: str = "", + save_in_one_file: bool = False, ): if not bool(module_name) ^ bool(module_names): raise ValueError( @@ -217,15 +223,19 @@ def extract_features( self._register_hooks(module_names=[module_name]) else: self._register_hooks(module_names=module_names) - features = super().extract_features( - batches=batches, - module_name=module_name, - module_names=module_names, - flatten_acts=flatten_acts, - output_type=output_type, - output_dir=output_dir, - step_size=step_size, - ) + + with ImageOnlyDataloaderModifier(dataloader=batches) as dataloader: + features = super().extract_features( + batches=dataloader, + module_name=module_name, + module_names=module_names, + flatten_acts=flatten_acts, + output_type=output_type, + output_dir=output_dir, + step_size=step_size, + file_name_suffix=file_name_suffix, + save_in_one_file=save_in_one_file, + ) self._unregister_hooks() return features @@ -369,3 +379,49 @@ def __exit__(self, *args): delattr(self.extractor, "module_name") delattr(self.extractor, "module_names") delattr(self.extractor, "output_type") + + +class ImageOnlyDataloaderModifier: + """Class to temporarily replace the collate function of a dataloader with images only collate. + This is useful when we want to use the dataloader for feature extraction with thingsvision. + Assumes that a dataloader either returns image or a tuple of (image, *args) or a list of (image, *args). + If the dataloader returns a tuple of (image, *args) or a list of (image, *args), it will return only the images. + The class does not modify otherwise (e.g., specialized dataloaders). + """ + + def __init__(self, dataloader: DataLoader) -> None: + self.dataloader = dataloader + self.new_collate_fn = self._images_only_collate + self.original_collate_fn = None + self.should_replace = False + + @staticmethod + def _images_only_collate(batch: list[tuple[torch.Tensor, ...]]) -> torch.Tensor: + """Collate function to return only the images from the batch. + This is useful when we want to use the dataloader for feature extraction with thingsvision. + """ + return torch.stack([item[0] for item in batch]) + + def _check_dataloader_format(self) -> bool: + sample_batch = next(iter(self.dataloader)) + return isinstance(sample_batch, (tuple, list)) and len(sample_batch) == 2 + + def __enter__(self) -> DataLoader: + """Enter the context manager and replace the collate function with images only collate, + if the dataloader is in the correct format. + """ + self.should_replace = self._check_dataloader_format() + if self.should_replace: + warnings.warn( + "\nThe dataloader is not in the correct format. The collate function will be replaced with images only collate.\n" + ) + self.original_collate_fn = self.dataloader.collate_fn + self.dataloader.collate_fn = self.new_collate_fn + return self.dataloader + + def __exit__(self, *args) -> None: + """Exit the context manager and restore the original collate function, + if the dataloader was not in the correct format. + """ + if self.should_replace and self.original_collate_fn is not None: + self.dataloader.collate_fn = self.original_collate_fn From bd929e66bd62bc0d5e66d155cab69f5365b984f3 Mon Sep 17 00:00:00 2001 From: Laure Ciernik <64899273+lciernik@users.noreply.github.com> Date: Mon, 11 Aug 2025 11:37:52 +0200 Subject: [PATCH 2/4] another bug --- thingsvision/core/extraction/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thingsvision/core/extraction/base.py b/thingsvision/core/extraction/base.py index d8ba337..d3edd6d 100644 --- a/thingsvision/core/extraction/base.py +++ b/thingsvision/core/extraction/base.py @@ -351,7 +351,7 @@ def extract_features( f"...Features for module '{module_name}' were saved to {features_file}." ) for file in feature_file_names[module_name]: - os.remove(os.path.join(output_dir, file)) + os.remove(file) print(f"...Features were saved to {output_dir}.") return None From a7ae787652c798bf79056997d021fc489574c6a3 Mon Sep 17 00:00:00 2001 From: Laure Ciernik <64899273+lciernik@users.noreply.github.com> Date: Mon, 11 Aug 2025 14:27:31 +0200 Subject: [PATCH 3/4] new test and package deoendency --- requirements.txt | 1 + setup.py | 1 + tests/test_features.py | 14 +++++++++++--- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 3b2b1f5..8bcb4c0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,6 +16,7 @@ torch>=2.0.0 torchvision==0.15.2 torchtyping tqdm +accelerate<1.10.0 transformers==4.40.1 pytest git+https://github.com/openai/CLIP.git diff --git a/setup.py b/setup.py index ff2df9e..a4ee525 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,7 @@ "torchvision==0.15.2", "torchtyping", "tqdm", + "accelerate<1.10.0", "transformers==4.40.1", "pytest", ] diff --git a/tests/test_features.py b/tests/test_features.py index 0dbd266..83facdb 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -44,7 +44,7 @@ def get_4D_features(self): flatten_acts=False, ) return features - + def get_multi_features(self): model_name = "vgg16_bn" extractor, _, batches = helper.create_extractor_and_dataloader( @@ -102,7 +102,7 @@ def test_storing_4d(self): ) self.check_file_exists("features", format, False) - + def test_storing_multi(self): features = self.get_multi_features() for _, feature in features.items(): @@ -115,6 +115,14 @@ def test_storing_multi(self): ) self.check_file_exists(f"features", format, False) + def test_extract_multi(self): + features = self.get_multi_features() + row_counts = [feature.shape[0] for feature in features.values()] + self.assertTrue( + all(count == row_counts[0] for count in row_counts), + "Not all features have the same number of rows!", + ) + def test_splitting_2d(self): n_splits = 3 features = self.get_2D_features() @@ -154,7 +162,7 @@ def test_splitting_4d(self): file_format="txt", n_splits=n_splits, ) - + def test_splitting_multi(self): n_splits = 3 features = self.get_multi_features() From bae13a4284db71b395572b0803e176d09b4ff757 Mon Sep 17 00:00:00 2001 From: Laure Ciernik <64899273+lciernik@users.noreply.github.com> Date: Mon, 11 Aug 2025 16:00:12 +0200 Subject: [PATCH 4/4] additional test for ImageOnlyDataloader modification --- tests/helper.py | 31 +++++++++++++- tests/test_rest.py | 101 ++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 130 insertions(+), 2 deletions(-) diff --git a/tests/helper.py b/tests/helper.py index 8521de0..2733a98 100644 --- a/tests/helper.py +++ b/tests/helper.py @@ -12,7 +12,7 @@ from thingsvision import get_extractor from thingsvision.utils.data import DataLoader, ImageDataset -from torch.utils.data import Subset +from torch.utils.data import Subset, Dataset DATA_PATH = "./data" TEST_PATH = "./test_images" @@ -361,6 +361,35 @@ def __len__(self) -> int: return len(self.values) +class MockImageDataset(Dataset): + """Mock dataset that returns (image, label) tuples""" + + def __init__(self, size=10): + self.size = size + + def __len__(self): + return self.size + + def __getitem__(self, idx): + image = torch.randn(3, 32, 32) + label = torch.tensor(idx % 5) + return image, label + + +class MockImageOnlyDataset(Dataset): + """Mock dataset that returns only images (not tuples)""" + + def __init__(self, size=10): + self.size = size + + def __len__(self): + return self.size + + def __getitem__(self, idx): + # Return only image tensor (no tuple) + return torch.randn(3, 32, 32) + + def iterate_through_all_model_combinations(): for model_config in MODEL_AND_MODULE_NAMES.values(): model_name = model_config["model_name"] diff --git a/tests/test_rest.py b/tests/test_rest.py index ecb3ba9..fcfb06f 100644 --- a/tests/test_rest.py +++ b/tests/test_rest.py @@ -1,12 +1,15 @@ import os import unittest - +import warnings import numpy as np +import torch +from torch.utils.data import DataLoader import tests.helper as helper from thingsvision.core.cka import get_cka from thingsvision.core.rsa import compute_rdm, correlate_rdms, plot_rdm from thingsvision.utils.storing import save_features +from thingsvision.core.extraction.torch import ImageOnlyDataloaderModifier class RSATestCase(unittest.TestCase): @@ -114,3 +117,99 @@ def test_filenames(self): if f.endswith("png"): img_files.append(os.path.join(root, f)) self.assertEqual(sorted(file_names), sorted(img_files)) + + +class TestImageOnlyDataloaderModifier(unittest.TestCase): + + def test_context_manager_with_tuple_format_dataloader(self): + """ + Test 1: Test the context manager with a dataloader that returns (image, label) tuples. + Should replace collate function and extract only images. + """ + dataset = helper.MockImageDataset(size=4) + dataloader = DataLoader(dataset, batch_size=2, shuffle=False) + modifier = ImageOnlyDataloaderModifier(dataloader) + + original_collate_fn = dataloader.collate_fn + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + with modifier as modified_dataloader: + self.assertEqual(len(w), 1) + self.assertIn( + "The dataloader is not in the correct format", str(w[0].message) + ) + + self.assertNotEqual(modified_dataloader.collate_fn, original_collate_fn) + self.assertEqual( + modified_dataloader.collate_fn, modifier.new_collate_fn + ) + self.assertTrue(modifier.should_replace) + + batch = next(iter(modified_dataloader)) + self.assertIsInstance(batch, torch.Tensor) + self.assertEqual(batch.shape, (2, 3, 32, 32)) + + self.assertEqual(dataloader.collate_fn, original_collate_fn) + self.assertEqual(modifier.original_collate_fn, original_collate_fn) + + def test_context_manager_with_image_only_dataloader(self): + """ + Test 2: Test the context manager with a dataloader that already returns only images. + Should NOT replace collate function since format is already correct. + """ + + dataset = helper.MockImageOnlyDataset(size=4) + dataloader = DataLoader(dataset, batch_size=2, shuffle=False) + modifier = ImageOnlyDataloaderModifier(dataloader) + + original_collate_fn = dataloader.collate_fn + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + with modifier as modified_dataloader: + assert len(w) == 0 + assert modified_dataloader.collate_fn == original_collate_fn + assert modifier.should_replace is False + assert modifier.original_collate_fn is None + + batch = next(iter(modified_dataloader)) + assert isinstance(batch, torch.Tensor) + assert batch.shape == (2, 3, 32, 32) + + assert dataloader.collate_fn == original_collate_fn + + def test_images_only_collate_function(self): + """ + Test 3: Test the static _images_only_collate function directly. + Verify it correctly extracts images from tuples. + """ + mock_batch = [ + (torch.randn(3, 32, 32), torch.tensor(0)), + (torch.randn(3, 32, 32), torch.tensor(1)), + (torch.randn(3, 32, 32), torch.tensor(2)), + ] + + result = ImageOnlyDataloaderModifier._images_only_collate(mock_batch) + + assert isinstance(result, torch.Tensor) + assert result.shape == (3, 3, 32, 32) + + for i, (original_image, _) in enumerate(mock_batch): + torch.testing.assert_close(result[i], original_image) + + def test_check_dataloader_format_method(self): + """ + Bonus Test: Test the _check_dataloader_format method directly. + """ + dataset_tuple = helper.MockImageDataset(size=2) + dataloader_tuple = DataLoader(dataset_tuple, batch_size=1) + modifier_tuple = ImageOnlyDataloaderModifier(dataloader_tuple) + assert modifier_tuple._check_dataloader_format() is True + + dataset_image = helper.MockImageOnlyDataset(size=2) + dataloader_image = DataLoader(dataset_image, batch_size=1) + modifier_image = ImageOnlyDataloaderModifier(dataloader_image) + assert modifier_image._check_dataloader_format() is False