Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,6 @@ environment.yml

# Datasets
/datasets/

# Development
/dev/
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
tl.tree_distance
tl.tree_neighbors
tl.n_extant
tl.fitness
```

## Plotting
Expand Down
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

notebooks/getting-started
notebooks/plotting
notebooks/growth-dynamics

api.md
changelog.md
Expand Down
679 changes: 679 additions & 0 deletions docs/notebooks/growth-dynamics.ipynb

Large diffs are not rendered by default.

14 changes: 1 addition & 13 deletions docs/notebooks/plotting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"id": "0694af36",
"metadata": {},
"source": [
"# Tree plotting\n",
"# Plotting trees\n",
"\n",
"Pycea implements an intuitive tree plotting language where complex plots can be built from simple components:\n",
"\n",
Expand Down Expand Up @@ -418,18 +418,6 @@
"display_name": "pycea",
"language": "python",
"name": "pycea"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
Expand Down
20 changes: 20 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,23 @@ @article{Richards_2013
keywords = {Cell lineage, Image analysis, Microscopy, Robustness, Stress},
pages = {12--23},
}

@article {Neher_2014,
article_type = {journal},
title = {Predicting evolution from the shape of genealogical trees},
author = {Neher, Richard A and Russell, Colin A and Shraiman, Boris I},
editor = {McVean, Gil},
volume = 3,
year = 2014,
month = {nov},
pub_date = {2014-11-11},
pages = {e03568},
citation = {eLife 2014;3:e03568},
doi = {10.7554/eLife.03568},
url = {https://doi.org/10.7554/eLife.03568},
abstract = {Given a sample of genome sequences from an asexual population, can one predict its evolutionary future? Here we demonstrate that the branching patterns of reconstructed genealogical trees contains information about the relative fitness of the sampled sequences and that this information can be used to predict successful strains. Our approach is based on the assumption that evolution proceeds by accumulation of small effect mutations, does not require species specific input and can be applied to any asexual population under persistent selection pressure. We demonstrate its performance using historical data on seasonal influenza A/H3N2 virus. We predict the progenitor lineage of the upcoming influenza season with near optimal performance in 30\% of cases and make informative predictions in 16 out of 19 years. Beyond providing a tool for prediction, our ability to make informative predictions implies persistent fitness variation among circulating influenza A/H3N2 viruses.},
keywords = {vaccine strain selection, adaptive evolution, population genetics},
journal = {eLife},
issn = {2050-084X},
publisher = {eLife Sciences Publications, Ltd},
}
4 changes: 2 additions & 2 deletions src/pycea/pl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def layout_nodes_and_branches(
children = list(tree.successors(node))
min_lon = min(node_coords[child][1] for child in children)
max_lon = max(node_coords[child][1] for child in children)
node_coords[node] = (tree.nodes[node].get(depth_key), (min_lon + max_lon) / 2)
node_coords[node] = (tree.nodes[node].get(depth_key), (min_lon + max_lon) / 2) # type: ignore
# Get branch coordinates
branch_coords = {}
for parent, child in tree.edges():
Expand Down Expand Up @@ -390,7 +390,7 @@ def _get_colors(
categories = tdata.obs[key].cat.categories
if set(data.unique()).issubset(categories):
data = pd.Series(
pd.Categorical(data, categories=categories),
pd.Categorical(data, categories=categories, ordered=True),
index=data.index,
)
color_map = _get_categorical_colors(tdata, str(key), data, palette)
Expand Down
18 changes: 17 additions & 1 deletion src/pycea/pl/plot_n_extant.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def n_extant(
n_extant_key: str | None = None,
stat: Literal["count", "proportion", "percent"] = "count",
order: Sequence[str] | None = None,
palette: dict[str, str] | None = None,
na_color: str | None = "lightgray",
legend: bool | None = None,
ax: Axes | None = None,
Expand All @@ -49,6 +50,15 @@ def n_extant(
Statistic to compute for the ribbons: 'count', 'fraction', or 'percent'.
order
Order of group categories in the stack.
palette
Colors to use for plotting categorical annotation groups.
The palette can be a valid :class:`~matplotlib.colors.ListedColormap` name
(`'Set2'`, `'tab20'`, …), a :class:`~cycler.Cycler` object, a dict mapping
categories to colors, or a sequence of colors. Colors must be valid to
matplotlib. (see :func:`~matplotlib.colors.is_color_like`).
If `None`, `mpl.rcParams["axes.prop_cycle"]` is used unless the categorical
variable already has colors stored in `tdata.uns["{var}_colors"]`.
If provided, values of `tdata.uns["{var}_colors"]` will be set.
na_color
The color to use for annotations with missing data.
legend
Expand Down Expand Up @@ -87,6 +97,7 @@ def n_extant(
if isinstance(color, Sequence) and not isinstance(color, str):
df["_group"] = df[list(color)].astype(str).agg("_".join, axis=1)
legend_title = "_".join(color)
df["_group"] = df["_group"].astype("category")
elif color is not None:
df["_group"] = df[color]
legend_title = str(color)
Expand Down Expand Up @@ -115,7 +126,12 @@ def n_extant(
ax = cast(Axes, ax)

color_key = legend_title
color_map = _get_categorical_colors(tdata, color_key, df["_group"])
color_map = _get_categorical_colors(
tdata,
color_key,
palette=palette,
data=df["_group"].cat.remove_categories("NA") if df["_group"].dtype.name == "category" else df["_group"],
)
if (na_color is not None) and ("NA" in df["_group"].values):
color_map["NA"] = na_color
legends: list[dict[str, Any]] = []
Expand Down
4 changes: 4 additions & 0 deletions src/pycea/pl/plot_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,8 @@ def nodes(
kwargs.update({"color": color})
elif isinstance(color, str):
color_data = get_keyed_node_data(tdata, color, tree_keys, slot=slot)[color]
if len(color_data) == 0:
raise ValueError(f"Key {color!r} is not present in any node.")
colors, color_legend, n_categories = _get_colors(
tdata, color, color_data, plot_nodes, palette, cmap, vmin, vmax, na_color, marker_type="marker"
)
Expand All @@ -319,6 +321,8 @@ def nodes(
kwargs.update({"s": size})
elif isinstance(size, str):
size_data = get_keyed_node_data(tdata, size, tree_keys, slot=slot)[size]
if len(size_data) == 0:
raise ValueError(f"Key {size!r} is not present in any node.")
marker_sizes, size_legend, n_categories = _get_sizes(
tdata, size, size_data, plot_nodes, sizes, na_value=na_size, marker_type="marker"
)
Expand Down
5 changes: 3 additions & 2 deletions src/pycea/tl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .ancestral_states import ancestral_states
from .autocorr import autocorr
from .clades import clades
from .distance import compare_distance, distance
from .neighbor_distance import neighbor_distance
from .fitness import fitness
from .n_extant import n_extant
from .neighbor_distance import neighbor_distance
from .sort import sort
from .tree_distance import tree_distance
from .tree_neighbors import tree_neighbors
from .autocorr import autocorr
12 changes: 12 additions & 0 deletions src/pycea/tl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,15 @@ def _check_tree_overlap(tdata, tree_keys):
raise ValueError("Cannot request multiple trees when tdata.allow_overlap is True.")
else:
raise ValueError("Tree keys must be a string, list of strings, or None.")


def _remove_attribute(tree, key, nodes=True, edges=True):
"""Remove node attribute from tree if it exists"""
if nodes:
for node in tree.nodes:
if key in tree.nodes[node]:
del tree.nodes[node][key]
if edges:
for u, v in tree.edges:
if key in tree.edges[u, v]:
del tree.edges[u, v][key]
9 changes: 4 additions & 5 deletions src/pycea/tl/clades.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from pycea.utils import check_tree_has_key, get_keyed_leaf_data, get_root, get_trees

from ._utils import _check_tree_overlap
from ._utils import _check_tree_overlap, _remove_attribute


def _nodes_at_depth(tree, parent, nodes, depth, depth_key):
Expand Down Expand Up @@ -46,10 +46,6 @@ def _clades(tree, depth, depth_key, clades, clade_key, name_generator, update):
pass
else:
raise ValueError("Must specify either clades or depth.")
if not update:
for node in tree.nodes:
if clade_key in tree.nodes[node]:
del tree.nodes[node][clade_key]
for node, clade in clades.items():
# Leaf
if tree.out_degree(node) == 0:
Expand Down Expand Up @@ -142,6 +138,8 @@ def clades(
name_generator = _clade_name_generator(dtype=dtype)
lcas = []
for key, t in trees.items():
if not update:
_remove_attribute(t, key_added)
tree_lcas = _clades(t, depth, depth_key, clades, key_added, name_generator, update)
tree_lcas = pd.DataFrame(tree_lcas.items(), columns=["node", key_added])
tree_lcas["tree"] = key
Expand All @@ -151,3 +149,4 @@ def clades(
tdata.obs[key_added] = tdata.obs.index.map(leaf_to_clade[key_added])
if copy:
return pd.concat(lcas)
return pd.concat(lcas)
15 changes: 14 additions & 1 deletion src/pycea/tl/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,19 @@ def _sample_pairs(pairs: Any, sample_n: int | None, n_obs: int) -> Any:
return pairs


def _pairwise_with_nans(X, metric_fn):
"""Compute pairwise distances with NaNs"""
n = X.shape[0]
distances = np.full((n, n), np.nan)
# rows without any NaN
clean_idx = np.where(~np.isnan(X).any(axis=1))[0]
if clean_idx.size == 0:
return distances # all rows have NaNs
D = metric_fn.pairwise(X[clean_idx], X[clean_idx])
distances[np.ix_(clean_idx, clean_idx)] = D
return distances


@overload
def distance(
tdata: td.TreeData,
Expand Down Expand Up @@ -168,7 +181,7 @@ def distance(
# Distance given indices
elif obs is None or (isinstance(obs, Sequence) and isinstance(obs[0], str)):
if obs is None:
distances = metric_fn.pairwise(X)
distances = _pairwise_with_nans(X, metric_fn) if np.isnan(X).any() else metric_fn.pairwise(X)
else:
idx = [tdata.obs_names.get_loc(o) for o in obs]
distances = metric_fn.pairwise(X[idx]) # type: ignore
Expand Down
Loading
Loading