From 121500fcf207279f6f25fac447c3072c491827a1 Mon Sep 17 00:00:00 2001 From: Johannes Roth Date: Tue, 3 Jun 2025 21:02:22 +0200 Subject: [PATCH] Fix documentation, coordinate handling bug, clarify docstring, add test --- README.md | 2 +- src/pysearchlight/__init__.py | 3 ++- src/pysearchlight/sl.py | 16 ++++++++++++---- tests/test_sl.py | 12 ++++++++++++ 4 files changed, 27 insertions(+), 6 deletions(-) create mode 100644 tests/test_sl.py diff --git a/README.md b/README.md index 85b5ac7..0527a94 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ pip install pysearchlight PySearchlight takes care of moving the sphere across the data and applying the user-defined function to the data within the sphere. The user only needs to provide the data, the function to apply within the searchlight, and the radius of the searchlight sphere. The function should take a single argument, which is the data within the sphere centered at a voxel location. The function can also take additional arguments, which can be passed using `functools.partial`. -Here is a simple example of how to use PySearchLight to train and evaluate classifier on data within a searchlight: +Here is a simple example of how to use PySearchLight to train and evaluate a classifier on data within a searchlight: ```python import numpy as np diff --git a/src/pysearchlight/__init__.py b/src/pysearchlight/__init__.py index c594f17..f76da0f 100644 --- a/src/pysearchlight/__init__.py +++ b/src/pysearchlight/__init__.py @@ -1 +1,2 @@ -from .sl import Searchlight +# Re-export the main SearchLight class +from .sl import SearchLight diff --git a/src/pysearchlight/sl.py b/src/pysearchlight/sl.py index cddbf49..89736e7 100644 --- a/src/pysearchlight/sl.py +++ b/src/pysearchlight/sl.py @@ -8,8 +8,12 @@ @jit(nopython=True, cache=True) def get_searchlight_data(data, coords): - """ - Returns data in a sphere with radius self.radius at position x, y, z + """Return the data for each list of voxel indices in ``coords``. + + ``coords`` should contain one iterable of voxel indices per sphere. Each + iterable lists the flattened indices of voxels belonging to that sphere. The + returned array therefore has shape ``(len(coords), len(coords[0]), + data.shape[1])``. """ # Get coordinates of all voxels in sphere out = np.zeros((len(coords), len(coords[0]), data.shape[1])) @@ -77,13 +81,17 @@ def fit( x_shape, y_shape, z_shape = self.original_data_shape[:3] if coords is not None: - coords = np.array(coords) + coords = list(coords) elif self.mask is not None: print("Using mask to determine searchlight coordinates") coords = list(zip(*np.nonzero(self.mask))) else: print("Neither mask nor coordinates specified, using all voxels") - coords = itertools.product(range(x_shape), range(y_shape), range(z_shape)) + coords = list(itertools.product(range(x_shape), range(y_shape), range(z_shape))) + # ``coords`` may originate from a generator (e.g. ``itertools.product``) + # which would be exhausted after computing ``sphere_coords`` below. By + # converting it to a list we ensure it can be iterated over multiple + # times, both for computing sphere indices and for assigning results. results = np.zeros((x_shape, y_shape, z_shape, output_size)) diff --git a/tests/test_sl.py b/tests/test_sl.py new file mode 100644 index 0000000..6e711b9 --- /dev/null +++ b/tests/test_sl.py @@ -0,0 +1,12 @@ +import numpy as np +from pysearchlight.sl import SearchLight + +def dummy(data): + return 0 + +def test_get_sphere_coords_size_consistency(): + data = np.zeros((5, 5, 5, 1)) + sl = SearchLight(data=data, sl_fn=dummy, radius=1) + center = sl.get_sphere_coords(2, 2, 2) + edge = sl.get_sphere_coords(0, 0, 0) + assert len(center) == len(edge)