diff --git a/CHANGELOG.md b/CHANGELOG.md index 15f3ab1..15c6056 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,13 +12,15 @@ and this project adheres to [Semantic Versioning][]. ### Added -- `pycea.tl.partition_test` to test for statistically significant differences between leaf partitions. +- `pycea.tl.partition_test` to test for statistically significant differences between leaf partitions. (#40) ### Changed +- Replaced `tdata.obs_keys()` with `tdata.obs.keys()` to conform with anndata API changes. (#41) + ### Fixed -- Replaced `tdata.obs_keys()` with `tdata.obs.keys()` to conform with anndata API changes. (#41) +- Fixed node plotting when `isinstance(nodes,str)`. (#39) ## [0.1.0] - 2025-09-19 diff --git a/src/pycea/get/palette.py b/src/pycea/get/palette.py index 8088375..98bad33 100755 --- a/src/pycea/get/palette.py +++ b/src/pycea/get/palette.py @@ -72,12 +72,17 @@ def palette( """ Get color palette for a given key. + This function gets the mapping from category → color for a given key + in ``tdata``. If no customizations are provided, the function will return + a previously stored palette in ``tdata.uns`` if it exists. Otherwise, + a new palette is generated. + Parameters ---------- tdata The `treedata.TreeData` object. key - A key from `obs_keys`, `obsm_keys`, or `obsp_keys` to generate a color palette for. + A key from `obs.keys()`, `obsm.keys()`, or `obsp.keys()` to generate a color palette for. custom A dictionary mapping specific values to colors (e.g., `{"category1": "red"}`). cmap @@ -94,6 +99,21 @@ def palette( Returns ------- palette - Color palette for the given key + + Examples + -------- + Get character color palette with saturation adjusted by indel probability: + + >>> tdata = py.datasets.yang22() + >>> indel_palette = py.get.palette( + ... tdata, + ... "characters", + ... custom={"-": "white", "*": "lightgrey"}, + ... cmap="gist_rainbow", + ... priors=tdata.uns["priors"], + ... sort="random", + ... ) + """ # Setup tdata._sanitize() diff --git a/src/pycea/pl/plot_tree.py b/src/pycea/pl/plot_tree.py index a3491ba..5785016 100644 --- a/src/pycea/pl/plot_tree.py +++ b/src/pycea/pl/plot_tree.py @@ -57,6 +57,13 @@ def branches( """\ Plot the branches of a tree. + Plots the branches of one or more trees stored in ``tdata.obst`` as a + :class:`matplotlib.collections.LineCollection`. Branch appearance (`color` and `linewidth`) + can be fixed scalars or set based on edge attributes (continuous or categorical). + Coloring of continuous variables is based on a colormap (`cmap`), while categorical + variables can be colored using a custom `palette` or a default categorical color set. + Polar coordinates are used when `polar` is True. + Parameters ---------- tdata @@ -94,6 +101,22 @@ def branches( Returns ------- ax - The axes that the plot was drawn on. + + Notes + ----- + * If ``ax`` is provided the coordinate system must match the ``polar`` setting. + * Continuous color attributes use ``cmap`` with ``vmin``/``vmax`` normalization. + + Examples + -------- + Plot tree branches: + + >>> tdata = py.datasets.packer19() + >>> py.pl.branches(tdata, depth_key="time") + + Plot tree branches in polar coordinates, colored by clade: + + >>> py.pl.branches(tdata, depth_key="time", polar=True, color="clade") """ # noqa: D205 # Setup tdata._sanitize() @@ -207,6 +230,11 @@ def nodes( """\ Plot the nodes of a tree. + Plot the nodes of one or more trees from ``tdata.obst`` on the current axes using + :func:`matplotlib.pyplot.scatter`. Appearance can be fixed (single color/marker/size) or set based on + node attributes (continuous or categorical). You can plot only leaves, only + internal nodes, all nodes, or an explicit list of node names. + Parameters ---------- tdata @@ -249,6 +277,25 @@ def nodes( Returns ------- ax - The axes that the plot was drawn on. + + Notes + ----- + * Must call :func:`pycea.pl.branches` or :func:`pycea.pl.tree` before calling this function. + * Continuous color attributes use ``cmap`` with ``vmin``/``vmax`` normalization. + + Examples + -------- + Plot internal nodes colored by depth: + + >>> tdata = py.datasets.packer19() + >>> py.pl.branches(tdata, depth_key="time") + >>> py.pl.nodes(tdata, nodes="internal", color="time", cmap="plasma") + + Color nodes by "elt-2" expression and highlight the "E" node with a star marker: + + >>> py.pl.branches(tdata, depth_key="time") + >>> py.pl.nodes(tdata, color="elt-2", nodes="all") + >>> py.pl.nodes(tdata, color="red", nodes="EMS", style="*", size=200, slot="obst") """ # noqa: D205 # Setup kwargs = kwargs if kwargs else {} @@ -282,6 +329,8 @@ def nodes( elif nodes == "internal": plot_nodes = [node for node in all_nodes if node[1] not in attrs["leaves"]] elif isinstance(nodes, Sequence): + if isinstance(nodes, str): + nodes = [nodes] if len(attrs["tree_keys"]) > 1 and len(tree_keys) > 1: raise ValueError("Multiple trees are present. To plot a list of nodes, you must specify the tree.") plot_nodes = [(tree_keys[0], node) for node in nodes] @@ -395,12 +444,17 @@ def annotation( """\ Plot leaf annotations for a tree. + Plots one or more leaf annotations (small heatmap-like bars) next to the tree’s + leaves, preserving the leaf order/layout used by :func:`pycea.pl.branches`. Each key can be + continuous (colored via a `colormap`) or categorical (colored via a `palette`). Multiple + keys are stacked horizontally (or radially if your tree plot is polar). + Parameters ---------- tdata The TreeData object. keys - One or more `obs_keys`, `var_names`, `obsm_keys`, or `obsp_keys` to plot. + One or more `obs.keys()`, `var_names`, `obsm.keys()`, or `obsp.keys()` to plot. width The width of the annotation bar relative to the tree. gap @@ -432,6 +486,25 @@ def annotation( Returns ------- ax - The axes that the plot was drawn on. + + Notes + ----- + * Must call :func:`pycea.pl.branches` or :func:`pycea.pl.tree` before calling this function. + * Continuous color attributes use ``cmap`` with ``vmin``/``vmax`` normalization. + + Examples + -------- + Plot leaf annotations for "elt-2" and "pal-1" expression: + + >>> tdata = py.datasets.packer19(tree = "observed") + >>> py.pl.branches(tdata, depth_key="time") + >>> py.pl.annotation(tdata, keys=["elt-2", "pal-1"]) + + Plot leaf annotation for spatial distance between leaves: + + >>> py.tl.distance(tdata, key = "spatial") + >>> py.pl.branches(tdata, depth_key="time") + >>> py.pl.annotation(tdata, keys="spatial_distances", cmap="magma") """ # noqa: D205 # Setup if tree: # TODO: Annotate only the leaves for the given tree @@ -598,12 +671,16 @@ def tree( """\ Plot a tree with branches, nodes, and annotations. + This function combines :func:`pycea.pl.branches`, :func:`pycea.pl.nodes`, and :func:`pycea.pl.annotation` to enable + plotting a complete tree with branches, nodes, and leaf annotations in a single call. Each component + (branches, nodes, annotations) can be customized independently using the respective parameters. + Parameters ---------- tdata The TreeData object. keys - One or more `obs_keys`, `var_names`, `obsm_keys`, or `obsp_keys` annotations. + One or more `obs.keys()`, `var_names`, `obsm.keys()`, or `obsp.keys()` annotations. nodes Either "all", "leaves", "internal", or a list of nodes to plot. Defaults to "internal" if node color, style, or size is set. polar @@ -641,6 +718,18 @@ def tree( Returns ------- ax - The axes that the plot was drawn on. + + Notes + ----- + * If ``ax`` is provided the coordinate system must match the ``polar`` setting. + * Continuous color attributes use ``cmap`` with ``vmin``/``vmax`` normalization. + + Examples + -------- + Plot a tree with nodes and leaves colored by "elt-2" expression: + + >>> tdata = py.datasets.packer19() + >>> py.pl.tree(tdata, nodes="all", node_color="elt-2", keys="elt-2", depth_key="time") """ # noqa: D205 # Setup branch_legend = legend.get("branch", None) if isinstance(legend, Mapping) else legend diff --git a/src/pycea/tl/_aggregators.py b/src/pycea/tl/_aggregators.py index b593f61..30c9fea 100644 --- a/src/pycea/tl/_aggregators.py +++ b/src/pycea/tl/_aggregators.py @@ -7,7 +7,7 @@ from numpy.typing import NDArray _AggregatorFn = Callable[[np.ndarray], np.ndarray | float] -_Aggregator = Literal["mean", "median", "sum", "var"] +_Aggregator = Literal["mean", "median", "sum", "min", "max", "var"] def _reduce(fn, X: NDArray[np.generic]) -> NDArray[np.generic] | float: @@ -24,6 +24,8 @@ def _var(X: NDArray[np.generic]) -> NDArray[np.generic] | float: "mean": lambda X: _reduce(np.mean, X), "median": lambda X: _reduce(np.median, X), "sum": lambda X: _reduce(np.sum, X), + "min": lambda X: _reduce(np.min, X), + "max": lambda X: _reduce(np.max, X), "var": _var, } diff --git a/src/pycea/tl/ancestral_states.py b/src/pycea/tl/ancestral_states.py index 211ee88..142c01f 100755 --- a/src/pycea/tl/ancestral_states.py +++ b/src/pycea/tl/ancestral_states.py @@ -260,12 +260,17 @@ def ancestral_states( ) -> pd.DataFrame | None: """Reconstructs ancestral states for an attribute. + This function reconstructs ancestral (internal node) states for categorical or + continuous attributes defined on tree leaves. Several reconstruction methods + are supported, ranging from simple aggregation rules to the Sankoff and Fitch-Hartigan + algorithms for discrete character data, or a custom aggregation function can be provided. + Parameters ---------- tdata TreeData object. keys - One or more `obs_keys`, `var_names`, `obsm_keys`, or `obsp_keys` to reconstruct. + One or more `obs.keys()`, `var_names`, `obsm.keys()`, or `obsp.keys()` to reconstruct. method Method to reconstruct ancestral states: @@ -279,7 +284,7 @@ def ancestral_states( default_state The expected state for the root node. costs - A pd.DataFrame with the costs of changing states (from rows to columns). + A pd.DataFrame with the costs of changing states (from rows to columns). Only used if method is 'sankoff'. keys_added Attribute keys of `tdata.obst[tree].nodes` where ancestral states will be stored. If `None`, `keys` are used. tree @@ -295,6 +300,18 @@ def ancestral_states( * `tdata.obst[tree].nodes[key_added]` : `float` | `Object` | `List[Object]` - Inferred ancestral states. List of states if data was an array. + + Examples + -------- + Infer the expression of Krt20 and Cd74 based on their mean value in descendant cells: + + >>> tdata = py.datasets.yang22() + >>> py.tl.ancestral_states(tdata, keys=["Krt20", "Cd74"], method="mean") + + Reconstruct ancestral character states using the Fitch-Hartigan algorithm: + + >>> py.tl.ancestral_states(tdata, keys="characters", method="fitch_hartigan", missing_state=-1) + """ if isinstance(keys, str): keys = [keys] diff --git a/src/pycea/tl/autocorr.py b/src/pycea/tl/autocorr.py index 1b85c8f..18899e3 100755 --- a/src/pycea/tl/autocorr.py +++ b/src/pycea/tl/autocorr.py @@ -84,14 +84,48 @@ def autocorr( layer: str | None = None, copy: Literal[True, False] = False, ) -> pd.DataFrame | None: - """Calculate autocorrelation statistic. + r"""Calculate autocorrelation statistic. + + This function computes autocorrelation for one or more variables using + either **Moran’s I** or **Geary’s C** statistic, based on a specified connectivity + graph between observations. + + Mathematically, the two statistics are defined as follows: + + .. math:: + + I = + \frac{ + N \sum_{i,j} w_{i,j} (x_i - \bar{x})(x_j - \bar{x}) + }{ + W \sum_i (x_i - \bar{x})^2 + } + + C = + \frac{ + (N - 1)\sum_{i,j} w_{i,j} (x_i - x_j)^2 + }{ + 2W \sum_i (x_i - \bar{x})^2 + } + + where: + * :math:`N` is the number of observations, + * :math:`x_i` is the value of observation *i*, + * :math:`\bar{x}` is the mean of all observations, + * :math:`w_{i,j}` is the spatial weight between *i* and *j*, and + * :math:`W = \sum_{i,j} w_{i,j}`. + + A Moran’s I value close to 1 indicates strong positive autocorrelation, + while values near 0 suggest randomness. For Geary’s C behaves inversely: + values less than 1 indicate positive autocorrelation, while values + greater than 1 indicate negative autocorrelation. Parameters ---------- tdata TreeData object. keys - One or more `obs_keys`, `var_names`, `obsm_keys`, or `obsp_keys` to calculate autocorrelation for. Defaults to all 'var_names'. + One or more `obs.keys()`, `var_names`, `obsm.keys()`, or `obsp.keys()` to calculate autocorrelation for. Defaults to all 'var_names'. connect_key `tdata.obsp` connectivity key specifying set of neighbors for each observation. method @@ -116,6 +150,14 @@ def autocorr( * `tdata.uns['moranI']` : Above DataFrame for if method is `'moran'`. * `tdata.uns['gearyC']` : Above DataFrame for if method is `'geary'`. + + Examples + -------- + Estimate gene expression heritability using Moran's I autocorrelation: + + >>> tdata = py.datasets.yang22() + >>> py.tl.tree_neighbors(tdata, n_neighbors=10) + >>> py.tl.autocorr(tdata, connect_key="tree_connectivities", method="moran") """ # Setup if keys is None: diff --git a/src/pycea/tl/clades.py b/src/pycea/tl/clades.py index 121ebee..4c1f08b 100755 --- a/src/pycea/tl/clades.py +++ b/src/pycea/tl/clades.py @@ -103,6 +103,20 @@ def clades( ) -> None | pd.DataFrame: """Marks clades in a tree. + A clade is defined by a **ancestral node**; all nodes and + edges in the ancestral node's descendant subtree inherit the same clade label. + You can specify clades in two ways: + + * **Depth-based:** + Given a ``depth`` threshold, all nodes that are extant + at that depth are considered ancestral nodes. Each ancestral node and its descendants + are assigned a unique clade label. + + * **Explicit mapping:** + When ``clades``, a dictionary mapping nodes to clade labels is provided, + those nodes are considered ancestral nodes. Each such node and its descendants are assigned + the corresponding clade label. + Parameters ---------- tdata @@ -134,6 +148,17 @@ def clades( - Clade assignment for each observation. * `tdata.obst[tree].nodes[key_added]` : `Object` - Clade assignment for each node. + + Examples + -------- + Mark clades at specified depth + + >>> tdata = pycea.datasets.koblan25() + >>> pycea.tl.clades(tdata, depth=4, depth_key="time") + + Highlight descendants of 'node6' + + >>> pycea.tl.clades(tdata, clades={"node6": "node6_descendants"}, key_added="highlight") """ # Setup tree_keys = tree diff --git a/src/pycea/tl/distance.py b/src/pycea/tl/distance.py index 2691981..7c1a5e9 100755 --- a/src/pycea/tl/distance.py +++ b/src/pycea/tl/distance.py @@ -98,6 +98,11 @@ def distance( ) -> None | np.ndarray | sp.sparse.csr_matrix: """Computes distances between observations. + Supports full pairwise distances, distances from a single observation to all others, + distances within a specified subset, or distances for an explicit list of pairs. + Distances can be computed using a named metric (e.g. ``"euclidean"``, ``"cosine"``, + ``"manhattan"``) or a user-supplied callable. + Parameters ---------- tdata @@ -143,6 +148,28 @@ def distance( - Connectivity between observations. * `tdata.obs['{key_added}_distances']` : :class:`Series ` (dtype `float`) if `obs` is a string. - Distance from specified observation to others. + + Notes + ----- + * When both ``connect_key`` and ``sample_n`` are provided, sampling is performed + **within** the connected pairs induced by the connectivity. + * If you pass a callable metric, it must accept two 1D vectors and return a scalar. + + Examples + -------- + Calculate pairwise spatial distance between all observations: + + >>> tdata = py.datasets.koblan25() + >>> py.tl.distance(tdata, key="spatial") + + Calculate spatial distance between closely related observations: + + >>> py.tl.tree_neighbors(tdata, n_neighbors=10, depth_key="time") + >>> py.tl.distance(tdata, key="spatial", connect_key="tree_connectivities") + + Calculate distance from a single observation to all others: + + >>> py.tl.distance(tdata, key="spatial", obs="M3-1-19") """ # Setup _set_random_state(random_state) @@ -270,6 +297,12 @@ def compare_distance( ) -> pd.DataFrame: """Get pairwise observation distances. + This function gathers distances between the same observation pairs from one or + more entries in ``tdata.obsp`` and returns them side-by-side in a tidy + :class:`pandas.DataFrame`. Only pairs for which **all** requested distance + matrices have defined values are included. Optionally, comparisons can be + restricted within groups and/or randomly subsampled. + Parameters ---------- tdata @@ -294,6 +327,15 @@ def compare_distance( * `obs1` and `obs2` are the observation names. * `{dist_key}_distances` are the distances between the observations. + Examples + -------- + Compare spatial and tree distances for 1000 random pairs of observations: + + >>> tdata = py.datasets.koblan25() + >>> py.tl.distance(tdata, key="spatial", sample_n=1000) + >>> py.tl.tree_distance(tdata, key="tree", connect_key="spatial_connectivities") + >>> df = py.tl.compare_distance(tdata, dist_keys=["spatial_distances", "tree_distances"]) + """ # Setup _set_random_state(random_state) diff --git a/src/pycea/tl/n_extant.py b/src/pycea/tl/n_extant.py index ebdb8e1..4c87816 100644 --- a/src/pycea/tl/n_extant.py +++ b/src/pycea/tl/n_extant.py @@ -23,8 +23,6 @@ def n_extant( slot: Literal["obst", "obs"] = "obst", copy: Literal[True, False] = True, ) -> pd.DataFrame: ... - - @overload def n_extant( tdata: td.TreeData, @@ -38,8 +36,6 @@ def n_extant( slot: Literal["obst", "obs"] = "obst", copy: Literal[True, False] = False, ) -> None: ... - - def n_extant( tdata: td.TreeData, depth_key: str = "depth", @@ -55,8 +51,14 @@ def n_extant( """ Counts extant branches over time. - Computes the number of extant branches at each depth bin in the tree, - optionally stratified by node grouping variable(s). + This function computes the number of lineages that are alive (extant) + at each depth bin. Lineages are counted by sweeping along each edge + (parent → child): a lineage is present from the parent’s depth (inclusive) up to + the child’s depth (exclusive), unless ``extend_branches=True`` and the child is a + leaf, in which case the lineage extends to the maximum depth bin. + + Grouping is supported by one or more node attributes (``groupby``); counts are + computed separately per unique group combination. Parameters ---------- @@ -91,6 +93,13 @@ def n_extant( * `tdata.uns[key_added]` : :class:`DataFrame ` with columns depth_key, `n_extant`, grouping variables, and `tree`. + + Examples + -------- + Calculate number cells in each clade over time: + + >>> tdata = py.datasets.koblan25() + >>> py.tl.n_extant(tdata, depth_key="time", groupby="clade", bins=10) """ # Validate tree keys and get trees tree_keys = tree @@ -165,4 +174,3 @@ def n_extant( if copy: return extant return None - return None diff --git a/src/pycea/tl/neighbor_distance.py b/src/pycea/tl/neighbor_distance.py index 2022396..1659768 100755 --- a/src/pycea/tl/neighbor_distance.py +++ b/src/pycea/tl/neighbor_distance.py @@ -1,6 +1,5 @@ from __future__ import annotations -from collections.abc import Callable from typing import Literal, overload import numpy as np @@ -8,20 +7,10 @@ import scipy as sp import treedata as td +from ._aggregators import _Aggregator, _AggregatorFn, _get_aggregator from ._utils import _csr_data_mask, _format_keys -def _get_agg_func(method): - """Returns aggregation function.""" - agg_funcs = {"mean": np.mean, "max": np.max, "min": np.min, "median": np.median} - if method in agg_funcs: - return agg_funcs[method] - elif callable(method): - return method - else: - raise ValueError(f"Invalid method: {method}") - - def _assert_distance_specified(dist, mask): """Asserts that distance is specified for where connected""" if isinstance(dist, sp.sparse.csr_matrix): @@ -36,7 +25,7 @@ def neighbor_distance( tdata: td.TreeData, connect_key: str | None = None, dist_key: str | None = None, - method: str | Callable = "mean", + method: _AggregatorFn | _Aggregator = "mean", key_added: str = "neighbor_distances", copy: Literal[True, False] = True, ) -> pd.Series: ... @@ -45,7 +34,7 @@ def neighbor_distance( tdata: td.TreeData, connect_key: str | None = None, dist_key: str | None = None, - method: str | Callable = "mean", + method: _AggregatorFn | _Aggregator = "mean", key_added: str = "neighbor_distances", copy: Literal[True, False] = False, ) -> None: ... @@ -53,11 +42,24 @@ def neighbor_distance( tdata: td.TreeData, connect_key: str | None = None, dist_key: str | None = None, - method: str | Callable = "mean", + method: _AggregatorFn | _Aggregator = "mean", key_added: str = "neighbor_distances", copy: Literal[True, False] = False, ) -> None | pd.Series: - """Aggregates distance to neighboring observations. + r"""Aggregates distance to neighboring observations. + + For each observation :math:`i`, this function collects the distances + :math:`\\{ D_{ij} : j \in \mathcal{N}(i) \\}` to its neighbors (as defined by a + binary/weighted connectivity in ``tdata.obsp[connect_key]``) and reduces them to a + single value via an aggregation function :math:`g`: + + .. math:: + + d_i = g\big( \{ D_{ij} : j \in \mathcal{N}(i) \} \big) + + The aggregator :math:`g` can be the mean, median, min, max, or a user-supplied + callable. If an observation has no neighbors, the result for that observation is + ``NaN``. Parameters ---------- @@ -68,14 +70,7 @@ def neighbor_distance( dist_key `tdata.obsp` distances key specifying distances between observations. method - Method to calculate neighbor distances: - - * 'mean' : The mean distance to neighboring observations. - * 'median' : The median distance to neighboring observations. - * 'min' : The minimum distance to neighboring observations. - * 'max' : The maximum distance to neighboring observations. - * Any function that takes a list of values and returns a single value. - + Aggregation function used to calculate neighbor distances. key_added `tdata.obs` key to store neighbor distances. copy @@ -89,6 +84,16 @@ def neighbor_distance( * `tdata.obs[key_added]` : :class:`Series ` (dtype `float`) - Neighbor distances for each observation. + + Examples + -------- + Calculate mean spatial distance to tree neighbors: + + >>> tdata = py.datasets.koblan25() + >>> py.tl.tree_neighbors(tdata, n_neighbors=5, depth_key="time") + >>> py.tl.distance(tdata, key="spatial", connect_key="tree_connectivities") + >>> py.tl.neighbor_distance(tdata, dist_key="spatial_distances", connect_key="tree_connectivities", method="mean") + """ # Setup if connect_key is None: @@ -99,18 +104,21 @@ def neighbor_distance( _format_keys(connect_key, "connectivities") if dist_key not in tdata.obsp.keys(): _format_keys(dist_key, "distances") - agg_func = _get_agg_func(method) + agg_func = _get_aggregator(method) mask = tdata.obsp[connect_key] > 0 dist = tdata.obsp[dist_key] _assert_distance_specified(dist, mask) # Calculate neighbor distances agg_dist = [] for i in range(dist.shape[0]): # type: ignore - if isinstance(mask, sp.sparse.csr_matrix): + if sp.sparse.issparse(mask): indices = mask[i].indices else: indices = np.nonzero(mask[i])[0] - row_dist = dist[i, indices] + if sp.sparse.issparse(dist): + row_dist = dist[i, indices].toarray().ravel() + else: + row_dist = dist[i, indices] if row_dist.size > 0: agg_dist.append(agg_func(row_dist)) else: diff --git a/src/pycea/tl/partition_test.py b/src/pycea/tl/partition_test.py index 77b6500..64972ef 100644 --- a/src/pycea/tl/partition_test.py +++ b/src/pycea/tl/partition_test.py @@ -128,16 +128,16 @@ def partition_test( descended from each internal node (group1) to the set of leaves defined by the `comparison` parameter (group2): - • **comparison='siblings':** + * ``comparison='siblings':`` Compare to the descendants of sibling nodes. When there is more than one sibling (i.e., a non-binary split), each child node is compared individually to the pooled set of all other siblings. - • **comparison='rest':** + * ``comparison='rest':`` Compare to all other leaves in the tree not descended from the given node. The `test` parameter defines how the two groups are compared: - • **test='permutation':** + * ``test='permutation':`` a two-sided permutation test is performed by repeatedly shuffling the pooled rows (group1 + group2), applying the ``aggregate`` function, and then recomputing the split statistic using the `metric` function. @@ -146,13 +146,20 @@ def partition_test( ``comb(n_left + n_right, n_left)``). The p-value is computed with standard +1 smoothing: - pval = ( #{ \|perm_stat\| >= \|observed\| } + 1 ) / ( permutations_performed + 1 ) + .. math:: - • **test='test-t':** + p_\text{val} = + \frac{ + \#\{\,|\mathrm{perm\_stat}| \ge |\mathrm{observed}|\,\} + 1 + }{ + N_\text{perm} + 1 + } + + * ``test='test-t':`` a two-sided t-test is performed for each group. Note that for small numbers of leaves the p-value of this t-test can be unreliable. - • **test=None:** + * ``test=None:`` no statistical test is performed; only the partition statistic is computed. P-values are calculated as long as both groups have at least ``min_group_leaves`` leaves; @@ -217,6 +224,13 @@ def partition_test( - P-value for the partition test at that edge (if performed). * `tdata.obst[tree].edges[f"{key_added}_metric"]` : `float` - Metric value for the partition at that edge (only if test="permutation"). + + Examples + -------- + Identify clades with the highest expression of "elt-2": + + >>> tdata = py.datasets.packer19() + >>> py.tl.partition_test(tdata, keys=["elt-2"], test="t-test", comparison="rest") """ _set_random_state(random_state) if isinstance(keys, str): diff --git a/src/pycea/tl/sort.py b/src/pycea/tl/sort.py index f3eb06a..e0a58ab 100755 --- a/src/pycea/tl/sort.py +++ b/src/pycea/tl/sort.py @@ -18,9 +18,7 @@ def _sort_tree(tree: nx.DiGraph, key: str, reverse: bool = False) -> nx.DiGraph: # children. children = list(tree.successors(node)) try: - sorted_children = sorted( - children, key=lambda x: tree.nodes[x][key], reverse=reverse - ) + sorted_children = sorted(children, key=lambda x: tree.nodes[x][key], reverse=reverse) except KeyError as err: raise KeyError( f"Node {next(iter(children))} does not have a {key} attribute.", @@ -28,22 +26,25 @@ def _sort_tree(tree: nx.DiGraph, key: str, reverse: bool = False) -> nx.DiGraph: ) from err # Capture edge attributes prior to removal - edge_data = { - child: tree.get_edge_data(node, child, default={}) for child in children - } + edge_data = {child: tree.get_edge_data(node, child, default={}) for child in children} # Remove existing edges and re-add them in the sorted order with # their associated metadata. tree.remove_edges_from((node, child) for child in children) - tree.add_edges_from( - (node, child, edge_data[child]) for child in sorted_children - ) + tree.add_edges_from((node, child, edge_data[child]) for child in sorted_children) return tree def sort(tdata: td.TreeData, key: str, reverse: bool = False, tree: str | Sequence[str] | None = None) -> None: """Reorders branches based on a node attribute. + For every internal node with multiple children, reorders outgoing edges (child branches) + based on a given node attribute. The order is applied in-place preserving all node and edge metadata. + + Sorting allows for consistent or meaningful ordering of descendants in + tree visualizations, e.g., ordering by inferred ancestral state values + or other numeric or categorical metrics. + Parameters ---------- tdata @@ -58,6 +59,16 @@ def sort(tdata: td.TreeData, key: str, reverse: bool = False, tree: str | Sequen Returns ------- Returns `None` and does not set any fields. + + Examples + -------- + Sort branches by number of descendant leaves: + + >>> tdata = py.datasets.yang22() + >>> tdata.obs["n"] = 1 + >>> py.tl.ancestral_states(tdata, keys="n", method="sum") + >>> py.tl.sort(tdata, key="n") + """ trees = get_trees(tdata, tree) for name, t in trees.items(): diff --git a/src/pycea/tl/tree_distance.py b/src/pycea/tl/tree_distance.py index 4e2201d..dd2c3f8 100755 --- a/src/pycea/tl/tree_distance.py +++ b/src/pycea/tl/tree_distance.py @@ -143,7 +143,28 @@ def tree_distance( tree: str | Sequence[Any] | None = None, copy: Literal[True, False] = False, ) -> None | sp.sparse.csr_matrix | np.ndarray: - """Computes tree distances between observations. + r"""Computes tree distances between observations. + + This function calculates distances between observations (typically tree leaves) + based on their positions and depths in the tree. It supports *lowest common ancestor (lca)* + and *path* distances. + + Given two nodes :math:`i` and :math:`j` in a rooted tree, with depths + :math:`d_i` and :math:`d_j`, and with their lowest common ancestor having + depth :math:`d_{LCA(i,j)}`: + + .. math:: + + D_{ij}^{lca} = d_{LCA(i,j)} + + .. math:: + + D_{ij}^{path} = || d_i + d_j - 2 d_{LCA(i,j)} || + + :math:`D_{ij}^{lca}` represents the depth of the node’s shared ancestor + (larger values indicate greater shared ancestry). In contrast, :math:`D_{ij}^{path}` + measures the distance along the tree between two nodes (smaller values indicate + closer proximity). Parameters ---------- @@ -191,6 +212,17 @@ def tree_distance( - Connectivity between observations. * `tdata.obs['{key_added}_distances']` : :class:`Series ` (dtype `float`) if `obs` is a string. - Distance from specified observation to others. + + Examples + -------- + Compute full pairwise path distances for tree leaves: + + >>> tdata = py.datasets.koblan25() + >>> py.tl.tree_distance(tdata, metric="path") + + Sample 1000 random LCA distances using node 'time' as depth: + + >>> py.tl.tree_distance(tdata, metric="lca", sample_n=1000, depth_key="time") """ # Setup _set_random_state(random_state) diff --git a/src/pycea/tl/tree_neighbors.py b/src/pycea/tl/tree_neighbors.py index 6d98358..830dd64 100755 --- a/src/pycea/tl/tree_neighbors.py +++ b/src/pycea/tl/tree_neighbors.py @@ -120,6 +120,17 @@ def tree_neighbors( ) -> None | tuple[sp.sparse.csr_matrix, sp.sparse.csr_matrix]: """Identifies neighbors in the tree. + For each leaf, this function identifies neighbors according to a chosen + tree distance `metric` and either: + + * the top-``n_neighbors`` closest leaves (ties broken at random) + + * all leaves within a distance threshold ``max_dist``. + + Results are stored as sparse connectivities and distances, or returned when + ``copy=True``. You can restrict the operation to a subset of leaves via + ``obs`` and/or to specific trees via ``tree``. + Parameters ---------- tdata @@ -165,6 +176,13 @@ def tree_neighbors( - Set of neighbors for each observation. * `tdata.obs['{key_added}_neighbors']` : :class:`Series ` (dtype `bool`) if `obs` is a string. - Set of neighbors for specified observation. + + Examples + -------- + Identify the 5 closest neighbors for each leaf based on path distance: + + >>> tdata = py.datasets.koblan25() + >>> py.tl.tree_neighbors(tdata, n_neighbors=5, depth_key="time") """ # Setup _set_random_state(random_state) diff --git a/src/pycea/utils.py b/src/pycea/utils.py index f42ce34..fe4bb42 100755 --- a/src/pycea/utils.py +++ b/src/pycea/utils.py @@ -1,7 +1,7 @@ from __future__ import annotations import random -from collections.abc import Mapping, Sequence, Hashable +from collections.abc import Hashable, Mapping, Sequence from typing import Any, cast import networkx as nx @@ -203,7 +203,7 @@ def get_keyed_obs_data( if slot is None: raise ValueError( f"Key {key!r} is invalid! You must pass a valid observation annotation. " - f"One of obs_keys, var_names, obsm_keys, obsp_keys." + f"One of obs.keys(), var_names, obsm.keys(), obsp.keys()." ) else: raise ValueError( @@ -308,6 +308,7 @@ def _check_tree_overlap( else: raise ValueError("Tree keys must be a string, list of strings, or None.") + def _get_descendant_leaves(t: nx.DiGraph) -> dict[Hashable, list]: """ Return a dict mapping each node -> list of leaf descendants (including itself if it is a leaf). @@ -321,9 +322,9 @@ def _get_descendant_leaves(t: nx.DiGraph) -> dict[Hashable, list]: # One topological sort, then a single bottom-up sweep for u in reversed(list(nx.topological_sort(t))): children = list(t.successors(u)) - if not children: # leaf + if not children: # leaf leaves_sets[u] = {u} - else: # union of children's leaf sets + else: # union of children's leaf sets s = set() for v in children: s |= leaves_sets[v] @@ -331,4 +332,3 @@ def _get_descendant_leaves(t: nx.DiGraph) -> dict[Hashable, list]: # Convert sets to lists (unsorted to avoid type-comparison issues) return {u: list(s) for u, s in leaves_sets.items()} -