Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 26 additions & 48 deletions src/mudata/_core/mudata.py
Original file line number Diff line number Diff line change
Expand Up @@ -1312,43 +1312,49 @@ def obs_names_make_unique(self):
obs_names = [obs for a in self.mod.values() for obs in a.obs_names.values]
self._obs.index = obs_names

@property
def obs_names(self) -> pd.Index:
"""Names of variables (alias for `.obs.index`)."""
return self.obs.index

@obs_names.setter
def obs_names(self, names: Sequence[str]):
"""Set the observation names for all the nested AnnData/MuData objects."""
def _set_names(self, attr: str, axis: int, names: Sequence[str]):
if isinstance(names, pd.Index):
if not isinstance(names.name, str | type(None)):
raise ValueError(
f"MuData expects .obs.index.name to be a string or None, "
f"MuData expects .{attr}.index.name to be a string or None, "
f"but you passed a name of type {type(names.name).__name__!r}"
)
else:
names = pd.Index(names)
if not isinstance(names.name, str | type(None)):
names.name = None

mod_obs_sum = np.sum([a.n_obs for a in self.mod.values()])
if mod_obs_sum != self.n_obs:
self.update_obs()
mod_shape_sum = np.sum([a.shape[axis] for a in self.mod.values()])
if mod_shape_sum != self.shape[axis]:
self._update_attr(attr, axis=1 - axis)

if len(names) != self.n_obs:
if len(names) != self.shape[axis]:
raise ValueError(
f"The length of provided observation names {len(names)} does not match the length {self.shape[0]} of MuData.obs."
f"The length of provided observation names {len(names)} does not match the length {self.shape[axis]} of MuData.{attr}."
)

if self.is_view:
self._init_as_actual(self.copy())

self._obs.index = names
for mod in self.mod.keys():
indices = self.obsmap[mod]
self.mod[mod].obs_names = names[indices[indices != 0] - 1]
getattr(self, attr).index = names
map = getattr(self, f"{attr}map")
for modname, mod in self.mod.items():
newnames = np.empty(mod.shape[axis], dtype=object)
modmap = map[modname].ravel()
mask = modmap > 0
newnames[modmap[mask] - 1] = names[mask]
setattr(mod, f"{attr}_names", newnames)

self.update_obs()
self._update_attr(attr, axis=1 - axis)

@property
def obs_names(self) -> pd.Index:
"""Names of variables (alias for `.obs.index`)."""
return self.obs.index

@obs_names.setter
def obs_names(self, names: Sequence[str]):
self._set_names("obs", 0, names)

@property
def var(self) -> pd.DataFrame:
Expand Down Expand Up @@ -1442,35 +1448,7 @@ def var_names(self) -> pd.Index:
@var_names.setter
def var_names(self, names: Sequence[str]):
"""Set the variable names for all the nested AnnData/MuData objects."""
if isinstance(names, pd.Index):
if not isinstance(names.name, str | type(None)):
raise ValueError(
f"MuData expects .var.index.name to be a string or None, "
f"but you passed a name of type {type(names.name).__name__!r}"
)
else:
names = pd.Index(names)
if not isinstance(names.name, str | type(None)):
names.name = None

mod_var_sum = np.sum([a.n_vars for a in self.mod.values()])
if mod_var_sum != self.n_vars:
self.update_var()

if len(names) != self.n_vars:
raise ValueError(
f"The length of provided variable names {len(names)} does not match the length {self.shape[0]} of MuData.var."
)

if self.is_view:
self._init_as_actual(self.copy())

self._var.index = names
for mod in self.mod.keys():
indices = self.varmap[mod]
self.mod[mod].var_names = names[indices[indices != 0] - 1]

self.update_var()
self._set_names("var", 1, names)

# Multi-dimensional annotations (.obsm and .varm)

Expand Down
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import pytest


