Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ and this project adheres to [Semantic Versioning][].

### Added

- `pycea.tl.partition_test` to test for statistically significant differences between leaf partitions.
- `pycea.tl.partition_test` to test for statistically significant differences between leaf partitions. (#40)

### Changed

- Replaced `tdata.obs_keys()` with `tdata.obs.keys()` to conform with anndata API changes. (#41)

### Fixed

- Replaced `tdata.obs_keys()` with `tdata.obs.keys()` to conform with anndata API changes. (#41)
- Fixed node plotting when `isinstance(nodes,str)`. (#39)

## [0.1.0] - 2025-09-19

Expand Down
22 changes: 21 additions & 1 deletion src/pycea/get/palette.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,17 @@ def palette(
"""
Get color palette for a given key.

This function gets the mapping from category → color for a given key
in ``tdata``. If no customizations are provided, the function will return
a previously stored palette in ``tdata.uns`` if it exists. Otherwise,
a new palette is generated.

Parameters
----------
tdata
The `treedata.TreeData` object.
key
A key from `obs_keys`, `obsm_keys`, or `obsp_keys` to generate a color palette for.
A key from `obs.keys()`, `obsm.keys()`, or `obsp.keys()` to generate a color palette for.
custom
A dictionary mapping specific values to colors (e.g., `{"category1": "red"}`).
cmap
Expand All @@ -94,6 +99,21 @@ def palette(
Returns
-------
palette - Color palette for the given key

Examples
--------
Get character color palette with saturation adjusted by indel probability:

>>> tdata = py.datasets.yang22()
>>> indel_palette = py.get.palette(
... tdata,
... "characters",
... custom={"-": "white", "*": "lightgrey"},
... cmap="gist_rainbow",
... priors=tdata.uns["priors"],
... sort="random",
... )

"""
# Setup
tdata._sanitize()
Expand Down
93 changes: 91 additions & 2 deletions src/pycea/pl/plot_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ def branches(
"""\
Plot the branches of a tree.

Plots the branches of one or more trees stored in ``tdata.obst`` as a
:class:`matplotlib.collections.LineCollection`. Branch appearance (`color` and `linewidth`)
can be fixed scalars or set based on edge attributes (continuous or categorical).
Coloring of continuous variables is based on a colormap (`cmap`), while categorical
variables can be colored using a custom `palette` or a default categorical color set.
Polar coordinates are used when `polar` is True.

Parameters
----------
tdata
Expand Down Expand Up @@ -94,6 +101,22 @@ def branches(
Returns
-------
ax - The axes that the plot was drawn on.

Notes
-----
* If ``ax`` is provided the coordinate system must match the ``polar`` setting.
* Continuous color attributes use ``cmap`` with ``vmin``/``vmax`` normalization.

Examples
--------
Plot tree branches:

>>> tdata = py.datasets.packer19()
>>> py.pl.branches(tdata, depth_key="time")

Plot tree branches in polar coordinates, colored by clade:

>>> py.pl.branches(tdata, depth_key="time", polar=True, color="clade")
""" # noqa: D205
# Setup
tdata._sanitize()
Expand Down Expand Up @@ -207,6 +230,11 @@ def nodes(
"""\
Plot the nodes of a tree.

Plot the nodes of one or more trees from ``tdata.obst`` on the current axes using
:func:`matplotlib.pyplot.scatter`. Appearance can be fixed (single color/marker/size) or set based on
node attributes (continuous or categorical). You can plot only leaves, only
internal nodes, all nodes, or an explicit list of node names.

Parameters
----------
tdata
Expand Down Expand Up @@ -249,6 +277,25 @@ def nodes(
Returns
-------
ax - The axes that the plot was drawn on.

Notes
-----
* Must call :func:`pycea.pl.branches` or :func:`pycea.pl.tree` before calling this function.
* Continuous color attributes use ``cmap`` with ``vmin``/``vmax`` normalization.

Examples
--------
Plot internal nodes colored by depth:

>>> tdata = py.datasets.packer19()
>>> py.pl.branches(tdata, depth_key="time")
>>> py.pl.nodes(tdata, nodes="internal", color="time", cmap="plasma")

Color nodes by "elt-2" expression and highlight the "E" node with a star marker:

>>> py.pl.branches(tdata, depth_key="time")
>>> py.pl.nodes(tdata, color="elt-2", nodes="all")
>>> py.pl.nodes(tdata, color="red", nodes="EMS", style="*", size=200, slot="obst")
""" # noqa: D205
# Setup
kwargs = kwargs if kwargs else {}
Expand Down Expand Up @@ -282,6 +329,8 @@ def nodes(
elif nodes == "internal":
plot_nodes = [node for node in all_nodes if node[1] not in attrs["leaves"]]
elif isinstance(nodes, Sequence):
if isinstance(nodes, str):
nodes = [nodes]
if len(attrs["tree_keys"]) > 1 and len(tree_keys) > 1:
raise ValueError("Multiple trees are present. To plot a list of nodes, you must specify the tree.")
plot_nodes = [(tree_keys[0], node) for node in nodes]
Expand Down Expand Up @@ -395,12 +444,17 @@ def annotation(
"""\
Plot leaf annotations for a tree.

Plots one or more leaf annotations (small heatmap-like bars) next to the tree’s
leaves, preserving the leaf order/layout used by :func:`pycea.pl.branches`. Each key can be
continuous (colored via a `colormap`) or categorical (colored via a `palette`). Multiple
keys are stacked horizontally (or radially if your tree plot is polar).

Parameters
----------
tdata
The TreeData object.
keys
One or more `obs_keys`, `var_names`, `obsm_keys`, or `obsp_keys` to plot.
One or more `obs.keys()`, `var_names`, `obsm.keys()`, or `obsp.keys()` to plot.
width
The width of the annotation bar relative to the tree.
gap
Expand Down Expand Up @@ -432,6 +486,25 @@ def annotation(
Returns
-------
ax - The axes that the plot was drawn on.

Notes
-----
* Must call :func:`pycea.pl.branches` or :func:`pycea.pl.tree` before calling this function.
* Continuous color attributes use ``cmap`` with ``vmin``/``vmax`` normalization.

Examples
--------
Plot leaf annotations for "elt-2" and "pal-1" expression:

>>> tdata = py.datasets.packer19(tree = "observed")
>>> py.pl.branches(tdata, depth_key="time")
>>> py.pl.annotation(tdata, keys=["elt-2", "pal-1"])

Plot leaf annotation for spatial distance between leaves:

>>> py.tl.distance(tdata, key = "spatial")
>>> py.pl.branches(tdata, depth_key="time")
>>> py.pl.annotation(tdata, keys="spatial_distances", cmap="magma")
""" # noqa: D205
# Setup
if tree: # TODO: Annotate only the leaves for the given tree
Expand Down Expand Up @@ -598,12 +671,16 @@ def tree(
"""\
Plot a tree with branches, nodes, and annotations.

This function combines :func:`pycea.pl.branches`, :func:`pycea.pl.nodes`, and :func:`pycea.pl.annotation` to enable
plotting a complete tree with branches, nodes, and leaf annotations in a single call. Each component
(branches, nodes, annotations) can be customized independently using the respective parameters.

Parameters
----------
tdata
The TreeData object.
keys
One or more `obs_keys`, `var_names`, `obsm_keys`, or `obsp_keys` annotations.
One or more `obs.keys()`, `var_names`, `obsm.keys()`, or `obsp.keys()` annotations.
nodes
Either "all", "leaves", "internal", or a list of nodes to plot. Defaults to "internal" if node color, style, or size is set.
polar
Expand Down Expand Up @@ -641,6 +718,18 @@ def tree(
Returns
-------
ax - The axes that the plot was drawn on.

Notes
-----
* If ``ax`` is provided the coordinate system must match the ``polar`` setting.
* Continuous color attributes use ``cmap`` with ``vmin``/``vmax`` normalization.

Examples
--------
Plot a tree with nodes and leaves colored by "elt-2" expression:

>>> tdata = py.datasets.packer19()
>>> py.pl.tree(tdata, nodes="all", node_color="elt-2", keys="elt-2", depth_key="time")
""" # noqa: D205
# Setup
branch_legend = legend.get("branch", None) if isinstance(legend, Mapping) else legend
Expand Down
4 changes: 3 additions & 1 deletion src/pycea/tl/_aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from numpy.typing import NDArray

_AggregatorFn = Callable[[np.ndarray], np.ndarray | float]
_Aggregator = Literal["mean", "median", "sum", "var"]
_Aggregator = Literal["mean", "median", "sum", "min", "max", "var"]


def _reduce(fn, X: NDArray[np.generic]) -> NDArray[np.generic] | float:
Expand All @@ -24,6 +24,8 @@ def _var(X: NDArray[np.generic]) -> NDArray[np.generic] | float:
"mean": lambda X: _reduce(np.mean, X),
"median": lambda X: _reduce(np.median, X),
"sum": lambda X: _reduce(np.sum, X),
"min": lambda X: _reduce(np.min, X),
"max": lambda X: _reduce(np.max, X),
"var": _var,
}

Expand Down
21 changes: 19 additions & 2 deletions src/pycea/tl/ancestral_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,17 @@ def ancestral_states(
) -> pd.DataFrame | None:
"""Reconstructs ancestral states for an attribute.

This function reconstructs ancestral (internal node) states for categorical or
continuous attributes defined on tree leaves. Several reconstruction methods
are supported, ranging from simple aggregation rules to the Sankoff and Fitch-Hartigan
algorithms for discrete character data, or a custom aggregation function can be provided.

Parameters
----------
tdata
TreeData object.
keys
One or more `obs_keys`, `var_names`, `obsm_keys`, or `obsp_keys` to reconstruct.
One or more `obs.keys()`, `var_names`, `obsm.keys()`, or `obsp.keys()` to reconstruct.
method
Method to reconstruct ancestral states:

Expand All @@ -279,7 +284,7 @@ def ancestral_states(
default_state
The expected state for the root node.
costs
A pd.DataFrame with the costs of changing states (from rows to columns).
A pd.DataFrame with the costs of changing states (from rows to columns). Only used if method is 'sankoff'.
keys_added
Attribute keys of `tdata.obst[tree].nodes` where ancestral states will be stored. If `None`, `keys` are used.
tree
Expand All @@ -295,6 +300,18 @@ def ancestral_states(

* `tdata.obst[tree].nodes[key_added]` : `float` | `Object` | `List[Object]`
- Inferred ancestral states. List of states if data was an array.

Examples
--------
Infer the expression of Krt20 and Cd74 based on their mean value in descendant cells:

>>> tdata = py.datasets.yang22()
>>> py.tl.ancestral_states(tdata, keys=["Krt20", "Cd74"], method="mean")

Reconstruct ancestral character states using the Fitch-Hartigan algorithm:

>>> py.tl.ancestral_states(tdata, keys="characters", method="fitch_hartigan", missing_state=-1)

"""
if isinstance(keys, str):
keys = [keys]
Expand Down
46 changes: 44 additions & 2 deletions src/pycea/tl/autocorr.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,48 @@ def autocorr(
layer: str | None = None,
copy: Literal[True, False] = False,
) -> pd.DataFrame | None:
"""Calculate autocorrelation statistic.
r"""Calculate autocorrelation statistic.

This function computes autocorrelation for one or more variables using
either **Moran’s I** or **Geary’s C** statistic, based on a specified connectivity
graph between observations.

Mathematically, the two statistics are defined as follows:

.. math::

I =
\frac{
N \sum_{i,j} w_{i,j} (x_i - \bar{x})(x_j - \bar{x})
}{
W \sum_i (x_i - \bar{x})^2
}

C =
\frac{
(N - 1)\sum_{i,j} w_{i,j} (x_i - x_j)^2
}{
2W \sum_i (x_i - \bar{x})^2
}

where:
* :math:`N` is the number of observations,
* :math:`x_i` is the value of observation *i*,
* :math:`\bar{x}` is the mean of all observations,
* :math:`w_{i,j}` is the spatial weight between *i* and *j*, and
* :math:`W = \sum_{i,j} w_{i,j}`.

A Moran’s I value close to 1 indicates strong positive autocorrelation,
while values near 0 suggest randomness. For Geary’s C behaves inversely:
values less than 1 indicate positive autocorrelation, while values
greater than 1 indicate negative autocorrelation.

Parameters
----------
tdata
TreeData object.
keys
One or more `obs_keys`, `var_names`, `obsm_keys`, or `obsp_keys` to calculate autocorrelation for. Defaults to all 'var_names'.
One or more `obs.keys()`, `var_names`, `obsm.keys()`, or `obsp.keys()` to calculate autocorrelation for. Defaults to all 'var_names'.
connect_key
`tdata.obsp` connectivity key specifying set of neighbors for each observation.
method
Expand All @@ -116,6 +150,14 @@ def autocorr(

* `tdata.uns['moranI']` : Above DataFrame for if method is `'moran'`.
* `tdata.uns['gearyC']` : Above DataFrame for if method is `'geary'`.

Examples
--------
Estimate gene expression heritability using Moran's I autocorrelation:

>>> tdata = py.datasets.yang22()
>>> py.tl.tree_neighbors(tdata, n_neighbors=10)
>>> py.tl.autocorr(tdata, connect_key="tree_connectivities", method="moran")
"""
# Setup
if keys is None:
Expand Down
25 changes: 25 additions & 0 deletions src/pycea/tl/clades.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,20 @@ def clades(
) -> None | pd.DataFrame:
"""Marks clades in a tree.

A clade is defined by a **ancestral node**; all nodes and
edges in the ancestral node's descendant subtree inherit the same clade label.
You can specify clades in two ways:

* **Depth-based:**
Given a ``depth`` threshold, all nodes that are extant
at that depth are considered ancestral nodes. Each ancestral node and its descendants
are assigned a unique clade label.

* **Explicit mapping:**
When ``clades``, a dictionary mapping nodes to clade labels is provided,
those nodes are considered ancestral nodes. Each such node and its descendants are assigned
the corresponding clade label.

Parameters
----------
tdata
Expand Down Expand Up @@ -134,6 +148,17 @@ def clades(
- Clade assignment for each observation.
* `tdata.obst[tree].nodes[key_added]` : `Object`
- Clade assignment for each node.

Examples
--------
Mark clades at specified depth

>>> tdata = pycea.datasets.koblan25()
>>> pycea.tl.clades(tdata, depth=4, depth_key="time")

Highlight descendants of 'node6'

>>> pycea.tl.clades(tdata, clades={"node6": "node6_descendants"}, key_added="highlight")
"""
# Setup
tree_keys = tree
Expand Down
Loading
Loading