From 3a4b2e1dde995e6fb117ed9efe26352ba3ac6bc3 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 6 Oct 2025 10:51:26 +0200 Subject: [PATCH 1/8] pull_attr fixes - don't raise exception if mods is used together with common or prefixed there is nothing in the logic preventing it - don't raise if columns is used together with common or prefixed, warn instead - minor code cleanup --- src/mudata/_core/mudata.py | 34 +++++++++++++++++++--------------- src/mudata/_core/utils.py | 29 ++++++++--------------------- 2 files changed, 27 insertions(+), 36 deletions(-) diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index 76721dd..c1112b2 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -1965,21 +1965,20 @@ def _pull_attr( derived_name_count = Counter([col["derived_name"] for col in cols]) # - axis == self.axis - # e.g. combine var from multiple modalities (with unique vars) + # e.g. combine obs from multiple modalities (with shared obs) # - 1 - axis == self.axis - # . e.g. combine obs from multiple modalities (with shared obs) - axis = 0 if attr == "var" else 1 + # e.g. combine var from multiple modalities (with unique vars) + axis = 0 if attr == "obs" else 1 - if 1 - axis == self.axis or self.axis == -1: + if axis == self.axis or self.axis == -1: if join_common or join_nonunique: raise ValueError(f"Cannot join columns with the same name for shared {attr}_names.") if join_common is None: - join_common = False - if attr == "var": - join_common = self.axis == 0 - elif attr == "obs": + if attr == "obs": join_common = self.axis == 1 + else: + join_common = self.axis == 0 if join_nonunique is None: join_nonunique = False @@ -2242,8 +2241,6 @@ def _push_attr( raise ValueError("All mods should be present in mdata.mod") elif len(mods) == self.n_mod: mods = None - for k, v in {"common": common, "prefixed": prefixed}.items(): - assert v is None, f"Cannot use mods with {k}." if only_drop: drop = True @@ -2252,15 +2249,22 @@ def _push_attr( if columns is not None: for k, v in {"common": common, "prefixed": prefixed}.items(): - assert v is None, f"Cannot use columns with {k}." + if v: + warnings.warn( + f"Both columns and {k} given. Columns take precedence, {k} will be ignored", + RuntimeWarning, + stacklevel=2, + ) # - modname1:column -> [modname1:column] # - column -> [modname1:column, modname2:column, ...] - cols = [col for col in cols if col["name"] in columns or col["derived_name"] in columns] - # preemptively drop columns from other modalities - if mods is not None: - cols = [col for col in cols if col["prefix"] in mods or col["prefix"] == ""] + cols = [ + col + for col in cols + if (col["name"] in columns or col["derived_name"] in columns) + and (col["prefix"] == "" or mods is not None and col["prefix"] in mods) + ] else: if common is None: common = True diff --git a/src/mudata/_core/utils.py b/src/mudata/_core/utils.py index 712dc36..b6cbae4 100644 --- a/src/mudata/_core/utils.py +++ b/src/mudata/_core/utils.py @@ -122,30 +122,17 @@ def _classify_prefixed_columns( res: list[dict[str, str]] = [] for name in names: - name_common = { - "name": name, - "prefix": "", - "derived_name": name, - } - name_split = name.split(":", 1) - - if len(name_split) < 2: - res.append(name_common) + if len(name_split := name.split(":", 1)) < 2 or name_split[0] not in prefixes: + res.append({"name": name, "prefix": "", "derived_name": name, "class": "common"}) else: - maybe_modname, derived_name = name_split - - if maybe_modname in prefixes: - name_prefixed = { + res.append( + { "name": name, - "prefix": maybe_modname, - "derived_name": derived_name, + "prefix": name_split[0], + "derived_name": name_split[1], + "class": "prefixed", } - res.append(name_prefixed) - else: - res.append(name_common) - - for name_res in res: - name_res["class"] = "common" if name_res["prefix"] == "" else "prefixed" + ) return res From 8647ddd3aefe16068666f52b839d54628143374e Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 6 Oct 2025 15:50:47 +0200 Subject: [PATCH 2/8] push_attr fixes - don't raise exception if mods is used together with common, nonunique, or unique: there is nothing in the logic preventing it - don't raise if columns is used together with common, nonunique, or unique, warn instead - fix ordering of pushed column - minor code cleanup --- src/mudata/_core/mudata.py | 99 ++++++++++++++++++-------------------- src/mudata/_core/utils.py | 77 ++++++++++++----------------- tests/test_pull_push.py | 47 ++++++++++++++---- 3 files changed, 115 insertions(+), 108 deletions(-) diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index c1112b2..a30817d 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -1915,37 +1915,37 @@ def _pull_attr( if mods is not None: if isinstance(mods, str): mods = [mods] - mods = list(dict.fromkeys(mods)) if not all(m in self.mod for m in mods): raise ValueError("All mods should be present in mdata.mod") elif len(mods) == self.n_mod: mods = None - for k, v in {"common": common, "nonunique": nonunique, "unique": unique}.items(): - assert v is None, f"Cannot use mods with {k}." if only_drop: drop = True cols = _classify_attr_columns( - np.concatenate( - [ - [f"{m}:{val}" for val in getattr(mod, attr).columns.values] - for m, mod in self.mod.items() - ] - ), - self.mod.keys(), + {modname: getattr(mod, attr).columns for modname, mod in self.mod.items()} ) if columns is not None: for k, v in {"common": common, "nonunique": nonunique, "unique": unique}.items(): - assert v is None, f"Cannot use {k} with columns." + if v is not None: + warnings.warn( + f"Both columns and {k} given. Columns take precedence, {k} will be ignored", + RuntimeWarning, + stacklevel=2, + ) # - modname1:column -> [modname1:column] # - column -> [modname1:column, modname2:column, ...] - cols = [col for col in cols if col["name"] in columns or col["derived_name"] in columns] - - if mods is not None: - cols = [col for col in cols if col["prefix"] in mods] + cols = { + prefix: [ + col + for col in modcols + if col["name"] in columns or col["derived_name"] in columns + ] + for prefix, modcols in cols.items() + } # TODO: Counter for columns in order to track their usage # and error out if some columns were not used @@ -1959,10 +1959,17 @@ def _pull_attr( unique = True selector = {"common": common, "nonunique": nonunique, "unique": unique} + cols = { + prefix: [col for col in modcols if selector[col["class"]]] + for prefix, modcols in cols.items() + } - cols = [col for col in cols if selector[col["class"]]] + if mods is not None: + cols = {prefix: cols[prefix] for prefix in mods} - derived_name_count = Counter([col["derived_name"] for col in cols]) + derived_name_count = Counter( + [col["derived_name"] for modcols in cols.values() for col in modcols] + ) # - axis == self.axis # e.g. combine obs from multiple modalities (with shared obs) @@ -1994,44 +2001,36 @@ def _pull_attr( n_attr = self.n_vars if attr == "var" else self.n_obs dfs: list[pd.DataFrame] = [] - for m, mod in self.mod.items(): - if mods is not None and m not in mods: - continue + for m, modcols in cols.items(): + mod = self.mod[m] mod_map = attrmap[m].ravel() - mod_n_attr = mod.n_vars if attr == "var" else mod.n_obs - mask = mod_map != 0 - - mod_df = getattr(mod, attr) - mod_columns = [ - col["derived_name"] for col in cols if col["prefix"] == "" or col["prefix"] == m - ] - mod_df = mod_df[mod_df.columns.intersection(mod_columns)] + mask = mod_map > 0 + mod_df = getattr(mod, attr)[[col["derived_name"] for col in modcols]] if drop: getattr(mod, attr).drop(columns=mod_df.columns, inplace=True) - # Don't use modname: prefix if columns need to be joined - if join_common or join_nonunique or (not prefix_unique): - cols_special = [ - col["derived_name"] - for col in cols - if ( - (col["class"] == "common") & join_common - or (col["class"] == "nonunique") & join_nonunique - or (col["class"] == "unique") & (not prefix_unique) + mod_df.rename( + columns={ + col["derived_name"]: col["name"] + for col in modcols + if not ( + ( + join_common + and col["class"] == "common" + or join_nonunique + and col["class"] == "nonunique" + or not prefix_unique + and col["class"] == "unique" + ) + and derived_name_count[col["derived_name"]] == col["count"] ) - and col["prefix"] == m - and derived_name_count[col["derived_name"]] == col["count"] - ] - mod_df.columns = [ - col if col in cols_special else f"{m}:{col}" for col in mod_df.columns - ] - else: - mod_df.columns = [f"{m}:{col}" for col in mod_df.columns] + }, + inplace=True, + ) mod_df = ( _maybe_coerce_to_boolean(mod_df) - .set_index(np.arange(mod_n_attr)) .iloc[mod_map[mask] - 1] .set_index(np.arange(n_attr)[mask]) .reindex(np.arange(n_attr)) @@ -2296,7 +2295,7 @@ def _push_attr( if mods is not None and m not in mods: continue - mod_map = attrmap[m] + mod_map = attrmap[m].ravel() mask = mod_map != 0 mod_n_attr = mod.n_vars if attr == "var" else mod.n_obs @@ -2304,11 +2303,7 @@ def _push_attr( df = getattr(self, attr)[mask].loc[:, [col["name"] for col in mod_cols]] df.columns = [col["derived_name"] for col in mod_cols] - df = ( - df.set_index(np.arange(mod_n_attr)) - .iloc[mod_map[mask] - 1] - .set_index(np.arange(mod_n_attr)) - ) + df = df.iloc[np.argsort(mod_map[mask])].set_index(np.arange(mod_n_attr)) if not only_drop: # TODO: _maybe_coerce_to_bool diff --git a/src/mudata/_core/utils.py b/src/mudata/_core/utils.py index b6cbae4..f8ce782 100644 --- a/src/mudata/_core/utils.py +++ b/src/mudata/_core/utils.py @@ -1,5 +1,5 @@ from collections import Counter -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from typing import TypeVar import numpy as np @@ -38,9 +38,7 @@ def _maybe_coerce_to_boolean(df: T) -> T: return df -def _classify_attr_columns( - names: Sequence[str], prefixes: Sequence[str] -) -> Sequence[dict[str, str]]: +def _classify_attr_columns(names: Mapping[str, Sequence[str]]) -> dict[str, list[dict[str, str]]]: """ Classify names into common, non-unique, and unique w.r.t. to the list of prefixes. @@ -53,50 +51,35 @@ def _classify_attr_columns( and there is only one modality prefix for a column with a certain name. - E.g. ["global", "mod1:annotation", "mod2:annotation", "mod1:unique"] will be classified - into [ - {"name": "global", "prefix": "", "derived_name": "global", "count": 1, "class": "common"}, - {"name": "mod1:annotation", "prefix": "mod1", "derived_name": "annotation", "count": 2, "class": "nonunique"}, - {"name": "mod2:annotation", "prefix": "mod2", "derived_name": "annotation", "count": 2, "class": "nonunique"}, - {"name": "mod1:unique", "prefix": "mod1", "derived_name": "annotation", "count": 2, "class": "unique"}, - ] + E.g. {"mod1": ["annotation", "unique"], "mod2": ["annotation"]} will be classified + into {"mod1": [{"name": "mod1:annotation", "derived_name": "annotation", "count": 2, "class": "nonunique"}, + {"name": "mod1:unique", "derived_name": "unique", "count": 1, "class": "unique"}}], + "mod2": [{"name": "mod2:annotation", "derived_name": "annotation", "count": 2, "class": "nonunique"}], + } """ - n_mod = len(prefixes) - res: list[dict[str, str]] = [] - - for name in names: - name_common = { - "name": name, - "prefix": "", - "derived_name": name, - } - name_split = name.split(":", 1) - - if len(name_split) < 2: - res.append(name_common) - else: - maybe_modname, derived_name = name_split - - if maybe_modname in prefixes: - name_prefixed = { - "name": name, - "prefix": maybe_modname, - "derived_name": derived_name, + n_mod = len(names) + res: dict[str, list[dict[str, str]]] = {} + + derived_name_counts = Counter() + for prefix, names in names.items(): + cres = [] + for name in names: + cres.append( + { + "name": f"{prefix}:{name}", + "derived_name": name, } - res.append(name_prefixed) - else: - res.append(name_common) - - derived_name_counts = Counter(name_res["derived_name"] for name_res in res) - for name_res in res: - name_res["count"] = derived_name_counts[name_res["derived_name"]] - - for name_res in res: - name_res["class"] = ( - "common" - if name_res["count"] == n_mod - else "unique" if name_res["count"] == 1 else "nonunique" - ) + ) + derived_name_counts[name] += 1 + res[prefix] = cres + + for prefix, names in res.items(): + for name_res in names: + count = derived_name_counts[name_res["derived_name"]] + name_res["count"] = count + name_res["class"] = ( + "common" if count == n_mod else "unique" if count == 1 else "nonunique" + ) return res @@ -138,7 +121,7 @@ def _classify_prefixed_columns( def _update_and_concat(df1: pd.DataFrame, df2: pd.DataFrame) -> pd.DataFrame: - df = df1.copy() + df = df1.copy(deep=False) # This converts boolean to object dtype, unfortunately # df.update(df2) common_cols = df1.columns.intersection(df2.columns) diff --git a/tests/test_pull_push.py b/tests/test_pull_push.py index ddb0520..7251966 100644 --- a/tests/test_pull_push.py +++ b/tests/test_pull_push.py @@ -5,7 +5,7 @@ import pytest from anndata import AnnData -from mudata import MuData +from mudata import MuData, set_options @pytest.fixture() @@ -21,7 +21,8 @@ def modalities(request, obs_n, var_unique): mods[m].var["mod"] = m # common column - mods[m].var["highly_variable"] = np.tile([False, True], mods[m].n_vars // 2) + mods[m].var["highly_variable"] = np.random.choice([False, True], size=mods[m].n_vars) + mods[m].obs["common_obs_col"] = np.random.randint(0, int(1e6), size=mods[m].n_obs) if var_unique: mods[m].var_names = [f"mod{m}_var{j}" for j in range(mods[m].n_vars)] @@ -88,7 +89,6 @@ def test_pull_var(self, modalities): """ mdata = MuData(modalities) mdata.update() - mdata.pull_var() assert "mod" in mdata.var.columns @@ -165,6 +165,15 @@ def test_pull_obs_simple(self, modalities): for m in mdata.mod.keys(): assert f"{m}:mod" in mdata.obs.columns + assert f"{m}:common_obs_col" in mdata.obs.columns + + modmap = mdata.obsmap[m].ravel() + mask = modmap > 0 + assert ( + mdata.obs[f"{m}:common_obs_col"][mask].to_numpy() + == mdata.mod[m].obs["common_obs_col"].to_numpy()[modmap[mask] - 1] + ).all() + # join_common shouldn't work with pytest.raises(ValueError, match="shared obs_names"): mdata.pull_obs(join_common=True) @@ -182,14 +191,24 @@ def test_push_var_simple(self, modalities): mdata = MuData(modalities) mdata.update() - mdata.var["pushed"] = True - mdata.var["mod2:mod2_pushed"] = True + mdata.var["pushed"] = np.random.randint(0, int(1e6), size=mdata.n_var) + mdata.var["mod2:mod2_pushed"] = np.random.randint(0, int(1e6), size=mdata.n_var) mdata.push_var() # pushing should work - for mod in mdata.mod.values(): + for modname, mod in mdata.mod.items(): assert "pushed" in mod.var.columns + + map = mdata.varmap[modname].ravel() + mask = map > 0 + assert (mdata.var["pushed"][mask] == mod.var["pushed"][map[mask] - 1]).all() + assert "mod2_pushed" in mdata["mod2"].var.columns + map = mdata.varmap["mod2"].ravel() + mask = map > 0 + assert ( + mdata.var["mod2:mod2_pushed"][mask] == mdata["mod2"].var["mod2_pushed"][map[mask] - 1] + ).all() @pytest.mark.parametrize("var_unique", [True, False]) @pytest.mark.parametrize("obs_n", ["joint", "disjoint"]) @@ -200,14 +219,24 @@ def test_push_obs_simple(self, modalities): mdata = MuData(modalities) mdata.update() - mdata.obs["pushed"] = True - mdata.obs["mod2:mod2_pushed"] = True + mdata.obs["pushed"] = np.random.randint(0, int(1e6), size=mdata.n_obs) + mdata.obs["mod2:mod2_pushed"] = np.random.randint(0, int(1e6), size=mdata.n_obs) mdata.push_obs() # pushing should work - for mod in mdata.mod.values(): + for modname, mod in mdata.mod.items(): assert "pushed" in mod.obs.columns + + map = mdata.obsmap[modname].ravel() + mask = map > 0 + assert (mdata.obs["pushed"][mask] == mod.obs["pushed"][map[mask] - 1]).all() + assert "mod2_pushed" in mdata["mod2"].obs.columns + map = mdata.obsmap["mod2"].ravel() + mask = map > 0 + assert ( + mdata.obs["mod2:mod2_pushed"][mask] == mdata["mod2"].obs["mod2_pushed"][map[mask] - 1] + ).all() @pytest.mark.usefixtures("filepath_h5mu") From 2b26ebf6fc695d5e300bfb6d1615172ffa80a912 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 13 Oct 2025 10:57:23 +0200 Subject: [PATCH 3/8] push/pull: replace dict holding column information with custom class --- src/mudata/_core/mudata.py | 45 +++++++------- src/mudata/_core/utils.py | 116 ++++++++++++++++++------------------- 2 files changed, 80 insertions(+), 81 deletions(-) diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index a30817d..8caa898 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -27,8 +27,8 @@ from .file_backing import MuDataFileManager from .repr import MUDATA_CSS, block_matrix, details_block_table from .utils import ( + MetadataColumn, _classify_attr_columns, - _classify_prefixed_columns, _make_index_unique, _maybe_coerce_to_bool, _maybe_coerce_to_boolean, @@ -1940,9 +1940,7 @@ def _pull_attr( # - column -> [modname1:column, modname2:column, ...] cols = { prefix: [ - col - for col in modcols - if col["name"] in columns or col["derived_name"] in columns + col for col in modcols if col.name in columns or col.derived_name in columns ] for prefix, modcols in cols.items() } @@ -1960,7 +1958,7 @@ def _pull_attr( selector = {"common": common, "nonunique": nonunique, "unique": unique} cols = { - prefix: [col for col in modcols if selector[col["class"]]] + prefix: [col for col in modcols if selector[col.klass]] for prefix, modcols in cols.items() } @@ -1968,7 +1966,7 @@ def _pull_attr( cols = {prefix: cols[prefix] for prefix in mods} derived_name_count = Counter( - [col["derived_name"] for modcols in cols.values() for col in modcols] + [col.derived_name for modcols in cols.values() for col in modcols] ) # - axis == self.axis @@ -2006,24 +2004,24 @@ def _pull_attr( mod_map = attrmap[m].ravel() mask = mod_map > 0 - mod_df = getattr(mod, attr)[[col["derived_name"] for col in modcols]] + mod_df = getattr(mod, attr)[[col.derived_name for col in modcols]] if drop: getattr(mod, attr).drop(columns=mod_df.columns, inplace=True) mod_df.rename( columns={ - col["derived_name"]: col["name"] + col.derived_name: col.name for col in modcols if not ( ( join_common - and col["class"] == "common" + and col.klass == "common" or join_nonunique - and col["class"] == "nonunique" + and col.klass == "nonunique" or not prefix_unique - and col["class"] == "unique" + and col.klass == "unique" ) - and derived_name_count[col["derived_name"]] == col["count"] + and derived_name_count[col.derived_name] == col.count ) }, inplace=True, @@ -2244,7 +2242,10 @@ def _push_attr( if only_drop: drop = True - cols = _classify_prefixed_columns(getattr(self, attr).columns.values, self.mod.keys()) + cols = [ + MetadataColumn(allowed_prefixes=self.mod.keys(), name=name) + for name in getattr(self, attr).columns + ] if columns is not None: for k, v in {"common": common, "prefixed": prefixed}.items(): @@ -2261,8 +2262,8 @@ def _push_attr( cols = [ col for col in cols - if (col["name"] in columns or col["derived_name"] in columns) - and (col["prefix"] == "" or mods is not None and col["prefix"] in mods) + if (col.name in columns or col.derived_name in columns) + and (col.prefix is None or mods is not None and col.prefix in mods) ] else: if common is None: @@ -2270,14 +2271,14 @@ def _push_attr( if prefixed is None: prefixed = True - selector = {"common": common, "prefixed": prefixed} + selector = {"common": common, "unknown": prefixed} - cols = [col for col in cols if selector[col["class"]]] + cols = [col for col in cols if selector[col.klass]] if len(cols) == 0: return - derived_name_count = Counter([col["derived_name"] for col in cols]) + derived_name_count = Counter([col.derived_name for col in cols]) for c, count in derived_name_count.items(): # if count > 1, there are both colname and modname:colname present if count > 1 and c in getattr(self, attr).columns: @@ -2299,9 +2300,9 @@ def _push_attr( mask = mod_map != 0 mod_n_attr = mod.n_vars if attr == "var" else mod.n_obs - mod_cols = [col for col in cols if col["prefix"] == m or col["class"] == "common"] - df = getattr(self, attr)[mask].loc[:, [col["name"] for col in mod_cols]] - df.columns = [col["derived_name"] for col in mod_cols] + mod_cols = [col for col in cols if col.prefix == m or col.klass == "common"] + df = getattr(self, attr)[mask].loc[:, [col.name for col in mod_cols]] + df.columns = [col.derived_name for col in mod_cols] df = df.iloc[np.argsort(mod_map[mask])].set_index(np.arange(mod_n_attr)) @@ -2316,7 +2317,7 @@ def _push_attr( if drop: for col in cols: - getattr(self, attr).drop(col["name"], axis=1, inplace=True) + getattr(self, attr).drop(col.name, axis=1, inplace=True) def push_obs( self, diff --git a/src/mudata/_core/utils.py b/src/mudata/_core/utils.py index f8ce782..b03ff3f 100644 --- a/src/mudata/_core/utils.py +++ b/src/mudata/_core/utils.py @@ -1,6 +1,6 @@ from collections import Counter from collections.abc import Mapping, Sequence -from typing import TypeVar +from typing import Literal, TypeVar import numpy as np import pandas as pd @@ -38,7 +38,56 @@ def _maybe_coerce_to_boolean(df: T) -> T: return df -def _classify_attr_columns(names: Mapping[str, Sequence[str]]) -> dict[str, list[dict[str, str]]]: +class MetadataColumn: + __slots__ = ("prefix", "derived_name", "count", "_allowed_prefixes") + + def __init__( + self, + *, + allowed_prefixes: Sequence[str], + prefix: str | None = None, + name: str | None = None, + count: int = 0, + ): + self._allowed_prefixes = allowed_prefixes + if prefix is None: + self.name = name + else: + self.prefix = prefix + self.derived_name = name + self.count = count + + @property + def name(self) -> str: + if self.prefix is not None: + return f"{self.prefix}:{self.derived_name}" + else: + return self.derived_name + + @name.setter + def name(self, new_name): + if ( + len(name_split := new_name.split(":", 1)) < 2 + or name_split[0] not in self._allowed_prefixes + ): + self.prefix = None + self.derived_name = new_name + else: + self.prefix, self.derived_name = name_split + + @property + def klass(self) -> Literal["common", "unique", "nonunique", "unknown"]: + if self.prefix is None or self.count == len(self._allowed_prefixes): + return "common" + elif self.count == 1: + return "unique" + elif self.count > 0: + return "nonunique" + else: + return "unknown" + + +def _classify_attr_columns(names: Mapping[str, Sequence[str]]) -> dict[str, list[MetadataColumn]]: """ Classify names into common, non-unique, and unique w.r.t. to the list of prefixes. @@ -50,72 +99,21 @@ def _classify_attr_columns(names: Mapping[str, Sequence[str]]) -> dict[str, list - Unique columns are prefixed by modality names, and there is only one modality prefix for a column with a certain name. - - E.g. {"mod1": ["annotation", "unique"], "mod2": ["annotation"]} will be classified - into {"mod1": [{"name": "mod1:annotation", "derived_name": "annotation", "count": 2, "class": "nonunique"}, - {"name": "mod1:unique", "derived_name": "unique", "count": 1, "class": "unique"}}], - "mod2": [{"name": "mod2:annotation", "derived_name": "annotation", "count": 2, "class": "nonunique"}], - } """ - n_mod = len(names) - res: dict[str, list[dict[str, str]]] = {} + res: dict[str, list[MetadataColumn]] = {} derived_name_counts = Counter() - for prefix, names in names.items(): + for prefix, pnames in names.items(): cres = [] - for name in names: - cres.append( - { - "name": f"{prefix}:{name}", - "derived_name": name, - } - ) + for name in pnames: + cres.append(MetadataColumn(allowed_prefixes=names.keys(), prefix=prefix, name=name)) derived_name_counts[name] += 1 res[prefix] = cres for prefix, names in res.items(): for name_res in names: - count = derived_name_counts[name_res["derived_name"]] - name_res["count"] = count - name_res["class"] = ( - "common" if count == n_mod else "unique" if count == 1 else "nonunique" - ) - - return res - - -def _classify_prefixed_columns( - names: Sequence[str], prefixes: Sequence[str] -) -> Sequence[dict[str, str]]: - """ - Classify names into common and prefixed - w.r.t. to the list of prefixes. - - - Common columns do not have modality prefixes. - - Prefixed columns are prefixed by modality names. - - E.g. ["global", "mod1:annotation", "mod2:annotation", "mod1:unique"] will be classified - into [ - {"name": "global", "prefix": "", "derived_name": "global", "class": "common"}, - {"name": "mod1:annotation", "prefix": "mod1", "derived_name": "annotation", "class": "prefixed"}, - {"name": "mod2:annotation", "prefix": "mod2", "derived_name": "annotation", "class": "prefixed"}, - {"name": "mod1:unique", "prefix": "mod1", "derived_name": "annotation", "class": "prefixed"}, - ] - """ - res: list[dict[str, str]] = [] - - for name in names: - if len(name_split := name.split(":", 1)) < 2 or name_split[0] not in prefixes: - res.append({"name": name, "prefix": "", "derived_name": name, "class": "common"}) - else: - res.append( - { - "name": name, - "prefix": name_split[0], - "derived_name": name_split[1], - "class": "prefixed", - } - ) + count = derived_name_counts[name_res.derived_name] + name_res.count = count return res From 21b2e836aa98c732e6665b27634393ec1e497040 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 13 Oct 2025 14:21:31 +0200 Subject: [PATCH 4/8] get rid of _classify_attr_columns, sprinkle more comments throughout _pull_attr/_push_attr --- src/mudata/_core/mudata.py | 53 ++++++++++++++++++++++++++++---------- src/mudata/_core/utils.py | 43 ++++++------------------------- 2 files changed, 47 insertions(+), 49 deletions(-) diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index 8caa898..fd3c6ab 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -28,7 +28,6 @@ from .repr import MUDATA_CSS, block_matrix, details_block_table from .utils import ( MetadataColumn, - _classify_attr_columns, _make_index_unique, _maybe_coerce_to_bool, _maybe_coerce_to_boolean, @@ -1923,9 +1922,29 @@ def _pull_attr( if only_drop: drop = True - cols = _classify_attr_columns( - {modname: getattr(mod, attr).columns for modname, mod in self.mod.items()} - ) + cols: dict[str, list[MetadataColumn]] = {} + + # get all columns from all modalities and count how many times each column is present + derived_name_counts = Counter() + for prefix, mod in self.mod.items(): + modcols = getattr(mod, attr).columns + ccols = [] + for name in modcols: + ccols.append( + MetadataColumn( + allowed_prefixes=self.mod.keys(), + prefix=prefix, + name=name, + strip_prefix=False, + ) + ) + derived_name_counts[name] += 1 + cols[prefix] = ccols + + for prefix, modcols in cols.items(): + for col in modcols: + count = derived_name_counts[col.derived_name] + col.count = count # this is important to classify columns if columns is not None: for k, v in {"common": common, "nonunique": nonunique, "unique": unique}.items(): @@ -1936,8 +1955,7 @@ def _pull_attr( stacklevel=2, ) - # - modname1:column -> [modname1:column] - # - column -> [modname1:column, modname2:column, ...] + # keep only requested columns cols = { prefix: [ col for col in modcols if col.name in columns or col.derived_name in columns @@ -1956,15 +1974,18 @@ def _pull_attr( if unique is None: unique = True + # filter columns by class, keep only those that were requested selector = {"common": common, "nonunique": nonunique, "unique": unique} cols = { prefix: [col for col in modcols if selector[col.klass]] for prefix, modcols in cols.items() } + # filter columns, keep only requested modalities if mods is not None: cols = {prefix: cols[prefix] for prefix in mods} + # count final filtered column names, required later to decide whether to prefix a column with its source modality derived_name_count = Counter( [col.derived_name for modcols in cols.values() for col in modcols] ) @@ -2008,6 +2029,8 @@ def _pull_attr( if drop: getattr(mod, attr).drop(columns=mod_df.columns, inplace=True) + # prepend modality prefix to column names if requested via arguments and there are no skipped modalities with + # the same column name (prefixing those columns may cause problems with future pulls or pushes) mod_df.rename( columns={ col.derived_name: col.name @@ -2027,6 +2050,7 @@ def _pull_attr( inplace=True, ) + # reorder modality DF to conform to global order mod_df = ( _maybe_coerce_to_boolean(mod_df) .iloc[mod_map[mask] - 1] @@ -2242,6 +2266,7 @@ def _push_attr( if only_drop: drop = True + # get all global columns cols = [ MetadataColumn(allowed_prefixes=self.mod.keys(), name=name) for name in getattr(self, attr).columns @@ -2256,9 +2281,7 @@ def _push_attr( stacklevel=2, ) - # - modname1:column -> [modname1:column] - # - column -> [modname1:column, modname2:column, ...] - # preemptively drop columns from other modalities + # keep only requested columns cols = [ col for col in cols @@ -2271,8 +2294,8 @@ def _push_attr( if prefixed is None: prefixed = True + # filter columns by class, keep only those that were requested selector = {"common": common, "unknown": prefixed} - cols = [col for col in cols if selector[col.klass]] if len(cols) == 0: @@ -2290,20 +2313,22 @@ def _push_attr( ) attrmap = getattr(self, f"{attr}map") - _n_attr = self.n_vars if attr == "var" else self.n_obs - for m, mod in self.mod.items(): if mods is not None and m not in mods: continue mod_map = attrmap[m].ravel() - mask = mod_map != 0 - mod_n_attr = mod.n_vars if attr == "var" else mod.n_obs + mask = mod_map > 0 + mod_n_attr = mod.n_obs if attr == "obs" else mod.n_vars + # get all common and modality-specific columns for the current modality mod_cols = [col for col in cols if col.prefix == m or col.klass == "common"] df = getattr(self, attr)[mask].loc[:, [col.name for col in mod_cols]] + + # strip modality prefix where necessary df.columns = [col.derived_name for col in mod_cols] + # reorder global DF to conform to modality order df = df.iloc[np.argsort(mod_map[mask])].set_index(np.arange(mod_n_attr)) if not only_drop: diff --git a/src/mudata/_core/utils.py b/src/mudata/_core/utils.py index b03ff3f..520d182 100644 --- a/src/mudata/_core/utils.py +++ b/src/mudata/_core/utils.py @@ -1,5 +1,5 @@ from collections import Counter -from collections.abc import Mapping, Sequence +from collections.abc import Sequence from typing import Literal, TypeVar import numpy as np @@ -39,7 +39,7 @@ def _maybe_coerce_to_boolean(df: T) -> T: class MetadataColumn: - __slots__ = ("prefix", "derived_name", "count", "_allowed_prefixes") + __slots__ = ("prefix", "derived_name", "count", "_allowed_prefixes", "_strip_prefix") def __init__( self, @@ -48,9 +48,12 @@ def __init__( prefix: str | None = None, name: str | None = None, count: int = 0, + strip_prefix: bool = True, ): + self._strip_prefix = strip_prefix self._allowed_prefixes = allowed_prefixes - if prefix is None: + self.prefix = prefix + if prefix is None and strip_prefix: self.name = name else: self.prefix = prefix @@ -67,7 +70,8 @@ def name(self) -> str: @name.setter def name(self, new_name): if ( - len(name_split := new_name.split(":", 1)) < 2 + not self._strip_prefix + or len(name_split := new_name.split(":", 1)) < 2 or name_split[0] not in self._allowed_prefixes ): self.prefix = None @@ -87,37 +91,6 @@ def klass(self) -> Literal["common", "unique", "nonunique", "unknown"]: return "unknown" -def _classify_attr_columns(names: Mapping[str, Sequence[str]]) -> dict[str, list[MetadataColumn]]: - """ - Classify names into common, non-unique, and unique - w.r.t. to the list of prefixes. - - - Common columns do not have modality prefixes. - - Non-unqiue columns have a modality prefix, - and there are multiple columns that differ - only by their modality prefix. - - Unique columns are prefixed by modality names, - and there is only one modality prefix - for a column with a certain name. - """ - res: dict[str, list[MetadataColumn]] = {} - - derived_name_counts = Counter() - for prefix, pnames in names.items(): - cres = [] - for name in pnames: - cres.append(MetadataColumn(allowed_prefixes=names.keys(), prefix=prefix, name=name)) - derived_name_counts[name] += 1 - res[prefix] = cres - - for prefix, names in res.items(): - for name_res in names: - count = derived_name_counts[name_res.derived_name] - name_res.count = count - - return res - - def _update_and_concat(df1: pd.DataFrame, df2: pd.DataFrame) -> pd.DataFrame: df = df1.copy(deep=False) # This converts boolean to object dtype, unfortunately From 8b8d1c1cfeb62ceb0f4e65b89bcdeb8ba09b9bd5 Mon Sep 17 00:00:00 2001 From: ilia-kats Date: Wed, 22 Oct 2025 13:00:34 +0200 Subject: [PATCH 5/8] Apply suggestion from @ilan-gold Co-authored-by: Ilan Gold --- src/mudata/_core/mudata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index fd3c6ab..0d67958 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -2323,7 +2323,7 @@ def _push_attr( # get all common and modality-specific columns for the current modality mod_cols = [col for col in cols if col.prefix == m or col.klass == "common"] - df = getattr(self, attr)[mask].loc[:, [col.name for col in mod_cols]] + df = getattr(self, attr)[mask][[col.name for col in mod_cols]] # strip modality prefix where necessary df.columns = [col.derived_name for col in mod_cols] From f4d5eca1a5626bca39be6a63e08203db905738bd Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Wed, 22 Oct 2025 17:53:49 +0200 Subject: [PATCH 6/8] improve push performance scaling no need to argsort --- src/mudata/_core/mudata.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index 0d67958..8d5977d 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -2329,7 +2329,9 @@ def _push_attr( df.columns = [col.derived_name for col in mod_cols] # reorder global DF to conform to modality order - df = df.iloc[np.argsort(mod_map[mask])].set_index(np.arange(mod_n_attr)) + idx = np.empty(mod_n_attr, dtype=mod_map.dtype) + idx[mod_map[mask] - 1] = np.arange(mod_n_attr) + df = df.iloc[idx].set_index(np.arange(mod_n_attr)) if not only_drop: # TODO: _maybe_coerce_to_bool From 0d646ed0d9ce6a55c637beab9b83a915f174a97d Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Wed, 22 Oct 2025 18:16:45 +0200 Subject: [PATCH 7/8] fixup! improve push performance scaling --- src/mudata/_core/mudata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index 8d5977d..54d5e7d 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -2337,7 +2337,7 @@ def _push_attr( # TODO: _maybe_coerce_to_bool # TODO: _maybe_coerce_to_int # TODO: _prune_unused_categories - mod_df = getattr(mod, attr).set_index(np.arange(mod_n_attr)) + mod_df = getattr(mod, attr).set_index(np.arange(mod_n_attr, dtype=mod_map.dtype)) mod_df = _update_and_concat(mod_df, df) mod_df = mod_df.set_index(getattr(mod, f"{attr}_names")) setattr(mod, attr, mod_df) From 717c3624cf6190729c5c99f917e57e206aeb2964 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Wed, 22 Oct 2025 18:27:25 +0200 Subject: [PATCH 8/8] fixup! improve push performance scaling --- src/mudata/_core/mudata.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index 54d5e7d..c896592 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -2331,13 +2331,13 @@ def _push_attr( # reorder global DF to conform to modality order idx = np.empty(mod_n_attr, dtype=mod_map.dtype) idx[mod_map[mask] - 1] = np.arange(mod_n_attr) - df = df.iloc[idx].set_index(np.arange(mod_n_attr)) + df = df.iloc[idx].set_index(np.arange(mod_n_attr, dtype=mod_map.dtype)) if not only_drop: # TODO: _maybe_coerce_to_bool # TODO: _maybe_coerce_to_int # TODO: _prune_unused_categories - mod_df = getattr(mod, attr).set_index(np.arange(mod_n_attr, dtype=mod_map.dtype)) + mod_df = getattr(mod, attr).set_index(np.arange(mod_n_attr)) mod_df = _update_and_concat(mod_df, df) mod_df = mod_df.set_index(getattr(mod, f"{attr}_names")) setattr(mod, attr, mod_df)