Expand All @@ -19,3 +20,8 @@ def filepath_zarr(tmp_path):
@pytest.fixture
def filepath2_zarr(tmp_path):
return tmp_path / "testB.zarr"


@pytest.fixture(scope="module")
def rng():
return np.random.default_rng(42)
51 changes: 40 additions & 11 deletions tests/test_obs_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,23 @@
from mudata import MuData


@pytest.fixture()
def mdata():
@pytest.fixture(params=(0, 1))
def mdata(rng, request):
axis = request.param
attr = "obs" if axis == 0 else "var"
oattr = "var" if axis == 0 else "obs"

mod1 = AnnData(np.arange(0, 100, 0.1).reshape(-1, 10))
mod2 = AnnData(np.arange(101, 2101, 1).reshape(-1, 20))
mods = {"mod1": mod1, "mod2": mod2}
# Make var_names different in different modalities
for m in ["mod1", "mod2"]:
mods[m].var_names = [f"{m}_var{i}" for i in range(mods[m].n_vars)]
mdata = MuData(mods)
for modname, mod in mods.items():
setattr(
mod,
f"{attr}_names",
[f"{attr}_{i}" for i in rng.choice(mod.shape[axis], size=mod.shape[axis], replace=False)],
)
setattr(mod, f"{oattr}_names", [f"{modname}_{oattr}_{i}" for i in range(mod.shape[1 - axis])])
mdata = MuData(mods, axis=axis)
return mdata


Expand All @@ -27,24 +35,45 @@ def test_obs_global_columns(self, mdata, filepath_h5mu):
mod.obs["demo"] = m
mdata.obs["demo"] = "global"
mdata.update()
assert list(mdata.obs.columns.values) == [f"{m}:demo" for m in mdata.mod.keys()] + ["demo"]
if mdata.axis == 0:
assert list(mdata.obs.columns.values) == [f"{m}:demo" for m in mdata.mod.keys()] + ["demo"]
else:
assert list(mdata.obs.columns.values) == ["demo"]
mdata.write(filepath_h5mu)
mdata_ = mudata.read(filepath_h5mu)
assert list(mdata_.obs.columns.values) == [f"{m}:demo" for m in mdata_.mod.keys()] + ["demo"]
assert list(mdata_.obs.columns.values) == list(mdata.obs.columns.values)

def test_set_obs_names(self, mdata): # https://github.com/scverse/mudata/issues/112
names = {m: mod.obs_names for m, mod in mdata.mod.items()}
mdata.obs_names = mdata.obs_names
for m, mod in mdata.mod.items():
assert np.all(mod.obs_names == names[m])

def test_var_global_columns(self, mdata, filepath_h5mu):
for m, mod in mdata.mod.items():
mod.var["demo"] = m
mdata.update()
mdata.var["global"] = "global_var"
mdata.update()
assert list(mdata.var.columns.values) == ["demo", "global"]
if mdata.axis == 0:
assert list(mdata.var.columns.values) == ["demo", "global"]
else:
assert list(mdata.var.columns.values) == [f"{m}:demo" for m in mdata.mod.keys()] + ["global"]
del mdata.var["global"]
mdata.update()
assert list(mdata.var.columns.values) == ["demo"]
if mdata.axis == 0:
assert list(mdata.var.columns.values) == ["demo"]
else:
assert list(mdata.var.columns.values) == [f"{m}:demo" for m in mdata.mod.keys()]
mdata.write(filepath_h5mu)
mdata_ = mudata.read(filepath_h5mu)
assert list(mdata_.var.columns.values) == ["demo"]
assert list(mdata_.var.columns.values) == list(mdata.var.columns.values)

def test_set_var_names(self, mdata): # https://github.com/scverse/mudata/issues/112
names = {m: mod.var_names for m, mod in mdata.mod.items()}
mdata.var_names = mdata.var_names
for m, mod in mdata.mod.items():
assert np.all(mod.var_names == names[m])


if __name__ == "__main__":
Expand Down