From 5fcdf8b38200719818623065085b4ffaf5a9bb2b Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 23 Jan 2026 16:35:09 +0100 Subject: [PATCH] fix settings obs_names/var_names (closes #112) --- src/mudata/_core/mudata.py | 74 ++++++++++++++------------------------ tests/conftest.py | 6 ++++ tests/test_obs_var.py | 51 ++++++++++++++++++++------ 3 files changed, 72 insertions(+), 59 deletions(-) diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index f8c9557..b5a0e11 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -1312,18 +1312,11 @@ def obs_names_make_unique(self): obs_names = [obs for a in self.mod.values() for obs in a.obs_names.values] self._obs.index = obs_names - @property - def obs_names(self) -> pd.Index: - """Names of variables (alias for `.obs.index`).""" - return self.obs.index - - @obs_names.setter - def obs_names(self, names: Sequence[str]): - """Set the observation names for all the nested AnnData/MuData objects.""" + def _set_names(self, attr: str, axis: int, names: Sequence[str]): if isinstance(names, pd.Index): if not isinstance(names.name, str | type(None)): raise ValueError( - f"MuData expects .obs.index.name to be a string or None, " + f"MuData expects .{attr}.index.name to be a string or None, " f"but you passed a name of type {type(names.name).__name__!r}" ) else: @@ -1331,24 +1324,37 @@ def obs_names(self, names: Sequence[str]): if not isinstance(names.name, str | type(None)): names.name = None - mod_obs_sum = np.sum([a.n_obs for a in self.mod.values()]) - if mod_obs_sum != self.n_obs: - self.update_obs() + mod_shape_sum = np.sum([a.shape[axis] for a in self.mod.values()]) + if mod_shape_sum != self.shape[axis]: + self._update_attr(attr, axis=1 - axis) - if len(names) != self.n_obs: + if len(names) != self.shape[axis]: raise ValueError( - f"The length of provided observation names {len(names)} does not match the length {self.shape[0]} of MuData.obs." + f"The length of provided observation names {len(names)} does not match the length {self.shape[axis]} of MuData.{attr}." ) if self.is_view: self._init_as_actual(self.copy()) - self._obs.index = names - for mod in self.mod.keys(): - indices = self.obsmap[mod] - self.mod[mod].obs_names = names[indices[indices != 0] - 1] + getattr(self, attr).index = names + map = getattr(self, f"{attr}map") + for modname, mod in self.mod.items(): + newnames = np.empty(mod.shape[axis], dtype=object) + modmap = map[modname].ravel() + mask = modmap > 0 + newnames[modmap[mask] - 1] = names[mask] + setattr(mod, f"{attr}_names", newnames) - self.update_obs() + self._update_attr(attr, axis=1 - axis) + + @property + def obs_names(self) -> pd.Index: + """Names of variables (alias for `.obs.index`).""" + return self.obs.index + + @obs_names.setter + def obs_names(self, names: Sequence[str]): + self._set_names("obs", 0, names) @property def var(self) -> pd.DataFrame: @@ -1442,35 +1448,7 @@ def var_names(self) -> pd.Index: @var_names.setter def var_names(self, names: Sequence[str]): """Set the variable names for all the nested AnnData/MuData objects.""" - if isinstance(names, pd.Index): - if not isinstance(names.name, str | type(None)): - raise ValueError( - f"MuData expects .var.index.name to be a string or None, " - f"but you passed a name of type {type(names.name).__name__!r}" - ) - else: - names = pd.Index(names) - if not isinstance(names.name, str | type(None)): - names.name = None - - mod_var_sum = np.sum([a.n_vars for a in self.mod.values()]) - if mod_var_sum != self.n_vars: - self.update_var() - - if len(names) != self.n_vars: - raise ValueError( - f"The length of provided variable names {len(names)} does not match the length {self.shape[0]} of MuData.var." - ) - - if self.is_view: - self._init_as_actual(self.copy()) - - self._var.index = names - for mod in self.mod.keys(): - indices = self.varmap[mod] - self.mod[mod].var_names = names[indices[indices != 0] - 1] - - self.update_var() + self._set_names("var", 1, names) # Multi-dimensional annotations (.obsm and .varm) diff --git a/tests/conftest.py b/tests/conftest.py index 6faed70..3efce29 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import numpy as np import pytest @@ -19,3 +20,8 @@ def filepath_zarr(tmp_path): @pytest.fixture def filepath2_zarr(tmp_path): return tmp_path / "testB.zarr" + + +@pytest.fixture(scope="module") +def rng(): + return np.random.default_rng(42) diff --git a/tests/test_obs_var.py b/tests/test_obs_var.py index 167f711..6702621 100644 --- a/tests/test_obs_var.py +++ b/tests/test_obs_var.py @@ -8,15 +8,23 @@ from mudata import MuData -@pytest.fixture() -def mdata(): +@pytest.fixture(params=(0, 1)) +def mdata(rng, request): + axis = request.param + attr = "obs" if axis == 0 else "var" + oattr = "var" if axis == 0 else "obs" + mod1 = AnnData(np.arange(0, 100, 0.1).reshape(-1, 10)) mod2 = AnnData(np.arange(101, 2101, 1).reshape(-1, 20)) mods = {"mod1": mod1, "mod2": mod2} - # Make var_names different in different modalities - for m in ["mod1", "mod2"]: - mods[m].var_names = [f"{m}_var{i}" for i in range(mods[m].n_vars)] - mdata = MuData(mods) + for modname, mod in mods.items(): + setattr( + mod, + f"{attr}_names", + [f"{attr}_{i}" for i in rng.choice(mod.shape[axis], size=mod.shape[axis], replace=False)], + ) + setattr(mod, f"{oattr}_names", [f"{modname}_{oattr}_{i}" for i in range(mod.shape[1 - axis])]) + mdata = MuData(mods, axis=axis) return mdata @@ -27,10 +35,19 @@ def test_obs_global_columns(self, mdata, filepath_h5mu): mod.obs["demo"] = m mdata.obs["demo"] = "global" mdata.update() - assert list(mdata.obs.columns.values) == [f"{m}:demo" for m in mdata.mod.keys()] + ["demo"] + if mdata.axis == 0: + assert list(mdata.obs.columns.values) == [f"{m}:demo" for m in mdata.mod.keys()] + ["demo"] + else: + assert list(mdata.obs.columns.values) == ["demo"] mdata.write(filepath_h5mu) mdata_ = mudata.read(filepath_h5mu) - assert list(mdata_.obs.columns.values) == [f"{m}:demo" for m in mdata_.mod.keys()] + ["demo"] + assert list(mdata_.obs.columns.values) == list(mdata.obs.columns.values) + + def test_set_obs_names(self, mdata): # https://github.com/scverse/mudata/issues/112 + names = {m: mod.obs_names for m, mod in mdata.mod.items()} + mdata.obs_names = mdata.obs_names + for m, mod in mdata.mod.items(): + assert np.all(mod.obs_names == names[m]) def test_var_global_columns(self, mdata, filepath_h5mu): for m, mod in mdata.mod.items(): @@ -38,13 +55,25 @@ def test_var_global_columns(self, mdata, filepath_h5mu): mdata.update() mdata.var["global"] = "global_var" mdata.update() - assert list(mdata.var.columns.values) == ["demo", "global"] + if mdata.axis == 0: + assert list(mdata.var.columns.values) == ["demo", "global"] + else: + assert list(mdata.var.columns.values) == [f"{m}:demo" for m in mdata.mod.keys()] + ["global"] del mdata.var["global"] mdata.update() - assert list(mdata.var.columns.values) == ["demo"] + if mdata.axis == 0: + assert list(mdata.var.columns.values) == ["demo"] + else: + assert list(mdata.var.columns.values) == [f"{m}:demo" for m in mdata.mod.keys()] mdata.write(filepath_h5mu) mdata_ = mudata.read(filepath_h5mu) - assert list(mdata_.var.columns.values) == ["demo"] + assert list(mdata_.var.columns.values) == list(mdata.var.columns.values) + + def test_set_var_names(self, mdata): # https://github.com/scverse/mudata/issues/112 + names = {m: mod.var_names for m, mod in mdata.mod.items()} + mdata.var_names = mdata.var_names + for m, mod in mdata.mod.items(): + assert np.all(mod.var_names == names[m]) if __name__ == "__main__":