Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,20 @@ results = sl.fit(
n_chunks=10, # The number of chunks to split the data into (to reduce memory usage)
verbose=1, # Verbosity level
)
```
```

## API

`SearchLight` is the main interface of the package. It takes a 4D data array of
shape `(x, y, z, samples)` and a function that will be applied to the data inside
each spherical region. The most important arguments are:

* `data` – the input array.
* `sl_fn` – a callable that receives the data from one searchlight sphere and
returns a value (or an array of values) for that location.
* `radius` – radius of the searchlight sphere in voxels.
* `mask` – optional binary mask to restrict the analysis to specific voxels.

Calling `fit` will return a 3D array with the computed result at every voxel
location. The shape of the returned array is `(x, y, z, output_size)` where
`output_size` is determined by the output of `sl_fn`.
19 changes: 15 additions & 4 deletions src/pysearchlight/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,21 @@


def fit_clf(data, labels, clf=SVC()):
"""
Returns classification accuracy for a single voxel (cross-validated). Labels can either be
a numpy array or a list of numpy arrays (in which case the classification accuracies for every
label array in the list is returned).
"""Example searchlight function using a classifier.

Parameters
----------
data : np.ndarray
Data from one searchlight sphere with shape ``(n_voxels, n_samples)``.
labels : np.ndarray
Array of labels for the samples.
clf : sklearn.base.BaseEstimator, optional
Classifier instance used for cross validation.

Returns
-------
float
Mean cross-validation accuracy of ``clf`` for the given data.
"""
data = np.nan_to_num(data)

Expand Down
65 changes: 48 additions & 17 deletions src/pysearchlight/sl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,22 @@

@jit(nopython=True, cache=True)
def get_searchlight_data(data, coords):
"""Return the data for each searchlight sphere.

Parameters
----------
data : np.ndarray
Flattened data array of shape ``(n_voxels, n_features)``.
coords : np.ndarray
Iterable of arrays containing voxel indices for every searchlight
sphere.

Returns
-------
np.ndarray
Array with shape ``(n_centers, n_sphere_voxels, n_features)`` containing
the extracted data for every searchlight center.
"""
Returns data in a sphere with radius self.radius at position x, y, z
"""
# Get coordinates of all voxels in sphere
out = np.zeros((len(coords), len(coords[0]), data.shape[1]))
for i, sphere_coords in enumerate(coords):
for j, coord in enumerate(sphere_coords):
Expand All @@ -21,8 +33,23 @@ def get_searchlight_data(data, coords):


class SearchLight:
"""
SearchLight analysis with cross-validation and specified classifier.
"""Perform a searchlight analysis on 4D data.

The class takes care of iterating a spherical region across the provided
data and applying a user supplied function to each sphere.

Parameters
----------
data : np.ndarray
Input array with shape ``(x, y, z, samples)``.
sl_fn : Callable
Function that receives the data within a sphere as
``(n_voxels_in_sphere, n_samples)`` and returns one or more values.
radius : int
Radius of the spherical searchlight in voxels.
mask : np.ndarray, optional
Optional boolean mask of shape ``(x, y, z)`` limiting the voxels that
are evaluated.
"""

def __init__(
Expand Down Expand Up @@ -58,21 +85,28 @@ def fit(
n_chunks: int = 1,
verbose: int = 1,
) -> np.array:
"""
Returns a 3D array with the classification accuracy for each voxel.
"""Run the searchlight analysis.

Parameters
----------
coords : list of tuples, optional
List of coordinates to use as searchlight centers. If not specified, all voxels will be used.
coords : list of tuple of int, optional
Specific coordinates to use as centers of the searchlight. When not
provided all voxels (or the provided mask) are used.
output_size : int, optional
Number of output values per voxel - allows for multiple outputs per voxel (e.g. for running multiple classifiers for each voxel).
Expected number of values returned by ``sl_fn`` for each voxel.
n_jobs : int, optional
Number of jobs to run in parallel.
Number of parallel jobs.
n_chunks : int, optional
Number of chunks to split data into for data preloading, to avoid memory issues.
Split the computation into this many chunks in order to reduce
memory consumption.
verbose : int, optional
Verbosity level.
Verbosity level forwarded to ``joblib``.

Returns
-------
np.ndarray
Array of shape ``(x, y, z, output_size)`` with the computed result
for each voxel.
"""
x_shape, y_shape, z_shape = self.original_data_shape[:3]

Expand Down Expand Up @@ -117,10 +151,7 @@ def fit(
return results

def get_sphere_coords(self, x, y, z):
"""
Returns coordinates of all voxels in a sphere with radius self.radius
at position x, y, z
"""
"""Return voxel indices for the sphere centered at ``(x, y, z)``."""
# Get coordinates of all voxels in cube with side length 2*self.radius+1
x_coords, y_coords, z_coords = np.mgrid[
x - self.radius : x + self.radius + 1,
Expand Down