Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/pysearchlight/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .sl import Searchlight
from .sl import SearchLight

__all__ = ["SearchLight"]
9 changes: 7 additions & 2 deletions src/pysearchlight/examples.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import numpy as np
import itertools
from sklearn.model_selection import cross_val_score
from sklearn.svm import SVC
from typing import Union, List
from numpy.typing import NDArray


def fit_clf(data, labels, clf=SVC()):
def fit_clf(
data: NDArray[np.floating],
labels: Union[NDArray, List[NDArray]],
clf: SVC = SVC(),
) -> float:
"""
Returns classification accuracy for a single voxel (cross-validated). Labels can either be
a numpy array or a list of numpy arrays (in which case the classification accuracies for every
Expand Down
23 changes: 13 additions & 10 deletions src/pysearchlight/sl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
from joblib import Parallel, delayed
import tqdm
from numba import jit
from typing import List, Callable, Tuple
from typing import List, Callable, Tuple, Union, Optional
from numpy.typing import NDArray


@jit(nopython=True, cache=True)
def get_searchlight_data(data, coords):
"""
Returns data in a sphere with radius self.radius at position x, y, z
"""
def get_searchlight_data(data: NDArray[np.floating], coords: NDArray[np.int64]) -> NDArray[np.floating]:
"""Return all voxel data for a list of sphere coordinates."""
# Get coordinates of all voxels in sphere
out = np.zeros((len(coords), len(coords[0]), data.shape[1]))
for i, sphere_coords in enumerate(coords):
Expand All @@ -26,8 +25,12 @@ class SearchLight:
"""

def __init__(
self, data: np.array, sl_fn: Callable, radius: int, mask: np.array = None
):
self,
data: NDArray[np.floating],
sl_fn: Callable[[NDArray[np.floating]], Union[float, NDArray[np.floating]]],
radius: int,
mask: Optional[NDArray[np.int_]] = None,
) -> None:
"""
Parameters
----------
Expand All @@ -52,12 +55,12 @@ def __init__(

def fit(
self,
coords: List[Tuple] = None,
coords: Optional[List[Tuple[int, int, int]]] = None,
output_size: int = 1,
n_jobs: int = 1,
n_chunks: int = 1,
verbose: int = 1,
) -> np.array:
) -> NDArray[np.floating]:
"""
Returns a 3D array with the classification accuracy for each voxel.

Expand Down Expand Up @@ -116,7 +119,7 @@ def fit(

return results

def get_sphere_coords(self, x, y, z):
def get_sphere_coords(self, x: int, y: int, z: int) -> NDArray[np.int_]:
"""
Returns coordinates of all voxels in a sphere with radius self.radius
at position x, y, z
Expand Down