diff --git a/docs/builtin.rst b/docs/builtin.rst index e802ab47b0..1dda0ee9bb 100644 --- a/docs/builtin.rst +++ b/docs/builtin.rst @@ -292,6 +292,14 @@ Available | `Richman et al. (2000) `_ - Done - 0.0.4 + * - :class:`.FunctionalConnectivityLaggedParcels` + - Compute lagged functional connectivity over parcellation + - Done + - 0.0.6 + * - :class:`.FunctionalConnectivityLaggedSpheres` + - Compute lagged functional connectivity over spheres placed on coordinates + - Done + - 0.0.6 Planned ~~~~~~~ diff --git a/docs/changes/newsfragments/435.feature b/docs/changes/newsfragments/435.feature new file mode 100644 index 0000000000..f1f4c9cfca --- /dev/null +++ b/docs/changes/newsfragments/435.feature @@ -0,0 +1 @@ +Add support for :class:`.FunctionalConnectivityLaggedParcels` and :class:`.FunctionalConnectivityLaggedSpheres` markers by `Synchon Mandal`_ diff --git a/junifer/markers/__init__.pyi b/junifer/markers/__init__.pyi index 73462565fe..3462f1bf77 100644 --- a/junifer/markers/__init__.pyi +++ b/junifer/markers/__init__.pyi @@ -5,6 +5,8 @@ __all__ = [ "SphereAggregation", "FunctionalConnectivityParcels", "FunctionalConnectivitySpheres", + "FunctionalConnectivityLaggedParcels", + "FunctionalConnectivityLaggedSpheres", "CrossParcellationFC", "EdgeCentricFCParcels", "EdgeCentricFCSpheres", @@ -24,6 +26,8 @@ from .sphere_aggregation import SphereAggregation from .functional_connectivity import ( FunctionalConnectivityParcels, FunctionalConnectivitySpheres, + FunctionalConnectivityLaggedParcels, + FunctionalConnectivityLaggedSpheres, CrossParcellationFC, EdgeCentricFCParcels, EdgeCentricFCSpheres, diff --git a/junifer/markers/functional_connectivity/__init__.pyi b/junifer/markers/functional_connectivity/__init__.pyi index dd7790d35a..0d212fd972 100644 --- a/junifer/markers/functional_connectivity/__init__.pyi +++ b/junifer/markers/functional_connectivity/__init__.pyi @@ -1,6 +1,8 @@ __all__ = [ "FunctionalConnectivityParcels", "FunctionalConnectivitySpheres", + "FunctionalConnectivityLaggedParcels", + "FunctionalConnectivityLaggedSpheres", "CrossParcellationFC", "EdgeCentricFCParcels", "EdgeCentricFCSpheres", @@ -8,6 +10,12 @@ __all__ = [ from .functional_connectivity_parcels import FunctionalConnectivityParcels from .functional_connectivity_spheres import FunctionalConnectivitySpheres +from .functional_connectivity_lagged_parcels import ( + FunctionalConnectivityLaggedParcels, +) +from .functional_connectivity_lagged_spheres import ( + FunctionalConnectivityLaggedSpheres, +) from .crossparcellation_functional_connectivity import CrossParcellationFC from .edge_functional_connectivity_parcels import EdgeCentricFCParcels from .edge_functional_connectivity_spheres import EdgeCentricFCSpheres diff --git a/junifer/markers/functional_connectivity/functional_connectivity_lagged_base.py b/junifer/markers/functional_connectivity/functional_connectivity_lagged_base.py new file mode 100644 index 0000000000..dc4cc978f9 --- /dev/null +++ b/junifer/markers/functional_connectivity/functional_connectivity_lagged_base.py @@ -0,0 +1,174 @@ +"""Provide abstract base class for lagged functional connectivity (FC).""" + +# Authors: Synchon Mandal +# License: AGPL + + +from abc import abstractmethod +from itertools import product +from typing import Any, ClassVar, Optional, Union + +import numpy as np +from scipy.signal import correlate + +from ...typing import Dependencies, MarkerInOutMappings +from ...utils import raise_error +from ..base import BaseMarker + + +__all__ = ["FunctionalConnectivityLaggedBase"] + + +class FunctionalConnectivityLaggedBase(BaseMarker): + """Abstract base class for lagged functional connectivity markers. + + Parameters + ---------- + max_lag : int + The time lag range. The lag ranges from ``-max_lag`` to ``+max_lag`` + time points. + agg_method : str, optional + The method to perform aggregation using. + Check valid options in :func:`.get_aggfunc_by_name` + (default "mean"). + agg_method_params : dict, optional + Parameters to pass to the aggregation function. + Check valid options in :func:`.get_aggfunc_by_name` + (default None). + masks : str, dict or list of dict or str, optional + The specification of the masks to apply to regions before extracting + signals. Check :ref:`Using Masks ` for more details. + If None, will not apply any mask (default None). + name : str, optional + The name of the marker. If None, will use ``BOLD_`` + (default None). + + """ + + _DEPENDENCIES: ClassVar[Dependencies] = {"numpy", "scipy"} + + _MARKER_INOUT_MAPPINGS: ClassVar[MarkerInOutMappings] = { + "BOLD": { + "functional_connectivity": "matrix", + "lag": "matrix", + }, + } + + def __init__( + self, + max_lag: int, + agg_method: str = "mean", + agg_method_params: Optional[dict] = None, + masks: Union[str, dict, list[Union[dict, str]], None] = None, + name: Optional[str] = None, + ) -> None: + self.max_lag = max_lag + self.agg_method = agg_method + self.agg_method_params = agg_method_params + self.masks = masks + super().__init__(on="BOLD", name=name) + + @abstractmethod + def aggregate( + self, + input: dict[str, Any], + extra_input: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + """Perform aggregation.""" + raise_error( + msg="Concrete classes need to implement aggregate().", + klass=NotImplementedError, + ) + + def compute( + self, + input: dict[str, Any], + extra_input: Optional[dict] = None, + ) -> dict: + """Compute. + + Parameters + ---------- + input : dict + A single input from the pipeline data object in which to compute + the marker. + extra_input : dict, optional + The other fields in the pipeline data object. Useful for accessing + other data kind that needs to be used in the computation. For + example, the functional connectivity markers can make use of the + confounds if available (default None). + + Returns + ------- + dict + The computed result as dictionary. This will be either returned + to the user or stored in the storage by calling the store method + with this as a parameter. The dictionary has the following keys: + + * ``functional_connectivity`` : dictionary with the following keys: + + - ``data`` : functional connectivity matrix as ``numpy.ndarray`` + - ``row_names`` : ROI labels as list of str + - ``col_names`` : ROI labels as list of str + - ``matrix_kind`` : the kind of matrix (tril, triu or full) + + * ``lag`` : dictionary with the following keys: + + - ``data`` : lag matrix as ``numpy.ndarray`` + - ``row_names`` : ROI labels as list of str + - ``col_names`` : ROI labels as list of str + - ``matrix_kind`` : the kind of matrix (tril, triu or full) + + Notes + ----- + Pearson correlation is used to perform connectivity measure. + + """ + # Perform necessary aggregation + aggregation = self.aggregate(input, extra_input=extra_input) + # Initialize variables + # transposed to get (n_rois, n_timepoints) + fmri_data = aggregation["aggregation"]["data"].T + n_rois = fmri_data.shape[0] + fc_matrix = np.ones((n_rois, n_rois)) + lag_matrix = np.zeros((n_rois, n_rois), dtype=int) + lags = np.arange(-self.max_lag, self.max_lag + 1) + # Compute + for i, j in product(range(n_rois), range(n_rois)): + if i != j: + x = fmri_data[i] + y = fmri_data[j] + # Compute cross-correlation function (CCF) + ccf = correlate(x, y, mode="full", method="auto") + # Normalize CCF + ccf /= np.sqrt(np.sum(np.abs(x**2)) * np.sum(np.abs(y**2))) + # Limit lag range + ccf = ccf[ + len(ccf) // 2 + - self.max_lag : len(ccf) // 2 + + self.max_lag + + 1 + ] + # Find peak correlation and corresponding lag + peak_idx = np.argmax(np.abs(ccf)) # Peak correlation index + peak_corr = ccf[peak_idx] # Peak correlation value + peak_lag = lags[peak_idx] # Corresponding lag + # Store in matrices + fc_matrix[i, j] = peak_corr + lag_matrix[i, j] = peak_lag + # Create dictionary for output + roi_labels = aggregation["aggregation"]["col_names"] + return { + "functional_connectivity": { + "data": fc_matrix, + "row_names": roi_labels, + "col_names": roi_labels, + "matrix_kind": "tril", + }, + "lag": { + "data": lag_matrix, + "row_names": roi_labels, + "col_names": roi_labels, + "matrix_kind": "full", + }, + } diff --git a/junifer/markers/functional_connectivity/functional_connectivity_lagged_parcels.py b/junifer/markers/functional_connectivity/functional_connectivity_lagged_parcels.py new file mode 100644 index 0000000000..7a16f3c661 --- /dev/null +++ b/junifer/markers/functional_connectivity/functional_connectivity_lagged_parcels.py @@ -0,0 +1,102 @@ +"""Provide class for lagged functional connectivity using parcels.""" + +# Authors: Synchon Mandal +# License: AGPL + + +from typing import Any, Optional, Union + +from ...api.decorators import register_marker +from ..parcel_aggregation import ParcelAggregation +from .functional_connectivity_lagged_base import ( + FunctionalConnectivityLaggedBase, +) + + +__all__ = ["FunctionalConnectivityLaggedParcels"] + + +@register_marker +class FunctionalConnectivityLaggedParcels(FunctionalConnectivityLaggedBase): + """Class for lagged functional connectivity using parcellations. + + Parameters + ---------- + parcellation : str or list of str + The name(s) of the parcellation(s) to use. + See :func:`.list_data` for options. + max_lag : int + The time lag range. The lag ranges from ``-max_lag`` to ``+max_lag`` + time points. + agg_method : str, optional + The method to perform aggregation using. + See :func:`.get_aggfunc_by_name` for options + (default "mean"). + agg_method_params : dict, optional + Parameters to pass to the aggregation function. + See :func:`.get_aggfunc_by_name` for options + (default None). + masks : str, dict or list of dict or str, optional + The specification of the masks to apply to regions before extracting + signals. Check :ref:`Using Masks ` for more details. + If None, will not apply any mask (default None). + name : str, optional + The name of the marker. If None, will use + ``BOLD_FunctionalConnectivityLaggedParcels`` (default None). + + """ + + def __init__( + self, + parcellation: Union[str, list[str]], + max_lag: int, + agg_method: str = "mean", + agg_method_params: Optional[dict] = None, + masks: Union[str, dict, list[Union[dict, str]], None] = None, + name: Optional[str] = None, + ) -> None: + self.parcellation = parcellation + super().__init__( + max_lag=max_lag, + agg_method=agg_method, + agg_method_params=agg_method_params, + masks=masks, + name=name, + ) + + def aggregate( + self, input: dict[str, Any], extra_input: Optional[dict] = None + ) -> dict: + """Perform parcel aggregation. + + Parameters + ---------- + input : dict + A single input from the pipeline data object in which to compute + the marker. + extra_input : dict, optional + The other fields in the pipeline data object. Useful for accessing + other data kind that needs to be used in the computation. For + example, the functional connectivity markers can make use of the + confounds if available (default None). + + Returns + ------- + dict + The computed result as dictionary. This will be either returned + to the user or stored in the storage by calling the store method + with this as a parameter. The dictionary has the following keys: + + * ``aggregation`` : dictionary with the following keys: + + - ``data`` : ROI values as ``numpy.ndarray`` + - ``col_names`` : ROI labels as list of str + + """ + return ParcelAggregation( + parcellation=self.parcellation, + method=self.agg_method, + method_params=self.agg_method_params, + masks=self.masks, + on="BOLD", + ).compute(input=input, extra_input=extra_input) diff --git a/junifer/markers/functional_connectivity/functional_connectivity_lagged_spheres.py b/junifer/markers/functional_connectivity/functional_connectivity_lagged_spheres.py new file mode 100644 index 0000000000..19c2267a51 --- /dev/null +++ b/junifer/markers/functional_connectivity/functional_connectivity_lagged_spheres.py @@ -0,0 +1,119 @@ +"""Provide class for lagged functional connectivity using spheres.""" + +# Authors: Synchon Mandal +# License: AGPL + + +from typing import Any, Optional, Union + +from ...api.decorators import register_marker +from ..sphere_aggregation import SphereAggregation +from ..utils import raise_error +from .functional_connectivity_lagged_base import ( + FunctionalConnectivityLaggedBase, +) + + +__all__ = ["FunctionalConnectivityLaggedSpheres"] + + +@register_marker +class FunctionalConnectivityLaggedSpheres(FunctionalConnectivityLaggedBase): + """Class for lagged functional connectivity using coordinates (spheres). + + Parameters + ---------- + coords : str + The name of the coordinates list to use. + See :func:`.list_data` for options. + radius : positive float, optional + The radius of the sphere around each coordinates in millimetres. + If None, the signal will be extracted from a single voxel. + See :class:`.JuniferNiftiSpheresMasker` for more information + (default None). + allow_overlap : bool, optional + Whether to allow overlapping spheres. If False, an error is raised if + the spheres overlap (default False). + max_lag : int + The time lag range. The lag ranges from ``-max_lag`` to ``+max_lag`` + time points. + agg_method : str, optional + The method to perform aggregation using. + See :func:`.get_aggfunc_by_name` for options + (default "mean"). + agg_method_params : dict, optional + Parameters to pass to the aggregation function. + See :func:`.get_aggfunc_by_name` for options + (default None). + masks : str, dict or list of dict or str, optional + The specification of the masks to apply to regions before extracting + signals. Check :ref:`Using Masks ` for more details. + If None, will not apply any mask (default None). + name : str, optional + The name of the marker. If None, will use + ``BOLD_FunctionalConnectivityLaggedSpheres`` (default None). + + """ + + def __init__( + self, + coords: str, + max_lag: int, + radius: Optional[float] = None, + allow_overlap: bool = False, + agg_method: str = "mean", + agg_method_params: Optional[dict] = None, + masks: Union[str, dict, list[Union[dict, str]], None] = None, + name: Optional[str] = None, + ) -> None: + self.coords = coords + self.radius = radius + self.allow_overlap = allow_overlap + if radius is None or radius <= 0: + raise_error(f"radius should be > 0: provided {radius}") + super().__init__( + max_lag=max_lag, + agg_method=agg_method, + agg_method_params=agg_method_params, + masks=masks, + name=name, + ) + + def aggregate( + self, input: dict[str, Any], extra_input: Optional[dict] = None + ) -> dict: + """Perform sphere aggregation. + + Parameters + ---------- + input : dict + A single input from the pipeline data object in which to compute + the marker. + extra_input : dict, optional + The other fields in the pipeline data object. Useful for accessing + other data kind that needs to be used in the computation. For + example, the functional connectivity markers can make use of the + confounds if available (default None). + + Returns + ------- + dict + The computed result as dictionary. This will be either returned + to the user or stored in the storage by calling the store method + with this as a parameter. The dictionary has the following keys: + + * ``aggregation`` : dictionary with the following keys: + + - ``data`` : ROI values as ``numpy.ndarray`` + - ``col_names`` : ROI labels as list of str + + """ + return SphereAggregation( + coords=self.coords, + radius=self.radius, + allow_overlap=self.allow_overlap, + method=self.agg_method, + method_params=self.agg_method_params, + masks=self.masks, + on="BOLD", + ).compute(input=input, extra_input=extra_input) diff --git a/junifer/markers/functional_connectivity/tests/test_functional_connectivity_lagged_parcels.py b/junifer/markers/functional_connectivity/tests/test_functional_connectivity_lagged_parcels.py new file mode 100644 index 0000000000..975e25a318 --- /dev/null +++ b/junifer/markers/functional_connectivity/tests/test_functional_connectivity_lagged_parcels.py @@ -0,0 +1,72 @@ +"""Provide tests for lagged functional connectivity using parcels.""" + +# Authors: Synchon Mandal +# License: AGPL + +from pathlib import Path + +from junifer.datareader import DefaultDataReader +from junifer.markers.functional_connectivity import ( + FunctionalConnectivityLaggedParcels, +) +from junifer.storage import HDF5FeatureStorage +from junifer.testing.datagrabbers import PartlyCloudyTestingDataGrabber + + +def test_FunctionalConnectivityLaggedParcels( + tmp_path: Path, +) -> None: + """Test FunctionalConnectivityLaggedParcels. + + Parameters + ---------- + tmp_path : pathlib.Path + The path to the test directory. + + """ + with PartlyCloudyTestingDataGrabber() as dg: + # Get element data + element_data = DefaultDataReader().fit_transform(dg["sub-01"]) + # Setup marker + marker = FunctionalConnectivityLaggedParcels( + parcellation="TianxS1x3TxMNInonlinear2009cAsym", + max_lag=10, + ) + # Check correct outputs + assert "matrix" == marker.get_output_type( + input_type="BOLD", output_feature="functional_connectivity" + ) + assert "matrix" == marker.get_output_type( + input_type="BOLD", output_feature="lag" + ) + + # Fit-transform the data and check + lagged_fc = marker.fit_transform(element_data) + + for feature in ("functional_connectivity", "lag"): + lagged_fc_bold = lagged_fc["BOLD"][feature] + + assert "data" in lagged_fc_bold + assert "row_names" in lagged_fc_bold + assert "col_names" in lagged_fc_bold + assert lagged_fc_bold["data"].shape == (16, 16) + assert len(set(lagged_fc_bold["row_names"])) == 16 + assert len(set(lagged_fc_bold["col_names"])) == 16 + + # Store + storage = HDF5FeatureStorage( + uri=tmp_path / "test_lagged_fc_parcels.hdf5", + ) + marker.fit_transform(input=element_data, storage=storage) + features = storage.list_features() + assert all( + x["name"] + in ( + ( + "BOLD_FunctionalConnectivityLaggedParcels_" + "functional_connectivity" + ), + "BOLD_FunctionalConnectivityLaggedParcels_lag", + ) + for x in features.values() + ) diff --git a/junifer/markers/functional_connectivity/tests/test_functional_connectivity_lagged_spheres.py b/junifer/markers/functional_connectivity/tests/test_functional_connectivity_lagged_spheres.py new file mode 100644 index 0000000000..c26463713f --- /dev/null +++ b/junifer/markers/functional_connectivity/tests/test_functional_connectivity_lagged_spheres.py @@ -0,0 +1,83 @@ +"""Provide tests for lagged functional connectivity using spheres.""" + +# Authors: Synchon Mandal +# License: AGPL + +from pathlib import Path + +import pytest + +from junifer.datareader import DefaultDataReader +from junifer.markers.functional_connectivity import ( + FunctionalConnectivityLaggedSpheres, +) +from junifer.storage import HDF5FeatureStorage +from junifer.testing.datagrabbers import SPMAuditoryTestingDataGrabber + + +def test_FunctionalConnectivitySpheres( + tmp_path: Path, +) -> None: + """Test FunctionalConnectivityLaggedSpheres. + + Parameters + ---------- + tmp_path : pathlib.Path + The path to the test directory. + + """ + with SPMAuditoryTestingDataGrabber() as dg: + # Get element data + element_data = DefaultDataReader().fit_transform(dg["sub001"]) + # Setup marker + marker = FunctionalConnectivityLaggedSpheres( + coords="DMNBuckner", + radius=5.0, + max_lag=10, + ) + # Check correct outputs + assert "matrix" == marker.get_output_type( + input_type="BOLD", output_feature="functional_connectivity" + ) + assert "matrix" == marker.get_output_type( + input_type="BOLD", output_feature="lag" + ) + + # Fit-transform the data + lagged_fc = marker.fit_transform(element_data) + + for feature in ("functional_connectivity", "lag"): + lagged_fc_bold = lagged_fc["BOLD"][feature] + + assert "data" in lagged_fc_bold + assert "row_names" in lagged_fc_bold + assert "col_names" in lagged_fc_bold + assert lagged_fc_bold["data"].shape == (6, 6) + assert len(set(lagged_fc_bold["row_names"])) == 6 + assert len(set(lagged_fc_bold["col_names"])) == 6 + + # Store + storage = HDF5FeatureStorage( + uri=tmp_path / "test_lagged_fc_spheres.hdf5", + ) + marker.fit_transform(input=element_data, storage=storage) + features = storage.list_features() + assert all( + x["name"] + in ( + ( + "BOLD_FunctionalConnectivityLaggedSpheres_" + "functional_connectivity" + ), + "BOLD_FunctionalConnectivityLaggedSpheres_lag", + ) + for x in features.values() + ) + + +def test_FunctionalConnectivityLaggedSpheres_error() -> None: + """Test FunctionalConnectivityLaggedSpheres errors.""" + with pytest.raises(ValueError, match="radius should be > 0"): + FunctionalConnectivityLaggedSpheres( + coords="DMNBuckner", radius=-0.1, max_lag=10 + ) diff --git a/junifer/pipeline/pipeline_component_registry.py b/junifer/pipeline/pipeline_component_registry.py index 55a8fa31ad..a0d1a602cd 100644 --- a/junifer/pipeline/pipeline_component_registry.py +++ b/junifer/pipeline/pipeline_component_registry.py @@ -91,6 +91,12 @@ def __init__(self) -> None: "FunctionalConnectivitySpheres": ( "FunctionalConnectivitySpheres" ), + "FunctionalConnectivityLaggedParcels": ( + "FunctionalConnectivityLaggedParcels" + ), + "FunctionalConnectivityLaggedSpheres": ( + "FunctionalConnectivityLaggedSpheres" + ), "ParcelAggregation": "ParcelAggregation", "ReHoParcels": "ReHoParcels", "ReHoSpheres": "ReHoSpheres",