diff --git a/src/pysearchlight/__init__.py b/src/pysearchlight/__init__.py index c594f17..9f4518b 100644 --- a/src/pysearchlight/__init__.py +++ b/src/pysearchlight/__init__.py @@ -1 +1,3 @@ -from .sl import Searchlight +from .sl import SearchLight + +__all__ = ["SearchLight"] diff --git a/src/pysearchlight/examples.py b/src/pysearchlight/examples.py index 82ac31d..c09f7ef 100644 --- a/src/pysearchlight/examples.py +++ b/src/pysearchlight/examples.py @@ -1,10 +1,15 @@ import numpy as np -import itertools from sklearn.model_selection import cross_val_score from sklearn.svm import SVC +from typing import Union, List +from numpy.typing import NDArray -def fit_clf(data, labels, clf=SVC()): +def fit_clf( + data: NDArray[np.floating], + labels: Union[NDArray, List[NDArray]], + clf: SVC = SVC(), +) -> float: """ Returns classification accuracy for a single voxel (cross-validated). Labels can either be a numpy array or a list of numpy arrays (in which case the classification accuracies for every diff --git a/src/pysearchlight/sl.py b/src/pysearchlight/sl.py index cddbf49..69cd3cb 100644 --- a/src/pysearchlight/sl.py +++ b/src/pysearchlight/sl.py @@ -3,14 +3,13 @@ from joblib import Parallel, delayed import tqdm from numba import jit -from typing import List, Callable, Tuple +from typing import List, Callable, Tuple, Union, Optional +from numpy.typing import NDArray @jit(nopython=True, cache=True) -def get_searchlight_data(data, coords): - """ - Returns data in a sphere with radius self.radius at position x, y, z - """ +def get_searchlight_data(data: NDArray[np.floating], coords: NDArray[np.int64]) -> NDArray[np.floating]: + """Return all voxel data for a list of sphere coordinates.""" # Get coordinates of all voxels in sphere out = np.zeros((len(coords), len(coords[0]), data.shape[1])) for i, sphere_coords in enumerate(coords): @@ -26,8 +25,12 @@ class SearchLight: """ def __init__( - self, data: np.array, sl_fn: Callable, radius: int, mask: np.array = None - ): + self, + data: NDArray[np.floating], + sl_fn: Callable[[NDArray[np.floating]], Union[float, NDArray[np.floating]]], + radius: int, + mask: Optional[NDArray[np.int_]] = None, + ) -> None: """ Parameters ---------- @@ -52,12 +55,12 @@ def __init__( def fit( self, - coords: List[Tuple] = None, + coords: Optional[List[Tuple[int, int, int]]] = None, output_size: int = 1, n_jobs: int = 1, n_chunks: int = 1, verbose: int = 1, - ) -> np.array: + ) -> NDArray[np.floating]: """ Returns a 3D array with the classification accuracy for each voxel. @@ -116,7 +119,7 @@ def fit( return results - def get_sphere_coords(self, x, y, z): + def get_sphere_coords(self, x: int, y: int, z: int) -> NDArray[np.int_]: """ Returns coordinates of all voxels in a sphere with radius self.radius at position x, y, z