From fe67863629cbde503051be4a6ac63258eecbaefb Mon Sep 17 00:00:00 2001 From: Sravya Tirukkovalur Date: Fri, 6 Jun 2025 09:40:14 -0700 Subject: [PATCH 1/7] Moved the dataset --- .../dataset/h5ad_sentence_dataset.py | 115 ++++++++++++++++++ tests/test_h5ad_sentence_dataset.py | 36 ++++++ 2 files changed, 151 insertions(+) create mode 100644 src/cell_load/dataset/h5ad_sentence_dataset.py create mode 100644 tests/test_h5ad_sentence_dataset.py diff --git a/src/cell_load/dataset/h5ad_sentence_dataset.py b/src/cell_load/dataset/h5ad_sentence_dataset.py new file mode 100644 index 0000000..2457214 --- /dev/null +++ b/src/cell_load/dataset/h5ad_sentence_dataset.py @@ -0,0 +1,115 @@ +import h5py +import logging +import torch +import torch.utils.data as data +import functools +import numpy as np +from typing import Dict +from .. import utils + +log = logging.getLogger(__file__) + +EXPONENTIATED_UMIS_LIMIT = 5_000_000 +RAW_COUNT_HEURISTIC_THRESHOLD = 35 + +class H5adSentenceDataset(data.Dataset): + def __init__(self, cfg, test=False, datasets=None, shape_dict=None, adata=None, adata_name=None) -> None: + super(H5adSentenceDataset, self).__init__() + + self.adata = None + self.adata_name = adata_name + self.test = test + if adata is not None: + self.adata = adata + self.datasets = [adata_name] + self.shapes_dict = {self.datasets[0]: adata.shape} + elif datasets is None: + ds_path = utils.get_dataset_cfg(cfg).train + if test: + ds_path = utils.get_dataset_cfg(cfg).val + _, self.datasets, self.shapes_dict, self.dataset_path_map, self.dataset_group_map = utils.get_shapes_dict( + ds_path, utils.get_dataset_cfg(cfg).get("filter_by_species") + ) + else: + assert shape_dict is not None + assert len(datasets) == len(shape_dict) + self.datasets = datasets + self.shapes_dict = shape_dict + self.dataset_path_map = {dataset: dataset for dataset in datasets} + + self.datasets = sorted(self.datasets) + self.cfg = cfg + + self.num_cells = {} + self.num_genes = {} + + self.total_num_cells = 0 + for name in self.datasets: + num_cells, num_genes = self.shapes_dict[name] + self.num_cells[name] = num_cells + self.num_genes[name] = num_genes + self.total_num_cells += num_cells + + self.datasets_to_num = {k: v for k, v in zip(self.datasets, range(len(self.datasets)))} + + @functools.lru_cache + def dataset_file(self, dataset): + datafile = self.dataset_path_map[dataset] + return h5py.File(datafile, "r") + + def _compute_index(self, idx): + for dataset in self.datasets: + if idx < self.num_cells[dataset]: + return dataset, idx + else: + idx -= self.num_cells[dataset] + raise IndexError + + def __getitem__(self, idx): + if self.adata is not None: + # block is only used during validation + # if .X is a numpy.ndarray + if isinstance(self.adata.X, np.ndarray): + counts = torch.tensor(self.adata.X[idx]).reshape(1, -1) + else: + counts = torch.tensor(self.adata.X[idx].todense()) + + dataset = self.adata_name + dataset_num = 0 + return counts, idx, dataset, dataset_num + + dataset, ds_idx = self._compute_index(idx) + h5f = self.dataset_file(dataset) + attrs = dict(h5f["X"].attrs) + try: + if attrs["encoding-type"] == "csr_matrix": + indptr = h5f["X"].indptr + indices = h5f["X"].indices + data_ = h5f["X"].data + start = indptr[ds_idx] + end = indptr[ds_idx + 1] + sub_indices = torch.tensor(indices[start:end], dtype=torch.int64) + sub_data = torch.tensor(data_[start:end], dtype=torch.float32) + counts = torch.sparse_csr_tensor( + [0, sub_indices.shape[0]], + sub_indices, + sub_data, + (1, self.num_genes[dataset]), + ) + counts = counts.to_dense() + else: + log.info(ds_idx) + counts = torch.tensor(h5f["X"][ds_idx]).unsqueeze(0) + + except Exception as iex: + log.exception(f"Error in dataset {dataset} at index {ds_idx}") + raise iex + + dataset_num = self.datasets_to_num[dataset] + return counts, idx, dataset, dataset_num + + def __len__(self) -> int: + return self.total_num_cells + + def get_dim(self) -> Dict[str, int]: + return self.num_genes diff --git a/tests/test_h5ad_sentence_dataset.py b/tests/test_h5ad_sentence_dataset.py new file mode 100644 index 0000000..47c2c31 --- /dev/null +++ b/tests/test_h5ad_sentence_dataset.py @@ -0,0 +1,36 @@ +import numpy as np +import torch +import pytest +from types import SimpleNamespace +from cell_load.dataset.h5ad_sentence_dataset import H5adSentenceDataset + +class DummyAdata: + def __init__(self, X, var_names=None): + self.X = X + self.shape = X.shape + self.var = {'gene_name': np.array(var_names) if var_names is not None else np.array(['a', 'b', 'c'])} + self.var_names = np.array(var_names) if var_names is not None else np.array(['a', 'b', 'c']) + +@pytest.fixture +def dummy_cfg(): + # Minimal config with required attributes + cfg = SimpleNamespace() + cfg.model = SimpleNamespace() + cfg.model.batch_size = 2 + cfg.dataset = SimpleNamespace() + cfg.dataset.pad_length = 3 + cfg.dataset.P = 1 + cfg.dataset.N = 1 + cfg.dataset.S = 1 + return cfg + +def test_h5ad_sentence_dataset_with_adata(dummy_cfg): + X = np.array([[1, 2, 3], [4, 5, 6]]) + adata = DummyAdata(X, var_names=['a', 'b', 'c']) + ds = H5adSentenceDataset(cfg=dummy_cfg, adata=adata, adata_name='dummy') + assert len(ds) == 2 + counts, idx, dataset, dataset_num = ds[0] + assert isinstance(counts, torch.Tensor) + assert counts.shape[1] == 3 + assert dataset == 'dummy' + assert dataset_num == 0 From 2c45cb12cebeb3ab27cf6271fd78439287ab777a Mon Sep 17 00:00:00 2001 From: Sravya Tirukkovalur Date: Fri, 6 Jun 2025 10:29:06 -0700 Subject: [PATCH 2/7] Moved the cell sentence dataloader and the collator --- .../data_modules/cell_sentence_dataloader.py | 41 ++ .../dataset/cell_sentence_dataset.py | 407 ++++++++++++++++++ .../dataset/h5ad_sentence_dataset.py | 115 ----- tests/test_cell_sentence_dataset.py | 32 ++ tests/test_h5ad_sentence_dataset.py | 36 -- 5 files changed, 480 insertions(+), 151 deletions(-) create mode 100644 src/cell_load/data_modules/cell_sentence_dataloader.py create mode 100644 src/cell_load/dataset/cell_sentence_dataset.py delete mode 100644 src/cell_load/dataset/h5ad_sentence_dataset.py create mode 100644 tests/test_cell_sentence_dataset.py delete mode 100644 tests/test_h5ad_sentence_dataset.py diff --git a/src/cell_load/data_modules/cell_sentence_dataloader.py b/src/cell_load/data_modules/cell_sentence_dataloader.py new file mode 100644 index 0000000..d0707b4 --- /dev/null +++ b/src/cell_load/data_modules/cell_sentence_dataloader.py @@ -0,0 +1,41 @@ +import torch +from torch.utils.data import DataLoader +from cell_load.dataset.cell_sentence_dataset import FilteredGenesCounts +from cell_load.dataset.cell_sentence_dataset import CellSentenceCollator + +def create_cell_sentence_dataloader( + cfg, + workers=1, + data_dir=None, + datasets=None, + shape_dict=None, + adata=None, + adata_name=None, + shuffle=False, + sentence_collator=None, +): + """ + Expected to be used for inference + Either datasets and shape_dict or adata and adata_name should be provided + """ + if datasets is None and adata is None: + raise ValueError("Either datasets and shape_dict or adata and adata_name should be provided") + if adata is not None: + shuffle = False + if data_dir: + cfg.model.data_dir = data_dir + dataset = FilteredGenesCounts(cfg, datasets=datasets, shape_dict=shape_dict, adata=adata, adata_name=adata_name) + if sentence_collator is None: + sentence_collator = CellSentenceCollator( + cfg, valid_gene_mask=dataset.valid_gene_index, ds_emb_mapping_inference=dataset.ds_emb_map, is_train=False + ) + sentence_collator.training = False + dataloader = DataLoader( + dataset, + batch_size=cfg.model.batch_size, + shuffle=shuffle, + collate_fn=sentence_collator, + num_workers=workers, + persistent_workers=True, + ) + return dataloader diff --git a/src/cell_load/dataset/cell_sentence_dataset.py b/src/cell_load/dataset/cell_sentence_dataset.py new file mode 100644 index 0000000..ff20507 --- /dev/null +++ b/src/cell_load/dataset/cell_sentence_dataset.py @@ -0,0 +1,407 @@ +import h5py +import logging +import torch +import torch.utils.data as data +import functools +import numpy as np +from typing import Dict +from .. import utils + +log = logging.getLogger(__file__) + +class CellSentenceDataset(data.Dataset): + def __init__(self, cfg, test=False, datasets=None, shape_dict=None, adata=None, adata_name=None) -> None: + super().__init__() + self.adata = None + self.adata_name = adata_name + self.test = test + if adata is not None: + self.adata = adata + self.datasets = [adata_name] + self.shapes_dict = {self.datasets[0]: adata.shape} + elif datasets is None: + ds_path = utils.get_dataset_cfg(cfg).train + if test: + ds_path = utils.get_dataset_cfg(cfg).val + _, self.datasets, self.shapes_dict, self.dataset_path_map, self.dataset_group_map = utils.get_shapes_dict( + ds_path, utils.get_dataset_cfg(cfg).get("filter_by_species") + ) + else: + assert shape_dict is not None + assert len(datasets) == len(shape_dict) + self.datasets = datasets + self.shapes_dict = shape_dict + self.dataset_path_map = {dataset: dataset for dataset in datasets} + self.datasets = sorted(self.datasets) + self.cfg = cfg + self.num_cells = {} + self.num_genes = {} + self.total_num_cells = 0 + for name in self.datasets: + num_cells, num_genes = self.shapes_dict[name] + self.num_cells[name] = num_cells + self.num_genes[name] = num_genes + self.total_num_cells += num_cells + self.datasets_to_num = {k: v for k, v in zip(self.datasets, range(len(self.datasets)))} + @functools.lru_cache + def dataset_file(self, dataset): + datafile = self.dataset_path_map[dataset] + return h5py.File(datafile, "r") + def _compute_index(self, idx): + for dataset in self.datasets: + if idx < self.num_cells[dataset]: + return dataset, idx + else: + idx -= self.num_cells[dataset] + raise IndexError + def __getitem__(self, idx): + if self.adata is not None: + if isinstance(self.adata.X, np.ndarray): + counts = torch.tensor(self.adata.X[idx]).reshape(1, -1) + else: + counts = torch.tensor(self.adata.X[idx].todense()) + dataset = self.adata_name + dataset_num = 0 + return counts, idx, dataset, dataset_num + dataset, ds_idx = self._compute_index(idx) + h5f = self.dataset_file(dataset) + attrs = dict(h5f["X"].attrs) + try: + if attrs["encoding-type"] == "csr_matrix": + indptrs = h5f["/X/indptr"] + start_ptr = indptrs[ds_idx] + end_ptr = indptrs[ds_idx + 1] + sub_data = torch.tensor(h5f["/X/data"][start_ptr:end_ptr], dtype=torch.float) + sub_indices = torch.tensor(h5f["/X/indices"][start_ptr:end_ptr], dtype=torch.int32) + counts = torch.sparse_csr_tensor( + [0], + sub_indices, + sub_data, + (1, self.num_genes[dataset]), + ) + counts = counts.to_dense() + else: + log.info(ds_idx) + counts = torch.tensor(h5f["X"][ds_idx]).unsqueeze(0) + except Exception as iex: + log.exception(f"Error in dataset {dataset} at index {ds_idx}") + raise iex + dataset_num = self.datasets_to_num[dataset] + return counts, idx, dataset, dataset_num + def __len__(self) -> int: + return self.total_num_cells + def get_dim(self) -> Dict[str, int]: + return self.num_genes + +class FilteredGenesCounts(CellSentenceDataset): + def __init__(self, cfg, test=False, datasets=None, shape_dict=None, adata=None, adata_name=None) -> None: + super().__init__(cfg, test, datasets, shape_dict, adata, adata_name) + self.valid_gene_index = {} + _, self.datasets, self.shapes_dict, self.dataset_path_map, self.dataset_group_map = utils.get_shapes_dict( + "/home/aadduri/state/h5ad_all.csv" + ) + emb_cfg = utils.get_embedding_cfg(self.cfg) + try: + self.ds_emb_map = torch.load(emb_cfg.ds_emb_mapping, weights_only=False) + except (FileNotFoundError, IOError): + self.ds_emb_map = {} + if adata_name is not None: + self.datasets.append(adata_name) + self.shapes_dict[adata_name] = adata.shape + esm_data = torch.load(emb_cfg.all_embeddings, weights_only=False) + valid_genes_list = list(esm_data.keys()) + global_pos = {g: i for i, g in enumerate(valid_genes_list)} + gene_names = np.array(adata.var_names) + new_mapping = np.array([global_pos.get(g, -1) for g in gene_names]) + if (new_mapping == -1).all(): + gene_names = adata.var["gene_name"].values + new_mapping = np.array([global_pos.get(g, -1) for g in gene_names]) + self.ds_emb_map[adata_name] = new_mapping + if utils.get_embedding_cfg(self.cfg).ds_emb_mapping is not None: + esm_data = torch.load(utils.get_embedding_cfg(self.cfg)["all_embeddings"], weights_only=False) + valid_genes_list = list(esm_data.keys()) + for name in self.datasets: + if not utils.is_valid_uuid(name): + if adata is None: + a = self.dataset_file(name) + try: + gene_names = np.array( + [g.decode("utf-8") for g in a["/var/gene_name"][:]] + ) + except: + gene_categories = a["/var/gene_name/categories"][:] + gene_codes = np.array(a["/var/gene_name/codes"][:]) + gene_names = np.array([g.decode("utf-8") for g in gene_categories[gene_codes]]) + valid_mask = np.isin(gene_names, valid_genes_list) + self.valid_gene_index[name] = valid_mask + else: + gene_names = np.array(adata.var_names) + valid_mask = np.isin(gene_names, valid_genes_list) + if not valid_mask.any(): + gene_names = adata.var["gene_name"].values + valid_mask = np.isin(gene_names, valid_genes_list) + self.valid_gene_index[name] = valid_mask + +class CellSentenceCollator(object): + def __init__(self, cfg, valid_gene_mask=None, ds_emb_mapping_inference=None, is_train=True): + self.pad_length = cfg.dataset.pad_length + self.P = cfg.dataset.P + self.N = cfg.dataset.N + self.S = cfg.dataset.S + self.cfg = cfg + self.training = is_train + self.use_dataset_info = getattr(cfg.model, "dataset_correction", False) + self.batch_tabular_loss = getattr(cfg.model, "batch_tabular_loss", False) + if valid_gene_mask is not None: + self.valid_gene_mask = valid_gene_mask + self.dataset_to_protein_embeddings = ds_emb_mapping_inference + else: + gene_mask_file = utils.get_embedding_cfg(self.cfg).valid_genes_masks + if gene_mask_file is not None: + self.valid_gene_mask = torch.load(gene_mask_file, weights_only=False) + else: + self.valid_gene_mask = None + self.dataset_to_protein_embeddings = torch.load( + utils.get_embedding_cfg(self.cfg).ds_emb_mapping.format(utils.get_embedding_cfg(self.cfg).size), + weights_only=False, + ) + self.global_size = utils.get_embedding_cfg(self.cfg).num + self.global_to_local = {} + for dataset_name, ds_emb_idxs in self.dataset_to_protein_embeddings.items(): + ds_emb_idxs = torch.tensor(ds_emb_idxs, dtype=torch.long) + reverse_mapping = torch.full((self.global_size,), -1, dtype=torch.int64) + local_indices = torch.arange(ds_emb_idxs.size(0), dtype=torch.int64) + mask = (ds_emb_idxs >= 0) & (ds_emb_idxs < self.global_size) + reverse_mapping[ds_emb_idxs[mask]] = local_indices[mask] + self.global_to_local[dataset_name] = reverse_mapping + print(len(self.global_to_local)) + def __call__(self, batch): + num_aug = getattr(self.cfg.model, "num_downsample", 1) + if num_aug > 1 and self.training: + batch = [item for item in batch for _ in range(num_aug)] + batch_size = len(batch) + batch_sentences = torch.zeros((batch_size, self.pad_length), dtype=torch.int32) + batch_sentences_counts = torch.zeros((batch_size, self.pad_length)) + masks = torch.zeros((batch_size, self.pad_length), dtype=torch.bool) + idxs = torch.zeros(batch_size, dtype=torch.int32) + if self.cfg.loss.name == "tabular": + Xs = torch.zeros((batch_size, self.pad_length, self.P)) + Ys = torch.zeros((batch_size, self.pad_length, self.N)) + batch_weights = torch.ones((batch_size, self.pad_length)) + else: + Xs = Ys = batch_weights = None + dataset_nums = torch.zeros(batch_size, dtype=torch.int32) + total_counts_all = torch.zeros(batch_size) + for i, (counts, idx, dataset, dataset_num) in enumerate(batch): + batch_sentences[i, :counts.shape[1]] = counts.squeeze() + idxs[i] = idx + dataset_nums[i] = dataset_num + return ( + batch_sentences, + Xs, + Ys, + idxs, + batch_weights, + masks, + total_counts_all if getattr(self.cfg.model, "rda", False) else None, + batch_sentences_counts if getattr(self.cfg.model, "counts", False) else None, + dataset_nums if self.use_dataset_info else None, + ) + + def __init__(self, cfg, test=False, datasets=None, shape_dict=None, adata=None, adata_name=None) -> None: + super().__init__() + self.adata = None + self.adata_name = adata_name + self.test = test + if adata is not None: + self.adata = adata + self.datasets = [adata_name] + self.shapes_dict = {self.datasets[0]: adata.shape} + elif datasets is None: + ds_path = utils.get_dataset_cfg(cfg).train + if test: + ds_path = utils.get_dataset_cfg(cfg).val + _, self.datasets, self.shapes_dict, self.dataset_path_map, self.dataset_group_map = utils.get_shapes_dict( + ds_path, utils.get_dataset_cfg(cfg).get("filter_by_species") + ) + else: + assert shape_dict is not None + assert len(datasets) == len(shape_dict) + self.datasets = datasets + self.shapes_dict = shape_dict + self.dataset_path_map = {dataset: dataset for dataset in datasets} + self.datasets = sorted(self.datasets) + self.cfg = cfg + self.num_cells = {} + self.num_genes = {} + self.total_num_cells = 0 + for name in self.datasets: + num_cells, num_genes = self.shapes_dict[name] + self.num_cells[name] = num_cells + self.num_genes[name] = num_genes + self.total_num_cells += num_cells + self.datasets_to_num = {k: v for k, v in zip(self.datasets, range(len(self.datasets)))} + @functools.lru_cache + def dataset_file(self, dataset): + datafile = self.dataset_path_map[dataset] + return h5py.File(datafile, "r") + def _compute_index(self, idx): + for dataset in self.datasets: + if idx < self.num_cells[dataset]: + return dataset, idx + else: + idx -= self.num_cells[dataset] + raise IndexError + def __getitem__(self, idx): + if self.adata is not None: + if isinstance(self.adata.X, np.ndarray): + counts = torch.tensor(self.adata.X[idx]).reshape(1, -1) + else: + counts = torch.tensor(self.adata.X[idx].todense()) + dataset = self.adata_name + dataset_num = 0 + return counts, idx, dataset, dataset_num + dataset, ds_idx = self._compute_index(idx) + h5f = self.dataset_file(dataset) + attrs = dict(h5f["X"].attrs) + try: + if attrs["encoding-type"] == "csr_matrix": + indptrs = h5f["/X/indptr"] + start_ptr = indptrs[ds_idx] + end_ptr = indptrs[ds_idx + 1] + sub_data = torch.tensor(h5f["/X/data"][start_ptr:end_ptr], dtype=torch.float) + sub_indices = torch.tensor(h5f["/X/indices"][start_ptr:end_ptr], dtype=torch.int32) + counts = torch.sparse_csr_tensor( + [0], + sub_indices, + sub_data, + (1, self.num_genes[dataset]), + ) + counts = counts.to_dense() + else: + log.info(ds_idx) + counts = torch.tensor(h5f["X"][ds_idx]).unsqueeze(0) + except Exception as iex: + log.exception(f"Error in dataset {dataset} at index {ds_idx}") + raise iex + dataset_num = self.datasets_to_num[dataset] + return counts, idx, dataset, dataset_num + def __len__(self) -> int: + return self.total_num_cells + def get_dim(self) -> Dict[str, int]: + return self.num_genes + +class CellSentenceCollator(object): + def __init__(self, cfg, valid_gene_mask=None, ds_emb_mapping_inference=None, is_train=True): + self.pad_length = cfg.dataset.pad_length + self.P = cfg.dataset.P + self.N = cfg.dataset.N + self.S = cfg.dataset.S + self.cfg = cfg + self.training = is_train + self.use_dataset_info = getattr(cfg.model, "dataset_correction", False) + self.batch_tabular_loss = getattr(cfg.model, "batch_tabular_loss", False) + if valid_gene_mask is not None: + self.valid_gene_mask = valid_gene_mask + self.dataset_to_protein_embeddings = ds_emb_mapping_inference + else: + gene_mask_file = utils.get_embedding_cfg(self.cfg).valid_genes_masks + if gene_mask_file is not None: + self.valid_gene_mask = torch.load(gene_mask_file, weights_only=False) + else: + self.valid_gene_mask = None + self.dataset_to_protein_embeddings = torch.load( + utils.get_embedding_cfg(self.cfg).ds_emb_mapping.format(utils.get_embedding_cfg(self.cfg).size), + weights_only=False, + ) + self.global_size = utils.get_embedding_cfg(self.cfg).num + self.global_to_local = {} + for dataset_name, ds_emb_idxs in self.dataset_to_protein_embeddings.items(): + ds_emb_idxs = torch.tensor(ds_emb_idxs, dtype=torch.long) + reverse_mapping = torch.full((self.global_size,), -1, dtype=torch.int64) + local_indices = torch.arange(ds_emb_idxs.size(0), dtype=torch.int64) + mask = (ds_emb_idxs >= 0) & (ds_emb_idxs < self.global_size) + reverse_mapping[ds_emb_idxs[mask]] = local_indices[mask] + self.global_to_local[dataset_name] = reverse_mapping + print(len(self.global_to_local)) + def __call__(self, batch): + num_aug = getattr(self.cfg.model, "num_downsample", 1) + if num_aug > 1 and self.training: + batch = [item for item in batch for _ in range(num_aug)] + batch_size = len(batch) + batch_sentences = torch.zeros((batch_size, self.pad_length), dtype=torch.int32) + batch_sentences_counts = torch.zeros((batch_size, self.pad_length)) + masks = torch.zeros((batch_size, self.pad_length), dtype=torch.bool) + idxs = torch.zeros(batch_size, dtype=torch.int32) + if self.cfg.loss.name == "tabular": + Xs = torch.zeros((batch_size, self.pad_length, self.P)) + Ys = torch.zeros((batch_size, self.pad_length, self.N)) + batch_weights = torch.ones((batch_size, self.pad_length)) + else: + Xs = Ys = batch_weights = None + dataset_nums = torch.zeros(batch_size, dtype=torch.int32) + total_counts_all = torch.zeros(batch_size) + for i, (counts, idx, dataset, dataset_num) in enumerate(batch): + batch_sentences[i, :counts.shape[1]] = counts.squeeze() + idxs[i] = idx + dataset_nums[i] = dataset_num + return ( + batch_sentences, + Xs, + Ys, + idxs, + batch_weights, + masks, + total_counts_all if getattr(self.cfg.model, "rda", False) else None, + batch_sentences_counts if getattr(self.cfg.model, "counts", False) else None, + dataset_nums if self.use_dataset_info else None, + ) + +class FilteredGenesCounts(CellSentenceDataset): + def __init__(self, cfg, test=False, datasets=None, shape_dict=None, adata=None, adata_name=None) -> None: + super().__init__(cfg, test, datasets, shape_dict, adata, adata_name) + self.valid_gene_index = {} + _, self.datasets, self.shapes_dict, self.dataset_path_map, self.dataset_group_map = utils.get_shapes_dict( + "/home/aadduri/state/h5ad_all.csv" + ) + emb_cfg = utils.get_embedding_cfg(self.cfg) + try: + self.ds_emb_map = torch.load(emb_cfg.ds_emb_mapping, weights_only=False) + except (FileNotFoundError, IOError): + self.ds_emb_map = {} + if adata_name is not None: + self.datasets.append(adata_name) + self.shapes_dict[adata_name] = adata.shape + esm_data = torch.load(emb_cfg.all_embeddings, weights_only=False) + valid_genes_list = list(esm_data.keys()) + global_pos = {g: i for i, g in enumerate(valid_genes_list)} + gene_names = np.array(adata.var_names) + new_mapping = np.array([global_pos.get(g, -1) for g in gene_names]) + if (new_mapping == -1).all(): + gene_names = adata.var["gene_name"].values + new_mapping = np.array([global_pos.get(g, -1) for g in gene_names]) + self.ds_emb_map[adata_name] = new_mapping + if utils.get_embedding_cfg(self.cfg).ds_emb_mapping is not None: + esm_data = torch.load(utils.get_embedding_cfg(self.cfg)["all_embeddings"], weights_only=False) + valid_genes_list = list(esm_data.keys()) + for name in self.datasets: + if not utils.is_valid_uuid(name): + if adata is None: + a = self.dataset_file(name) + try: + gene_names = np.array( + [g.decode("utf-8") for g in a["/var/gene_name"][:]] + ) + except: + gene_categories = a["/var/gene_name/categories"][:] + gene_codes = np.array(a["/var/gene_name/codes"][:]) + gene_names = np.array([g.decode("utf-8") for g in gene_categories[gene_codes]]) + valid_mask = np.isin(gene_names, valid_genes_list) + self.valid_gene_index[name] = valid_mask + else: + gene_names = np.array(adata.var_names) + valid_mask = np.isin(gene_names, valid_genes_list) + if not valid_mask.any(): + gene_names = adata.var["gene_name"].values + valid_mask = np.isin(gene_names, valid_genes_list) + self.valid_gene_index[name] = valid_mask diff --git a/src/cell_load/dataset/h5ad_sentence_dataset.py b/src/cell_load/dataset/h5ad_sentence_dataset.py deleted file mode 100644 index 2457214..0000000 --- a/src/cell_load/dataset/h5ad_sentence_dataset.py +++ /dev/null @@ -1,115 +0,0 @@ -import h5py -import logging -import torch -import torch.utils.data as data -import functools -import numpy as np -from typing import Dict -from .. import utils - -log = logging.getLogger(__file__) - -EXPONENTIATED_UMIS_LIMIT = 5_000_000 -RAW_COUNT_HEURISTIC_THRESHOLD = 35 - -class H5adSentenceDataset(data.Dataset): - def __init__(self, cfg, test=False, datasets=None, shape_dict=None, adata=None, adata_name=None) -> None: - super(H5adSentenceDataset, self).__init__() - - self.adata = None - self.adata_name = adata_name - self.test = test - if adata is not None: - self.adata = adata - self.datasets = [adata_name] - self.shapes_dict = {self.datasets[0]: adata.shape} - elif datasets is None: - ds_path = utils.get_dataset_cfg(cfg).train - if test: - ds_path = utils.get_dataset_cfg(cfg).val - _, self.datasets, self.shapes_dict, self.dataset_path_map, self.dataset_group_map = utils.get_shapes_dict( - ds_path, utils.get_dataset_cfg(cfg).get("filter_by_species") - ) - else: - assert shape_dict is not None - assert len(datasets) == len(shape_dict) - self.datasets = datasets - self.shapes_dict = shape_dict - self.dataset_path_map = {dataset: dataset for dataset in datasets} - - self.datasets = sorted(self.datasets) - self.cfg = cfg - - self.num_cells = {} - self.num_genes = {} - - self.total_num_cells = 0 - for name in self.datasets: - num_cells, num_genes = self.shapes_dict[name] - self.num_cells[name] = num_cells - self.num_genes[name] = num_genes - self.total_num_cells += num_cells - - self.datasets_to_num = {k: v for k, v in zip(self.datasets, range(len(self.datasets)))} - - @functools.lru_cache - def dataset_file(self, dataset): - datafile = self.dataset_path_map[dataset] - return h5py.File(datafile, "r") - - def _compute_index(self, idx): - for dataset in self.datasets: - if idx < self.num_cells[dataset]: - return dataset, idx - else: - idx -= self.num_cells[dataset] - raise IndexError - - def __getitem__(self, idx): - if self.adata is not None: - # block is only used during validation - # if .X is a numpy.ndarray - if isinstance(self.adata.X, np.ndarray): - counts = torch.tensor(self.adata.X[idx]).reshape(1, -1) - else: - counts = torch.tensor(self.adata.X[idx].todense()) - - dataset = self.adata_name - dataset_num = 0 - return counts, idx, dataset, dataset_num - - dataset, ds_idx = self._compute_index(idx) - h5f = self.dataset_file(dataset) - attrs = dict(h5f["X"].attrs) - try: - if attrs["encoding-type"] == "csr_matrix": - indptr = h5f["X"].indptr - indices = h5f["X"].indices - data_ = h5f["X"].data - start = indptr[ds_idx] - end = indptr[ds_idx + 1] - sub_indices = torch.tensor(indices[start:end], dtype=torch.int64) - sub_data = torch.tensor(data_[start:end], dtype=torch.float32) - counts = torch.sparse_csr_tensor( - [0, sub_indices.shape[0]], - sub_indices, - sub_data, - (1, self.num_genes[dataset]), - ) - counts = counts.to_dense() - else: - log.info(ds_idx) - counts = torch.tensor(h5f["X"][ds_idx]).unsqueeze(0) - - except Exception as iex: - log.exception(f"Error in dataset {dataset} at index {ds_idx}") - raise iex - - dataset_num = self.datasets_to_num[dataset] - return counts, idx, dataset, dataset_num - - def __len__(self) -> int: - return self.total_num_cells - - def get_dim(self) -> Dict[str, int]: - return self.num_genes diff --git a/tests/test_cell_sentence_dataset.py b/tests/test_cell_sentence_dataset.py new file mode 100644 index 0000000..d80572b --- /dev/null +++ b/tests/test_cell_sentence_dataset.py @@ -0,0 +1,32 @@ +import pytest +import numpy as np +import torch +from cell_load.dataset.cell_sentence_dataset import CellSentenceDataset + +class DummyAdata: + def __init__(self, shape=(3, 5)): + self.X = np.arange(np.prod(shape)).reshape(shape) + self.var_names = [f"gene{i}" for i in range(shape[1])] + self.shape = shape + self.var = {'gene_name': np.array(self.var_names)} + +@pytest.fixture +def dummy_adata(): + return DummyAdata() + +def test_cell_sentence_dataset_basic(dummy_adata): + cfg = type('cfg', (), {})() + cfg.model = type('model', (), {})() + cfg.model.batch_size = 2 + cfg.dataset = type('dataset', (), {})() + cfg.dataset.pad_length = 5 + cfg.dataset.P = 1 + cfg.dataset.N = 1 + cfg.dataset.S = 1 + dataset = CellSentenceDataset(cfg, adata=dummy_adata, adata_name='dummy') + assert len(dataset) == dummy_adata.shape[0] + counts, idx, dataset_name, dataset_num = dataset[0] + assert isinstance(counts, torch.Tensor) + assert counts.shape[1] == dummy_adata.shape[1] + assert dataset_name == 'dummy' + assert dataset_num == 0 diff --git a/tests/test_h5ad_sentence_dataset.py b/tests/test_h5ad_sentence_dataset.py deleted file mode 100644 index 47c2c31..0000000 --- a/tests/test_h5ad_sentence_dataset.py +++ /dev/null @@ -1,36 +0,0 @@ -import numpy as np -import torch -import pytest -from types import SimpleNamespace -from cell_load.dataset.h5ad_sentence_dataset import H5adSentenceDataset - -class DummyAdata: - def __init__(self, X, var_names=None): - self.X = X - self.shape = X.shape - self.var = {'gene_name': np.array(var_names) if var_names is not None else np.array(['a', 'b', 'c'])} - self.var_names = np.array(var_names) if var_names is not None else np.array(['a', 'b', 'c']) - -@pytest.fixture -def dummy_cfg(): - # Minimal config with required attributes - cfg = SimpleNamespace() - cfg.model = SimpleNamespace() - cfg.model.batch_size = 2 - cfg.dataset = SimpleNamespace() - cfg.dataset.pad_length = 3 - cfg.dataset.P = 1 - cfg.dataset.N = 1 - cfg.dataset.S = 1 - return cfg - -def test_h5ad_sentence_dataset_with_adata(dummy_cfg): - X = np.array([[1, 2, 3], [4, 5, 6]]) - adata = DummyAdata(X, var_names=['a', 'b', 'c']) - ds = H5adSentenceDataset(cfg=dummy_cfg, adata=adata, adata_name='dummy') - assert len(ds) == 2 - counts, idx, dataset, dataset_num = ds[0] - assert isinstance(counts, torch.Tensor) - assert counts.shape[1] == 3 - assert dataset == 'dummy' - assert dataset_num == 0 From e1a693a7e199c045ca028effc217e810a356c0bc Mon Sep 17 00:00:00 2001 From: Sravya Tirukkovalur Date: Fri, 6 Jun 2025 10:54:29 -0700 Subject: [PATCH 3/7] Ruff formating --- .../data_modules/cell_sentence_dataloader.py | 18 +- .../dataset/cell_sentence_dataset.py | 161 ++++++++++++++---- tests/test_cell_sentence_dataset.py | 15 +- 3 files changed, 155 insertions(+), 39 deletions(-) diff --git a/src/cell_load/data_modules/cell_sentence_dataloader.py b/src/cell_load/data_modules/cell_sentence_dataloader.py index d0707b4..edc2a0e 100644 --- a/src/cell_load/data_modules/cell_sentence_dataloader.py +++ b/src/cell_load/data_modules/cell_sentence_dataloader.py @@ -3,6 +3,7 @@ from cell_load.dataset.cell_sentence_dataset import FilteredGenesCounts from cell_load.dataset.cell_sentence_dataset import CellSentenceCollator + def create_cell_sentence_dataloader( cfg, workers=1, @@ -19,15 +20,26 @@ def create_cell_sentence_dataloader( Either datasets and shape_dict or adata and adata_name should be provided """ if datasets is None and adata is None: - raise ValueError("Either datasets and shape_dict or adata and adata_name should be provided") + raise ValueError( + "Either datasets and shape_dict or adata and adata_name should be provided" + ) if adata is not None: shuffle = False if data_dir: cfg.model.data_dir = data_dir - dataset = FilteredGenesCounts(cfg, datasets=datasets, shape_dict=shape_dict, adata=adata, adata_name=adata_name) + dataset = FilteredGenesCounts( + cfg, + datasets=datasets, + shape_dict=shape_dict, + adata=adata, + adata_name=adata_name, + ) if sentence_collator is None: sentence_collator = CellSentenceCollator( - cfg, valid_gene_mask=dataset.valid_gene_index, ds_emb_mapping_inference=dataset.ds_emb_map, is_train=False + cfg, + valid_gene_mask=dataset.valid_gene_index, + ds_emb_mapping_inference=dataset.ds_emb_map, + is_train=False, ) sentence_collator.training = False dataloader = DataLoader( diff --git a/src/cell_load/dataset/cell_sentence_dataset.py b/src/cell_load/dataset/cell_sentence_dataset.py index ff20507..3cf2b1a 100644 --- a/src/cell_load/dataset/cell_sentence_dataset.py +++ b/src/cell_load/dataset/cell_sentence_dataset.py @@ -9,8 +9,17 @@ log = logging.getLogger(__file__) + class CellSentenceDataset(data.Dataset): - def __init__(self, cfg, test=False, datasets=None, shape_dict=None, adata=None, adata_name=None) -> None: + def __init__( + self, + cfg, + test=False, + datasets=None, + shape_dict=None, + adata=None, + adata_name=None, + ) -> None: super().__init__() self.adata = None self.adata_name = adata_name @@ -23,7 +32,13 @@ def __init__(self, cfg, test=False, datasets=None, shape_dict=None, adata=None, ds_path = utils.get_dataset_cfg(cfg).train if test: ds_path = utils.get_dataset_cfg(cfg).val - _, self.datasets, self.shapes_dict, self.dataset_path_map, self.dataset_group_map = utils.get_shapes_dict( + ( + _, + self.datasets, + self.shapes_dict, + self.dataset_path_map, + self.dataset_group_map, + ) = utils.get_shapes_dict( ds_path, utils.get_dataset_cfg(cfg).get("filter_by_species") ) else: @@ -42,11 +57,15 @@ def __init__(self, cfg, test=False, datasets=None, shape_dict=None, adata=None, self.num_cells[name] = num_cells self.num_genes[name] = num_genes self.total_num_cells += num_cells - self.datasets_to_num = {k: v for k, v in zip(self.datasets, range(len(self.datasets)))} + self.datasets_to_num = { + k: v for k, v in zip(self.datasets, range(len(self.datasets))) + } + @functools.lru_cache def dataset_file(self, dataset): datafile = self.dataset_path_map[dataset] return h5py.File(datafile, "r") + def _compute_index(self, idx): for dataset in self.datasets: if idx < self.num_cells[dataset]: @@ -54,6 +73,7 @@ def _compute_index(self, idx): else: idx -= self.num_cells[dataset] raise IndexError + def __getitem__(self, idx): if self.adata is not None: if isinstance(self.adata.X, np.ndarray): @@ -71,8 +91,12 @@ def __getitem__(self, idx): indptrs = h5f["/X/indptr"] start_ptr = indptrs[ds_idx] end_ptr = indptrs[ds_idx + 1] - sub_data = torch.tensor(h5f["/X/data"][start_ptr:end_ptr], dtype=torch.float) - sub_indices = torch.tensor(h5f["/X/indices"][start_ptr:end_ptr], dtype=torch.int32) + sub_data = torch.tensor( + h5f["/X/data"][start_ptr:end_ptr], dtype=torch.float + ) + sub_indices = torch.tensor( + h5f["/X/indices"][start_ptr:end_ptr], dtype=torch.int32 + ) counts = torch.sparse_csr_tensor( [0], sub_indices, @@ -88,18 +112,33 @@ def __getitem__(self, idx): raise iex dataset_num = self.datasets_to_num[dataset] return counts, idx, dataset, dataset_num + def __len__(self) -> int: return self.total_num_cells + def get_dim(self) -> Dict[str, int]: return self.num_genes + class FilteredGenesCounts(CellSentenceDataset): - def __init__(self, cfg, test=False, datasets=None, shape_dict=None, adata=None, adata_name=None) -> None: + def __init__( + self, + cfg, + test=False, + datasets=None, + shape_dict=None, + adata=None, + adata_name=None, + ) -> None: super().__init__(cfg, test, datasets, shape_dict, adata, adata_name) self.valid_gene_index = {} - _, self.datasets, self.shapes_dict, self.dataset_path_map, self.dataset_group_map = utils.get_shapes_dict( - "/home/aadduri/state/h5ad_all.csv" - ) + ( + _, + self.datasets, + self.shapes_dict, + self.dataset_path_map, + self.dataset_group_map, + ) = utils.get_shapes_dict("/home/aadduri/state/h5ad_all.csv") emb_cfg = utils.get_embedding_cfg(self.cfg) try: self.ds_emb_map = torch.load(emb_cfg.ds_emb_mapping, weights_only=False) @@ -118,7 +157,9 @@ def __init__(self, cfg, test=False, datasets=None, shape_dict=None, adata=None, new_mapping = np.array([global_pos.get(g, -1) for g in gene_names]) self.ds_emb_map[adata_name] = new_mapping if utils.get_embedding_cfg(self.cfg).ds_emb_mapping is not None: - esm_data = torch.load(utils.get_embedding_cfg(self.cfg)["all_embeddings"], weights_only=False) + esm_data = torch.load( + utils.get_embedding_cfg(self.cfg)["all_embeddings"], weights_only=False + ) valid_genes_list = list(esm_data.keys()) for name in self.datasets: if not utils.is_valid_uuid(name): @@ -131,7 +172,9 @@ def __init__(self, cfg, test=False, datasets=None, shape_dict=None, adata=None, except: gene_categories = a["/var/gene_name/categories"][:] gene_codes = np.array(a["/var/gene_name/codes"][:]) - gene_names = np.array([g.decode("utf-8") for g in gene_categories[gene_codes]]) + gene_names = np.array( + [g.decode("utf-8") for g in gene_categories[gene_codes]] + ) valid_mask = np.isin(gene_names, valid_genes_list) self.valid_gene_index[name] = valid_mask else: @@ -142,8 +185,11 @@ def __init__(self, cfg, test=False, datasets=None, shape_dict=None, adata=None, valid_mask = np.isin(gene_names, valid_genes_list) self.valid_gene_index[name] = valid_mask + class CellSentenceCollator(object): - def __init__(self, cfg, valid_gene_mask=None, ds_emb_mapping_inference=None, is_train=True): + def __init__( + self, cfg, valid_gene_mask=None, ds_emb_mapping_inference=None, is_train=True + ): self.pad_length = cfg.dataset.pad_length self.P = cfg.dataset.P self.N = cfg.dataset.N @@ -162,7 +208,9 @@ def __init__(self, cfg, valid_gene_mask=None, ds_emb_mapping_inference=None, is_ else: self.valid_gene_mask = None self.dataset_to_protein_embeddings = torch.load( - utils.get_embedding_cfg(self.cfg).ds_emb_mapping.format(utils.get_embedding_cfg(self.cfg).size), + utils.get_embedding_cfg(self.cfg).ds_emb_mapping.format( + utils.get_embedding_cfg(self.cfg).size + ), weights_only=False, ) self.global_size = utils.get_embedding_cfg(self.cfg).num @@ -175,6 +223,7 @@ def __init__(self, cfg, valid_gene_mask=None, ds_emb_mapping_inference=None, is_ reverse_mapping[ds_emb_idxs[mask]] = local_indices[mask] self.global_to_local[dataset_name] = reverse_mapping print(len(self.global_to_local)) + def __call__(self, batch): num_aug = getattr(self.cfg.model, "num_downsample", 1) if num_aug > 1 and self.training: @@ -193,7 +242,7 @@ def __call__(self, batch): dataset_nums = torch.zeros(batch_size, dtype=torch.int32) total_counts_all = torch.zeros(batch_size) for i, (counts, idx, dataset, dataset_num) in enumerate(batch): - batch_sentences[i, :counts.shape[1]] = counts.squeeze() + batch_sentences[i, : counts.shape[1]] = counts.squeeze() idxs[i] = idx dataset_nums[i] = dataset_num return ( @@ -204,11 +253,21 @@ def __call__(self, batch): batch_weights, masks, total_counts_all if getattr(self.cfg.model, "rda", False) else None, - batch_sentences_counts if getattr(self.cfg.model, "counts", False) else None, + batch_sentences_counts + if getattr(self.cfg.model, "counts", False) + else None, dataset_nums if self.use_dataset_info else None, ) - def __init__(self, cfg, test=False, datasets=None, shape_dict=None, adata=None, adata_name=None) -> None: + def __init__( + self, + cfg, + test=False, + datasets=None, + shape_dict=None, + adata=None, + adata_name=None, + ) -> None: super().__init__() self.adata = None self.adata_name = adata_name @@ -221,7 +280,13 @@ def __init__(self, cfg, test=False, datasets=None, shape_dict=None, adata=None, ds_path = utils.get_dataset_cfg(cfg).train if test: ds_path = utils.get_dataset_cfg(cfg).val - _, self.datasets, self.shapes_dict, self.dataset_path_map, self.dataset_group_map = utils.get_shapes_dict( + ( + _, + self.datasets, + self.shapes_dict, + self.dataset_path_map, + self.dataset_group_map, + ) = utils.get_shapes_dict( ds_path, utils.get_dataset_cfg(cfg).get("filter_by_species") ) else: @@ -240,11 +305,15 @@ def __init__(self, cfg, test=False, datasets=None, shape_dict=None, adata=None, self.num_cells[name] = num_cells self.num_genes[name] = num_genes self.total_num_cells += num_cells - self.datasets_to_num = {k: v for k, v in zip(self.datasets, range(len(self.datasets)))} + self.datasets_to_num = { + k: v for k, v in zip(self.datasets, range(len(self.datasets))) + } + @functools.lru_cache def dataset_file(self, dataset): datafile = self.dataset_path_map[dataset] return h5py.File(datafile, "r") + def _compute_index(self, idx): for dataset in self.datasets: if idx < self.num_cells[dataset]: @@ -252,6 +321,7 @@ def _compute_index(self, idx): else: idx -= self.num_cells[dataset] raise IndexError + def __getitem__(self, idx): if self.adata is not None: if isinstance(self.adata.X, np.ndarray): @@ -269,8 +339,12 @@ def __getitem__(self, idx): indptrs = h5f["/X/indptr"] start_ptr = indptrs[ds_idx] end_ptr = indptrs[ds_idx + 1] - sub_data = torch.tensor(h5f["/X/data"][start_ptr:end_ptr], dtype=torch.float) - sub_indices = torch.tensor(h5f["/X/indices"][start_ptr:end_ptr], dtype=torch.int32) + sub_data = torch.tensor( + h5f["/X/data"][start_ptr:end_ptr], dtype=torch.float + ) + sub_indices = torch.tensor( + h5f["/X/indices"][start_ptr:end_ptr], dtype=torch.int32 + ) counts = torch.sparse_csr_tensor( [0], sub_indices, @@ -286,13 +360,18 @@ def __getitem__(self, idx): raise iex dataset_num = self.datasets_to_num[dataset] return counts, idx, dataset, dataset_num + def __len__(self) -> int: return self.total_num_cells + def get_dim(self) -> Dict[str, int]: return self.num_genes + class CellSentenceCollator(object): - def __init__(self, cfg, valid_gene_mask=None, ds_emb_mapping_inference=None, is_train=True): + def __init__( + self, cfg, valid_gene_mask=None, ds_emb_mapping_inference=None, is_train=True + ): self.pad_length = cfg.dataset.pad_length self.P = cfg.dataset.P self.N = cfg.dataset.N @@ -311,7 +390,9 @@ def __init__(self, cfg, valid_gene_mask=None, ds_emb_mapping_inference=None, is_ else: self.valid_gene_mask = None self.dataset_to_protein_embeddings = torch.load( - utils.get_embedding_cfg(self.cfg).ds_emb_mapping.format(utils.get_embedding_cfg(self.cfg).size), + utils.get_embedding_cfg(self.cfg).ds_emb_mapping.format( + utils.get_embedding_cfg(self.cfg).size + ), weights_only=False, ) self.global_size = utils.get_embedding_cfg(self.cfg).num @@ -324,6 +405,7 @@ def __init__(self, cfg, valid_gene_mask=None, ds_emb_mapping_inference=None, is_ reverse_mapping[ds_emb_idxs[mask]] = local_indices[mask] self.global_to_local[dataset_name] = reverse_mapping print(len(self.global_to_local)) + def __call__(self, batch): num_aug = getattr(self.cfg.model, "num_downsample", 1) if num_aug > 1 and self.training: @@ -342,7 +424,7 @@ def __call__(self, batch): dataset_nums = torch.zeros(batch_size, dtype=torch.int32) total_counts_all = torch.zeros(batch_size) for i, (counts, idx, dataset, dataset_num) in enumerate(batch): - batch_sentences[i, :counts.shape[1]] = counts.squeeze() + batch_sentences[i, : counts.shape[1]] = counts.squeeze() idxs[i] = idx dataset_nums[i] = dataset_num return ( @@ -353,17 +435,32 @@ def __call__(self, batch): batch_weights, masks, total_counts_all if getattr(self.cfg.model, "rda", False) else None, - batch_sentences_counts if getattr(self.cfg.model, "counts", False) else None, + batch_sentences_counts + if getattr(self.cfg.model, "counts", False) + else None, dataset_nums if self.use_dataset_info else None, ) + class FilteredGenesCounts(CellSentenceDataset): - def __init__(self, cfg, test=False, datasets=None, shape_dict=None, adata=None, adata_name=None) -> None: + def __init__( + self, + cfg, + test=False, + datasets=None, + shape_dict=None, + adata=None, + adata_name=None, + ) -> None: super().__init__(cfg, test, datasets, shape_dict, adata, adata_name) self.valid_gene_index = {} - _, self.datasets, self.shapes_dict, self.dataset_path_map, self.dataset_group_map = utils.get_shapes_dict( - "/home/aadduri/state/h5ad_all.csv" - ) + ( + _, + self.datasets, + self.shapes_dict, + self.dataset_path_map, + self.dataset_group_map, + ) = utils.get_shapes_dict("/home/aadduri/state/h5ad_all.csv") emb_cfg = utils.get_embedding_cfg(self.cfg) try: self.ds_emb_map = torch.load(emb_cfg.ds_emb_mapping, weights_only=False) @@ -382,7 +479,9 @@ def __init__(self, cfg, test=False, datasets=None, shape_dict=None, adata=None, new_mapping = np.array([global_pos.get(g, -1) for g in gene_names]) self.ds_emb_map[adata_name] = new_mapping if utils.get_embedding_cfg(self.cfg).ds_emb_mapping is not None: - esm_data = torch.load(utils.get_embedding_cfg(self.cfg)["all_embeddings"], weights_only=False) + esm_data = torch.load( + utils.get_embedding_cfg(self.cfg)["all_embeddings"], weights_only=False + ) valid_genes_list = list(esm_data.keys()) for name in self.datasets: if not utils.is_valid_uuid(name): @@ -395,7 +494,9 @@ def __init__(self, cfg, test=False, datasets=None, shape_dict=None, adata=None, except: gene_categories = a["/var/gene_name/categories"][:] gene_codes = np.array(a["/var/gene_name/codes"][:]) - gene_names = np.array([g.decode("utf-8") for g in gene_categories[gene_codes]]) + gene_names = np.array( + [g.decode("utf-8") for g in gene_categories[gene_codes]] + ) valid_mask = np.isin(gene_names, valid_genes_list) self.valid_gene_index[name] = valid_mask else: diff --git a/tests/test_cell_sentence_dataset.py b/tests/test_cell_sentence_dataset.py index d80572b..6019fcb 100644 --- a/tests/test_cell_sentence_dataset.py +++ b/tests/test_cell_sentence_dataset.py @@ -3,30 +3,33 @@ import torch from cell_load.dataset.cell_sentence_dataset import CellSentenceDataset + class DummyAdata: def __init__(self, shape=(3, 5)): self.X = np.arange(np.prod(shape)).reshape(shape) self.var_names = [f"gene{i}" for i in range(shape[1])] self.shape = shape - self.var = {'gene_name': np.array(self.var_names)} + self.var = {"gene_name": np.array(self.var_names)} + @pytest.fixture def dummy_adata(): return DummyAdata() + def test_cell_sentence_dataset_basic(dummy_adata): - cfg = type('cfg', (), {})() - cfg.model = type('model', (), {})() + cfg = type("cfg", (), {})() + cfg.model = type("model", (), {})() cfg.model.batch_size = 2 - cfg.dataset = type('dataset', (), {})() + cfg.dataset = type("dataset", (), {})() cfg.dataset.pad_length = 5 cfg.dataset.P = 1 cfg.dataset.N = 1 cfg.dataset.S = 1 - dataset = CellSentenceDataset(cfg, adata=dummy_adata, adata_name='dummy') + dataset = CellSentenceDataset(cfg, adata=dummy_adata, adata_name="dummy") assert len(dataset) == dummy_adata.shape[0] counts, idx, dataset_name, dataset_num = dataset[0] assert isinstance(counts, torch.Tensor) assert counts.shape[1] == dummy_adata.shape[1] - assert dataset_name == 'dummy' + assert dataset_name == "dummy" assert dataset_num == 0 From 8e3079f655bc2abf752fd47a27177c0f322a4670 Mon Sep 17 00:00:00 2001 From: Sravya Tirukkovalur Date: Fri, 6 Jun 2025 11:12:02 -0700 Subject: [PATCH 4/7] Remove duplicate code --- .../data_modules/cell_sentence_dataloader.py | 9 +- .../dataset/cell_sentence_dataset.py | 761 ++++++++++++------ 2 files changed, 520 insertions(+), 250 deletions(-) diff --git a/src/cell_load/data_modules/cell_sentence_dataloader.py b/src/cell_load/data_modules/cell_sentence_dataloader.py index edc2a0e..b675022 100644 --- a/src/cell_load/data_modules/cell_sentence_dataloader.py +++ b/src/cell_load/data_modules/cell_sentence_dataloader.py @@ -4,7 +4,7 @@ from cell_load.dataset.cell_sentence_dataset import CellSentenceCollator -def create_cell_sentence_dataloader( +def create_dataloader( cfg, workers=1, data_dir=None, @@ -23,10 +23,14 @@ def create_cell_sentence_dataloader( raise ValueError( "Either datasets and shape_dict or adata and adata_name should be provided" ) + if adata is not None: shuffle = False + if data_dir: cfg.model.data_dir = data_dir + # ? utils.get_dataset_cfg(cfg).data_dir = data_dir + dataset = FilteredGenesCounts( cfg, datasets=datasets, @@ -41,7 +45,10 @@ def create_cell_sentence_dataloader( ds_emb_mapping_inference=dataset.ds_emb_map, is_train=False, ) + + # validation should not use cell augmentations sentence_collator.training = False + dataloader = DataLoader( dataset, batch_size=cfg.model.batch_size, diff --git a/src/cell_load/dataset/cell_sentence_dataset.py b/src/cell_load/dataset/cell_sentence_dataset.py index 3cf2b1a..f6638d4 100644 --- a/src/cell_load/dataset/cell_sentence_dataset.py +++ b/src/cell_load/dataset/cell_sentence_dataset.py @@ -1,3 +1,11 @@ +""" +cell_sentence_dataset.py + +Dataset classes and utilities for handling single-cell gene expression data in sentence form. + +This module provides PyTorch Dataset classes and related utilities for loading, filtering, and collating single-cell gene expression data and integrating gene/protein embeddings. +""" + import h5py import logging import torch @@ -11,6 +19,18 @@ class CellSentenceDataset(data.Dataset): + """ + PyTorch Dataset for single-cell gene expression data, supporting multiple datasets and formats. + + Args: + cfg: Configuration object. + test (bool): If True, uses validation data. + datasets (list, optional): List of dataset names/IDs. + shape_dict (dict, optional): Mapping from dataset name to shape (num_cells, num_genes). + adata (AnnData, optional): AnnData object for in-memory data. + adata_name (str, optional): Name/ID for the AnnData object. + """ + def __init__( self, cfg, @@ -47,25 +67,25 @@ def __init__( self.datasets = datasets self.shapes_dict = shape_dict self.dataset_path_map = {dataset: dataset for dataset in datasets} + self.datasets = sorted(self.datasets) self.cfg = cfg + self.num_cells = {} self.num_genes = {} + self.total_num_cells = 0 for name in self.datasets: num_cells, num_genes = self.shapes_dict[name] self.num_cells[name] = num_cells self.num_genes[name] = num_genes + self.total_num_cells += num_cells + self.datasets_to_num = { k: v for k, v in zip(self.datasets, range(len(self.datasets))) } - @functools.lru_cache - def dataset_file(self, dataset): - datafile = self.dataset_path_map[dataset] - return h5py.File(datafile, "r") - def _compute_index(self, idx): for dataset in self.datasets: if idx < self.num_cells[dataset]: @@ -74,15 +94,24 @@ def _compute_index(self, idx): idx -= self.num_cells[dataset] raise IndexError + @functools.lru_cache + def dataset_file(self, dataset): + datafile = self.dataset_path_map[dataset] + return h5py.File(datafile, "r") + def __getitem__(self, idx): if self.adata is not None: + # block is only used during validation + # if .X is a numpy.ndarray if isinstance(self.adata.X, np.ndarray): counts = torch.tensor(self.adata.X[idx]).reshape(1, -1) else: counts = torch.tensor(self.adata.X[idx].todense()) + dataset = self.adata_name dataset_num = 0 return counts, idx, dataset, dataset_num + dataset, ds_idx = self._compute_index(idx) h5f = self.dataset_file(dataset) attrs = dict(h5f["X"].attrs) @@ -97,8 +126,11 @@ def __getitem__(self, idx): sub_indices = torch.tensor( h5f["/X/indices"][start_ptr:end_ptr], dtype=torch.int32 ) + counts = torch.sparse_csr_tensor( - [0], + [ + 0, + ], sub_indices, sub_data, (1, self.num_genes[dataset]), @@ -107,9 +139,11 @@ def __getitem__(self, idx): else: log.info(ds_idx) counts = torch.tensor(h5f["X"][ds_idx]).unsqueeze(0) + except Exception as iex: log.exception(f"Error in dataset {dataset} at index {ds_idx}") raise iex + dataset_num = self.datasets_to_num[dataset] return counts, idx, dataset, dataset_num @@ -121,6 +155,12 @@ def get_dim(self) -> Dict[str, int]: class FilteredGenesCounts(CellSentenceDataset): + """ + Dataset class that filters genes based on available protein/gene embeddings. + + Extends CellSentenceDataset to provide valid gene indices and embedding mappings for each dataset. + """ + def __init__( self, cfg, @@ -132,6 +172,8 @@ def __init__( ) -> None: super().__init__(cfg, test, datasets, shape_dict, adata, adata_name) self.valid_gene_index = {} + + # make sure we get training datasets ( _, self.datasets, @@ -139,36 +181,52 @@ def __init__( self.dataset_path_map, self.dataset_group_map, ) = utils.get_shapes_dict("/home/aadduri/state/h5ad_all.csv") + emb_cfg = utils.get_embedding_cfg(self.cfg) try: self.ds_emb_map = torch.load(emb_cfg.ds_emb_mapping, weights_only=False) except (FileNotFoundError, IOError): self.ds_emb_map = {} + + # for inference, let's make sure this dataset's valid mask is available if adata_name is not None: + # append it to self.datasets self.datasets.append(adata_name) self.shapes_dict[adata_name] = adata.shape + + # compute its embedding‐index vector esm_data = torch.load(emb_cfg.all_embeddings, weights_only=False) valid_genes_list = list(esm_data.keys()) + # make a gene→global‐index lookup global_pos = {g: i for i, g in enumerate(valid_genes_list)} + + # grab var_names from the AnnData gene_names = np.array(adata.var_names) + + # for each gene in this dataset, find its global idx or -1 if missing new_mapping = np.array([global_pos.get(g, -1) for g in gene_names]) if (new_mapping == -1).all(): + # probably it contains ensembl id's instead gene_names = adata.var["gene_name"].values new_mapping = np.array([global_pos.get(g, -1) for g in gene_names]) + self.ds_emb_map[adata_name] = new_mapping + if utils.get_embedding_cfg(self.cfg).ds_emb_mapping is not None: esm_data = torch.load( utils.get_embedding_cfg(self.cfg)["all_embeddings"], weights_only=False ) valid_genes_list = list(esm_data.keys()) for name in self.datasets: - if not utils.is_valid_uuid(name): + if not utils.is_valid_uuid( + name + ): # had to add this in for now as cellxgene h5ad fles don't have gene_name object but tahoe does if adata is None: a = self.dataset_file(name) try: gene_names = np.array( [g.decode("utf-8") for g in a["/var/gene_name"][:]] - ) + ) # Decode byte strings except: gene_categories = a["/var/gene_name/categories"][:] gene_codes = np.array(a["/var/gene_name/codes"][:]) @@ -180,329 +238,534 @@ def __init__( else: gene_names = np.array(adata.var_names) valid_mask = np.isin(gene_names, valid_genes_list) + if not valid_mask.any(): + # none of the genes were valid, probably ensembl id's gene_names = adata.var["gene_name"].values valid_mask = np.isin(gene_names, valid_genes_list) + self.valid_gene_index[name] = valid_mask + def __getitem__(self, idx): + counts, idx, dataset, dataset_num = super().__getitem__(idx) + return counts, idx, dataset, dataset_num + class CellSentenceCollator(object): + """ + Collate function for batching single-cell gene expression data with gene/protein embedding support. + """ + def __init__( self, cfg, valid_gene_mask=None, ds_emb_mapping_inference=None, is_train=True ): + """ + Initialize the CellSentenceCollator. + + Args: + cfg: Configuration object. + valid_gene_mask (dict, optional): Mapping from dataset name to valid gene mask arrays. + ds_emb_mapping_inference (dict, optional): Dataset-to-embedding index mapping. + is_train (bool): Whether collator is for training or inference. + """ self.pad_length = cfg.dataset.pad_length self.P = cfg.dataset.P self.N = cfg.dataset.N self.S = cfg.dataset.S self.cfg = cfg self.training = is_train + + # Load the dataset mappings self.use_dataset_info = getattr(cfg.model, "dataset_correction", False) self.batch_tabular_loss = getattr(cfg.model, "batch_tabular_loss", False) + if valid_gene_mask is not None: + # this branch is for inference self.valid_gene_mask = valid_gene_mask self.dataset_to_protein_embeddings = ds_emb_mapping_inference else: + # otherwise for training, load from config gene_mask_file = utils.get_embedding_cfg(self.cfg).valid_genes_masks if gene_mask_file is not None: + # we have a config for training self.valid_gene_mask = torch.load(gene_mask_file, weights_only=False) else: + # we don't have a config for training self.valid_gene_mask = None + self.dataset_to_protein_embeddings = torch.load( utils.get_embedding_cfg(self.cfg).ds_emb_mapping.format( utils.get_embedding_cfg(self.cfg).size ), weights_only=False, ) + self.global_size = utils.get_embedding_cfg(self.cfg).num self.global_to_local = {} for dataset_name, ds_emb_idxs in self.dataset_to_protein_embeddings.items(): + # make sure tensor with long data type ds_emb_idxs = torch.tensor(ds_emb_idxs, dtype=torch.long) + # assert ds_emb_idxs.unique().numel() == ds_emb_idxs.numel(), f"duplicate global IDs in dataset {dataset_name}!" + + # Create a tensor filled with -1 (indicating not present in this dataset) reverse_mapping = torch.full((self.global_size,), -1, dtype=torch.int64) + local_indices = torch.arange(ds_emb_idxs.size(0), dtype=torch.int64) mask = (ds_emb_idxs >= 0) & (ds_emb_idxs < self.global_size) reverse_mapping[ds_emb_idxs[mask]] = local_indices[mask] self.global_to_local[dataset_name] = reverse_mapping + print(len(self.global_to_local)) def __call__(self, batch): num_aug = getattr(self.cfg.model, "num_downsample", 1) if num_aug > 1 and self.training: + # for each original sample, duplicate it num_aug times batch = [item for item in batch for _ in range(num_aug)] + batch_size = len(batch) + batch_sentences = torch.zeros((batch_size, self.pad_length), dtype=torch.int32) batch_sentences_counts = torch.zeros((batch_size, self.pad_length)) masks = torch.zeros((batch_size, self.pad_length), dtype=torch.bool) + idxs = torch.zeros(batch_size, dtype=torch.int32) if self.cfg.loss.name == "tabular": - Xs = torch.zeros((batch_size, self.pad_length, self.P)) - Ys = torch.zeros((batch_size, self.pad_length, self.N)) - batch_weights = torch.ones((batch_size, self.pad_length)) + task_num = self.P + self.N + self.S + else: + task_num = self.P + self.N + Xs = torch.zeros((batch_size, (task_num)), dtype=torch.int32) + Ys = torch.zeros((batch_size, (task_num))) + + largest_cnt = max([x[0].shape[1] for x in batch]) + batch_weights = torch.zeros((batch_size, largest_cnt)) + + total_counts_all = None + if self.cfg.model.rda: + total_counts_all = torch.zeros(batch_size) + + datasets = [] + for ( + _, + _, + ds_name, + _, + ) in batch: + datasets.append(ds_name) + + if self.cfg.loss.name == "tabular": + if "batch_tabular_loss" in self.__dict__ and self.batch_tabular_loss: + # Find genes shared across all datasets + shared_mask = None + for dataset in datasets: + dataset_mask = self.global_to_local[dataset] >= 0 + if shared_mask is None: + shared_mask = dataset_mask + else: + shared_mask &= dataset_mask + + # Get indices of shared genes + shared_indices = torch.where(shared_mask)[0] + + # Repeat shared genes to reach size S + n_shared = shared_indices.size(0) + if n_shared > 0: + # Calculate how many times to repeat and remainder + repeats = self.S // n_shared + remainder = self.S % n_shared + + # Repeat the full sequence + shared_genes = shared_indices.repeat(repeats) + + # Add remaining genes needed + if remainder > 0: + shared_genes = torch.cat( + [shared_genes, shared_indices[:remainder]] + ) + else: + # If no shared genes, sample randomly from global gene space + shared_genes = torch.randint( + low=0, + high=self.global_size, + size=(self.S,), + device=masks.device, + dtype=torch.long, + ) + else: + if "global_size" not in self.__dict__: + self.global_size = utils.get_embedding_cfg(self.cfg).num + shared_genes = torch.randint( + low=0, + high=self.global_size, + size=(self.S,), + device=masks.device, + dtype=torch.long, + ) else: - Xs = Ys = batch_weights = None + shared_genes = None + dataset_nums = torch.zeros(batch_size, dtype=torch.int32) - total_counts_all = torch.zeros(batch_size) - for i, (counts, idx, dataset, dataset_num) in enumerate(batch): - batch_sentences[i, : counts.shape[1]] = counts.squeeze() + + i = 0 + max_len = 0 + for counts, idx, dataset, dataset_num in batch: + if self.valid_gene_mask is not None: + if dataset in self.valid_gene_mask: + valid_mask = self.valid_gene_mask[dataset] + else: + valid_mask = None + else: + valid_mask = None + + # compute downsample fraction. this is the first sample of the augmentation then + # use no downsampling + downsample_fraction = ( + 1.0 if (num_aug > 1 and i % num_aug == 0 and self.training) else None + ) + ( + bs, + xx, + yy, + batch_weight, + mask, + cell_total_counts, + cell_sentence_counts, + ) = self.sample_cell_sentences( + counts, dataset, shared_genes, valid_mask, downsample_fraction + ) + + batch_sentences[i, :] = bs + masks[i, :] = mask + batch_weight = batch_weight.squeeze() + batch_weights[i, : len(batch_weight)] = batch_weight + + max_len = max(max_len, self.cfg.dataset.pad_length) idxs[i] = idx + + Xs[i] = xx # [pn_idx] + Ys[i] = yy.squeeze() # [pn_idx] dataset_nums[i] = dataset_num + + if self.cfg.model.rda and cell_total_counts is not None: + total_counts_all[i] = cell_total_counts[0] + if self.cfg.model.counts and cell_sentence_counts is not None: + batch_sentences_counts[i, :] = cell_sentence_counts + i += 1 + return ( - batch_sentences, + batch_sentences[:, :max_len], Xs, Ys, idxs, batch_weights, masks, - total_counts_all if getattr(self.cfg.model, "rda", False) else None, - batch_sentences_counts - if getattr(self.cfg.model, "counts", False) - else None, + total_counts_all if self.cfg.model.rda else None, + batch_sentences_counts if self.cfg.model.counts else None, dataset_nums if self.use_dataset_info else None, ) - def __init__( + def softmax(self, x): + e_x = np.exp(x - np.max(x)) + return e_x / e_x.sum() + + def is_raw_integer_counts(self, counts: torch.Tensor) -> bool: + """ + Heuristic check to decide whether `counts` are raw integer UMI counts + versus log1p-transformed counts. + + 1. If any entry > RAW_COUNT_HEURISTIC_THRESHOLD, assume raw ints. + 2. Otherwise, invert log1p (via expm1) and sum: + - If the total UMIs exceeds EXPONENTIATED_UMIS_LIMIT, it means + the data were actually raw ints that we mistakenly log-transformed. + - Otherwise, assume the data were correctly log1p counts. + """ + max_val = torch.max(counts).item() + + # Primary heuristic: very large individual counts => raw counts + if max_val > RAW_COUNT_HEURISTIC_THRESHOLD: + return True + + # Ambiguous case: try undoing log1p + total_umis = int(torch.expm1(counts).sum().item()) + if total_umis > EXPONENTIATED_UMIS_LIMIT: + return True + + return False + + # sampling a single cell sentence + # counts_raw is a view of a cell + def sample_cell_sentences( self, - cfg, - test=False, - datasets=None, - shape_dict=None, - adata=None, - adata_name=None, - ) -> None: - super().__init__() - self.adata = None - self.adata_name = adata_name - self.test = test - if adata is not None: - self.adata = adata - self.datasets = [adata_name] - self.shapes_dict = {self.datasets[0]: adata.shape} - elif datasets is None: - ds_path = utils.get_dataset_cfg(cfg).train - if test: - ds_path = utils.get_dataset_cfg(cfg).val - ( - _, - self.datasets, - self.shapes_dict, - self.dataset_path_map, - self.dataset_group_map, - ) = utils.get_shapes_dict( - ds_path, utils.get_dataset_cfg(cfg).get("filter_by_species") - ) - else: - assert shape_dict is not None - assert len(datasets) == len(shape_dict) - self.datasets = datasets - self.shapes_dict = shape_dict - self.dataset_path_map = {dataset: dataset for dataset in datasets} - self.datasets = sorted(self.datasets) - self.cfg = cfg - self.num_cells = {} - self.num_genes = {} - self.total_num_cells = 0 - for name in self.datasets: - num_cells, num_genes = self.shapes_dict[name] - self.num_cells[name] = num_cells - self.num_genes[name] = num_genes - self.total_num_cells += num_cells - self.datasets_to_num = { - k: v for k, v in zip(self.datasets, range(len(self.datasets))) - } + counts_raw, + dataset, + shared_genes=None, + valid_gene_mask=None, + downsample_frac=None, + ): + if torch.isnan(counts_raw).any(): + log.error(f"NaN values in counts for dataset {dataset}") - @functools.lru_cache - def dataset_file(self, dataset): - datafile = self.dataset_path_map[dataset] - return h5py.File(datafile, "r") + if torch.any(counts_raw < 0): + counts_raw = F.relu(counts_raw) - def _compute_index(self, idx): - for dataset in self.datasets: - if idx < self.num_cells[dataset]: - return dataset, idx - else: - idx -= self.num_cells[dataset] - raise IndexError + if self.is_raw_integer_counts(counts_raw): # CAN WE CHANGE THIS TO INT VS REAL + total_umis = int(counts_raw.sum(axis=1).item()) + count_expr_dist = counts_raw / counts_raw.sum(axis=1, keepdim=True) + counts_raw = torch.log1p(counts_raw) + else: # counts are already log1p + exp_log_counts = torch.expm1(counts_raw) + total_umis = int(exp_log_counts.sum(axis=1).item()) + count_expr_dist = exp_log_counts / exp_log_counts.sum(axis=1, keepdim=True) - def __getitem__(self, idx): - if self.adata is not None: - if isinstance(self.adata.X, np.ndarray): - counts = torch.tensor(self.adata.X[idx]).reshape(1, -1) - else: - counts = torch.tensor(self.adata.X[idx].todense()) - dataset = self.adata_name - dataset_num = 0 - return counts, idx, dataset, dataset_num - dataset, ds_idx = self._compute_index(idx) - h5f = self.dataset_file(dataset) - attrs = dict(h5f["X"].attrs) - try: - if attrs["encoding-type"] == "csr_matrix": - indptrs = h5f["/X/indptr"] - start_ptr = indptrs[ds_idx] - end_ptr = indptrs[ds_idx + 1] - sub_data = torch.tensor( - h5f["/X/data"][start_ptr:end_ptr], dtype=torch.float - ) - sub_indices = torch.tensor( - h5f["/X/indices"][start_ptr:end_ptr], dtype=torch.int32 + ### At this point, counts_raw is assumed to be log counts ### + + # store the raw counts here, we need them as targets + original_counts_raw = counts_raw.clone() + + # if we are using downsample augmentation, decide if we need to update counts_raw + num_aug = getattr(self.cfg.model, "num_downsample", 1) + if num_aug > 1: + if downsample_frac is None: + downsample_frac = torch.empty(1).uniform_(0.3, 1.0).item() + + down_umis = int(total_umis * downsample_frac) + if down_umis > 0 and downsample_frac < 1.0: + # build a distribution over raw counts + genes_sampled = torch.multinomial( + count_expr_dist.squeeze(), down_umis, replacement=True ) - counts = torch.sparse_csr_tensor( - [0], - sub_indices, - sub_data, - (1, self.num_genes[dataset]), + # flatten to a 1D gene vector, and get the counts for the newly sampled genes + flat = counts_raw.view(-1) + counts_aug_flat = torch.zeros_like(flat) + counts_aug_flat.scatter_add_( + 0, + genes_sampled, + torch.ones( + down_umis, + dtype=counts_aug_flat.dtype, + ), ) - counts = counts.to_dense() + # restore original shape (1, D) + counts_aug = counts_aug_flat.view_as(counts_raw) + counts_aug = torch.log1p(counts_aug) else: - log.info(ds_idx) - counts = torch.tensor(h5f["X"][ds_idx]).unsqueeze(0) - except Exception as iex: - log.exception(f"Error in dataset {dataset} at index {ds_idx}") - raise iex - dataset_num = self.datasets_to_num[dataset] - return counts, idx, dataset, dataset_num - - def __len__(self) -> int: - return self.total_num_cells + counts_aug = counts_raw - def get_dim(self) -> Dict[str, int]: - return self.num_genes + # if we are using an augmentation, update the raw counts here + counts_raw = counts_aug + # logic to sample a single cell sentence and task sentence here + ds_emb_idxs = torch.tensor( + self.dataset_to_protein_embeddings[dataset], dtype=torch.long + ) -class CellSentenceCollator(object): - def __init__( - self, cfg, valid_gene_mask=None, ds_emb_mapping_inference=None, is_train=True - ): - self.pad_length = cfg.dataset.pad_length - self.P = cfg.dataset.P - self.N = cfg.dataset.N - self.S = cfg.dataset.S - self.cfg = cfg - self.training = is_train - self.use_dataset_info = getattr(cfg.model, "dataset_correction", False) - self.batch_tabular_loss = getattr(cfg.model, "batch_tabular_loss", False) + original_counts = original_counts_raw + counts = counts_raw if valid_gene_mask is not None: - self.valid_gene_mask = valid_gene_mask - self.dataset_to_protein_embeddings = ds_emb_mapping_inference - else: - gene_mask_file = utils.get_embedding_cfg(self.cfg).valid_genes_masks - if gene_mask_file is not None: - self.valid_gene_mask = torch.load(gene_mask_file, weights_only=False) + if ds_emb_idxs.shape[0] == valid_gene_mask.shape[0]: + # Filter the dataset embedding indices based on the valid gene mask + ds_emb_idxs = ds_emb_idxs[valid_gene_mask] else: - self.valid_gene_mask = None - self.dataset_to_protein_embeddings = torch.load( - utils.get_embedding_cfg(self.cfg).ds_emb_mapping.format( - utils.get_embedding_cfg(self.cfg).size - ), - weights_only=False, - ) - self.global_size = utils.get_embedding_cfg(self.cfg).num - self.global_to_local = {} - for dataset_name, ds_emb_idxs in self.dataset_to_protein_embeddings.items(): - ds_emb_idxs = torch.tensor(ds_emb_idxs, dtype=torch.long) - reverse_mapping = torch.full((self.global_size,), -1, dtype=torch.int64) - local_indices = torch.arange(ds_emb_idxs.size(0), dtype=torch.int64) - mask = (ds_emb_idxs >= 0) & (ds_emb_idxs < self.global_size) - reverse_mapping[ds_emb_idxs[mask]] = local_indices[mask] - self.global_to_local[dataset_name] = reverse_mapping - print(len(self.global_to_local)) + # Our preprocessing is such that sometimes the ds emb idxs are already filtered + # in this case we do nothing to (no subsetting) but assert that the mask matches + assert valid_gene_mask.sum() == ds_emb_idxs.shape[0], ( + f"Something wrong with filtering or mask for dataset {dataset}" + ) - def __call__(self, batch): - num_aug = getattr(self.cfg.model, "num_downsample", 1) - if num_aug > 1 and self.training: - batch = [item for item in batch for _ in range(num_aug)] - batch_size = len(batch) - batch_sentences = torch.zeros((batch_size, self.pad_length), dtype=torch.int32) - batch_sentences_counts = torch.zeros((batch_size, self.pad_length)) - masks = torch.zeros((batch_size, self.pad_length), dtype=torch.bool) - idxs = torch.zeros(batch_size, dtype=torch.int32) - if self.cfg.loss.name == "tabular": - Xs = torch.zeros((batch_size, self.pad_length, self.P)) - Ys = torch.zeros((batch_size, self.pad_length, self.N)) - batch_weights = torch.ones((batch_size, self.pad_length)) + # Counts are never filtered in our preprocessing step, so we always need to apply the valid genes mask + if counts_raw.shape[1] == valid_gene_mask.shape[0]: + counts = counts_raw[:, valid_gene_mask] + original_counts = original_counts_raw[:, valid_gene_mask] + + if counts.sum() == 0: + expression_weights = F.softmax(counts, dim=1) else: - Xs = Ys = batch_weights = None - dataset_nums = torch.zeros(batch_size, dtype=torch.int32) - total_counts_all = torch.zeros(batch_size) - for i, (counts, idx, dataset, dataset_num) in enumerate(batch): - batch_sentences[i, : counts.shape[1]] = counts.squeeze() - idxs[i] = idx - dataset_nums[i] = dataset_num - return ( - batch_sentences, - Xs, - Ys, - idxs, - batch_weights, - masks, - total_counts_all if getattr(self.cfg.model, "rda", False) else None, - batch_sentences_counts - if getattr(self.cfg.model, "counts", False) - else None, - dataset_nums if self.use_dataset_info else None, + expression_weights = counts / torch.sum(counts, dim=1, keepdim=True) + + cell_sentences = torch.zeros((counts.shape[0], self.cfg.dataset.pad_length)) + cell_sentence_counts = torch.zeros( + (counts.shape[0], self.cfg.dataset.pad_length) + ) + mask = torch.zeros( + (counts.shape[0], self.cfg.dataset.pad_length), dtype=torch.bool ) + if self.cfg.loss.name == "tabular": + # include capacity for shared genes + task_num = self.cfg.dataset.P + self.cfg.dataset.N + self.cfg.dataset.S + else: + task_num = self.cfg.dataset.P + self.cfg.dataset.N -class FilteredGenesCounts(CellSentenceDataset): - def __init__( - self, - cfg, - test=False, - datasets=None, - shape_dict=None, - adata=None, - adata_name=None, - ) -> None: - super().__init__(cfg, test, datasets, shape_dict, adata, adata_name) - self.valid_gene_index = {} - ( - _, - self.datasets, - self.shapes_dict, - self.dataset_path_map, - self.dataset_group_map, - ) = utils.get_shapes_dict("/home/aadduri/state/h5ad_all.csv") - emb_cfg = utils.get_embedding_cfg(self.cfg) - try: - self.ds_emb_map = torch.load(emb_cfg.ds_emb_mapping, weights_only=False) - except (FileNotFoundError, IOError): - self.ds_emb_map = {} - if adata_name is not None: - self.datasets.append(adata_name) - self.shapes_dict[adata_name] = adata.shape - esm_data = torch.load(emb_cfg.all_embeddings, weights_only=False) - valid_genes_list = list(esm_data.keys()) - global_pos = {g: i for i, g in enumerate(valid_genes_list)} - gene_names = np.array(adata.var_names) - new_mapping = np.array([global_pos.get(g, -1) for g in gene_names]) - if (new_mapping == -1).all(): - gene_names = adata.var["gene_name"].values - new_mapping = np.array([global_pos.get(g, -1) for g in gene_names]) - self.ds_emb_map[adata_name] = new_mapping - if utils.get_embedding_cfg(self.cfg).ds_emb_mapping is not None: - esm_data = torch.load( - utils.get_embedding_cfg(self.cfg)["all_embeddings"], weights_only=False + task_counts = torch.zeros((counts.shape[0], task_num)) + task_sentence = torch.zeros((counts.shape[0], task_num)) + + if self.cfg.model.rda: + cell_total_counts = torch.zeros((counts.shape[0],)) + else: + cell_total_counts = None + + # len(counts) = 1, e.g., we are looping over [cell] + for c, cell in enumerate(counts): + num_pos_genes = torch.sum(cell > 0) + # this is either the number of positive genes, or the first pad_length / 2 most expressed genes + # the first is only used if you have more expressed genes than pad_length / 2 + assert self.cfg.model.counts + # shuffle before argsort - randomly break ties so we select random unexpressed genes each time, if pad_length > num_non_zero genes + indices = torch.randperm(cell.shape[-1]) + shuffled_cell = cell[indices] + shuffled_genes_ranked_exp = torch.argsort(shuffled_cell, descending=True) + genes_ranked_exp = indices[shuffled_genes_ranked_exp] + cell_sentences[c, 0] = self.cfg.dataset.cls_token_idx + if len(genes_ranked_exp) >= self.cfg.dataset.pad_length - 1: + cell_sentences[c, 1:] = genes_ranked_exp[ + : self.cfg.dataset.pad_length - 1 + ] + else: + # take the nonzero genes first + num_nonzero = min(num_pos_genes, self.cfg.dataset.pad_length - 1) + cell_sentences[c, 1 : num_nonzero + 1] = genes_ranked_exp[:num_nonzero] + + # sample the unexpressed genes with replacement + remaining_slots = self.cfg.dataset.pad_length - 1 - num_nonzero + unexpressed_genes = genes_ranked_exp[num_nonzero:] + cell_sentences[c, num_nonzero + 1 :] = unexpressed_genes[ + torch.randint(len(unexpressed_genes), (remaining_slots,)) + ] + + cell_sentence_counts[c, :] = ( + 100 * expression_weights[c, cell_sentences[c, :].to(torch.int32)] ) - valid_genes_list = list(esm_data.keys()) - for name in self.datasets: - if not utils.is_valid_uuid(name): - if adata is None: - a = self.dataset_file(name) - try: - gene_names = np.array( - [g.decode("utf-8") for g in a["/var/gene_name"][:]] - ) - except: - gene_categories = a["/var/gene_name/categories"][:] - gene_codes = np.array(a["/var/gene_name/codes"][:]) - gene_names = np.array( - [g.decode("utf-8") for g in gene_categories[gene_codes]] - ) - valid_mask = np.isin(gene_names, valid_genes_list) - self.valid_gene_index[name] = valid_mask - else: - gene_names = np.array(adata.var_names) - valid_mask = np.isin(gene_names, valid_genes_list) - if not valid_mask.any(): - gene_names = adata.var["gene_name"].values - valid_mask = np.isin(gene_names, valid_genes_list) - self.valid_gene_index[name] = valid_mask + + # Convert tokens to Embeddings - local to global + # this also includes the cls token, but we will override it later with a learnable torch vector + cell_sentences[c, :] = ds_emb_idxs[cell_sentences[c, :].to(torch.int32)] + + # pick P expressed genes to mask for MLM + exp_genes = torch.where(cell > 0)[0] + if len(exp_genes) > self.cfg.dataset.P: + task_sentence[c, : self.cfg.dataset.P] = exp_genes[ + torch.randperm(len(exp_genes))[0 : self.cfg.dataset.P] + ] + elif len(exp_genes) > 0: + task_sentence[c, : self.cfg.dataset.P] = exp_genes[ + torch.randint(len(exp_genes), (self.cfg.dataset.P,)) + ] + + # get the total number of genes unique to this cell; everything + # past this are shared genes across all cells in a batch, used for tabular loss + unshared_num = self.cfg.dataset.P + self.cfg.dataset.N + + unexp_genes = torch.where(cell < 1)[0] + if len(unexp_genes) > self.cfg.dataset.N: + task_sentence[c, self.cfg.dataset.P : unshared_num] = unexp_genes[ + torch.randperm(len(unexp_genes))[0 : self.cfg.dataset.N] + ] + else: + task_sentence[c, self.cfg.dataset.P : unshared_num] = unexp_genes[ + torch.randint(len(unexp_genes), (self.cfg.dataset.N,)) + ] + + # set counts for unshared genes + task_idxs = task_sentence[c, :unshared_num].to(torch.int32) + task_counts[c, :unshared_num] = original_counts[c, task_idxs] + + # convert from dataset specific gene indices to global gene indices + # only do this for everything up to shared genes, which are already global indices + task_sentence[c, :unshared_num] = ds_emb_idxs[ + task_sentence[c, :unshared_num].to(torch.int32) + ] + + # now take care of shared genes across all cells in the batch + if shared_genes is not None: + # Overwrite the final positions of task_sentence + + task_sentence[c, unshared_num:] = ( + shared_genes # in the old impl these are global gene indices + ) + # task_sentence[c, unshared_num:] = ds_emb_idxs[shared_genes.to(torch.int32)] # in the new impl these are local gene indices + + # convert the shared_genes, which are global indices, to the dataset specific indices + local_indices = self.global_to_local[dataset][shared_genes].to( + cell.device + ) # in the old impl these are global gene indices + # local_indices = shared_genes # in the new impl these are local gene indices + + shared_counts = torch.zeros( + local_indices.shape, dtype=cell.dtype, device=cell.device + ) + valid_mask = local_indices != -1 + if valid_mask.any(): + shared_counts[valid_mask] = original_counts_raw[ + c, local_indices[valid_mask] + ] + + # for indices which are -1, count is 0, else index into cell + task_counts[c, unshared_num:] = shared_counts + + assert self.cfg.model.rda + # sum the counts of the task sentence + cell_total_counts[c] = torch.sum(task_counts[c]) + + if self.cfg.loss.name == "cross_entropy": + # binarize the counts to 0/1 + task_counts[c] = (task_counts[c] > 0).float() + + # mask out the task genes from the cell sentence + task_gene_set = torch.tensor( + task_sentence[c].tolist(), dtype=cell_sentences.dtype + ) + potential_mask = torch.isin(cell_sentences[c], task_gene_set) + + # Calculate target number of masked tokens + target_mask_count = int(self.cfg.task.mask * self.cfg.dataset.pad_length) + current_mask_count = potential_mask.sum().item() + + if current_mask_count > target_mask_count: + # Too many tokens are being masked - randomly select subset + # Only consider indices after the CLS token (index 0) + mask_indices = ( + torch.where(potential_mask[1:])[0] + 1 + ) # +1 to adjust for offset + keep_indices = torch.randperm(len(mask_indices))[:target_mask_count] + selected_indices = mask_indices[keep_indices] + + # Create new mask with only the selected indices, ensuring CLS is not masked + final_mask = torch.zeros_like(potential_mask) + final_mask[selected_indices] = True + mask[c] = final_mask + elif current_mask_count < target_mask_count: + # Not enough tokens masked - we need to mask additional tokens + non_masked = ~potential_mask + + # Exclude the CLS token (index 0) by only considering indices 1 and up + non_masked_indices = ( + torch.where(non_masked[1:])[0] + 1 + ) # +1 to adjust for offset + + # Calculate how many more tokens to mask + additional_needed = target_mask_count - current_mask_count + additional_needed = min(additional_needed, len(non_masked_indices)) + + if len(non_masked_indices) > 0 and additional_needed > 0: + additional_indices = non_masked_indices[ + torch.randperm(len(non_masked_indices))[:additional_needed] + ] + potential_mask[additional_indices] = True + + mask[c] = potential_mask + else: + # Exactly self.cfg.task.mask percent are masked, use the potential mask as is + mask[c] = potential_mask + + # make sure that the CLS token is never masked out. + mask[c, 0] = False + + return ( + cell_sentences, + task_sentence, + task_counts, + counts, + mask, + cell_total_counts if self.cfg.model.rda else None, + cell_sentence_counts if self.cfg.model.counts else None, + ) From 27c95bc2ff242e4c88ff3cf203ee5e13e4caa631 Mon Sep 17 00:00:00 2001 From: Sravya Tirukkovalur Date: Fri, 6 Jun 2025 13:05:52 -0700 Subject: [PATCH 5/7] linter fixes --- src/cell_load/config.py | 2 +- src/cell_load/data_modules/cell_sentence_dataloader.py | 1 - src/cell_load/data_modules/perturbation_dataloader.py | 3 +-- src/cell_load/utils/data_utils.py | 2 +- 4 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/cell_load/config.py b/src/cell_load/config.py index 0dacda2..7610c00 100644 --- a/src/cell_load/config.py +++ b/src/cell_load/config.py @@ -2,7 +2,7 @@ import logging from dataclasses import dataclass from pathlib import Path -from typing import Dict, List, Optional, Set +from typing import Dict, List, Set import toml diff --git a/src/cell_load/data_modules/cell_sentence_dataloader.py b/src/cell_load/data_modules/cell_sentence_dataloader.py index b675022..4eab5b6 100644 --- a/src/cell_load/data_modules/cell_sentence_dataloader.py +++ b/src/cell_load/data_modules/cell_sentence_dataloader.py @@ -1,4 +1,3 @@ -import torch from torch.utils.data import DataLoader from cell_load.dataset.cell_sentence_dataset import FilteredGenesCounts from cell_load.dataset.cell_sentence_dataset import CellSentenceCollator diff --git a/src/cell_load/data_modules/perturbation_dataloader.py b/src/cell_load/data_modules/perturbation_dataloader.py index 0c05c1c..db49eb3 100644 --- a/src/cell_load/data_modules/perturbation_dataloader.py +++ b/src/cell_load/data_modules/perturbation_dataloader.py @@ -1,7 +1,6 @@ import logging -from collections import defaultdict from pathlib import Path -from typing import Dict, List, Literal, Optional, Set, Tuple +from typing import Dict, List, Literal, Optional, Set import h5py import numpy as np diff --git a/src/cell_load/utils/data_utils.py b/src/cell_load/utils/data_utils.py index 2ca3678..596b7f1 100644 --- a/src/cell_load/utils/data_utils.py +++ b/src/cell_load/utils/data_utils.py @@ -7,7 +7,7 @@ import torch import scipy.sparse as sp -from typing import List, Optional +from typing import Optional from .singleton import Singleton log = logging.getLogger(__name__) From dd0a612a2b522d0e50f7bb7659df357e8748c11a Mon Sep 17 00:00:00 2001 From: Sravya Tirukkovalur Date: Fri, 6 Jun 2025 16:40:31 -0700 Subject: [PATCH 6/7] Making paths and thresholds configurable --- src/cell_load/config.py | 10 +++++++ .../dataset/cell_sentence_dataset.py | 27 +++++++++++-------- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/src/cell_load/config.py b/src/cell_load/config.py index 7610c00..ec263f7 100644 --- a/src/cell_load/config.py +++ b/src/cell_load/config.py @@ -25,6 +25,13 @@ class ExperimentConfig: # Fewshot perturbation assignments (dataset.celltype -> {split: [perts]}) fewshot: Dict[str, Dict[str, List[str]]] + # Path to h5ad CSV summary file for gene/protein embedding mapping + h5ad_csv_path: str = "" + + # Thresholds for raw count heuristics + RAW_COUNT_HEURISTIC_THRESHOLD: int = 1000 + EXPONENTIATED_UMIS_LIMIT: int = 1000000 + @classmethod def from_toml(cls, toml_path: str) -> "ExperimentConfig": """Load configuration from TOML file.""" @@ -36,6 +43,9 @@ def from_toml(cls, toml_path: str) -> "ExperimentConfig": training=config.get("training", {}), zeroshot=config.get("zeroshot", {}), fewshot=config.get("fewshot", {}), + h5ad_csv_path=config.get("h5ad_csv_path", ""), + RAW_COUNT_HEURISTIC_THRESHOLD=config.get("RAW_COUNT_HEURISTIC_THRESHOLD", 1000), + EXPONENTIATED_UMIS_LIMIT=config.get("EXPONENTIATED_UMIS_LIMIT", 1000000), ) def get_all_datasets(self) -> Set[str]: diff --git a/src/cell_load/dataset/cell_sentence_dataset.py b/src/cell_load/dataset/cell_sentence_dataset.py index f6638d4..c2a8fec 100644 --- a/src/cell_load/dataset/cell_sentence_dataset.py +++ b/src/cell_load/dataset/cell_sentence_dataset.py @@ -3,13 +3,15 @@ Dataset classes and utilities for handling single-cell gene expression data in sentence form. -This module provides PyTorch Dataset classes and related utilities for loading, filtering, and collating single-cell gene expression data and integrating gene/protein embeddings. +This module provides PyTorch Dataset classes and related utilities for loading, filtering, and collating +single-cell gene expression data and integrating gene/protein embeddings. """ import h5py import logging import torch import torch.utils.data as data +import torch.nn.functional as F import functools import numpy as np from typing import Dict @@ -180,7 +182,7 @@ def __init__( self.shapes_dict, self.dataset_path_map, self.dataset_group_map, - ) = utils.get_shapes_dict("/home/aadduri/state/h5ad_all.csv") + ) = utils.get_shapes_dict(getattr(cfg, 'h5ad_csv_path', '')) emb_cfg = utils.get_embedding_cfg(self.cfg) try: @@ -227,7 +229,7 @@ def __init__( gene_names = np.array( [g.decode("utf-8") for g in a["/var/gene_name"][:]] ) # Decode byte strings - except: + except KeyError: gene_categories = a["/var/gene_name/categories"][:] gene_codes = np.array(a["/var/gene_name/codes"][:]) gene_names = np.array( @@ -305,7 +307,7 @@ def __init__( for dataset_name, ds_emb_idxs in self.dataset_to_protein_embeddings.items(): # make sure tensor with long data type ds_emb_idxs = torch.tensor(ds_emb_idxs, dtype=torch.long) - # assert ds_emb_idxs.unique().numel() == ds_emb_idxs.numel(), f"duplicate global IDs in dataset {dataset_name}!" + # Create a tensor filled with -1 (indicating not present in this dataset) reverse_mapping = torch.full((self.global_size,), -1, dtype=torch.int64) @@ -482,12 +484,14 @@ def is_raw_integer_counts(self, counts: torch.Tensor) -> bool: max_val = torch.max(counts).item() # Primary heuristic: very large individual counts => raw counts - if max_val > RAW_COUNT_HEURISTIC_THRESHOLD: + threshold = getattr(self.cfg, 'RAW_COUNT_HEURISTIC_THRESHOLD', 35) + if max_val > threshold: return True # Ambiguous case: try undoing log1p total_umis = int(torch.expm1(counts).sum().item()) - if total_umis > EXPONENTIATED_UMIS_LIMIT: + umi_limit = getattr(self.cfg, 'EXPONENTIATED_UMIS_LIMIT', 5_000_000) + if total_umis > umi_limit: return True return False @@ -508,7 +512,7 @@ def sample_cell_sentences( if torch.any(counts_raw < 0): counts_raw = F.relu(counts_raw) - if self.is_raw_integer_counts(counts_raw): # CAN WE CHANGE THIS TO INT VS REAL + if self.is_raw_integer_counts(counts_raw): total_umis = int(counts_raw.sum(axis=1).item()) count_expr_dist = counts_raw / counts_raw.sum(axis=1, keepdim=True) counts_raw = torch.log1p(counts_raw) @@ -566,7 +570,7 @@ def sample_cell_sentences( # Filter the dataset embedding indices based on the valid gene mask ds_emb_idxs = ds_emb_idxs[valid_gene_mask] else: - # Our preprocessing is such that sometimes the ds emb idxs are already filtered + # our preprocessing is such that sometimes the ds emb idxs are already filtered # in this case we do nothing to (no subsetting) but assert that the mask matches assert valid_gene_mask.sum() == ds_emb_idxs.shape[0], ( f"Something wrong with filtering or mask for dataset {dataset}" @@ -610,7 +614,8 @@ def sample_cell_sentences( # this is either the number of positive genes, or the first pad_length / 2 most expressed genes # the first is only used if you have more expressed genes than pad_length / 2 assert self.cfg.model.counts - # shuffle before argsort - randomly break ties so we select random unexpressed genes each time, if pad_length > num_non_zero genes + # shuffle before argsort - randomly break ties so we select random unexpressed genes each time. + indices = torch.randperm(cell.shape[-1]) shuffled_cell = cell[indices] shuffled_genes_ranked_exp = torch.argsort(shuffled_cell, descending=True) @@ -682,13 +687,13 @@ def sample_cell_sentences( task_sentence[c, unshared_num:] = ( shared_genes # in the old impl these are global gene indices ) - # task_sentence[c, unshared_num:] = ds_emb_idxs[shared_genes.to(torch.int32)] # in the new impl these are local gene indices + # convert the shared_genes, which are global indices, to the dataset specific indices local_indices = self.global_to_local[dataset][shared_genes].to( cell.device ) # in the old impl these are global gene indices - # local_indices = shared_genes # in the new impl these are local gene indices + shared_counts = torch.zeros( local_indices.shape, dtype=cell.dtype, device=cell.device From b1abc52763af59168c30438c0b465b6800ab01d8 Mon Sep 17 00:00:00 2001 From: Sravya Tirukkovalur Date: Fri, 6 Jun 2025 16:50:13 -0700 Subject: [PATCH 7/7] minor fixed --- src/cell_load/config.py | 8 +++++--- src/cell_load/dataset/cell_sentence_dataset.py | 11 ++++------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/cell_load/config.py b/src/cell_load/config.py index 1832092..2c361fd 100644 --- a/src/cell_load/config.py +++ b/src/cell_load/config.py @@ -29,8 +29,8 @@ class ExperimentConfig: h5ad_csv_path: str = "" # Thresholds for raw count heuristics - RAW_COUNT_HEURISTIC_THRESHOLD: int = 1000 - EXPONENTIATED_UMIS_LIMIT: int = 1000000 + RAW_COUNT_HEURISTIC_THRESHOLD: int = 35 + EXPONENTIATED_UMIS_LIMIT: int = 5_000_000 @classmethod def from_toml(cls, toml_path: str) -> "ExperimentConfig": @@ -44,7 +44,9 @@ def from_toml(cls, toml_path: str) -> "ExperimentConfig": zeroshot=config.get("zeroshot", {}), fewshot=config.get("fewshot", {}), h5ad_csv_path=config.get("h5ad_csv_path", ""), - RAW_COUNT_HEURISTIC_THRESHOLD=config.get("RAW_COUNT_HEURISTIC_THRESHOLD", 1000), + RAW_COUNT_HEURISTIC_THRESHOLD=config.get( + "RAW_COUNT_HEURISTIC_THRESHOLD", 1000 + ), EXPONENTIATED_UMIS_LIMIT=config.get("EXPONENTIATED_UMIS_LIMIT", 1000000), ) diff --git a/src/cell_load/dataset/cell_sentence_dataset.py b/src/cell_load/dataset/cell_sentence_dataset.py index c2a8fec..2d7606a 100644 --- a/src/cell_load/dataset/cell_sentence_dataset.py +++ b/src/cell_load/dataset/cell_sentence_dataset.py @@ -182,7 +182,7 @@ def __init__( self.shapes_dict, self.dataset_path_map, self.dataset_group_map, - ) = utils.get_shapes_dict(getattr(cfg, 'h5ad_csv_path', '')) + ) = utils.get_shapes_dict(getattr(cfg, "h5ad_csv_path", "")) emb_cfg = utils.get_embedding_cfg(self.cfg) try: @@ -307,7 +307,6 @@ def __init__( for dataset_name, ds_emb_idxs in self.dataset_to_protein_embeddings.items(): # make sure tensor with long data type ds_emb_idxs = torch.tensor(ds_emb_idxs, dtype=torch.long) - # Create a tensor filled with -1 (indicating not present in this dataset) reverse_mapping = torch.full((self.global_size,), -1, dtype=torch.int64) @@ -484,13 +483,13 @@ def is_raw_integer_counts(self, counts: torch.Tensor) -> bool: max_val = torch.max(counts).item() # Primary heuristic: very large individual counts => raw counts - threshold = getattr(self.cfg, 'RAW_COUNT_HEURISTIC_THRESHOLD', 35) + threshold = getattr(self.cfg, "RAW_COUNT_HEURISTIC_THRESHOLD", 35) if max_val > threshold: return True # Ambiguous case: try undoing log1p total_umis = int(torch.expm1(counts).sum().item()) - umi_limit = getattr(self.cfg, 'EXPONENTIATED_UMIS_LIMIT', 5_000_000) + umi_limit = getattr(self.cfg, "EXPONENTIATED_UMIS_LIMIT", 5_000_000) if total_umis > umi_limit: return True @@ -512,7 +511,7 @@ def sample_cell_sentences( if torch.any(counts_raw < 0): counts_raw = F.relu(counts_raw) - if self.is_raw_integer_counts(counts_raw): + if self.is_raw_integer_counts(counts_raw): total_umis = int(counts_raw.sum(axis=1).item()) count_expr_dist = counts_raw / counts_raw.sum(axis=1, keepdim=True) counts_raw = torch.log1p(counts_raw) @@ -688,13 +687,11 @@ def sample_cell_sentences( shared_genes # in the old impl these are global gene indices ) - # convert the shared_genes, which are global indices, to the dataset specific indices local_indices = self.global_to_local[dataset][shared_genes].to( cell.device ) # in the old impl these are global gene indices - shared_counts = torch.zeros( local_indices.shape, dtype=cell.dtype, device=cell.device )