Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,31 @@ and this project adheres to [Semantic Versioning][].
[keep a changelog]: https://keepachangelog.com/en/1.0.0/
[semantic versioning]: https://semver.org/spec/v2.0.0.html

## [Unreleased]
## [Unrealeased]

### Added

### Changed

### Fixed

## [0.1.0] - 2025-09-19

### Added

- `pycea.get` module for data retrieval (#32)
- Added `pycea.tl.n_extant` and `pycea.pl.n_extant` for calculating and plotting the number of extant lineages over time (#33)
- Added `pycea.tl.fitness` for estimating fitness of nodes in a tree (#35)

### Changed

- Only require `tree` parameter to be specified when trees in `tdata` actually overlap (#37)

### Fixed

- Sorting now preserves edge metadata (#31)

## [0.0.1]

### Added

Expand Down
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ requires = [ "hatchling" ]

[project]
name = "pycea-lineage"
version = "0.0.1"
version = "0.1.0"
description = "Scverse lineage tracing toolkit"
readme = "README.md"
license = { file = "LICENSE" }
Expand Down Expand Up @@ -34,7 +34,7 @@ dependencies = [
"scikit-learn",
"scipy",
"session-info",
"treedata>=0.2",
"treedata>=0.2.2",
]
optional-dependencies.dev = [
"pre-commit",
Expand Down Expand Up @@ -78,6 +78,9 @@ scripts.clean = "git clean -fdX -- {args:docs}"
[tool.hatch.envs.hatch-test]
features = [ "test" ]

[tool.hatch.build.targets.wheel]
packages = [ "src/pycea" ]

[tool.ruff]
line-length = 120
src = [ "src" ]
Expand Down Expand Up @@ -145,6 +148,3 @@ skip = [
"docs/references.md",
"docs/notebooks/example.ipynb",
]

[tool.hatch.build.targets.wheel]
packages = ["src/pycea"]
17 changes: 0 additions & 17 deletions src/pycea/pl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,23 +347,6 @@ def _series_to_rgb_array(
return rgb_array


def _check_tree_overlap(
tdata: td.TreeData,
tree_keys: str | Sequence[str] | None = None,
) -> None:
"""Check single tree is requested when allow_overlap is True"""
if tree_keys is None:
if tdata.allow_overlap and len(tdata.obst.keys()) > 1:
raise ValueError("Must specify a tree when tdata.allow_overlap is True.")
elif isinstance(tree_keys, str):
pass
elif isinstance(tree_keys, Sequence):
if tdata.allow_overlap:
raise ValueError("Cannot request multiple trees when tdata.allow_overlap is True.")
else:
raise ValueError("Tree keys must be a string, list of strings, or None.")


def _get_colors(
tdata: td.TreeData,
key: str,
Expand Down
3 changes: 1 addition & 2 deletions src/pycea/pl/plot_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
from matplotlib.axes import Axes
from matplotlib.collections import LineCollection

from pycea.utils import get_keyed_edge_data, get_keyed_node_data, get_keyed_obs_data, get_trees
from pycea.utils import _check_tree_overlap, get_keyed_edge_data, get_keyed_node_data, get_keyed_obs_data, get_trees

from ._docs import _doc_params, doc_common_plot_args
from ._legend import _categorical_legend, _cbar_legend, _render_legends
from ._utils import (
_check_tree_overlap,
_get_categorical_colors,
_get_categorical_markers,
_get_colors,
Expand Down
17 changes: 1 addition & 16 deletions src/pycea/pp/setup_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pandas as pd
import treedata as td

from pycea.utils import get_keyed_leaf_data, get_keyed_node_data, get_root, get_trees
from pycea.utils import _check_tree_overlap, get_keyed_leaf_data, get_keyed_node_data, get_root, get_trees


def _add_depth(tree, depth_key):
Expand All @@ -17,21 +17,6 @@ def _add_depth(tree, depth_key):
nx.set_node_attributes(tree, depths, depth_key)


def _check_tree_overlap(tdata, tree_keys):
"""Check single tree is requested when allow_overlap is True"""
if tree_keys is None:
tree_keys = tdata.obst.keys()
if tdata.allow_overlap and len(tree_keys) > 1:
raise ValueError("Must specify a tree when tdata.allow_overlap is True.")
elif isinstance(tree_keys, str):
pass
elif isinstance(tree_keys, Sequence):
if tdata.allow_overlap:
raise ValueError("Cannot request multiple trees when tdata.allow_overlap is True.")
else:
raise ValueError("Tree keys must be a string, list of strings, or None.")


@overload
def add_depth(
tdata: td.TreeData,
Expand Down
15 changes: 0 additions & 15 deletions src/pycea/tl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,21 +97,6 @@ def _assert_param_xor(params):
return None


def _check_tree_overlap(tdata, tree_keys):
"""Check single tree is requested when allow_overlap is True"""
if tree_keys is None:
tree_keys = tdata.obst.keys()
if tdata.allow_overlap and len(tree_keys) > 1:
raise ValueError("Must specify a tree when tdata.allow_overlap is True.")
elif isinstance(tree_keys, str):
pass
elif isinstance(tree_keys, Sequence):
if tdata.allow_overlap:
raise ValueError("Cannot request multiple trees when tdata.allow_overlap is True.")
else:
raise ValueError("Tree keys must be a string, list of strings, or None.")


def _remove_attribute(tree, key, nodes=True, edges=True):
"""Remove node attribute from tree if it exists"""
if nodes:
Expand Down
6 changes: 2 additions & 4 deletions src/pycea/tl/ancestral_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
import pandas as pd
import treedata as td

from pycea.utils import get_keyed_node_data, get_keyed_obs_data, get_root, get_trees

from ._utils import _check_tree_overlap
from pycea.utils import _check_tree_overlap, get_keyed_node_data, get_keyed_obs_data, get_root, get_trees


def _most_common(arr: np.ndarray) -> Any:
Expand Down Expand Up @@ -337,5 +335,5 @@ def ancestral_states(
nx.set_node_attributes(t, data[key].to_dict(), key_added)
_ancestral_states(t, key_added, method, costs, missing_state, default_state)
if copy:
return get_keyed_node_data(tdata, keys_added, tree_keys)
return get_keyed_node_data(tdata, keys_added, tree_keys, slot="obst")
return get_keyed_node_data(tdata, keys_added, tree_keys)
19 changes: 15 additions & 4 deletions src/pycea/tl/clades.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,16 @@
import pandas as pd
import treedata as td

from pycea.utils import check_tree_has_key, get_keyed_leaf_data, get_root, get_trees
from pycea.utils import (
_check_tree_overlap,
check_tree_has_key,
get_keyed_leaf_data,
get_keyed_node_data,
get_root,
get_trees,
)

from ._utils import _check_tree_overlap, _remove_attribute
from ._utils import _remove_attribute


def _nodes_at_depth(tree, parent, nodes, depth, depth_key):
Expand Down Expand Up @@ -145,8 +152,12 @@ def clades(
tree_lcas["tree"] = key
lcas.append(tree_lcas)
# Update TreeData and return
leaf_to_clade = get_keyed_leaf_data(tdata, key_added, tree_keys)
tdata.obs[key_added] = tdata.obs.index.map(leaf_to_clade[key_added])
if tdata.alignment == "leaves":
node_to_clade = get_keyed_leaf_data(tdata, key_added, tree_keys)
else:
node_to_clade = get_keyed_node_data(tdata, key_added, tree_keys, slot="obst")
node_to_clade.index = node_to_clade.index.droplevel(0)
tdata.obs[key_added] = tdata.obs.index.map(node_to_clade[key_added])
if copy:
return pd.concat(lcas)
return pd.concat(lcas)
12 changes: 10 additions & 2 deletions src/pycea/tl/fitness.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,18 @@
import treedata as td
from scipy.interpolate import interp1d

from pycea.utils import check_tree_has_key, get_keyed_leaf_data, get_keyed_node_data, get_leaves, get_root, get_trees
from pycea.utils import (
_check_tree_overlap,
check_tree_has_key,
get_keyed_leaf_data,
get_keyed_node_data,
get_leaves,
get_root,
get_trees,
)

from ._metrics import _path_distance
from ._utils import _check_tree_overlap, _set_random_state
from ._utils import _set_random_state
from .tree_distance import _tree_distance

non_negativity_cutoff = 1e-20
Expand Down
4 changes: 1 addition & 3 deletions src/pycea/tl/n_extant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
import pandas as pd
import treedata as td

from pycea.utils import check_tree_has_key, get_keyed_node_data, get_root, get_trees

from ._utils import _check_tree_overlap
from pycea.utils import _check_tree_overlap, check_tree_has_key, get_keyed_node_data, get_root, get_trees


@overload
Expand Down
10 changes: 8 additions & 2 deletions src/pycea/tl/tree_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,18 @@
import scipy as sp
import treedata as td

from pycea.utils import check_tree_has_key, get_obs_to_tree_map, get_root, get_tree_to_obs_map, get_trees
from pycea.utils import (
_check_tree_overlap,
check_tree_has_key,
get_obs_to_tree_map,
get_root,
get_tree_to_obs_map,
get_trees,
)

from ._metrics import _get_tree_metric, _TreeMetric
from ._utils import (
_check_previous_params,
_check_tree_overlap,
_csr_data_mask,
_format_keys,
_set_distances_and_connectivities,
Expand Down
3 changes: 1 addition & 2 deletions src/pycea/tl/tree_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@
import scipy as sp
import treedata as td

from pycea.utils import check_tree_has_key, get_leaves, get_trees
from pycea.utils import _check_tree_overlap, check_tree_has_key, get_leaves, get_trees

from ._metrics import _get_tree_metric, _TreeMetric
from ._utils import (
_assert_param_xor,
_check_previous_params,
_check_tree_overlap,
_csr_data_mask,
_set_distances_and_connectivities,
_set_random_state,
Expand Down
17 changes: 17 additions & 0 deletions src/pycea/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,3 +290,20 @@ def get_obs_to_tree_map(tdata: td.TreeData, tree_keys: set[str] | Sequence[str]
if obs in node_to_tree:
obs_to_tree[obs] = list(node_to_tree[obs] & tree_keys)[0]
return obs_to_tree


def _check_tree_overlap(
tdata: td.TreeData,
tree_keys: str | Sequence[str] | None = None,
) -> None:
"""Check single tree is requested when allow_overlap is True"""
if tree_keys is None:
if tdata.has_overlap:
raise ValueError("Must specify a tree when tdata.has_overlap is True.")
elif isinstance(tree_keys, str):
pass
elif isinstance(tree_keys, Sequence):
if tdata.has_overlap:
raise ValueError("Cannot request multiple trees when tdata.has_overlap is True.")
else:
raise ValueError("Tree keys must be a string, list of strings, or None.")
20 changes: 20 additions & 0 deletions tests/test_ancestral_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,19 @@ def tdata():
yield tdata


@pytest.fixture
def nodes_tdata():
tree = nx.DiGraph([("root", "B"), ("root", "C"), ("C", "D"), ("C", "E")])
spatial = np.array([[np.nan, np.nan], [0, 1], [1, 1], [2, 1], [4, 4]])
nodes_tdata = td.TreeData(
obs=pd.DataFrame(index=["root", "B", "C", "D", "E"]),
obst={"tree": tree},
obsm={"spatial": spatial}, # type: ignore
alignment="nodes",
)
yield nodes_tdata


def test_ancestral_states(tdata):
# Mean
states = ancestral_states(tdata, "value", method="mean", copy=True)
Expand Down Expand Up @@ -100,6 +113,13 @@ def test_ancestral_states_sankoff(tdata):
assert tdata.obst["tree1"].nodes["root"]["characters"] == ["2", "1"]


def test_ancestral_states_nodes_tdata(nodes_tdata):
states = ancestral_states(nodes_tdata, "spatial", method="mean", copy=True)
print(nodes_tdata.obst["tree"].nodes["root"]["spatial"])
print(states)
assert states.loc[("tree", "root"), "spatial"] == [2.0, 2.0]


def test_ancestral_states_invalid(tdata):
with pytest.raises(ValueError):
ancestral_states(tdata, "characters", method="sankoff")
Expand Down
12 changes: 12 additions & 0 deletions tests/test_clades.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import networkx as nx
import numpy as np
import pandas as pd
import pytest
import treedata as td
Expand All @@ -20,6 +21,12 @@ def tdata(tree):
yield tdata


@pytest.fixture
def nodes_tdata(tree):
nodes_tdata = td.TreeData(obs=pd.DataFrame(index=["A", "B", "C", "D", "E"]), obst={"tree": tree}, alignment="nodes")
yield nodes_tdata


def test_nodes_at_depth(tree):
assert _nodes_at_depth(tree, "A", [], 0, "depth") == ["A"]
assert _nodes_at_depth(tree, "A", [], 1, "depth") == ["B", "C"]
Expand Down Expand Up @@ -77,6 +84,11 @@ def test_clades_multiple_trees():
assert pd.isna(tdata.obs.loc["B", "test"])


def test_clades_nodes_tdata(nodes_tdata):
clades(nodes_tdata, depth=1)
assert nodes_tdata.obs["clade"].tolist() == [np.nan, "0", "1", "1", "1"]


def test_clades_dtype(tdata):
clades(tdata, depth=0, dtype=int)
assert tdata.obs["clade"].dtype == int
Expand Down
Loading