diff --git a/src/cell_load/config.py b/src/cell_load/config.py index 58e7f5d..2c361fd 100644 --- a/src/cell_load/config.py +++ b/src/cell_load/config.py @@ -25,6 +25,13 @@ class ExperimentConfig: # Fewshot perturbation assignments (dataset.celltype -> {split: [perts]}) fewshot: dict[str, dict[str, list[str]]] + # Path to h5ad CSV summary file for gene/protein embedding mapping + h5ad_csv_path: str = "" + + # Thresholds for raw count heuristics + RAW_COUNT_HEURISTIC_THRESHOLD: int = 35 + EXPONENTIATED_UMIS_LIMIT: int = 5_000_000 + @classmethod def from_toml(cls, toml_path: str) -> "ExperimentConfig": """Load configuration from TOML file.""" @@ -36,6 +43,11 @@ def from_toml(cls, toml_path: str) -> "ExperimentConfig": training=config.get("training", {}), zeroshot=config.get("zeroshot", {}), fewshot=config.get("fewshot", {}), + h5ad_csv_path=config.get("h5ad_csv_path", ""), + RAW_COUNT_HEURISTIC_THRESHOLD=config.get( + "RAW_COUNT_HEURISTIC_THRESHOLD", 1000 + ), + EXPONENTIATED_UMIS_LIMIT=config.get("EXPONENTIATED_UMIS_LIMIT", 1000000), ) def get_all_datasets(self) -> Set[str]: diff --git a/src/cell_load/data_modules/cell_sentence_dataloader.py b/src/cell_load/data_modules/cell_sentence_dataloader.py new file mode 100644 index 0000000..4eab5b6 --- /dev/null +++ b/src/cell_load/data_modules/cell_sentence_dataloader.py @@ -0,0 +1,59 @@ +from torch.utils.data import DataLoader +from cell_load.dataset.cell_sentence_dataset import FilteredGenesCounts +from cell_load.dataset.cell_sentence_dataset import CellSentenceCollator + + +def create_dataloader( + cfg, + workers=1, + data_dir=None, + datasets=None, + shape_dict=None, + adata=None, + adata_name=None, + shuffle=False, + sentence_collator=None, +): + """ + Expected to be used for inference + Either datasets and shape_dict or adata and adata_name should be provided + """ + if datasets is None and adata is None: + raise ValueError( + "Either datasets and shape_dict or adata and adata_name should be provided" + ) + + if adata is not None: + shuffle = False + + if data_dir: + cfg.model.data_dir = data_dir + # ? utils.get_dataset_cfg(cfg).data_dir = data_dir + + dataset = FilteredGenesCounts( + cfg, + datasets=datasets, + shape_dict=shape_dict, + adata=adata, + adata_name=adata_name, + ) + if sentence_collator is None: + sentence_collator = CellSentenceCollator( + cfg, + valid_gene_mask=dataset.valid_gene_index, + ds_emb_mapping_inference=dataset.ds_emb_map, + is_train=False, + ) + + # validation should not use cell augmentations + sentence_collator.training = False + + dataloader = DataLoader( + dataset, + batch_size=cfg.model.batch_size, + shuffle=shuffle, + collate_fn=sentence_collator, + num_workers=workers, + persistent_workers=True, + ) + return dataloader diff --git a/src/cell_load/dataset/cell_sentence_dataset.py b/src/cell_load/dataset/cell_sentence_dataset.py new file mode 100644 index 0000000..2d7606a --- /dev/null +++ b/src/cell_load/dataset/cell_sentence_dataset.py @@ -0,0 +1,773 @@ +""" +cell_sentence_dataset.py + +Dataset classes and utilities for handling single-cell gene expression data in sentence form. + +This module provides PyTorch Dataset classes and related utilities for loading, filtering, and collating +single-cell gene expression data and integrating gene/protein embeddings. +""" + +import h5py +import logging +import torch +import torch.utils.data as data +import torch.nn.functional as F +import functools +import numpy as np +from typing import Dict +from .. import utils + +log = logging.getLogger(__file__) + + +class CellSentenceDataset(data.Dataset): + """ + PyTorch Dataset for single-cell gene expression data, supporting multiple datasets and formats. + + Args: + cfg: Configuration object. + test (bool): If True, uses validation data. + datasets (list, optional): List of dataset names/IDs. + shape_dict (dict, optional): Mapping from dataset name to shape (num_cells, num_genes). + adata (AnnData, optional): AnnData object for in-memory data. + adata_name (str, optional): Name/ID for the AnnData object. + """ + + def __init__( + self, + cfg, + test=False, + datasets=None, + shape_dict=None, + adata=None, + adata_name=None, + ) -> None: + super().__init__() + self.adata = None + self.adata_name = adata_name + self.test = test + if adata is not None: + self.adata = adata + self.datasets = [adata_name] + self.shapes_dict = {self.datasets[0]: adata.shape} + elif datasets is None: + ds_path = utils.get_dataset_cfg(cfg).train + if test: + ds_path = utils.get_dataset_cfg(cfg).val + ( + _, + self.datasets, + self.shapes_dict, + self.dataset_path_map, + self.dataset_group_map, + ) = utils.get_shapes_dict( + ds_path, utils.get_dataset_cfg(cfg).get("filter_by_species") + ) + else: + assert shape_dict is not None + assert len(datasets) == len(shape_dict) + self.datasets = datasets + self.shapes_dict = shape_dict + self.dataset_path_map = {dataset: dataset for dataset in datasets} + + self.datasets = sorted(self.datasets) + self.cfg = cfg + + self.num_cells = {} + self.num_genes = {} + + self.total_num_cells = 0 + for name in self.datasets: + num_cells, num_genes = self.shapes_dict[name] + self.num_cells[name] = num_cells + self.num_genes[name] = num_genes + + self.total_num_cells += num_cells + + self.datasets_to_num = { + k: v for k, v in zip(self.datasets, range(len(self.datasets))) + } + + def _compute_index(self, idx): + for dataset in self.datasets: + if idx < self.num_cells[dataset]: + return dataset, idx + else: + idx -= self.num_cells[dataset] + raise IndexError + + @functools.lru_cache + def dataset_file(self, dataset): + datafile = self.dataset_path_map[dataset] + return h5py.File(datafile, "r") + + def __getitem__(self, idx): + if self.adata is not None: + # block is only used during validation + # if .X is a numpy.ndarray + if isinstance(self.adata.X, np.ndarray): + counts = torch.tensor(self.adata.X[idx]).reshape(1, -1) + else: + counts = torch.tensor(self.adata.X[idx].todense()) + + dataset = self.adata_name + dataset_num = 0 + return counts, idx, dataset, dataset_num + + dataset, ds_idx = self._compute_index(idx) + h5f = self.dataset_file(dataset) + attrs = dict(h5f["X"].attrs) + try: + if attrs["encoding-type"] == "csr_matrix": + indptrs = h5f["/X/indptr"] + start_ptr = indptrs[ds_idx] + end_ptr = indptrs[ds_idx + 1] + sub_data = torch.tensor( + h5f["/X/data"][start_ptr:end_ptr], dtype=torch.float + ) + sub_indices = torch.tensor( + h5f["/X/indices"][start_ptr:end_ptr], dtype=torch.int32 + ) + + counts = torch.sparse_csr_tensor( + [ + 0, + ], + sub_indices, + sub_data, + (1, self.num_genes[dataset]), + ) + counts = counts.to_dense() + else: + log.info(ds_idx) + counts = torch.tensor(h5f["X"][ds_idx]).unsqueeze(0) + + except Exception as iex: + log.exception(f"Error in dataset {dataset} at index {ds_idx}") + raise iex + + dataset_num = self.datasets_to_num[dataset] + return counts, idx, dataset, dataset_num + + def __len__(self) -> int: + return self.total_num_cells + + def get_dim(self) -> Dict[str, int]: + return self.num_genes + + +class FilteredGenesCounts(CellSentenceDataset): + """ + Dataset class that filters genes based on available protein/gene embeddings. + + Extends CellSentenceDataset to provide valid gene indices and embedding mappings for each dataset. + """ + + def __init__( + self, + cfg, + test=False, + datasets=None, + shape_dict=None, + adata=None, + adata_name=None, + ) -> None: + super().__init__(cfg, test, datasets, shape_dict, adata, adata_name) + self.valid_gene_index = {} + + # make sure we get training datasets + ( + _, + self.datasets, + self.shapes_dict, + self.dataset_path_map, + self.dataset_group_map, + ) = utils.get_shapes_dict(getattr(cfg, "h5ad_csv_path", "")) + + emb_cfg = utils.get_embedding_cfg(self.cfg) + try: + self.ds_emb_map = torch.load(emb_cfg.ds_emb_mapping, weights_only=False) + except (FileNotFoundError, IOError): + self.ds_emb_map = {} + + # for inference, let's make sure this dataset's valid mask is available + if adata_name is not None: + # append it to self.datasets + self.datasets.append(adata_name) + self.shapes_dict[adata_name] = adata.shape + + # compute its embedding‐index vector + esm_data = torch.load(emb_cfg.all_embeddings, weights_only=False) + valid_genes_list = list(esm_data.keys()) + # make a gene→global‐index lookup + global_pos = {g: i for i, g in enumerate(valid_genes_list)} + + # grab var_names from the AnnData + gene_names = np.array(adata.var_names) + + # for each gene in this dataset, find its global idx or -1 if missing + new_mapping = np.array([global_pos.get(g, -1) for g in gene_names]) + if (new_mapping == -1).all(): + # probably it contains ensembl id's instead + gene_names = adata.var["gene_name"].values + new_mapping = np.array([global_pos.get(g, -1) for g in gene_names]) + + self.ds_emb_map[adata_name] = new_mapping + + if utils.get_embedding_cfg(self.cfg).ds_emb_mapping is not None: + esm_data = torch.load( + utils.get_embedding_cfg(self.cfg)["all_embeddings"], weights_only=False + ) + valid_genes_list = list(esm_data.keys()) + for name in self.datasets: + if not utils.is_valid_uuid( + name + ): # had to add this in for now as cellxgene h5ad fles don't have gene_name object but tahoe does + if adata is None: + a = self.dataset_file(name) + try: + gene_names = np.array( + [g.decode("utf-8") for g in a["/var/gene_name"][:]] + ) # Decode byte strings + except KeyError: + gene_categories = a["/var/gene_name/categories"][:] + gene_codes = np.array(a["/var/gene_name/codes"][:]) + gene_names = np.array( + [g.decode("utf-8") for g in gene_categories[gene_codes]] + ) + valid_mask = np.isin(gene_names, valid_genes_list) + self.valid_gene_index[name] = valid_mask + else: + gene_names = np.array(adata.var_names) + valid_mask = np.isin(gene_names, valid_genes_list) + + if not valid_mask.any(): + # none of the genes were valid, probably ensembl id's + gene_names = adata.var["gene_name"].values + valid_mask = np.isin(gene_names, valid_genes_list) + + self.valid_gene_index[name] = valid_mask + + def __getitem__(self, idx): + counts, idx, dataset, dataset_num = super().__getitem__(idx) + return counts, idx, dataset, dataset_num + + +class CellSentenceCollator(object): + """ + Collate function for batching single-cell gene expression data with gene/protein embedding support. + """ + + def __init__( + self, cfg, valid_gene_mask=None, ds_emb_mapping_inference=None, is_train=True + ): + """ + Initialize the CellSentenceCollator. + + Args: + cfg: Configuration object. + valid_gene_mask (dict, optional): Mapping from dataset name to valid gene mask arrays. + ds_emb_mapping_inference (dict, optional): Dataset-to-embedding index mapping. + is_train (bool): Whether collator is for training or inference. + """ + self.pad_length = cfg.dataset.pad_length + self.P = cfg.dataset.P + self.N = cfg.dataset.N + self.S = cfg.dataset.S + self.cfg = cfg + self.training = is_train + + # Load the dataset mappings + self.use_dataset_info = getattr(cfg.model, "dataset_correction", False) + self.batch_tabular_loss = getattr(cfg.model, "batch_tabular_loss", False) + + if valid_gene_mask is not None: + # this branch is for inference + self.valid_gene_mask = valid_gene_mask + self.dataset_to_protein_embeddings = ds_emb_mapping_inference + else: + # otherwise for training, load from config + gene_mask_file = utils.get_embedding_cfg(self.cfg).valid_genes_masks + if gene_mask_file is not None: + # we have a config for training + self.valid_gene_mask = torch.load(gene_mask_file, weights_only=False) + else: + # we don't have a config for training + self.valid_gene_mask = None + + self.dataset_to_protein_embeddings = torch.load( + utils.get_embedding_cfg(self.cfg).ds_emb_mapping.format( + utils.get_embedding_cfg(self.cfg).size + ), + weights_only=False, + ) + + self.global_size = utils.get_embedding_cfg(self.cfg).num + self.global_to_local = {} + for dataset_name, ds_emb_idxs in self.dataset_to_protein_embeddings.items(): + # make sure tensor with long data type + ds_emb_idxs = torch.tensor(ds_emb_idxs, dtype=torch.long) + + # Create a tensor filled with -1 (indicating not present in this dataset) + reverse_mapping = torch.full((self.global_size,), -1, dtype=torch.int64) + + local_indices = torch.arange(ds_emb_idxs.size(0), dtype=torch.int64) + mask = (ds_emb_idxs >= 0) & (ds_emb_idxs < self.global_size) + reverse_mapping[ds_emb_idxs[mask]] = local_indices[mask] + self.global_to_local[dataset_name] = reverse_mapping + + print(len(self.global_to_local)) + + def __call__(self, batch): + num_aug = getattr(self.cfg.model, "num_downsample", 1) + if num_aug > 1 and self.training: + # for each original sample, duplicate it num_aug times + batch = [item for item in batch for _ in range(num_aug)] + + batch_size = len(batch) + + batch_sentences = torch.zeros((batch_size, self.pad_length), dtype=torch.int32) + batch_sentences_counts = torch.zeros((batch_size, self.pad_length)) + masks = torch.zeros((batch_size, self.pad_length), dtype=torch.bool) + + idxs = torch.zeros(batch_size, dtype=torch.int32) + if self.cfg.loss.name == "tabular": + task_num = self.P + self.N + self.S + else: + task_num = self.P + self.N + Xs = torch.zeros((batch_size, (task_num)), dtype=torch.int32) + Ys = torch.zeros((batch_size, (task_num))) + + largest_cnt = max([x[0].shape[1] for x in batch]) + batch_weights = torch.zeros((batch_size, largest_cnt)) + + total_counts_all = None + if self.cfg.model.rda: + total_counts_all = torch.zeros(batch_size) + + datasets = [] + for ( + _, + _, + ds_name, + _, + ) in batch: + datasets.append(ds_name) + + if self.cfg.loss.name == "tabular": + if "batch_tabular_loss" in self.__dict__ and self.batch_tabular_loss: + # Find genes shared across all datasets + shared_mask = None + for dataset in datasets: + dataset_mask = self.global_to_local[dataset] >= 0 + if shared_mask is None: + shared_mask = dataset_mask + else: + shared_mask &= dataset_mask + + # Get indices of shared genes + shared_indices = torch.where(shared_mask)[0] + + # Repeat shared genes to reach size S + n_shared = shared_indices.size(0) + if n_shared > 0: + # Calculate how many times to repeat and remainder + repeats = self.S // n_shared + remainder = self.S % n_shared + + # Repeat the full sequence + shared_genes = shared_indices.repeat(repeats) + + # Add remaining genes needed + if remainder > 0: + shared_genes = torch.cat( + [shared_genes, shared_indices[:remainder]] + ) + else: + # If no shared genes, sample randomly from global gene space + shared_genes = torch.randint( + low=0, + high=self.global_size, + size=(self.S,), + device=masks.device, + dtype=torch.long, + ) + else: + if "global_size" not in self.__dict__: + self.global_size = utils.get_embedding_cfg(self.cfg).num + shared_genes = torch.randint( + low=0, + high=self.global_size, + size=(self.S,), + device=masks.device, + dtype=torch.long, + ) + else: + shared_genes = None + + dataset_nums = torch.zeros(batch_size, dtype=torch.int32) + + i = 0 + max_len = 0 + for counts, idx, dataset, dataset_num in batch: + if self.valid_gene_mask is not None: + if dataset in self.valid_gene_mask: + valid_mask = self.valid_gene_mask[dataset] + else: + valid_mask = None + else: + valid_mask = None + + # compute downsample fraction. this is the first sample of the augmentation then + # use no downsampling + downsample_fraction = ( + 1.0 if (num_aug > 1 and i % num_aug == 0 and self.training) else None + ) + ( + bs, + xx, + yy, + batch_weight, + mask, + cell_total_counts, + cell_sentence_counts, + ) = self.sample_cell_sentences( + counts, dataset, shared_genes, valid_mask, downsample_fraction + ) + + batch_sentences[i, :] = bs + masks[i, :] = mask + batch_weight = batch_weight.squeeze() + batch_weights[i, : len(batch_weight)] = batch_weight + + max_len = max(max_len, self.cfg.dataset.pad_length) + idxs[i] = idx + + Xs[i] = xx # [pn_idx] + Ys[i] = yy.squeeze() # [pn_idx] + dataset_nums[i] = dataset_num + + if self.cfg.model.rda and cell_total_counts is not None: + total_counts_all[i] = cell_total_counts[0] + if self.cfg.model.counts and cell_sentence_counts is not None: + batch_sentences_counts[i, :] = cell_sentence_counts + i += 1 + + return ( + batch_sentences[:, :max_len], + Xs, + Ys, + idxs, + batch_weights, + masks, + total_counts_all if self.cfg.model.rda else None, + batch_sentences_counts if self.cfg.model.counts else None, + dataset_nums if self.use_dataset_info else None, + ) + + def softmax(self, x): + e_x = np.exp(x - np.max(x)) + return e_x / e_x.sum() + + def is_raw_integer_counts(self, counts: torch.Tensor) -> bool: + """ + Heuristic check to decide whether `counts` are raw integer UMI counts + versus log1p-transformed counts. + + 1. If any entry > RAW_COUNT_HEURISTIC_THRESHOLD, assume raw ints. + 2. Otherwise, invert log1p (via expm1) and sum: + - If the total UMIs exceeds EXPONENTIATED_UMIS_LIMIT, it means + the data were actually raw ints that we mistakenly log-transformed. + - Otherwise, assume the data were correctly log1p counts. + """ + max_val = torch.max(counts).item() + + # Primary heuristic: very large individual counts => raw counts + threshold = getattr(self.cfg, "RAW_COUNT_HEURISTIC_THRESHOLD", 35) + if max_val > threshold: + return True + + # Ambiguous case: try undoing log1p + total_umis = int(torch.expm1(counts).sum().item()) + umi_limit = getattr(self.cfg, "EXPONENTIATED_UMIS_LIMIT", 5_000_000) + if total_umis > umi_limit: + return True + + return False + + # sampling a single cell sentence + # counts_raw is a view of a cell + def sample_cell_sentences( + self, + counts_raw, + dataset, + shared_genes=None, + valid_gene_mask=None, + downsample_frac=None, + ): + if torch.isnan(counts_raw).any(): + log.error(f"NaN values in counts for dataset {dataset}") + + if torch.any(counts_raw < 0): + counts_raw = F.relu(counts_raw) + + if self.is_raw_integer_counts(counts_raw): + total_umis = int(counts_raw.sum(axis=1).item()) + count_expr_dist = counts_raw / counts_raw.sum(axis=1, keepdim=True) + counts_raw = torch.log1p(counts_raw) + else: # counts are already log1p + exp_log_counts = torch.expm1(counts_raw) + total_umis = int(exp_log_counts.sum(axis=1).item()) + count_expr_dist = exp_log_counts / exp_log_counts.sum(axis=1, keepdim=True) + + ### At this point, counts_raw is assumed to be log counts ### + + # store the raw counts here, we need them as targets + original_counts_raw = counts_raw.clone() + + # if we are using downsample augmentation, decide if we need to update counts_raw + num_aug = getattr(self.cfg.model, "num_downsample", 1) + if num_aug > 1: + if downsample_frac is None: + downsample_frac = torch.empty(1).uniform_(0.3, 1.0).item() + + down_umis = int(total_umis * downsample_frac) + if down_umis > 0 and downsample_frac < 1.0: + # build a distribution over raw counts + genes_sampled = torch.multinomial( + count_expr_dist.squeeze(), down_umis, replacement=True + ) + # flatten to a 1D gene vector, and get the counts for the newly sampled genes + flat = counts_raw.view(-1) + counts_aug_flat = torch.zeros_like(flat) + counts_aug_flat.scatter_add_( + 0, + genes_sampled, + torch.ones( + down_umis, + dtype=counts_aug_flat.dtype, + ), + ) + # restore original shape (1, D) + counts_aug = counts_aug_flat.view_as(counts_raw) + counts_aug = torch.log1p(counts_aug) + else: + counts_aug = counts_raw + + # if we are using an augmentation, update the raw counts here + counts_raw = counts_aug + + # logic to sample a single cell sentence and task sentence here + ds_emb_idxs = torch.tensor( + self.dataset_to_protein_embeddings[dataset], dtype=torch.long + ) + + original_counts = original_counts_raw + counts = counts_raw + if valid_gene_mask is not None: + if ds_emb_idxs.shape[0] == valid_gene_mask.shape[0]: + # Filter the dataset embedding indices based on the valid gene mask + ds_emb_idxs = ds_emb_idxs[valid_gene_mask] + else: + # our preprocessing is such that sometimes the ds emb idxs are already filtered + # in this case we do nothing to (no subsetting) but assert that the mask matches + assert valid_gene_mask.sum() == ds_emb_idxs.shape[0], ( + f"Something wrong with filtering or mask for dataset {dataset}" + ) + + # Counts are never filtered in our preprocessing step, so we always need to apply the valid genes mask + if counts_raw.shape[1] == valid_gene_mask.shape[0]: + counts = counts_raw[:, valid_gene_mask] + original_counts = original_counts_raw[:, valid_gene_mask] + + if counts.sum() == 0: + expression_weights = F.softmax(counts, dim=1) + else: + expression_weights = counts / torch.sum(counts, dim=1, keepdim=True) + + cell_sentences = torch.zeros((counts.shape[0], self.cfg.dataset.pad_length)) + cell_sentence_counts = torch.zeros( + (counts.shape[0], self.cfg.dataset.pad_length) + ) + mask = torch.zeros( + (counts.shape[0], self.cfg.dataset.pad_length), dtype=torch.bool + ) + + if self.cfg.loss.name == "tabular": + # include capacity for shared genes + task_num = self.cfg.dataset.P + self.cfg.dataset.N + self.cfg.dataset.S + else: + task_num = self.cfg.dataset.P + self.cfg.dataset.N + + task_counts = torch.zeros((counts.shape[0], task_num)) + task_sentence = torch.zeros((counts.shape[0], task_num)) + + if self.cfg.model.rda: + cell_total_counts = torch.zeros((counts.shape[0],)) + else: + cell_total_counts = None + + # len(counts) = 1, e.g., we are looping over [cell] + for c, cell in enumerate(counts): + num_pos_genes = torch.sum(cell > 0) + # this is either the number of positive genes, or the first pad_length / 2 most expressed genes + # the first is only used if you have more expressed genes than pad_length / 2 + assert self.cfg.model.counts + # shuffle before argsort - randomly break ties so we select random unexpressed genes each time. + + indices = torch.randperm(cell.shape[-1]) + shuffled_cell = cell[indices] + shuffled_genes_ranked_exp = torch.argsort(shuffled_cell, descending=True) + genes_ranked_exp = indices[shuffled_genes_ranked_exp] + cell_sentences[c, 0] = self.cfg.dataset.cls_token_idx + if len(genes_ranked_exp) >= self.cfg.dataset.pad_length - 1: + cell_sentences[c, 1:] = genes_ranked_exp[ + : self.cfg.dataset.pad_length - 1 + ] + else: + # take the nonzero genes first + num_nonzero = min(num_pos_genes, self.cfg.dataset.pad_length - 1) + cell_sentences[c, 1 : num_nonzero + 1] = genes_ranked_exp[:num_nonzero] + + # sample the unexpressed genes with replacement + remaining_slots = self.cfg.dataset.pad_length - 1 - num_nonzero + unexpressed_genes = genes_ranked_exp[num_nonzero:] + cell_sentences[c, num_nonzero + 1 :] = unexpressed_genes[ + torch.randint(len(unexpressed_genes), (remaining_slots,)) + ] + + cell_sentence_counts[c, :] = ( + 100 * expression_weights[c, cell_sentences[c, :].to(torch.int32)] + ) + + # Convert tokens to Embeddings - local to global + # this also includes the cls token, but we will override it later with a learnable torch vector + cell_sentences[c, :] = ds_emb_idxs[cell_sentences[c, :].to(torch.int32)] + + # pick P expressed genes to mask for MLM + exp_genes = torch.where(cell > 0)[0] + if len(exp_genes) > self.cfg.dataset.P: + task_sentence[c, : self.cfg.dataset.P] = exp_genes[ + torch.randperm(len(exp_genes))[0 : self.cfg.dataset.P] + ] + elif len(exp_genes) > 0: + task_sentence[c, : self.cfg.dataset.P] = exp_genes[ + torch.randint(len(exp_genes), (self.cfg.dataset.P,)) + ] + + # get the total number of genes unique to this cell; everything + # past this are shared genes across all cells in a batch, used for tabular loss + unshared_num = self.cfg.dataset.P + self.cfg.dataset.N + + unexp_genes = torch.where(cell < 1)[0] + if len(unexp_genes) > self.cfg.dataset.N: + task_sentence[c, self.cfg.dataset.P : unshared_num] = unexp_genes[ + torch.randperm(len(unexp_genes))[0 : self.cfg.dataset.N] + ] + else: + task_sentence[c, self.cfg.dataset.P : unshared_num] = unexp_genes[ + torch.randint(len(unexp_genes), (self.cfg.dataset.N,)) + ] + + # set counts for unshared genes + task_idxs = task_sentence[c, :unshared_num].to(torch.int32) + task_counts[c, :unshared_num] = original_counts[c, task_idxs] + + # convert from dataset specific gene indices to global gene indices + # only do this for everything up to shared genes, which are already global indices + task_sentence[c, :unshared_num] = ds_emb_idxs[ + task_sentence[c, :unshared_num].to(torch.int32) + ] + + # now take care of shared genes across all cells in the batch + if shared_genes is not None: + # Overwrite the final positions of task_sentence + + task_sentence[c, unshared_num:] = ( + shared_genes # in the old impl these are global gene indices + ) + + # convert the shared_genes, which are global indices, to the dataset specific indices + local_indices = self.global_to_local[dataset][shared_genes].to( + cell.device + ) # in the old impl these are global gene indices + + shared_counts = torch.zeros( + local_indices.shape, dtype=cell.dtype, device=cell.device + ) + valid_mask = local_indices != -1 + if valid_mask.any(): + shared_counts[valid_mask] = original_counts_raw[ + c, local_indices[valid_mask] + ] + + # for indices which are -1, count is 0, else index into cell + task_counts[c, unshared_num:] = shared_counts + + assert self.cfg.model.rda + # sum the counts of the task sentence + cell_total_counts[c] = torch.sum(task_counts[c]) + + if self.cfg.loss.name == "cross_entropy": + # binarize the counts to 0/1 + task_counts[c] = (task_counts[c] > 0).float() + + # mask out the task genes from the cell sentence + task_gene_set = torch.tensor( + task_sentence[c].tolist(), dtype=cell_sentences.dtype + ) + potential_mask = torch.isin(cell_sentences[c], task_gene_set) + + # Calculate target number of masked tokens + target_mask_count = int(self.cfg.task.mask * self.cfg.dataset.pad_length) + current_mask_count = potential_mask.sum().item() + + if current_mask_count > target_mask_count: + # Too many tokens are being masked - randomly select subset + # Only consider indices after the CLS token (index 0) + mask_indices = ( + torch.where(potential_mask[1:])[0] + 1 + ) # +1 to adjust for offset + keep_indices = torch.randperm(len(mask_indices))[:target_mask_count] + selected_indices = mask_indices[keep_indices] + + # Create new mask with only the selected indices, ensuring CLS is not masked + final_mask = torch.zeros_like(potential_mask) + final_mask[selected_indices] = True + mask[c] = final_mask + elif current_mask_count < target_mask_count: + # Not enough tokens masked - we need to mask additional tokens + non_masked = ~potential_mask + + # Exclude the CLS token (index 0) by only considering indices 1 and up + non_masked_indices = ( + torch.where(non_masked[1:])[0] + 1 + ) # +1 to adjust for offset + + # Calculate how many more tokens to mask + additional_needed = target_mask_count - current_mask_count + additional_needed = min(additional_needed, len(non_masked_indices)) + + if len(non_masked_indices) > 0 and additional_needed > 0: + additional_indices = non_masked_indices[ + torch.randperm(len(non_masked_indices))[:additional_needed] + ] + potential_mask[additional_indices] = True + + mask[c] = potential_mask + else: + # Exactly self.cfg.task.mask percent are masked, use the potential mask as is + mask[c] = potential_mask + + # make sure that the CLS token is never masked out. + mask[c, 0] = False + + return ( + cell_sentences, + task_sentence, + task_counts, + counts, + mask, + cell_total_counts if self.cfg.model.rda else None, + cell_sentence_counts if self.cfg.model.counts else None, + ) diff --git a/tests/test_cell_sentence_dataset.py b/tests/test_cell_sentence_dataset.py new file mode 100644 index 0000000..6019fcb --- /dev/null +++ b/tests/test_cell_sentence_dataset.py @@ -0,0 +1,35 @@ +import pytest +import numpy as np +import torch +from cell_load.dataset.cell_sentence_dataset import CellSentenceDataset + + +class DummyAdata: + def __init__(self, shape=(3, 5)): + self.X = np.arange(np.prod(shape)).reshape(shape) + self.var_names = [f"gene{i}" for i in range(shape[1])] + self.shape = shape + self.var = {"gene_name": np.array(self.var_names)} + + +@pytest.fixture +def dummy_adata(): + return DummyAdata() + + +def test_cell_sentence_dataset_basic(dummy_adata): + cfg = type("cfg", (), {})() + cfg.model = type("model", (), {})() + cfg.model.batch_size = 2 + cfg.dataset = type("dataset", (), {})() + cfg.dataset.pad_length = 5 + cfg.dataset.P = 1 + cfg.dataset.N = 1 + cfg.dataset.S = 1 + dataset = CellSentenceDataset(cfg, adata=dummy_adata, adata_name="dummy") + assert len(dataset) == dummy_adata.shape[0] + counts, idx, dataset_name, dataset_num = dataset[0] + assert isinstance(counts, torch.Tensor) + assert counts.shape[1] == dummy_adata.shape[1] + assert dataset_name == "dummy" + assert dataset_num == 0