diff --git a/cell2location/cell_comm/around_target.py b/cell2location/cell_comm/around_target.py index f99ec4ab..6778ba35 100755 --- a/cell2location/cell_comm/around_target.py +++ b/cell2location/cell_comm/around_target.py @@ -118,7 +118,7 @@ def compute_weighted_average_around_target( source_cell_type_data, index=adata.obs_names, columns=source_names, - ) + ).astype("float32") # get normalising quantile values source_normalisation_quantile = source_cell_type_data.quantile(normalisation_quantile, axis=0) # compute average abundance above this quantile @@ -151,7 +151,7 @@ def compute_weighted_average_around_target( # to account for locations with no neighbours within a bin (sum == 0) data_[np.isnan(data_)] = 0 # complete the average for a given sample - source_cell_type_data.loc[sample_ind, :] = data_ + source_cell_type_data.loc[sample_ind, :] = data_.astype("float32") # normalise data by normalising quantile (global value across distance bins) source_cell_type_data = source_cell_type_data / source_normalisation_quantile # account for cases of undetected signal @@ -183,8 +183,9 @@ def compute_weighted_average_around_target( weighted_avg_ = pd.Series(weighted_avg_, name=ct, index=source_names) - # hack to make self interactions less apparent - weighted_avg_[ct] = weighted_avg_[~weighted_avg_.index.isin([ct])].max() + 0.02 + if genes_to_use_as_source is None: + # hack to make self interactions less apparent + weighted_avg_[ct] = (weighted_avg_[~weighted_avg_.index.isin([ct])].max() + 0.02).astype("float32") # complete the results dataframe weighted_avg.loc[f"target {ct}", :] = weighted_avg_ diff --git a/cell2location/cluster_averages/cluster_averages.py b/cell2location/cluster_averages/cluster_averages.py index 2f1e5dba..ad77a952 100644 --- a/cell2location/cluster_averages/cluster_averages.py +++ b/cell2location/cluster_averages/cluster_averages.py @@ -3,7 +3,13 @@ from scipy.sparse import csr_matrix -def compute_cluster_averages(adata, labels, use_raw=True, layer=None): +def compute_cluster_averages( + adata, + labels, + use_raw=True, + layer=None, + use_dask=False, +): """ Compute average expression of each gene in each cluster @@ -44,7 +50,17 @@ def compute_cluster_averages(adata, labels, use_raw=True, layer=None): averages_mat = np.zeros((1, x.shape[1])) for c in all_clusters: - sparse_subset = csr_matrix(x[np.isin(adata.obs[labels], c), :]) + if use_dask: + from dask import config + from dask.array.core import Array + + with config.set(**{"array.slicing.split_large_chunks": False}): + cur_data = x[np.isin(adata.obs[labels], c), :] + if isinstance(cur_data, Array): + cur_data = cur_data.compute() + else: + cur_data = x[np.isin(adata.obs[labels], c), :] + sparse_subset = csr_matrix(cur_data) aver = sparse_subset.mean(0) averages_mat = np.concatenate((averages_mat, aver)) averages_mat = averages_mat[1:, :].T diff --git a/cell2location/dataloaders/__init__.py b/cell2location/dataloaders/__init__.py new file mode 100755 index 00000000..e69de29b diff --git a/cell2location/dataloaders/_defined_grid_dataloader.py b/cell2location/dataloaders/_defined_grid_dataloader.py new file mode 100755 index 00000000..d4dc1889 --- /dev/null +++ b/cell2location/dataloaders/_defined_grid_dataloader.py @@ -0,0 +1,610 @@ +# import geopandas as gpd +import math +from copy import copy + +# from shapely.geometry import Polygon +# from geopandas import GeoDataFrame +from typing import Iterator, Optional, TypeVar, Union + +import lightning.pytorch as pl +import numpy as np +import pandas as pd +import scvi +import torch +import torch.distributed as dist +from scipy.sparse import csc_matrix +from scvi.data import AnnDataManager +from scvi.dataloaders import AnnTorchDataset +from scvi.dataloaders._data_splitting import validate_data_split +from torch.utils.data import DataLoader +from torch.utils.data.distributed import Sampler + +__all__ = [ + "DistributedSampler", +] + +T_co = TypeVar("T_co", covariant=True) + + +def assign_tiles_to_locations( + adata, + rows: int, + cols: int, + spatial_key: str = "spatial", + batch_key: str = None, +): + """ + Create a grid of tiles and assign each location to a tile. + + Parameters + ---------- + adata + AnnData object with spatial coordinates in `adata.obsm[spatial_key]`. + rows + Number of rows in the grid. + cols + Number of columns in the grid. + spatial_key + Key in `adata.obsm` where the spatial coordinates are stored. + batch_key + Key in `adata.obs` where the batch information is stored. Tiles are created for each batch separately. + + Returns + + ------- + AnnData object with a new column in `adata.obs` called "tile" that contains the tile index for each location. + The object is sorted by tile index. + """ + if batch_key is None: + adata.obs["batch"] = "0" + batch_key = "batch" + adata.obs[f"{spatial_key}_x"] = adata.obsm[spatial_key][:, 0] + adata.obs[f"{spatial_key}_y"] = adata.obsm[spatial_key][:, 1] + if "tiles" in adata.obs.columns: + adata.obs["tiles"] = "" + for batch in adata.obs[batch_key].unique(): + adata_batch = adata[adata.obs[batch_key] == batch, :].copy() + x_start_positions = np.arange( + np.min(adata_batch.obsm[spatial_key][:, 0]), np.max(adata_batch.obsm[spatial_key][:, 0]), step=rows + ) + y_start_positions = np.arange( + np.min(adata_batch.obsm[spatial_key][:, 1]), np.max(adata_batch.obsm[spatial_key][:, 1]), step=cols + ) + ind_x = np.digitize(adata_batch.obsm[spatial_key][:, 0], x_start_positions) + ind_y = np.digitize(adata_batch.obsm[spatial_key][:, 1], y_start_positions) + adata.obs.loc[adata.obs[batch_key] == batch, "tiles"] = ( + adata.obs[batch_key].astype(str) + + pd.Series("_", index=adata.obs_names).astype(str) + + pd.Series(ind_x, index=adata.obs_names).astype(str) + + pd.Series("_", index=adata.obs_names).astype(str) + + pd.Series(ind_y, index=adata.obs_names).astype(str) + ) + sorting_index = adata.obs.sort_values(by=["tiles", f"{spatial_key}_x", f"{spatial_key}_y"]).index + adata = adata[sorting_index, :].copy() + adata.obsm["tiles"] = csc_matrix(pd.get_dummies(adata.obs["tiles"], sparse=True).values.astype("uint32")) + adata.uns["tiles_names"] = np.array(pd.get_dummies(adata.obs["tiles"], sparse=True).columns.values.astype("str")) + return adata + + +def expand_tiles( + adata_vis, + tile_key: str = "leiden", + distance: float = 2000.0, + distance_step: float = 100.0, + threshold: float = 0.001, + overlap: float = 2.0, + distances_key: str = "distances", +): + current_overlap = 0.0 + while current_overlap < overlap: + from scipy.sparse import csr_matrix + + distances = adata_vis.obsp[distances_key].copy() + distances.data[distances.data >= distance] = 0 + expanded = distances.astype("float32") @ csr_matrix(pd.get_dummies(adata_vis.obs[tile_key]).values).astype( + "float32" + ) + expanded = pd.DataFrame( + expanded.toarray(), + index=adata_vis.obs_names, + columns=pd.get_dummies(adata_vis.obs[tile_key]).columns, + ) + expanded = expanded > threshold + if current_overlap == expanded.sum(1).mean(): + break + current_overlap = expanded.sum(1).mean() + distance = distance + distance_step + return expanded, pd.get_dummies(adata_vis.obs[tile_key]) + + +class SpatialGridBatchSampler(torch.utils.data.sampler.BatchSampler): + """ + Custom torch Sampler that returns a list of indices of size batch_size. + Parameters + ---------- + indices + list of indices to sample from + batch_size + batch size of each iteration + shuffle + if ``True``, shuffles indices before sampling + drop_last + if int, drops the last batch if its length is less than drop_last. + if drop_last == True, drops last non-full batch. + if drop_last == False, iterate over all batches. + """ + + def __init__( + self, + batch_size: int = 1, + indices: np.ndarray = None, + tiles: csc_matrix = None, + shuffle: bool = True, + drop_last: Union[bool, int] = False, + ): + self.batch_size = batch_size + + self.indices = indices + self.n_obs = len(indices) + + self.tiles = tiles.astype("bool") + self.tiles_index = np.arange(tiles.shape[1]).astype("uint32") + self.n_tiles = tiles.shape[1] + + self.shuffle = shuffle + + # drop last WHAT? + last_batch_len = self.n_tiles % self.batch_size + if (drop_last is True) or (last_batch_len < drop_last): + drop_last_n = last_batch_len + elif (drop_last is False) or (last_batch_len >= drop_last): + drop_last_n = 0 + else: + raise ValueError("Invalid input for drop_last param. Must be bool or int.") + self.drop_last_n = drop_last_n + + def get_tile_batches(self): + """Get batches of tiles. + + Returns + ------- + Iterable over batches of tiles. + + """ + + if self.shuffle is True: + tile_idx = torch.randperm(self.n_tiles).numpy() + else: + tile_idx = torch.arange(self.n_tiles).numpy() + + if self.drop_last_n != 0: + tile_idx = tile_idx[: -self.drop_last_n] + + n_tiles = len(tile_idx) + batch_start_indices = np.arange(0, n_tiles, step=self.batch_size) + tile_batches = np.empty(len(batch_start_indices), dtype=object) + tile_batches[:] = [ + np.array(self.tiles_index[tile_idx[c : c + self.batch_size]], dtype=object) for c in batch_start_indices + ] # n_batches + + return tile_batches + + def get_obs_batches(self, tile_batches): + """ + Get batches of observations. + + Returns + ------- + Iterable over batches of observations. + + """ + obs_batches = np.empty(len(tile_batches), dtype=object) + obs_batches[:] = [ + np.array(self.indices[np.asarray(self.tiles[:, tiles].sum(1)).flatten().astype("bool")], dtype="int64") + for tiles in tile_batches + ] + return obs_batches + + @staticmethod + def apply_independent_multi_gpu_merge(batches, n_gpus): + # merge batches from multiple GPUs + if n_gpus > 0: + new_size = int(np.floor(batches.shape[0] / n_gpus)) + new_batches = np.empty(new_size, dtype=object) + for i in range(0, new_size): + new_batches[i] = np.concatenate(batches[i * n_gpus : i * n_gpus + n_gpus]) + return new_batches + else: + return batches + + def __iter__(self): + tile_batches = self.get_tile_batches() + obs_batches = self.get_obs_batches(tile_batches) + return iter(obs_batches) + + def __len__(self): + from math import ceil + + if self.drop_last_n != 0: + n_batches = self.n_tiles // self.batch_size + else: + n_batches = ceil(self.n_tiles / self.batch_size) + return n_batches + + +class DistributedSampler(Sampler[T_co]): + r"""Sampler that restricts data loading to a subset of the dataset. + + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each + process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a + :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the + original dataset that is exclusive to it. + + .. note:: + Dataset is assumed to be of constant size and that any instance of it always + returns the same elements in the same order. + + Args: + dataset: Dataset used for sampling. + num_replicas (int, optional): Number of processes participating in + distributed training. By default, :attr:`world_size` is retrieved from the + current distributed group. + rank (int, optional): Rank of the current process within :attr:`num_replicas`. + By default, :attr:`rank` is retrieved from the current distributed + group. + shuffle (bool, optional): If ``True``, sampler will shuffle the + indices. Default is ``False`` because shuffling within a batch is irrelevant. + seed (int, optional): random seed used to shuffle the sampler if + :attr:`shuffle=True`. This number should be identical across all + processes in the distributed group. Default: ``0``. + drop_last (bool, optional): if ``True``, then the sampler will drop the + tail of the data to make it evenly divisible across the number of + replicas. If ``False``, the sampler will add extra indices to make + the data evenly divisible across the replicas. Default: ``True``. + + .. warning:: + In distributed mode, calling the :meth:`set_epoch` method at + the beginning of each epoch **before** creating the :class:`DataLoader` iterator + is necessary to make shuffling work properly across multiple epochs. Otherwise, + the same ordering will be always used. + + Example:: + + >>> # xdoctest: +SKIP + >>> sampler = DistributedSampler(dataset) if is_distributed else None + >>> loader = DataLoader(dataset, shuffle=(sampler is None), + ... sampler=sampler) + >>> for epoch in range(start_epoch, n_epochs): + ... if is_distributed: + ... sampler.set_epoch(epoch) + ... train(loader) + """ + + def __init__( + self, + iterable, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = False, + seed: int = 0, + drop_last: bool = True, + ) -> None: + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + if rank >= num_replicas or rank < 0: + raise ValueError( + "Invalid rank {}, rank should be in the interval" " [0, {}]".format(rank, num_replicas - 1) + ) + self.iterable = iterable + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.drop_last = drop_last + # If the dataset length is evenly divisible by # of replicas, then there + # is no need to drop any data, since the dataset will be split equally. + # TODO pick the correct sizes for total cells and genes in the current minibatch + # distributed sampler should only distribute cell indices not gene indices + self.total_size_tiles = len(self.iterable) + if self.drop_last and (self.total_size_tiles % self.num_replicas) != 0: # type: ignore[arg-type] + # Split to the nearest available length that is evenly divisible. + # This is to ensure each rank receives the same amount of data when + # using this Sampler. + self.num_tile_samples = math.ceil( + (self.total_size_tiles - self.num_replicas) / self.num_replicas # type: ignore[arg-type] + ) + else: + self.num_tile_samples = math.ceil(self.total_size_tiles / self.num_replicas) # type: ignore[arg-type] + self.total_tile_size = self.num_tile_samples * self.num_replicas + self.shuffle = shuffle + self.seed = seed + + def __iter__(self) -> Iterator[T_co]: + tile_indices = self.iterable + + if not self.drop_last: + raise NotImplementedError("DistributedSampler with drop_last=False is not implemented yet") + else: + # remove tail of data to make it evenly divisible. + tile_indices = tile_indices[: self.total_tile_size] + assert len(tile_indices) == self.total_tile_size + + # subsample + items_per_gpu = int(self.total_tile_size / self.num_replicas) + tile_indices = tile_indices[self.rank * items_per_gpu : (self.rank + 1) * items_per_gpu] + assert len(tile_indices) == self.num_tile_samples + return iter(tile_indices) + + def __len__(self) -> int: + return self.num_replicas + + def set_epoch(self, epoch: int) -> None: + r""" + Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas + use a different random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch + + +class DistributedBatchSampler(SpatialGridBatchSampler): + """`BatchSampler` wrapper that distributes across each batch multiple workers. Copied from PyTorch NLP. + + Args: + batch_sampler (torch.utils.data.sampler.BatchSampler) + num_replicas (int, optional): Number of processes participating in distributed training. + rank (int, optional): Rank of the current process within num_replicas. + + Example: + >>> from torch.utils.data.sampler import BatchSampler + >>> from torch.utils.data.sampler import SequentialSampler + >>> sampler = SequentialSampler(list(range(12))) + >>> batch_sampler = BatchSampler(sampler, batch_size=4, drop_last=False) + >>> + >>> list(DistributedBatchSampler(batch_sampler, num_replicas=2, rank=0)) + [[0, 2], [4, 6], [8, 10]] + >>> list(DistributedBatchSampler(batch_sampler, num_replicas=2, rank=1)) + [[1, 3], [5, 7], [9, 11]] + """ + + def __init__(self, batch_sampler, **kwargs): + self.batch_sampler = batch_sampler + self.kwargs = kwargs + + def __iter__(self): + for batch in iter(self.batch_sampler.get_tile_batches()): + yield self.batch_sampler.get_obs_batches(list(DistributedSampler(batch, **self.kwargs))) + + def __len__(self): + return int(len(self.batch_sampler) / self.kwargs["num_replicas"]) + + +class SpatialGridAnnDataLoader(DataLoader): + """ + DataLoader for loading tensors from AnnData objects. + Parameters + ---------- + adata_manager + :class:`~scvi.data.AnnDataManager` object with a registered AnnData object. + shuffle + Whether the data should be shuffled + indices + The indices of the observations in the adata to load + batch_size + minibatch size to load each iteration + data_and_attributes + Dictionary with keys representing keys in data registry (``adata_manager.data_registry``) + and value equal to desired numpy loading type (later made into torch tensor). + If ``None``, defaults to all registered data. + data_loader_kwargs + Keyword arguments for :class:`~torch.utils.data.DataLoader` + iter_ndarray + Whether to iterate over numpy arrays instead of torch tensors + """ + + def __init__( + self, + adata_manager: AnnDataManager, + indices: np.ndarray = None, + # tiles: np.ndarray = None, + shuffle: bool = True, + batch_size: int = 1, + data_and_attributes: Optional[dict] = None, + drop_last: Union[bool, int] = False, + iter_ndarray: bool = False, + use_ddp: bool = False, + **data_loader_kwargs, + ): + if adata_manager.adata is None: + raise ValueError("Please run register_fields() on your AnnDataManager object first.") + + if data_and_attributes is not None: + data_registry = adata_manager.data_registry + for key in data_and_attributes.keys(): + if key not in data_registry.keys(): + raise ValueError(f"{key} required for model but not registered with AnnDataManager.") + + self.dataset = AnnTorchDataset( + adata_manager, + getitem_tensors=data_and_attributes, + ) + # print(self.dataset[[[100, 53, 1], [0, 5, 6]]]) + + sampler_kwargs = { + "tiles": adata_manager.get_from_registry("tiles"), + "batch_size": batch_size, + "shuffle": shuffle, + "drop_last": drop_last, + } + + if indices is None: + indices = np.arange(adata_manager.adata.n_obs).astype("int64") + sampler_kwargs["indices"] = indices + else: + if hasattr(indices, "dtype") and indices.dtype is np.dtype("bool"): + indices = np.where(indices)[0].ravel() + indices = np.asarray(indices).astype("int64") + sampler_kwargs["indices"] = indices + + self.sampler_kwargs = sampler_kwargs + sampler = SpatialGridBatchSampler(**self.sampler_kwargs) + if use_ddp: + sampler = DistributedBatchSampler( + sampler, + ) + self.data_loader_kwargs = copy(data_loader_kwargs) + # do not touch batch size here, sampler gives batched indices + self.data_loader_kwargs.update({"sampler": sampler, "batch_size": None}) + + if iter_ndarray: + self.data_loader_kwargs.update({"collate_fn": _dummy_collate}) + + super().__init__(self.dataset, **self.data_loader_kwargs) + + +def _dummy_collate(b): + """Dummy collate to have dataloader return numpy ndarrays.""" + return b + + +class SpatialGridDataSplitter(pl.LightningDataModule): + """ + Creates data loaders ``train_set``, ``validation_set``, ``test_set``. + If ``train_size + validation_set < 1`` then ``test_set`` is non-empty. + Parameters + ---------- + adata_manager + :class:`~scvi.data.AnnDataManager` object that has been created via ``setup_anndata``. + train_size + float, or None (default is 0.9) + validation_size + float, or None (default is None) + use_gpu + Use default GPU if available (if None or True), or index of GPU to use (if int), + or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False). + **kwargs + Keyword args for data loader. If adata has labeled data, data loader + class is :class:`~scvi.dataloaders.SemiSupervisedDataLoader`, + else data loader class is :class:`~scvi.dataloaders.AnnDataLoader`. + Examples + -------- + >>> adata = scvi.data.synthetic_iid() + >>> scvi.model.SCVI.setup_anndata(adata) + >>> adata_manager = scvi.model.SCVI(adata).adata_manager + >>> splitter = DataSplitter(adata) + >>> splitter.setup() + >>> train_dl = splitter.train_dataloader() + """ + + def __init__( + self, + adata_manager: AnnDataManager, + train_size: float = 1.0, + validation_size: Optional[float] = None, + accelerator: str = "auto", + device: Union[int, str] = "auto", + use_ddp: bool = False, + shuffle_training: bool = True, + drop_last: bool = False, + pin_memory: bool = False, + shuffle_set_split: bool = True, + **kwargs, + ): + super().__init__() + self.adata_manager = adata_manager + self.train_size = float(train_size) + self.validation_size = validation_size + self.data_loader_kwargs = kwargs + self.accelerator = accelerator + self.device = device + self.use_ddp = use_ddp + self.shuffle_training = shuffle_training + self.drop_last = drop_last + self.pin_memory = pin_memory + self.shuffle_set_split = shuffle_set_split + + self.n_train_ = dict() + self.n_val_ = dict() + # if self.data_loader_kwargs.get("tiles", None) is None: + # raise ValueError("tiles must be specified in data_loader_kwargs") + # tiles = self.data_loader_kwargs.get("tiles", None) + tiles = self.adata_manager.get_from_registry("tiles") + n_tiles = tiles.shape[1] + self.n_train_["n_tiles"], self.n_val_["n_tiles"] = validate_data_split( + n_tiles, + self.train_size, + self.validation_size, + ) + + def setup(self, stage: Optional[str] = None): + """Split indices in train/test/val sets.""" + n_train = self.n_train_["n_tiles"] + n_val = self.n_val_["n_tiles"] + random_state = np.random.RandomState(seed=scvi.settings.seed) + + tiles = self.adata_manager.get_from_registry("tiles") + # tiles = self.data_loader_kwargs.get("tiles", None) + tiles_index = np.arange(tiles.shape[1]) + n_tiles = tiles.shape[1] + + tile_idx = np.arange(n_tiles) + if self.shuffle_set_split: + tile_idx = random_state.permutation(tile_idx) + + self.tile_idx_train_idx = tiles_index[tile_idx[:n_train]] + self.tile_idx_val_idx = tiles_index[tile_idx[n_train : (n_val + n_train)]] + self.tile_idx_test_idx = tiles_index[tile_idx[(n_val + n_train) :]] + + obs_idx = np.arange(self.adata_manager.adata.n_obs) + + self.val_idx = obs_idx[np.asarray(tiles[:, self.tile_idx_val_idx].sum(1)).ravel().astype("bool")] + self.train_idx = obs_idx[np.asarray(tiles[:, self.tile_idx_train_idx].sum(1)).ravel().astype("bool")] + self.test_idx = obs_idx[np.asarray(tiles[:, self.tile_idx_test_idx].sum(1)).ravel().astype("bool")] + + self.pin_memory = True if (self.pin_memory and self.accelerator == "gpu") else False + + def train_dataloader(self): + return SpatialGridAnnDataLoader( + self.adata_manager, + shuffle=self.shuffle_training, + indices=self.train_idx, + drop_last=self.drop_last, + pin_memory=self.pin_memory, + use_ddp=self.use_ddp, + **self.data_loader_kwargs, + ) + + def val_dataloader(self): + if len(self.val_idx) > 0: + return SpatialGridAnnDataLoader( + self.adata_manager, + indices=self.val_idx, + shuffle=False, + drop_last=self.drop_last, + pin_memory=self.pin_memory, + **self.data_loader_kwargs, + ) + else: + pass + + def test_dataloader(self): + if len(self.test_idx) > 0: + return SpatialGridAnnDataLoader( + self.adata_manager, + indices=self.test_idx, + shuffle=False, + drop_last=self.drop_last, + pin_memory=self.pin_memory, + **self.data_loader_kwargs, + ) + else: + pass diff --git a/cell2location/dataloaders/_square_dataloader.py b/cell2location/dataloaders/_square_dataloader.py new file mode 100755 index 00000000..e69de29b diff --git a/cell2location/distributions/AutoAmortisedNormalMessenger.py b/cell2location/distributions/AutoAmortisedNormalMessenger.py index b359d9e5..b0caea30 100755 --- a/cell2location/distributions/AutoAmortisedNormalMessenger.py +++ b/cell2location/distributions/AutoAmortisedNormalMessenger.py @@ -1,6 +1,8 @@ from copy import deepcopy -from typing import Callable, Literal, Optional, Union +from typing import Callable, Literal, Optional, Tuple, Union +import numpy as np +import pyro import pyro.distributions as dist import torch from pyro.distributions.distribution import Distribution @@ -348,6 +350,21 @@ def _get_params(self, name: str, prior: Distribution): linear_scale_encoder = deep_getattr(self.hidden2scales, f"{name}.encoder") loc = linear_loc(linear_loc_encoder(*x_in)) scale = self.softplus(linear_scale(linear_scale_encoder(*x_in)) + self._init_scale_unconstrained) + # determine parameter dimensions + out_dim = self.amortised_plate_sites["sites"][name] + if isinstance(out_dim, tuple): + from string import ascii_lowercase + + from einops import rearrange + + variables = [ascii_lowercase[i] for i in range(len(out_dim))] + variables_str = " ".join(variables) + loc = rearrange( + loc, f"z ({variables_str}) -> z {variables_str}", **{v: dim for v, dim in zip(variables, out_dim)} + ) + scale = rearrange( + scale, f"z ({variables_str}) -> z {variables_str}", **{v: dim for v, dim in zip(variables, out_dim)} + ) if (self._hierarchical_sites is None) or (name in self._hierarchical_sites): if self.weight_type == "element-wise": # weight is element-wise @@ -359,6 +376,12 @@ def _get_params(self, name: str, prior: Distribution): weight = self.softplus( linear_weight(linear_weight_encoder(hidden)) + self._init_weight_unconstrained ) + if isinstance(out_dim, tuple): + weight = rearrange( + weight, + f"z ({variables_str}) -> z {variables_str}", + **{v: dim for v, dim in zip(variables, out_dim)}, + ) if self.weight_type == "scalar": # weight is a single value parameter weight = deep_getattr(self.weights, name) @@ -385,6 +408,8 @@ def _get_params(self, name: str, prior: Distribution): n_hidden = self.n_hidden["single"] # determine parameter dimensions out_dim = self.amortised_plate_sites["sites"][name] + if isinstance(out_dim, tuple): + out_dim = np.product(out_dim) deep_setattr( self, @@ -566,3 +591,56 @@ def _get_mutual_information(self, name, prior): log_qz = log_sum_exp(log_density, dim=1) - torch.log(x_batch) return (neg_entropy - log_qz.mean(-1)).item() + + +class AutoNormalMessenger(pyro.infer.autoguide.AutoNormalMessenger): + def __init__( + self, + model: Callable, + *, + init_loc_fn: Callable = init_to_mean(fallback=init_to_feasible), + init_scale: float = 0.1, + amortized_plates: Tuple[str, ...] = (), + ): + if not isinstance(init_scale, float) or not (init_scale > 0): + raise ValueError("Expected init_scale > 0. but got {}".format(init_scale)) + super().__init__(model, amortized_plates=amortized_plates) + self.init_loc_fn = init_loc_fn + self._init_scale = init_scale + self._computing_median = False + self._computing_quantiles = False + + def get_posterior( + self, + name: str, + prior: Distribution, + ) -> Union[Distribution, torch.Tensor]: + if self._computing_quantiles: + return self._get_posterior_quantiles(name, prior) + if self._computing_median: + return self._get_posterior_median(name, prior) + + with helpful_support_errors({"name": name, "fn": prior}): + transform = biject_to(prior.support) + loc, scale = self._get_params(name, prior) + posterior = dist.TransformedDistribution( + dist.Normal(loc, scale).to_event(transform.domain.event_dim), + transform.with_cache(), + ) + return posterior + + def quantiles(self, quantiles, *args, **kwargs): + self._computing_quantiles = True + self._quantile_values = quantiles + try: + return self(*args, **kwargs) + finally: + self._computing_quantiles = False + + @torch.no_grad() + def _get_posterior_quantiles(self, name, prior): + transform = biject_to(prior.support) + loc, scale = self._get_params(name, prior) + site_quantiles = torch.tensor(self._quantile_values, dtype=loc.dtype, device=loc.device) + site_quantiles_values = dist.Normal(loc, scale).icdf(site_quantiles) + return transform(site_quantiles_values) diff --git a/cell2location/models/__init__.py b/cell2location/models/__init__.py index 58168bb3..4a5fa7da 100644 --- a/cell2location/models/__init__.py +++ b/cell2location/models/__init__.py @@ -3,11 +3,13 @@ LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGeneAlphaPyroModel, ) from ._cell2location_WTA_model import Cell2location_WTA +from ._cellcomm_model import CellCommModel from .downstream import CoLocatedGroupsSklearnNMF from .reference import RegressionModel __all__ = [ "Cell2location", + "CellCommModel", "RegressionModel", "Cell2location_WTA", "LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGeneAlphaPyroModel", diff --git a/cell2location/models/_cell2location_WTA_model.py b/cell2location/models/_cell2location_WTA_model.py index 77bb9e38..920c1659 100644 --- a/cell2location/models/_cell2location_WTA_model.py +++ b/cell2location/models/_cell2location_WTA_model.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union +from typing import List, Optional import matplotlib.pyplot as plt import numpy as np @@ -18,22 +18,14 @@ NumericalObsField, ObsmField, ) -from scvi.dataloaders import DataSplitter, DeviceBackedDataSplitter from scvi.model.base import BaseModelClass, PyroSampleMixin, PyroSviTrainMixin -from scvi.model.base._pyromixin import PyroJitGuideWarmup -from scvi.train import TrainRunner from scvi.utils import setup_anndata_dsp from cell2location.models._cell2location_WTA_module import ( LocationModelWTAMultiExperimentHierarchicalGeneLevel, ) from cell2location.models.base._pyro_base_loc_module import Cell2locationBaseModule -from cell2location.models.base._pyro_mixin import ( - PltExportMixin, - PyroAggressiveConvergence, - PyroAggressiveTrainingPlan, - QuantileMixin, -) +from cell2location.models.base._pyro_mixin import PltExportMixin, QuantileMixin from cell2location.utils import select_slide @@ -47,8 +39,6 @@ class Cell2location_WTA(QuantileMixin, PyroSampleMixin, PyroSviTrainMixin, PltEx spatial AnnData object that has been registered via :func:`~scvi.data.setup_anndata`. cell_state_df pd.DataFrame with reference expression signatures for each gene (rows) in each cell type/population (columns). - use_gpu - Use the GPU? **model_kwargs Keyword args for :class:`~cell2location.models.LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGeneAlphaPyroModel` @@ -214,95 +204,6 @@ def train( super().train(**kwargs) - def train_aggressive( - self, - max_epochs: Optional[int] = 1000, - use_gpu: Optional[Union[str, int, bool]] = None, - train_size: float = 1, - validation_size: Optional[float] = None, - batch_size: int = None, - early_stopping: bool = False, - lr: Optional[float] = None, - plan_kwargs: Optional[dict] = None, - **trainer_kwargs, - ): - """ - Train the model. - Parameters - ---------- - max_epochs - Number of passes through the dataset. If `None`, defaults to - `np.min([round((20000 / n_cells) * 400), 400])` - use_gpu - Use default GPU if available (if None or True), or index of GPU to use (if int), - or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False). - train_size - Size of training set in the range [0.0, 1.0]. - validation_size - Size of the test set. If `None`, defaults to 1 - `train_size`. If - `train_size + validation_size < 1`, the remaining cells belong to a test set. - batch_size - Minibatch size to use during training. If `None`, no minibatching occurs and all - data is copied to device (e.g., GPU). - early_stopping - Perform early stopping. Additional arguments can be passed in `**kwargs`. - See :class:`~scvi.train.Trainer` for further options. - lr - Optimiser learning rate (default optimiser is :class:`~pyro.optim.ClippedAdam`). - Specifying optimiser via plan_kwargs overrides this choice of lr. - plan_kwargs - Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to - `train()` will overwrite values present in `plan_kwargs`, when appropriate. - **trainer_kwargs - Other keyword args for :class:`~scvi.train.Trainer`. - """ - if max_epochs is None: - n_obs = self.adata_manager.adata.n_obs - max_epochs = np.min([round((20000 / n_obs) * 1000), 1000]) - - plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else dict() - if lr is not None and "optim" not in plan_kwargs.keys(): - plan_kwargs.update({"optim_kwargs": {"lr": lr}}) - - if batch_size is None: - # use data splitter which moves data to GPU once - data_splitter = DeviceBackedDataSplitter( - self.adata_manager, - train_size=train_size, - validation_size=validation_size, - batch_size=batch_size, - use_gpu=use_gpu, - ) - else: - data_splitter = DataSplitter( - self.adata_manager, - train_size=train_size, - validation_size=validation_size, - batch_size=batch_size, - use_gpu=use_gpu, - ) - training_plan = PyroAggressiveTrainingPlan(pyro_module=self.module, **plan_kwargs) - - es = "early_stopping" - trainer_kwargs[es] = early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es] - - if "callbacks" not in trainer_kwargs.keys(): - trainer_kwargs["callbacks"] = [] - trainer_kwargs["callbacks"].append(PyroJitGuideWarmup()) - trainer_kwargs["callbacks"].append(PyroAggressiveConvergence()) - - runner = TrainRunner( - self, - training_plan=training_plan, - data_splitter=data_splitter, - max_epochs=max_epochs, - use_gpu=use_gpu, - **trainer_kwargs, - ) - res = runner() - self.mi_ = self.mi_ + training_plan.mi - return res - def export_posterior( self, adata, diff --git a/cell2location/models/_cell2location_WTA_module.py b/cell2location/models/_cell2location_WTA_module.py index e8db6ef3..d4fb0da7 100644 --- a/cell2location/models/_cell2location_WTA_module.py +++ b/cell2location/models/_cell2location_WTA_module.py @@ -247,7 +247,7 @@ def list_obs_plate_vars(self): def forward(self, x_data, neg_data, n_nuclei, idx, batch_index): self.n_neg_probes = neg_data.shape[1] - obs2sample = one_hot(batch_index, self.n_batch) + obs2sample = one_hot(batch_index, self.n_batch).float() obs_plate = self.create_plates(x_data, neg_data, n_nuclei, idx, batch_index) diff --git a/cell2location/models/_cell2location_model.py b/cell2location/models/_cell2location_model.py index 36afa834..827babb2 100755 --- a/cell2location/models/_cell2location_model.py +++ b/cell2location/models/_cell2location_model.py @@ -1,3 +1,4 @@ +import logging from typing import List, Optional, Union import matplotlib.pyplot as plt @@ -16,13 +17,14 @@ LayerField, NumericalJointObsField, NumericalObsField, + ObsmField, ) from scvi.dataloaders import DeviceBackedDataSplitter from scvi.model.base import BaseModelClass, PyroSampleMixin, PyroSviTrainMixin -from scvi.model.base._pyromixin import PyroJitGuideWarmup -from scvi.train import TrainRunner +from scvi.train import PyroTrainingPlan, TrainRunner from scvi.utils import setup_anndata_dsp +from cell2location.dataloaders._defined_grid_dataloader import SpatialGridDataSplitter from cell2location.models._cell2location_module import ( LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGeneAlphaPyroModel, ) @@ -32,6 +34,7 @@ PyroAggressiveConvergence, PyroAggressiveTrainingPlan, QuantileMixin, + setup_pyro_model, ) from cell2location.utils import select_slide @@ -46,8 +49,6 @@ class Cell2location(QuantileMixin, PyroSampleMixin, PyroSviTrainMixin, PltExport spatial AnnData object that has been registered via :func:`~scvi.data.setup_anndata`. cell_state_df pd.DataFrame with reference expression signatures for each gene (rows) in each cell type/population (columns). - use_gpu - Use the GPU? **model_kwargs Keyword args for :class:`~cell2location.models.LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGeneAlphaPyroModel` @@ -64,6 +65,7 @@ def __init__( model_class: Optional[PyroModule] = None, detection_mean_per_sample: bool = False, detection_mean_correction: float = 1.0, + on_load_batch_size: Optional[int] = None, **model_kwargs, ): # in case any other model was created before that shares the same parameter names. @@ -83,11 +85,20 @@ def __init__( self.n_factors_ = cell_state_df.shape[1] self.factor_names_ = cell_state_df.columns.values + # annotations for extra categorical covariates + if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry: + self.extra_categoricals_ = self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY) + self.n_extra_categoricals_ = self.extra_categoricals_.n_cats_per_key + model_kwargs["n_extra_categoricals"] = self.n_extra_categoricals_ + if not detection_mean_per_sample: # compute expected change in sensitivity (m_g in V1 or y_s in V2) sc_total = cell_state_df.sum(0).mean() sp_total = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY).sum(1).mean() - self.detection_mean_ = (sp_total / model_kwargs.get("N_cells_per_location", 1)) / sc_total + N_cells_per_location = model_kwargs.get("N_cells_per_location", 1.0) + if isinstance(N_cells_per_location, np.ndarray): + N_cells_per_location = N_cells_per_location.mean() + self.detection_mean_ = (sp_total / N_cells_per_location) / sc_total self.detection_mean_ = self.detection_mean_ * detection_mean_correction model_kwargs["detection_mean"] = self.detection_mean_ else: @@ -96,7 +107,10 @@ def __init__( sp_total = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY).sum(1) batch = self.adata_manager.get_from_registry(REGISTRY_KEYS.BATCH_KEY).flatten() sp_total = np.array([sp_total[batch == b].mean() for b in range(self.summary_stats["n_batch"])]) - self.detection_mean_ = (sp_total / model_kwargs.get("N_cells_per_location", 1)) / sc_total + N_cells_per_location = model_kwargs.get("N_cells_per_location", 1.0) + if isinstance(N_cells_per_location, np.ndarray): + N_cells_per_location = N_cells_per_location.mean() + self.detection_mean_ = (sp_total / N_cells_per_location) / sc_total self.detection_mean_ = self.detection_mean_ * detection_mean_correction model_kwargs["detection_mean"] = self.detection_mean_.reshape((self.summary_stats["n_batch"], 1)).astype( "float32" @@ -110,7 +124,14 @@ def __init__( model_kwargs["detection_alpha"] = self.detection_alpha_.values.reshape( (self.summary_stats["n_batch"], 1) ).astype("float32") - + if ( + (model_kwargs.get("amortised_sliding_window_size", 0) > 0) + or (model_kwargs.get("sliding_window_size", 0) > 0) + or ("tiles" in self.adata_manager.data_registry) + ): + on_load_batch_size = 1 + self._data_splitter_cls = SpatialGridDataSplitter + logging.info("Updating data splitter to SpatialGridDataSplitter.") self.module = Cell2locationBaseModule( model=model_class, n_obs=self.summary_stats["n_cells"], @@ -118,6 +139,10 @@ def __init__( n_factors=self.n_factors_, n_batch=self.summary_stats["n_batch"], cell_state_mat=self.cell_state_df_.values.astype("float32"), + on_load_kwargs={ + "batch_size": on_load_batch_size, + "max_epochs": 1, + }, **model_kwargs, ) self._model_summary_string = f'cell2location model with the following params: \nn_factors: {self.n_factors_} \nn_batch: {self.summary_stats["n_batch"]} ' @@ -129,8 +154,14 @@ def setup_anndata( cls, adata: AnnData, layer: Optional[str] = None, + layer_normalised: Optional[str] = None, batch_key: Optional[str] = None, labels_key: Optional[str] = None, + position_key: Optional[str] = None, + tiles_key: Optional[str] = None, + tiles_unexpanded_key: Optional[str] = None, + in_tissue_key: Optional[str] = None, + normalising_factor_y_s_key: Optional[str] = None, categorical_covariate_keys: Optional[List[str]] = None, continuous_covariate_keys: Optional[List[str]] = None, **kwargs, @@ -156,11 +187,145 @@ def setup_anndata( NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), ] + if layer_normalised is not None: + anndata_fields.append(LayerField("x_data_normalised", layer_normalised, is_count_data=False)) + if position_key is not None: + anndata_fields.append(ObsmField("positions", position_key)) + if tiles_key is not None: + anndata_fields.append(ObsmField("tiles", tiles_key)) + if tiles_unexpanded_key is not None: + anndata_fields.append(ObsmField("tiles_unexpanded", tiles_unexpanded_key)) + if in_tissue_key is not None: + anndata_fields.append(NumericalObsField("in_tissue", in_tissue_key)) + if normalising_factor_y_s_key is not None: + anndata_fields.append(NumericalObsField("normalising_factor_y_s", normalising_factor_y_s_key)) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) def train( + self, + max_epochs: int = 30000, + batch_size: int = None, + train_size: float = 1, + lr: float = 0.002, + num_particles: int = 1, + scale_elbo: float = "auto", + accelerator: str = "auto", + device: Union[int, str] = "auto", + validation_size: Optional[float] = None, + shuffle_set_split: bool = True, + early_stopping: bool = False, + training_plan: Optional[PyroTrainingPlan] = None, + plan_kwargs: Optional[dict] = None, + datasplitter_kwargs: Optional[dict] = None, + **trainer_kwargs, + ): + """Train the model. + + Parameters + ---------- + max_epochs + Number of passes through the dataset. If `None`, defaults to + `np.min([round((20000 / n_cells) * 400), 400])` + %(param_use_gpu)s + %(param_accelerator)s + %(param_device)s + train_size + Size of training set in the range [0.0, 1.0]. + validation_size + Size of the test set. If `None`, defaults to 1 - `train_size`. If + `train_size + validation_size < 1`, the remaining cells belong to a test set. + shuffle_set_split + Whether to shuffle indices before splitting. If `False`, the val, train, and test set are split in the + sequential order of the data according to `validation_size` and `train_size` percentages. + batch_size + Minibatch size to use during training. If `None`, no minibatching occurs and all + data is copied to device (e.g., GPU). + early_stopping + Perform early stopping. Additional arguments can be passed in `**kwargs`. + See :class:`~scvi.train.Trainer` for further options. + lr + Optimiser learning rate (default optimiser is :class:`~pyro.optim.ClippedAdam`). + Specifying optimiser via plan_kwargs overrides this choice of lr. + training_plan + Training plan :class:`~scvi.train.PyroTrainingPlan`. + plan_kwargs + Keyword args for :class:`~scvi.train.PyroTrainingPlan`. Keyword arguments passed to + `train()` will overwrite values present in `plan_kwargs`, when appropriate. + **trainer_kwargs + Other keyword args for :class:`~scvi.train.Trainer`. + """ + # if max_epochs is None: + # max_epochs = get_max_epochs_heuristic(self.adata.n_obs, epochs_cap=1000) + if datasplitter_kwargs is None: + datasplitter_kwargs = dict() + + if issubclass(self._data_splitter_cls, SpatialGridDataSplitter): + self.module.model.n_tiles = batch_size + + plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else {} + if lr is not None and "optim" not in plan_kwargs.keys(): + plan_kwargs.update({"optim_kwargs": {"lr": lr}}) + if getattr(self.module.model, "discrete_variables", None) and (len(self.module.model.discrete_variables) > 0): + plan_kwargs["loss_fn"] = TraceEnum_ELBO(num_particles=num_particles) + else: + plan_kwargs["loss_fn"] = Trace_ELBO(num_particles=num_particles) + if scale_elbo != 1.0: + if scale_elbo == "auto": + scale_elbo = 1.0 / (self.summary_stats["n_cells"] * self.summary_stats["n_vars"]) + plan_kwargs["scale_elbo"] = scale_elbo + + if batch_size is None: + # use data splitter which moves data to GPU once + data_splitter = DeviceBackedDataSplitter( + self.adata_manager, + train_size=train_size, + validation_size=validation_size, + batch_size=batch_size, + accelerator=accelerator, + device=device, + ) + else: + data_splitter = self._data_splitter_cls( + self.adata_manager, + train_size=train_size, + validation_size=validation_size, + shuffle_set_split=shuffle_set_split, + batch_size=batch_size, + **datasplitter_kwargs, + ) + + if training_plan is None: + training_plan = self._training_plan_cls(self.module, **plan_kwargs) + + es = "early_stopping" + trainer_kwargs[es] = early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es] + + if "callbacks" not in trainer_kwargs.keys(): + trainer_kwargs["callbacks"] = [] + + # Initialise pyro model with data + from copy import copy + + dl = copy(data_splitter) + dl.setup() + dl = dl.train_dataloader() + setup_pyro_model(dl, training_plan) + + runner = self._train_runner_cls( + self, + training_plan=training_plan, + data_splitter=data_splitter, + max_epochs=max_epochs, + accelerator=accelerator, + devices=device, + **trainer_kwargs, + ) + return runner() + + def train_v1( self, max_epochs: int = 30000, batch_size: int = None, @@ -211,7 +376,6 @@ def train( def train_aggressive( self, max_epochs: Optional[int] = 1000, - use_gpu: Optional[Union[str, int, bool]] = None, accelerator: str = "auto", device: Union[int, str] = "auto", train_size: float = 1, @@ -230,9 +394,6 @@ def train_aggressive( max_epochs Number of passes through the dataset. If `None`, defaults to `np.min([round((20000 / n_cells) * 400), 400])` - use_gpu - Use default GPU if available (if None or True), or index of GPU to use (if int), - or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False). train_size Size of training set in the range [0.0, 1.0]. validation_size @@ -268,7 +429,6 @@ def train_aggressive( train_size=train_size, validation_size=validation_size, batch_size=batch_size, - use_gpu=use_gpu, accelerator=accelerator, device=device, ) @@ -287,15 +447,20 @@ def train_aggressive( if "callbacks" not in trainer_kwargs.keys(): trainer_kwargs["callbacks"] = [] - trainer_kwargs["callbacks"].append(PyroJitGuideWarmup()) trainer_kwargs["callbacks"].append(PyroAggressiveConvergence()) + from copy import copy + + dl = copy(data_splitter) + dl.setup() + dl = dl.train_dataloader() + setup_pyro_model(dl, training_plan) + runner = TrainRunner( self, training_plan=training_plan, data_splitter=data_splitter, max_epochs=max_epochs, - use_gpu=use_gpu, accelerator=accelerator, devices=device, **trainer_kwargs, diff --git a/cell2location/models/_cell2location_module.py b/cell2location/models/_cell2location_module.py index 30c9e994..069eebac 100755 --- a/cell2location/models/_cell2location_module.py +++ b/cell2location/models/_cell2location_module.py @@ -5,14 +5,13 @@ import pyro import pyro.distributions as dist import torch +from einops import rearrange +from pyro.infer.autoguide.utils import deep_getattr, deep_setattr from pyro.nn import PyroModule from scipy.sparse import csr_matrix from scvi import REGISTRY_KEYS from scvi.nn import one_hot -# class NegativeBinomial(TorchDistributionMixin, ScVINegativeBinomial): -# pass - class LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGeneAlphaPyroModel(PyroModule): r""" @@ -53,6 +52,7 @@ class LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGen however, the model is robust to a range of similar values. In settings where suitable histology images are not available, the size of capture regions relative to the expected size of cells can be used to estimate `N_cells_per_location`. + `N_cells_per_location` has to be a scalar or an array of shape (n_obs, 1). The prior on detection efficiency per location :math:`y_s` is selected to discourage over-normalisation, such that unless data has evidence of strong technical effect, the effect is assumed to be small and close to @@ -72,6 +72,18 @@ class LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGen training_wo_observed = False training_wo_initial = False + n_cell_compartments = 3 + named_dims = { + "cell_compartment_w_sfk": -3, + } + n_tiles = 1 + use_concatenated_cnn = False + + n_pathways = 8 + use_pathway_interaction_effect = True + dropout_rate = 0.0 + use_non_negative_weights = False + def __init__( self, n_obs, @@ -79,14 +91,17 @@ def __init__( n_factors, n_batch, cell_state_mat, + n_extra_categoricals: int = None, n_groups: int = 50, detection_mean=1 / 2, detection_alpha=20.0, m_g_gene_level_prior={"mean": 1, "mean_var_ratio": 1.0, "alpha_mean": 3.0}, - N_cells_per_location=8.0, - A_factors_per_location=7.0, - B_groups_per_location=7.0, - N_cells_mean_var_ratio=1.0, + N_cells_per_location: float = 8.0, # float or array + A_factors_per_location: float = 7.0, + B_groups_per_location: float = 7.0, + N_cells_mean_var_ratio: float = 1.0, + N_cells_per_location_alpha_prior: float = None, + A_B_per_location_alpha_prior: float = None, alpha_g_phi_hyp_prior={"alpha": 9.0, "beta": 3.0}, gene_add_alpha_hyp_prior={"alpha": 9.0, "beta": 3.0}, gene_add_mean_hyp_prior={ @@ -94,10 +109,32 @@ def __init__( "beta": 100.0, }, detection_hyp_prior={"mean_alpha": 10.0}, + gene_tech_prior={"mean": 1, "alpha": 200}, + detection_cell_type_prior_alpha=10.0, + use_per_cell_type_normalisation: bool = False, w_sf_mean_var_ratio=5.0, init_vals: Optional[dict] = None, - init_alpha=20.0, - dropout_p=0.0, + init_alpha: float = 20.0, + dropout_p: float = 0.0, + signal_bool: Optional[np.ndarray] = None, + receptor_bool: Optional[np.ndarray] = None, + receptor_bool_b: Optional[np.ndarray] = None, + signal_receptor_mask: Optional[np.ndarray] = None, + receptor_tf_mask: Optional[np.ndarray] = None, + use_learnable_mean_var_ratio: bool = False, + use_independent_prior_on_w_sf: bool = False, + use_proportion_factorisation_prior_on_w_sf: bool = False, + use_n_s_cells_per_location_limit: bool = False, + sliding_window_size: Optional[int] = 0, + amortised_sliding_window_size: Optional[int] = 0, + sliding_window_size_list: Optional[list] = None, + use_normalising_factor_y_s: bool = False, + image_size: Optional[tuple] = None, + use_aggregated_w_sf: bool = False, + use_aggregated_detection_y_s: bool = False, + use_cell_compartments: bool = False, + use_weigted_cnn_weights: bool = False, + n_hidden: int = 256, ): super().__init__() @@ -105,7 +142,9 @@ def __init__( self.n_vars = n_vars self.n_factors = n_factors self.n_batch = n_batch + self.n_extra_categoricals = n_extra_categoricals self.n_groups = n_groups + self.n_hidden = n_hidden self.m_g_gene_level_prior = m_g_gene_level_prior @@ -121,6 +160,35 @@ def __init__( if self.dropout_p is not None: self.dropout = torch.nn.Dropout(p=self.dropout_p) + if signal_bool is not None: + self.register_buffer("signal_bool", torch.tensor(signal_bool.astype("int32"))) + if receptor_bool is not None: + self.register_buffer("receptor_bool", torch.tensor(receptor_bool.astype("int32"))) + if receptor_bool_b is None: + raise ValueError("receptor_bool_b must be provided if receptor_bool is provided") + self.register_buffer("receptor_bool_b", torch.tensor(receptor_bool_b.astype("int32"))) + self.signal_receptor_mask = signal_receptor_mask + self.receptor_tf_mask = receptor_tf_mask + + self.use_learnable_mean_var_ratio = use_learnable_mean_var_ratio + self.use_independent_prior_on_w_sf = use_independent_prior_on_w_sf + self.use_proportion_factorisation_prior_on_w_sf = use_proportion_factorisation_prior_on_w_sf + self.use_n_s_cells_per_location_limit = use_n_s_cells_per_location_limit + self.sliding_window_size = ( + sliding_window_size if sliding_window_size_list is None else max(sliding_window_size_list) + ) + self.amortised_sliding_window_size = amortised_sliding_window_size + self.sliding_window_size_list_exist = sliding_window_size_list is not None + self.sliding_window_size_list = np.array(sliding_window_size_list) + self.use_normalising_factor_y_s = use_normalising_factor_y_s + self.image_size = image_size + self.use_aggregated_w_sf = use_aggregated_w_sf + self.use_aggregated_detection_y_s = use_aggregated_detection_y_s + self.use_cell_compartments = use_cell_compartments + self.use_weigted_cnn_weights = use_weigted_cnn_weights + + self.weights = PyroModule() + if (init_vals is not None) & (type(init_vals) is dict): self.np_init_vals = init_vals for k in init_vals.keys(): @@ -130,6 +198,17 @@ def __init__( factors_per_groups = A_factors_per_location / B_groups_per_location + if n_extra_categoricals is not None: + self.gene_tech_prior = gene_tech_prior + self.register_buffer( + "gene_tech_prior_alpha", + torch.tensor(self.gene_tech_prior["alpha"]), + ) + self.register_buffer( + "gene_tech_prior_beta", + torch.tensor(self.gene_tech_prior["alpha"] / self.gene_tech_prior["mean"]), + ) + self.register_buffer( "detection_hyp_prior_alpha", torch.tensor(self.detection_hyp_prior["alpha"]), @@ -142,6 +221,12 @@ def __init__( "detection_mean_hyp_prior_beta", torch.tensor(self.detection_hyp_prior["mean_alpha"] / self.detection_hyp_prior["mean"]), ) + if use_per_cell_type_normalisation: + self.register_buffer( + "detection_cell_type_prior_alpha", + torch.tensor(detection_cell_type_prior_alpha), + ) + self.use_per_cell_type_normalisation = use_per_cell_type_normalisation # compute hyperparameters from mean and sd self.register_buffer("m_g_mu_hyp", torch.tensor(self.m_g_gene_level_prior["mean"])) @@ -155,10 +240,31 @@ def __init__( self.cell_state_mat = cell_state_mat self.register_buffer("cell_state", torch.tensor(cell_state_mat.T)) + if isinstance(N_cells_per_location, np.ndarray): + assert ( + N_cells_per_location.shape[0] == self.n_obs + ), "N_cells_per_location must have shape (n_obs, 1) or be a scalar" + if isinstance(N_cells_per_location, float) or isinstance(N_cells_per_location, int): + N_cells_per_location = np.array([N_cells_per_location], dtype="float32") self.register_buffer("N_cells_per_location", torch.tensor(N_cells_per_location)) + self.register_buffer("A_factors_per_location", torch.tensor(A_factors_per_location)) self.register_buffer("factors_per_groups", torch.tensor(factors_per_groups)) self.register_buffer("B_groups_per_location", torch.tensor(B_groups_per_location)) - self.register_buffer("N_cells_mean_var_ratio", torch.tensor(N_cells_mean_var_ratio)) + assert (N_cells_per_location_alpha_prior is None) or ( + N_cells_mean_var_ratio is None + ), "N_cells_per_location_alpha_prior and N_cells_mean_var_ratio cannot be provided at the same time" + if N_cells_per_location_alpha_prior is not None: + self.register_buffer( + "N_cells_per_location_alpha_prior", torch.tensor(float(N_cells_per_location_alpha_prior)) + ) + self.N_cells_mean_var_ratio = None + else: + self.register_buffer("N_cells_mean_var_ratio", torch.tensor(float(N_cells_mean_var_ratio))) + self.N_cells_per_location_alpha_prior = None + if A_B_per_location_alpha_prior is not None: + self.register_buffer("A_B_per_location_alpha_prior", torch.tensor(float(A_B_per_location_alpha_prior))) + else: + self.A_B_per_location_alpha_prior = None self.register_buffer( "alpha_g_phi_hyp_prior_alpha", @@ -191,7 +297,10 @@ def __init__( self.register_buffer("n_groups_tensor", torch.tensor(self.n_groups)) self.register_buffer("ones", torch.ones((1, 1))) + self.register_buffer("ones_1d", torch.ones(1)) + self.register_buffer("zeros", torch.zeros((1, 1))) self.register_buffer("ones_1_n_groups", torch.ones((1, self.n_groups))) + self.register_buffer("ones_1_n_factors", torch.ones((1, self.n_factors))) self.register_buffer("ones_n_batch_1", torch.ones((self.n_batch, 1))) self.register_buffer("eps", torch.tensor(1e-8)) @@ -200,11 +309,63 @@ def _get_fn_args_from_batch(tensor_dict): x_data = tensor_dict[REGISTRY_KEYS.X_KEY] ind_x = tensor_dict["ind_x"].long().squeeze() batch_index = tensor_dict[REGISTRY_KEYS.BATCH_KEY] - return (x_data, ind_x, batch_index), {} - - def create_plates(self, x_data, idx, batch_index): + kwargs = {} + if "positions" in tensor_dict.keys(): + kwargs["positions"] = tensor_dict["positions"] + if "tiles" in tensor_dict.keys(): + kwargs["tiles"] = tensor_dict["tiles"] + if "tiles_unexpanded" in tensor_dict.keys(): + kwargs["tiles_unexpanded"] = tensor_dict["tiles_unexpanded"] + if "in_tissue" in tensor_dict.keys(): + kwargs["in_tissue"] = tensor_dict["in_tissue"].bool() + if "normalising_factor_y_s" in tensor_dict.keys(): + kwargs["normalising_factor_y_s"] = tensor_dict["normalising_factor_y_s"] + if "x_data_normalised" in tensor_dict.keys(): + kwargs["x_data_normalised"] = tensor_dict["x_data_normalised"] + if REGISTRY_KEYS.CAT_COVS_KEY in tensor_dict.keys(): + kwargs["extra_categoricals"] = tensor_dict[REGISTRY_KEYS.CAT_COVS_KEY] + return (x_data, ind_x, batch_index), kwargs + + def create_plates( + self, + x_data, + idx, + batch_index, + tiles: torch.Tensor = None, + tiles_unexpanded: torch.Tensor = None, + positions: torch.Tensor = None, + in_tissue: torch.Tensor = None, + normalising_factor_y_s: Optional[torch.Tensor] = None, + x_data_normalised: Optional[torch.Tensor] = None, + extra_categoricals: Optional[torch.Tensor] = None, + ): return pyro.plate("obs_plate", size=self.n_obs, dim=-2, subsample=idx) + def conv2d_aggregate(self, x_data): + x_data_agg = self.aggregate_conv2d( + x_data, + n_tiles=self.n_tiles, + size=max(self.amortised_sliding_window_size, self.sliding_window_size), + padding="same", + ) + x_data = torch.cat([x_data, x_data_agg], dim=-1) + return torch.log1p(x_data) + + def learnable_conv2d(self, x_data): + x_data = torch.log1p(x_data) + x_data_agg = self.learnable_neighbour_effect_conv2d_nn_layers( + x_data, + n_tiles=self.n_tiles, + name="amortised_sliding_window", + size=max(self.amortised_sliding_window_size, self.sliding_window_size), + n_out=self.n_hidden, + padding="same", + ) + # x_data = self.aggregate_conv2d(x_data, padding="same") + if self.use_concatenated_cnn: + x_data_agg = torch.cat([x_data, x_data_agg], dim=-1) + return x_data_agg + def list_obs_plate_vars(self): """ Create a dictionary with: @@ -218,12 +379,33 @@ def list_obs_plate_vars(self): * values - the dimensions in non-plate axis of each variable (used to construct output layer of encoder network when using amortised inference) """ - + input_transform = torch.log1p + n_in = self.n_vars + + if (self.amortised_sliding_window_size > 0) and (self.sliding_window_size == 0): + input_transform = self.learnable_conv2d + if self.use_concatenated_cnn: + n_in = self.n_vars + self.n_hidden + else: + n_in = self.n_hidden + elif (self.amortised_sliding_window_size == 0) and (self.sliding_window_size > 0): + input_transform = self.conv2d_aggregate + n_in = self.n_vars * 2 + elif (self.amortised_sliding_window_size > 0) and (self.sliding_window_size > 0): + input_transform = self.learnable_conv2d + if self.use_concatenated_cnn: + n_in = self.n_vars + self.n_hidden + else: + n_in = self.n_hidden + input = [0, 2] # expression data + (optional) batch index + if self.use_normalising_factor_y_s: + input = ["x_data_normalised", 2] return { "name": "obs_plate", - "input": [0, 2], # expression data + (optional) batch index + "input": input, # expression data + (optional) batch index + "n_in": n_in, "input_transform": [ - torch.log1p, + input_transform, lambda x: x, ], # how to transform input data before passing to NN "input_normalisation": [ @@ -233,112 +415,374 @@ def list_obs_plate_vars(self): "sites": { "n_s_cells_per_location": 1, "b_s_groups_per_location": 1, + "a_s_factors_per_location": 1, "z_sr_groups_factors": self.n_groups, "w_sf": self.n_factors, + "w_sf_proportion": self.n_factors, + "prior_w_sf": self.n_factors, + "cell_compartment_w_sfk": (self.n_factors, int(self.n_cell_compartments - 1)), "detection_y_s": 1, }, } - def forward(self, x_data, idx, batch_index): - obs2sample = one_hot(batch_index, self.n_batch) + def reshape_input_2d(self, x, n_tiles=1, axis=-2, axis_offset=-4): + # conv2d expects 4d input: [batch, channels, height, width] + if self.image_size is None: + sizex = sizey = int(np.sqrt(x.shape[axis] / n_tiles)) + else: + sizex, sizey = self.image_size + # here batch dim has just one element + if n_tiles > 1: + return rearrange(x, "(t p o) g -> t g p o", p=sizex, o=sizey, t=n_tiles) + else: + return rearrange(x, "(p o) g -> g p o", p=sizex, o=sizey).unsqueeze(axis_offset) - obs_plate = self.create_plates(x_data, idx, batch_index) + def reshape_input_2d_inverse(self, x, n_tiles=1, axis=-2, axis_offset=-4): + # conv2d expects 4d input: [batch, channels, height, width] + # here batch dim has just one element + if n_tiles > 1: + return rearrange(x.squeeze(axis_offset), "t g p o -> (t p o) g") + else: + return rearrange(x.squeeze(axis_offset), "g p o -> (p o) g") + + def crop_according_to_valid_padding(self, x, n_tiles=1): + # remove observations that will not be included after convolution with padding='valid' + # reshape to 2d + x = self.reshape_input_2d(x, n_tiles=n_tiles) + # crop to valid observations + indx = np.arange(self.sliding_window_size // 2, x.shape[-2] - (self.sliding_window_size // 2)) + indy = np.arange(self.sliding_window_size // 2, x.shape[-1] - (self.sliding_window_size // 2)) + x = np.take(x, indx, axis=-2) + x = np.take(x, indy, axis=-1) + # reshape back to 1d + x = self.reshape_input_2d_inverse(x, n_tiles=n_tiles) + return x + + def aggregate_conv2d(self, x, n_tiles=1, size=None, padding="valid", mean=False): + # conv2d expects 4d input: [batch, channels, height, width] + input = self.reshape_input_2d(x, n_tiles=n_tiles) + # conv2d expects 4d weights: [out_channels, in_channels/groups, height, width] + if size is None: + size = self.sliding_window_size + weights = torch.ones((x.shape[-1], 1, size, size), device=input.device) + if mean: + weights = weights / torch.tensor(size * size, device=input.device) + x = torch.nn.functional.conv2d( + input, + weights, + padding=padding, + groups=x.shape[-1], + ) + x = self.reshape_input_2d_inverse(x, n_tiles=n_tiles) + return x + + def learnable_neighbour_effect_conv2d(self, x, name, n_tiles=1, size=None, n_out=None, padding="valid"): + # pyro version + + # conv2d expects 4d input: [batch, channels, height, width] + input = self.reshape_input_2d(x, n_tiles=n_tiles) + # conv2d expects 4d weights: [out_channels, in_channels/groups, height, width] + if n_out is None: + n_out = x.shape[-1] + groups = x.shape[-1] + else: + groups = 1 + if size is None: + size = self.sliding_window_size + weights_shape = [n_out, int(x.shape[-1] / groups), size, size] + weights = pyro.sample( + f"{name}_weights", + dist.SoftLaplace(self.zeros, self.ones).expand(weights_shape).to_event(len(weights_shape)), + ) # [self.n_factors, self.n_factors] + x = torch.nn.functional.conv2d( + input, + weights, + padding=padding, + groups=groups, + ) + x = self.reshape_input_2d_inverse(x, n_tiles=n_tiles) + return x + + def redistribute_conv2d(self, x, name, n_tiles=1, size=None, n_out=None, padding="same"): + # pyro version + + # conv2d expects 4d input: [batch, channels, height, width] + input = self.reshape_input_2d(x, n_tiles=n_tiles) + # conv2d expects 4d weights: [out_channels, in_channels/groups, height, width] + if n_out is None: + n_out = x.shape[-1] + groups = x.shape[-1] + else: + groups = 1 + if size is None: + size = self.sliding_window_size + weights_shape = [n_out, int(n_out / groups)] + weights = pyro.sample( + f"{name}_weights", + dist.Dirichlet(self.ones_1d.expand([size * size])) + .expand(weights_shape) + .to_event(reinterpreted_batch_ndims=None), + ) # [self.n_factors, self.n_factors] + weights = rearrange(weights, "o g (s z) -> o g s z", s=size, z=size) + x = torch.nn.functional.conv2d( + input, + weights, + padding=padding, + groups=groups, + ) + x = self.reshape_input_2d_inverse(x, n_tiles=n_tiles) + return x - # =====================Gene expression level scaling m_g======================= # - # Explains difference in sensitivity for each gene between single cell and spatial technology - m_g_mean = pyro.sample( - "m_g_mean", - dist.Gamma( - self.m_g_mu_mean_var_ratio_hyp * self.m_g_mu_hyp, - self.m_g_mu_mean_var_ratio_hyp, - ) - .expand([1, 1]) - .to_event(2), - ) # (1, 1) + def learnable_neighbour_effect_conv2d_nn( + self, + x, + name, + n_tiles=1, + size=None, + n_out=None, + padding="valid", + use_weigted_cnn_weights=None, + ): + # pure pytorch version - m_g_alpha_e_inv = pyro.sample( - "m_g_alpha_e_inv", - dist.Exponential(self.m_g_alpha_hyp_mean).expand([1, 1]).to_event(2), - ) # (1, 1) - m_g_alpha_e = self.ones / m_g_alpha_e_inv.pow(2) + if use_weigted_cnn_weights is None: + use_weigted_cnn_weights = self.use_weigted_cnn_weights - m_g = pyro.sample( - "m_g", - dist.Gamma(m_g_alpha_e, m_g_alpha_e / m_g_mean).expand([1, self.n_vars]).to_event(2), # self.m_g_mu_hyp) - ) # (1, n_vars) + # dropout + x = self.dropout(x) - # =====================Cell abundances w_sf======================= # - # factorisation prior on w_sf models similarity in locations - # across cell types f and reflects the absolute scale of w_sf - with obs_plate as ind: - k = "n_s_cells_per_location" - n_s_cells_per_location = pyro.sample( - k, - dist.Gamma( - self.N_cells_per_location * self.N_cells_mean_var_ratio, - self.N_cells_mean_var_ratio, - ), + # conv2d expects 4d input: [batch, channels, height, width] + input = self.reshape_input_2d(x, n_tiles=n_tiles) + # conv2d expects 4d weights: [out_channels, in_channels/groups, height, width] + if n_out is None: + n_out = x.shape[-1] + groups = x.shape[-1] + else: + groups = 1 + if size is None: + size = self.sliding_window_size + if False: + if getattr(self.weights, name + "_g", None) is None: + # gene-wise weights + init_param = torch.normal( + torch.full( + size=[1, x.shape[-1], 1, 1], + fill_value=0.0, + device=input.device, + ), + torch.full( + size=[1, x.shape[-1], 1, 1], + fill_value=1.0 / np.sqrt(self.cell_state.shape[1]), + device=input.device, + ), + ) + deep_setattr( + self.weights, + name + "_g", + pyro.nn.PyroParam( + init_param.to(input.device).requires_grad_(True), + constraint=torch.distributions.constraints.positive, + ), + ) + weights_g = deep_getattr(self.weights, name + "_g").to(input.device) + input = input * weights_g + if use_weigted_cnn_weights and (groups == 1): + weights = self.cell_state / ( + self.cell_state.sum(0, keepdims=True) + torch.tensor(1e-4, device=input.device) ) - if ( - self.training_wo_observed - and not self.training_wo_initial - and getattr(self, f"init_val_{k}", None) is not None - ): - # pre-training Variational distribution to initial values - pyro.sample( - k + "_initial", - dist.Gamma( - self.init_alpha_tt, - self.init_alpha_tt / getattr(self, f"init_val_{k}")[ind], + n_repeats = np.ceil(n_out / self.cell_state.shape[0]) + weights = torch.tile(weights, (int(n_repeats), 1))[:n_out, :] + weights = torch.tile(weights.unsqueeze(-1).unsqueeze(-1), (1, 1, size, size)).detach() + if getattr(self.weights, name, None) is None: + init_param = torch.normal( + torch.full( + size=(n_out, self.cell_state.shape[1], size, size), + fill_value=0.0, + device=input.device, + ), + torch.full( + size=(n_out, self.cell_state.shape[1], size, size), + fill_value=1.0 / np.sqrt(size * size * self.cell_state.shape[1]), + device=input.device, + ), + ) + deep_setattr( + self.weights, + name, + pyro.nn.PyroParam( + init_param, # .to(input.device).requires_grad_(True), ), - obs=n_s_cells_per_location, - ) # (self.n_obs, self.n_groups) + ) + weights_ = deep_getattr(self.weights, name).to(input.device) + x = torch.nn.functional.conv2d( + input, + weights * weights_, + padding=padding, + groups=groups, + ) + else: + if getattr(self.weights, name, None) is None: + deep_setattr( + self.weights, + name, + PyroModule[torch.nn.Conv2d]( + in_channels=x.shape[-1], + out_channels=n_out, + kernel_size=size, + padding=padding, + groups=groups, + ).to(input.device), + ) + mod = deep_getattr(self.weights, name).to(input.device) + x = mod(input) + x = self.reshape_input_2d_inverse(x, n_tiles=n_tiles) + if getattr(self.weights, name + "_layer_norm", None) is None: + deep_setattr( + self.weights, + name + "_layer_norm", + PyroModule[torch.nn.LayerNorm](x.shape[-1], elementwise_affine=False).to(input.device), + ) + mod = deep_getattr(self.weights, name + "_layer_norm").to(input.device) + x = mod(x) + x = torch.nn.functional.softplus(x) + return x + + def learnable_neighbour_effect_conv2d_nn_layers(self, x, name, n_tiles=1, size=None, n_out=None, padding="valid"): + # pure pytorch version + # n_layers = 2 + x = self.learnable_neighbour_effect_conv2d_nn( + x=x, name=name, n_tiles=n_tiles, size=size, n_out=n_out, padding=padding + ) + for i in [2]: + x = x + self.learnable_neighbour_effect_conv2d_nn( + x=x, + name=f"{name}_{i}", + n_tiles=n_tiles, + size=size, + n_out=n_out, + padding=padding, + use_weigted_cnn_weights=False, + ) - k = "b_s_groups_per_location" - b_s_groups_per_location = pyro.sample( - k, - dist.Gamma(self.B_groups_per_location, self.ones), + return x + + def n_cells_per_location_prior(self, obs_plate): + if len(self.N_cells_per_location) == self.n_obs: + with obs_plate as ind: + N_cells_per_location = self.N_cells_per_location[ind, :] + if self.N_cells_per_location_alpha_prior is not None: + n_s_cells_per_location_prior = self.N_cells_per_location_alpha_prior + else: + N_cells_per_location = self.N_cells_per_location + if self.N_cells_per_location_alpha_prior is not None: + n_s_cells_per_location_prior = pyro.sample( # 1/2 + "n_s_cells_per_location_prior", + dist.Exponential( + self.N_cells_per_location_alpha_prior * self.ones, # 2 + ) + .expand([1, 1]) + .to_event(2), + ) + n_s_cells_per_location_prior = self.ones / n_s_cells_per_location_prior.pow(2) # 4 + if self.N_cells_per_location_alpha_prior is not None: + with obs_plate: + n_s_cells_per_location = pyro.sample( + "n_s_cells_per_location", + dist.Gamma( + n_s_cells_per_location_prior, + n_s_cells_per_location_prior / N_cells_per_location, + ), + ) + else: + with obs_plate: + # prior on number of cells per location + n_s_cells_per_location = pyro.sample( + "n_s_cells_per_location", + dist.Gamma( + N_cells_per_location * self.N_cells_mean_var_ratio, + self.N_cells_mean_var_ratio, + ), + ) + return n_s_cells_per_location + + def a_s_factors_per_location_prior(self, obs_plate): + if self.A_B_per_location_alpha_prior is not None: + a_s_factors_per_location_prior = pyro.sample( # 1/2 + "a_s_factors_per_location_prior", + dist.Exponential( + self.A_B_per_location_alpha_prior * self.ones, # 2 + ) + .expand([1, 1]) + .to_event(2), ) - if ( - self.training_wo_observed - and not self.training_wo_initial - and getattr(self, f"init_val_{k}", None) is not None - ): - # pre-training Variational distribution to initial values - pyro.sample( - k + "_initial", + a_s_factors_per_location_prior = self.ones / a_s_factors_per_location_prior.pow(2) # 4 + with obs_plate: + a_s_factors_per_location = pyro.sample( + "a_s_factors_per_location", dist.Gamma( - self.init_alpha_tt, - self.init_alpha_tt / getattr(self, f"init_val_{k}")[ind], + a_s_factors_per_location_prior, + a_s_factors_per_location_prior / self.A_factors_per_location, ), - obs=b_s_groups_per_location, - ) # (self.n_obs, self.n_groups) + ) + else: + with obs_plate: + # prior on number of cells per location + a_s_factors_per_location = pyro.sample( + "a_s_factors_per_location", + dist.Gamma( + self.A_factors_per_location * self.ones, + self.ones, + ), + ) + return a_s_factors_per_location + + def b_s_groups_per_location_prior(self, obs_plate): + if self.A_B_per_location_alpha_prior is not None: + b_s_groups_per_location_prior = pyro.sample( # 1/2 + "b_s_groups_per_location_prior", + dist.Exponential( + self.A_B_per_location_alpha_prior * self.ones, # 2 + ) + .expand([1, 1]) + .to_event(2), + ) + b_s_groups_per_location_prior = self.ones / b_s_groups_per_location_prior.pow(2) # 4 + with obs_plate: + b_s_groups_per_location = pyro.sample( + "b_s_groups_per_location", + dist.Gamma( + b_s_groups_per_location_prior, + b_s_groups_per_location_prior / self.B_groups_per_location, + ), + ) + else: + with obs_plate: + # prior on number of cells per location + b_s_groups_per_location = pyro.sample( + "b_s_groups_per_location", + dist.Gamma( + self.B_groups_per_location * self.ones, + self.ones, + ), + ) + return b_s_groups_per_location + + def factorisation_prior_on_w_sf(self, obs_plate): + # factorisation prior on w_sf models similarity in locations + # across cell types f and reflects the absolute scale of w_sf + n_s_cells_per_location = self.n_cells_per_location_prior(obs_plate) + b_s_groups_per_location = self.b_s_groups_per_location_prior(obs_plate) # cell group loadings shape = self.ones_1_n_groups * b_s_groups_per_location / self.n_groups_tensor rate = self.ones_1_n_groups / (n_s_cells_per_location / b_s_groups_per_location) - with obs_plate as ind: + with obs_plate: k = "z_sr_groups_factors" z_sr_groups_factors = pyro.sample( k, dist.Gamma(shape, rate), # .to_event(1)#.expand([self.n_groups]).to_event(1) ) # (n_obs, n_groups) - if ( - self.training_wo_observed - and not self.training_wo_initial - and getattr(self, f"init_val_{k}", None) is not None - ): - # pre-training Variational distribution to initial values - pyro.sample( - k + "_initial", - dist.Gamma( - self.init_alpha_tt, - self.init_alpha_tt / getattr(self, f"init_val_{k}")[ind], - ), - obs=z_sr_groups_factors, - ) # (self.n_obs, self.n_groups) - k_r_factors_per_groups = pyro.sample( "k_r_factors_per_groups", dist.Gamma(self.factors_per_groups, self.ones).expand([self.n_groups, 1]).to_event(2), @@ -351,31 +795,199 @@ def forward(self, x_data, idx, batch_index): dist.Gamma(c2f_shape, k_r_factors_per_groups).expand([self.n_groups, self.n_factors]).to_event(2), ) # (self.n_groups, self.n_factors) - with obs_plate as ind: - w_sf_mu = z_sr_groups_factors @ x_fr_group2fact + w_sf_mu = z_sr_groups_factors @ x_fr_group2fact - k = "w_sf" - w_sf = pyro.sample( + return w_sf_mu + + def proportion_factorisation_prior_on_w_sf_v1(self, obs_plate): + # factorisation prior on w_sf models similarity in locations + # across cell types f and reflects the absolute scale of w_sf + n_s_cells_per_location = self.n_cells_per_location_prior(obs_plate) + b_s_groups_per_location = self.b_s_groups_per_location_prior(obs_plate) + + # cell group loadings + shape = self.ones_1_n_groups * b_s_groups_per_location / self.n_groups_tensor + rate = self.ones_1_n_groups + with obs_plate: + k = "z_sr_groups_factors" + z_sr_groups_factors = pyro.sample( + k, + dist.Gamma(shape, rate), # .to_event(1)#.expand([self.n_groups]).to_event(1) + ) # (n_obs, n_groups) + + c2f_shape = self.factors_per_groups / self.n_factors_tensor + + x_fr_lambdas_group2fact = pyro.sample( + "x_fr_lambdas_group2fact", + dist.Gamma(c2f_shape, self.ones_1_n_groups.T).expand([self.n_groups, self.n_factors]).to_event(2), + ) # (self.n_groups, self.n_factors) + x_fr_weights_group2fact = pyro.sample( + "x_fr_weights_group2fact", + dist.Normal(self.zeros, self.ones_1_n_groups.T).expand([self.n_groups, self.n_factors]).to_event(2), + ) # (self.n_groups, self.n_factors) + + w_sf_mu = z_sr_groups_factors @ (x_fr_lambdas_group2fact * x_fr_weights_group2fact) + w_sf_mu = w_sf_mu * torch.tensor(100.0, device=w_sf_mu.device) + # print("w_sf_mu 1 mean", w_sf_mu.mean().item(), "w_sf_mu 1 std", w_sf_mu.std().item()) + # print("w_sf_mu 1 min", w_sf_mu.min().item(), "w_sf_mu 1 max", w_sf_mu.max().item()) + + w_sf_mu = torch.softmax(w_sf_mu, dim=-1) + # print("w_sf_mu sum dim -2", w_sf_mu.sum(dim=-2).shape, w_sf_mu.sum(dim=-2)) + # print("w_sf_mu sum dim -1", w_sf_mu.sum(dim=-1).shape, w_sf_mu.sum(dim=-1)) + # print("w_sf_mu mean", w_sf_mu.mean().item(), "w_sf_mu std", w_sf_mu.std().item()) + # print("w_sf_mu min", w_sf_mu.min().item(), "w_sf_mu max", w_sf_mu.max().item()) + w_sf_mu = w_sf_mu * n_s_cells_per_location + + return w_sf_mu + + def proportion_factorisation_prior_on_w_sf(self, obs_plate): + # factorisation prior on w_sf models similarity in locations + # across cell types f and reflects the absolute scale of w_sf + n_s_cells_per_location = self.n_cells_per_location_prior(obs_plate) + b_s_groups_per_location = self.b_s_groups_per_location_prior(obs_plate) + + # cell group loadings + shape = self.ones_1_n_groups * b_s_groups_per_location / self.n_groups_tensor + rate = self.ones_1_n_groups + with obs_plate: + k = "z_sr_groups_factors" + z_sr_groups_factors = pyro.sample( k, + dist.Gamma(shape, rate), # .to_event(1)#.expand([self.n_groups]).to_event(1) + ) # (n_obs, n_groups) + + c2f_shape = self.factors_per_groups / self.n_factors_tensor + + x_fr_lambdas_group2fact = pyro.sample( + "x_fr_lambdas_group2fact", + dist.Gamma(c2f_shape, self.ones_1_n_groups.T).expand([self.n_groups, self.n_factors]).to_event(2), + ) # (self.n_groups, self.n_factors) + + w_sf_mu = z_sr_groups_factors @ x_fr_lambdas_group2fact + w_sf_mu = w_sf_mu * torch.tensor(100.0, device=w_sf_mu.device) + # print("w_sf_mu 1 mean", w_sf_mu.mean().item(), "w_sf_mu 1 std", w_sf_mu.std().item()) + # print("w_sf_mu 1 min", w_sf_mu.min().item(), "w_sf_mu 1 max", w_sf_mu.max().item()) + + w_sf_mu = w_sf_mu / w_sf_mu.sum(dim=-1, keepdim=True) + # print("w_sf_mu sum dim -2", w_sf_mu.sum(dim=-2).shape, w_sf_mu.sum(dim=-2)) + # print("w_sf_mu sum dim -1", w_sf_mu.sum(dim=-1).shape, w_sf_mu.sum(dim=-1)) + # print("w_sf_mu mean", w_sf_mu.mean().item(), "w_sf_mu std", w_sf_mu.std().item()) + # print("w_sf_mu min", w_sf_mu.min().item(), "w_sf_mu max", w_sf_mu.max().item()) + w_sf_mu = w_sf_mu * n_s_cells_per_location + + return w_sf_mu, n_s_cells_per_location + + def independent_prior_on_w_sf(self, obs_plate): + n_s_cells_per_location = self.n_cells_per_location_prior(obs_plate) + a_s_factors_per_location = self.a_s_factors_per_location_prior(obs_plate) + + # cell group loadings + shape = self.ones_1_n_factors * a_s_factors_per_location / self.n_factors_tensor + rate = self.ones_1_n_factors / (n_s_cells_per_location / a_s_factors_per_location) + + with obs_plate: + w_sf = pyro.sample( + "prior_w_sf", dist.Gamma( - w_sf_mu * self.w_sf_mean_var_ratio_tensor, - self.w_sf_mean_var_ratio_tensor, + shape, + rate, ), ) # (self.n_obs, self.n_factors) - if ( - self.training_wo_observed - and not self.training_wo_initial - and getattr(self, f"init_val_{k}", None) is not None - ): - # pre-training Variational distribution to initial values - pyro.sample( - k + "_initial", - dist.Gamma( - self.init_alpha_tt, - self.init_alpha_tt / getattr(self, f"init_val_{k}")[ind], - ), - obs=w_sf, - ) # (self.n_obs, self.n_factors) + return w_sf + + def forward( + self, + x_data, + idx, + batch_index, + tiles: torch.Tensor = None, + tiles_unexpanded: torch.Tensor = None, + positions: torch.Tensor = None, + in_tissue: torch.Tensor = None, + normalising_factor_y_s: Optional[torch.Tensor] = None, + x_data_normalised: Optional[torch.Tensor] = None, + extra_categoricals: Optional[torch.Tensor] = None, + ): + if tiles_unexpanded is not None: + tiles_in_use = tiles.sum(0).bool() + obs_in_use = tiles_unexpanded[:, tiles_in_use].sum(1).bool() + idx = idx[obs_in_use] + batch_index = batch_index[obs_in_use] + if positions is not None: + positions = positions[obs_in_use] + # if self.sliding_window_size > 0: + # # remove observations that will not be included after convolution with padding='valid' + # idx = self.crop_according_to_valid_padding(idx.unsqueeze(-1)).squeeze(-1) + # batch_index = self.crop_according_to_valid_padding(batch_index) + # if positions is not None: + # positions = self.crop_according_to_valid_padding(positions) + obs2sample = one_hot(batch_index, self.n_batch).float() + if self.n_extra_categoricals is not None: + obs2extra_categoricals = torch.cat( + [ + one_hot( + extra_categoricals[:, i].view((extra_categoricals.shape[0], 1)), + n_cat, + ) + for i, n_cat in enumerate(self.n_extra_categoricals) + ], + dim=1, + ).float() + obs_plate = self.create_plates( + x_data, + idx, + batch_index, + tiles, + tiles_unexpanded, + positions, + in_tissue, + normalising_factor_y_s, + x_data_normalised, + extra_categoricals, + ) + if tiles is not None: + n_tiles = tiles.shape[1] + else: + n_tiles = 1 + if in_tissue is None: + in_tissue = self.ones_1d.expand((x_data.shape[0], 1)).bool() + + # =====================Gene expression level scaling m_g======================= # + # Explains difference in sensitivity for each gene between single cell and spatial technology + m_g_mean = pyro.sample( + "m_g_mean", + dist.Gamma( + self.m_g_mu_mean_var_ratio_hyp * self.m_g_mu_hyp, + self.m_g_mu_mean_var_ratio_hyp, + ) + .expand([1, 1]) + .to_event(2), + ) # (1, 1) + + m_g_alpha_e_inv = pyro.sample( + "m_g_alpha_e_inv", + dist.Exponential(self.m_g_alpha_hyp_mean).expand([1, 1]).to_event(2), + ) # (1, 1) + m_g_alpha_e = self.ones / m_g_alpha_e_inv.pow(2) + + m_g = pyro.sample( + "m_g", + dist.Gamma(m_g_alpha_e, m_g_alpha_e / m_g_mean).expand([1, self.n_vars]).to_event(2), # self.m_g_mu_hyp) + ) # (1, n_vars) + + # =====================Gene-specific multiplicative component ======================= # + # `y_{t, g}` per gene multiplicative effect that explains the difference + # in sensitivity between genes in each technology or covariate effect + if self.n_extra_categoricals is not None: + detection_tech_gene_tg = pyro.sample( + "detection_tech_gene_tg", + dist.Gamma( + self.ones * self.gene_tech_prior_alpha, + self.ones * self.gene_tech_prior_beta, + ) + .expand([np.sum(self.n_extra_categoricals), self.n_vars]) + .to_event(2), + ) # =====================Location-specific detection efficiency ======================= # # y_s with hierarchical mean prior @@ -401,27 +1013,122 @@ def forward(self, x_data, idx, batch_index): dist.Gamma(obs2sample @ detection_hyp_prior_alpha, beta), ) # (self.n_obs, 1) - if ( - self.training_wo_observed - and not self.training_wo_initial - and getattr(self, f"init_val_{k}", None) is not None - ): - # pre-training Variational distribution to initial values - pyro.sample( - k + "_initial", - dist.Gamma( - self.init_alpha_tt, - self.init_alpha_tt / getattr(self, f"init_val_{k}")[ind], - ), - obs=detection_y_s, - ) # (self.n_obs, 1) + if normalising_factor_y_s is not None: + detection_y_s = detection_y_s * normalising_factor_y_s + pyro.deterministic("total_detection_y_s", detection_y_s) + + # if (self.sliding_window_size > 0) and self.use_aggregated_detection_y_s: + # detection_y_s = self.aggregate_conv2d( + # detection_y_s, + # padding="same", + # mean=True, + # size=20, + # n_tiles=n_tiles, + # ) + # pyro.deterministic("aggregated_detection_y_s", detection_y_s) + + # =====================Cell abundances w_sf======================= # + if not self.use_independent_prior_on_w_sf: + n_s_cells_per_location = None + if self.use_proportion_factorisation_prior_on_w_sf: + w_sf_mu, n_s_cells_per_location = self.proportion_factorisation_prior_on_w_sf(obs_plate) + else: + w_sf_mu = self.factorisation_prior_on_w_sf(obs_plate) + if self.use_learnable_mean_var_ratio: + w_sf_mean_var_ratio_hyp = pyro.sample( + "w_sf_mean_var_ratio_hyp", + dist.Gamma(self.w_sf_mean_var_ratio_tensor, self.ones).expand([1, 1]).to_event(2), + ) + w_sf_mean_var_ratio = pyro.sample( + "w_sf_mean_var_ratio", + dist.Exponential(w_sf_mean_var_ratio_hyp).expand([1, self.n_factors]).to_event(2), + ) # (self.n_batch, self.n_vars) + w_sf_mean_var_ratio = self.ones / ( + w_sf_mean_var_ratio + torch.tensor(1.0 / 20.0, device=w_sf_mean_var_ratio.device) + ) + else: + w_sf_mean_var_ratio = self.w_sf_mean_var_ratio_tensor + if self.use_n_s_cells_per_location_limit: + with obs_plate: + k = "w_sf_proportion" + w_sf = pyro.sample( + k, + dist.Gamma( + w_sf_mu * w_sf_mean_var_ratio, + w_sf_mean_var_ratio, + ), + ) # (self.n_obs, self.n_factors) + w_sf = w_sf / w_sf.sum(dim=-1, keepdim=True) + w_sf = w_sf * n_s_cells_per_location + pyro.deterministic("w_sf", w_sf) + else: + with obs_plate: + k = "w_sf" + w_sf = pyro.sample( + k, + dist.Gamma( + w_sf_mu * w_sf_mean_var_ratio, + w_sf_mean_var_ratio, + ), + ) # (self.n_obs, self.n_factors) + elif self.use_independent_prior_on_w_sf: + w_sf_mu = self.independent_prior_on_w_sf(obs_plate) + with obs_plate: + k = "w_sf" + w_sf = pyro.deterministic(k, w_sf_mu) # (self.n_obs, self.n_factors) + + if self.use_cell_compartments: + with obs_plate: + k = "cell_compartment_w_sfk" + w_sfk = pyro.sample( + k, + dist.Dirichlet(self.ones_1d.expand((self.n_factors, self.n_cell_compartments))), + ) # ( self.n_factors, self.n_obs, self.n_cell_compartments) + w_sf = torch.einsum("sfk,sf->fsk", w_sfk, w_sf) + + if (self.sliding_window_size > 0) and self.use_aggregated_w_sf: + w_sf = self.redistribute_conv2d( + rearrange(w_sf, "f s k -> s (f k)"), + n_tiles=n_tiles, + name="redistribute", + padding="same", + ) + w_sf = rearrange(w_sf, "s (f k) -> f s k", f=self.n_factors, k=self.n_cell_compartments) + pyro.deterministic("aggregated_w_fsk", w_sf) + if self.use_per_cell_type_normalisation: + per_cell_type_normalisation_f = pyro.sample( + "per_cell_type_normalisation_f", + dist.Gamma(self.detection_cell_type_prior_alpha, self.detection_cell_type_prior_alpha) + .expand([self.n_factors, 1, self.n_cell_compartments]) + .to_event(3), + ) # [self.n_factors, 1, self.n_cell_compartments] + w_sf = w_sf * per_cell_type_normalisation_f + else: + if (self.sliding_window_size > 0) and self.use_aggregated_w_sf: + w_sf = self.redistribute_conv2d( + w_sf, + name="redistribute", + padding="same", + n_tiles=n_tiles, + ) + pyro.deterministic("aggregated_w_sf", w_sf) + if self.use_per_cell_type_normalisation: + per_cell_type_normalisation_f = pyro.sample( + "per_cell_type_normalisation_f", + dist.Gamma(self.detection_cell_type_prior_alpha, self.detection_cell_type_prior_alpha) + .expand([1, self.n_factors]) + .to_event(2), + ) # (1, self.n_factors) + w_sf = w_sf * per_cell_type_normalisation_f # =====================Gene-specific additive component ======================= # # per gene molecule contribution that cannot be explained by # cell state signatures (e.g. background, free-floating RNA) s_g_gene_add_alpha_hyp = pyro.sample( "s_g_gene_add_alpha_hyp", - dist.Gamma(self.ones * self.gene_add_alpha_hyp_prior_alpha, self.ones * self.gene_add_alpha_hyp_prior_beta), + dist.Gamma(self.ones * self.gene_add_alpha_hyp_prior_alpha, self.ones * self.gene_add_alpha_hyp_prior_beta) + .expand([1, 1]) + .to_event(2), ) s_g_gene_add_mean = pyro.sample( "s_g_gene_add_mean", @@ -448,7 +1155,9 @@ def forward(self, x_data, idx, batch_index): # =====================Gene-specific overdispersion ======================= # alpha_g_phi_hyp = pyro.sample( "alpha_g_phi_hyp", - dist.Gamma(self.ones * self.alpha_g_phi_hyp_prior_alpha, self.ones * self.alpha_g_phi_hyp_prior_beta), + dist.Gamma(self.ones * self.alpha_g_phi_hyp_prior_alpha, self.ones * self.alpha_g_phi_hyp_prior_beta) + .expand([1, 1]) + .to_event(2), ) alpha_g_inverse = pyro.sample( "alpha_g_inverse", @@ -458,29 +1167,87 @@ def forward(self, x_data, idx, batch_index): # =====================Expected expression ======================= # if not self.training_wo_observed: # expected expression - mu = ((w_sf @ self.cell_state) * m_g + (obs2sample @ s_g_gene_add)) * detection_y_s + if self.use_cell_compartments: + k = "cell_compartment_g_fgk" + cell_compartment_g_fgk = pyro.sample( + k, + dist.Dirichlet(self.ones_1d.expand((1, 1, self.n_cell_compartments))) + .expand([self.n_factors, self.n_vars]) + .to_event(reinterpreted_batch_ndims=None), + ) # ( self.n_factors, self.n_vars, self.n_cell_compartments) + mu = torch.einsum("fsk,fgk,fg->sg", w_sf, cell_compartment_g_fgk, self.cell_state) + else: + mu = w_sf @ self.cell_state + mu = (mu * m_g + (obs2sample @ s_g_gene_add)) * detection_y_s alpha = obs2sample @ (self.ones / alpha_g_inverse.pow(2)) # convert mean and overdispersion to total count and logits # total_count, logits = _convert_mean_disp_to_counts_logits( # mu, alpha, eps=self.eps # ) + if self.n_extra_categoricals is not None: + # gene-specific normalisation for covatiates + mu = mu * (obs2extra_categoricals @ detection_tech_gene_tg) + # =====================DATA likelihood ======================= # # Likelihood (sampling distribution) of data_target & add overdispersion via NegativeBinomial - if self.dropout_p != 0: - x_data = self.dropout(x_data) - with obs_plate: - pyro.sample( - "data_target", - dist.GammaPoisson(concentration=alpha, rate=alpha / mu), - # dist.NegativeBinomial(total_count=total_count, logits=logits), - obs=x_data, - ) + # if self.dropout_p != 0: + # x_data = self.dropout(x_data) + if not self.sliding_window_size_list_exist: + if self.sliding_window_size > 0: + x_data = self.aggregate_conv2d( + x_data, + padding="same", + n_tiles=n_tiles, + size=self.sliding_window_size, + ) + with obs_plate: + pyro.sample( + "data_target", + dist.GammaPoisson(concentration=alpha, rate=alpha / mu), + obs=x_data, + ) + else: + for i, size in enumerate(self.sliding_window_size_list): + if self.sliding_window_size_list[i] > 0: + mu_ = self.aggregate_conv2d( + mu, + padding="same", + n_tiles=n_tiles, + size=size, + ) + alpha_ = alpha * torch.tensor((self.sliding_window_size_list[i] ** 2) / 100, device=mu.device) + # alpha_g_size_effect = pyro.sample( + # f"alpha_g_size_{size}", + # dist.Gamma(self.ones + self.ones, self.ones + self.ones).to_event(2), + # ) + # alpha_ = alpha_ * alpha_g_size_effect + x_data_ = self.aggregate_conv2d( + x_data, + padding="same", + n_tiles=n_tiles, + size=size, + ) + else: + mu_ = mu + alpha_ = alpha * torch.tensor((1**2) / 100, device=mu.device) + x_data_ = x_data + with obs_plate, pyro.poutine.mask(mask=in_tissue): + pyro.sample( + f"data_target_{size}", + dist.GammaPoisson(concentration=alpha_, rate=alpha_ / mu_), + obs=x_data_, + ) # =====================Compute mRNA count from each factor in locations ======================= # with obs_plate: - mRNA = w_sf * (self.cell_state * m_g).sum(-1) - pyro.deterministic("u_sf_mRNA_factors", mRNA) + if not self.training: + if self.use_cell_compartments: + mRNA = torch.einsum("fsk,fgk,fg->fsk", w_sf, cell_compartment_g_fgk, self.cell_state * m_g) + pyro.deterministic("u_fsk_mRNA_factors", mRNA) + else: + mRNA = w_sf * (self.cell_state * m_g).sum(-1) + pyro.deterministic("u_sf_mRNA_factors", mRNA) def compute_expected( self, diff --git a/cell2location/models/_cellcomm_model.py b/cell2location/models/_cellcomm_model.py new file mode 100755 index 00000000..db0b74de --- /dev/null +++ b/cell2location/models/_cellcomm_model.py @@ -0,0 +1,350 @@ +import logging +from typing import List, Optional, Union + +import numpy as np +import pandas as pd +from anndata import AnnData +from pyro import clear_param_store +from pyro.infer import Trace_ELBO, TraceEnum_ELBO +from pyro.nn import PyroModule +from scvi import REGISTRY_KEYS +from scvi.data import AnnDataManager +from scvi.data.fields import ( + CategoricalJointObsField, + CategoricalObsField, + LayerField, + NumericalJointObsField, + NumericalObsField, + ObsmField, +) +from scvi.dataloaders import DeviceBackedDataSplitter +from scvi.model.base import BaseModelClass, PyroSampleMixin, PyroSviTrainMixin +from scvi.train import PyroTrainingPlan +from scvi.utils import setup_anndata_dsp + +from cell2location.dataloaders._defined_grid_dataloader import SpatialGridDataSplitter +from cell2location.models._cellcomm_module import CellCommModule +from cell2location.models.base._pyro_base_loc_module import Cell2locationBaseModule +from cell2location.models.base._pyro_mixin import ( + PltExportMixin, + QuantileMixin, + setup_pyro_model, +) + + +class CellCommModel(QuantileMixin, PyroSampleMixin, PyroSviTrainMixin, PltExportMixin, BaseModelClass): + r""" + Cell2location model. User-end model class. See Module class for description of the model (incl. math). + + Parameters + ---------- + adata + spatial AnnData object that has been registered via :func:`~scvi.data.setup_anndata`. + cell_state_df + pd.DataFrame with reference expression signatures for each gene (rows) in each cell type/population (columns). + **model_kwargs + Keyword args for :class:`~cell2location.models.LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGeneAlphaPyroModel` + + Examples + -------- + TODO add example + >>> + """ + + def __init__( + self, + adata: AnnData, + receptor_abundance_df: pd.DataFrame, + model_class: Optional[PyroModule] = None, + on_load_batch_size: Optional[int] = None, + **model_kwargs, + ): + # in case any other model was created before that shares the same parameter names. + clear_param_store() + + super().__init__(adata) + + self.mi_ = [] + + if model_class is None: + model_class = CellCommModule + + self.receptor_abundance_ = receptor_abundance_df + self.n_factors_ = receptor_abundance_df.shape[1] + self.factor_names_ = receptor_abundance_df.columns.values + + if "tiles" in self.adata_manager.data_registry: + on_load_batch_size = 1 + self._data_splitter_cls = SpatialGridDataSplitter + logging.info("Updating data splitter to SpatialGridDataSplitter.") + self.module = Cell2locationBaseModule( + model=model_class, + n_obs=self.summary_stats["n_cells"], + n_vars=self.summary_stats["n_vars"], + n_factors=self.n_factors_, + n_batch=self.summary_stats["n_batch"], + receptor_abundance=self.receptor_abundance_.values.astype("float32"), + on_load_kwargs={ + "batch_size": on_load_batch_size, + "max_epochs": 1, + }, + **model_kwargs, + ) + self._model_summary_string = f'CellComm model with the following params: \nn_labels: {self.n_factors_} \nn_batch: {self.summary_stats["n_batch"]} ' + self.init_params_ = self._get_init_params(locals()) + + @classmethod + @setup_anndata_dsp.dedent + def setup_anndata( + cls, + adata: AnnData, + layer: Optional[str] = None, + signal_abundance_key: Optional[str] = None, + cell_abundance_key: Optional[str] = None, + cell_abundance_lvl2_key: Optional[str] = None, + batch_key: Optional[str] = None, + labels_key: Optional[str] = None, + position_key: Optional[str] = None, + tiles_key: Optional[str] = None, + tiles_unexpanded_key: Optional[str] = None, + in_tissue_key: Optional[str] = None, + categorical_covariate_keys: Optional[List[str]] = None, + continuous_covariate_keys: Optional[List[str]] = None, + **kwargs, + ): + """ + %(summary)s. + + Parameters + ---------- + %(param_layer)s + %(param_batch_key)s + %(param_labels_key)s + %(param_cat_cov_keys)s + %(param_cont_cov_keys)s + """ + setup_method_args = cls._get_setup_method_args(**locals()) + adata.obs["_indices"] = np.arange(adata.n_obs).astype("int64") + anndata_fields = [ + LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), + CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), + CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), + CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), + NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), + NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), + ] + if signal_abundance_key is not None: + anndata_fields.append(ObsmField("signal_abundance", signal_abundance_key)) + if cell_abundance_key is not None: + anndata_fields.append(ObsmField("w_sf", cell_abundance_key)) + if cell_abundance_lvl2_key is not None: + anndata_fields.append(ObsmField("w_sf_lvl2", cell_abundance_lvl2_key)) + if position_key is not None: + anndata_fields.append(ObsmField("positions", position_key)) + if tiles_key is not None: + anndata_fields.append(ObsmField("tiles", tiles_key)) + if tiles_unexpanded_key is not None: + anndata_fields.append(ObsmField("tiles_unexpanded", tiles_unexpanded_key)) + if in_tissue_key is not None: + anndata_fields.append(NumericalObsField("in_tissue", in_tissue_key)) + + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) + adata_manager.register_fields(adata, **kwargs) + cls.register_manager(adata_manager) + + def train( + self, + max_epochs: int = 30000, + batch_size: int = None, + train_size: float = 1, + lr: float = 0.002, + num_particles: int = 1, + scale_elbo: float = "auto", + accelerator: str = "auto", + device: Union[int, str] = "auto", + validation_size: Optional[float] = None, + shuffle_set_split: bool = True, + early_stopping: bool = False, + training_plan: Optional[PyroTrainingPlan] = None, + plan_kwargs: Optional[dict] = None, + datasplitter_kwargs: Optional[dict] = None, + **trainer_kwargs, + ): + """Train the model. + + Parameters + ---------- + max_epochs + Number of passes through the dataset. If `None`, defaults to + `np.min([round((20000 / n_cells) * 400), 400])` + %(param_use_gpu)s + %(param_accelerator)s + %(param_device)s + train_size + Size of training set in the range [0.0, 1.0]. + validation_size + Size of the test set. If `None`, defaults to 1 - `train_size`. If + `train_size + validation_size < 1`, the remaining cells belong to a test set. + shuffle_set_split + Whether to shuffle indices before splitting. If `False`, the val, train, and test set are split in the + sequential order of the data according to `validation_size` and `train_size` percentages. + batch_size + Minibatch size to use during training. If `None`, no minibatching occurs and all + data is copied to device (e.g., GPU). + early_stopping + Perform early stopping. Additional arguments can be passed in `**kwargs`. + See :class:`~scvi.train.Trainer` for further options. + lr + Optimiser learning rate (default optimiser is :class:`~pyro.optim.ClippedAdam`). + Specifying optimiser via plan_kwargs overrides this choice of lr. + training_plan + Training plan :class:`~scvi.train.PyroTrainingPlan`. + plan_kwargs + Keyword args for :class:`~scvi.train.PyroTrainingPlan`. Keyword arguments passed to + `train()` will overwrite values present in `plan_kwargs`, when appropriate. + **trainer_kwargs + Other keyword args for :class:`~scvi.train.Trainer`. + """ + # if max_epochs is None: + # max_epochs = get_max_epochs_heuristic(self.adata.n_obs, epochs_cap=1000) + if datasplitter_kwargs is None: + datasplitter_kwargs = dict() + + if issubclass(self._data_splitter_cls, SpatialGridDataSplitter): + self.module.model.n_tiles = batch_size + + plan_kwargs = plan_kwargs if isinstance(plan_kwargs, dict) else {} + if lr is not None and "optim" not in plan_kwargs.keys(): + plan_kwargs.update({"optim_kwargs": {"lr": lr}}) + if getattr(self.module.model, "discrete_variables", None) and (len(self.module.model.discrete_variables) > 0): + plan_kwargs["loss_fn"] = TraceEnum_ELBO(num_particles=num_particles) + else: + plan_kwargs["loss_fn"] = Trace_ELBO(num_particles=num_particles) + if scale_elbo != 1.0: + if scale_elbo == "auto": + scale_elbo = 1.0 / (self.summary_stats["n_cells"] * self.summary_stats["n_vars"]) + plan_kwargs["scale_elbo"] = scale_elbo + + if batch_size is None: + # use data splitter which moves data to GPU once + data_splitter = DeviceBackedDataSplitter( + self.adata_manager, + train_size=train_size, + validation_size=validation_size, + batch_size=batch_size, + accelerator=accelerator, + device=device, + ) + else: + data_splitter = self._data_splitter_cls( + self.adata_manager, + train_size=train_size, + validation_size=validation_size, + shuffle_set_split=shuffle_set_split, + batch_size=batch_size, + **datasplitter_kwargs, + ) + + if training_plan is None: + training_plan = self._training_plan_cls(self.module, **plan_kwargs) + + es = "early_stopping" + trainer_kwargs[es] = early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es] + + if "callbacks" not in trainer_kwargs.keys(): + trainer_kwargs["callbacks"] = [] + + # Initialise pyro model with data + from copy import copy + + dl = copy(data_splitter) + dl.setup() + dl = dl.train_dataloader() + setup_pyro_model(dl, training_plan) + + runner = self._train_runner_cls( + self, + training_plan=training_plan, + data_splitter=data_splitter, + max_epochs=max_epochs, + accelerator=accelerator, + devices=device, + **trainer_kwargs, + ) + return runner() + + def export_posterior( + self, + adata, + sample_kwargs: Optional[dict] = None, + export_slot: str = "mod", + add_to_obsm: list = ["means", "stds", "q05", "q95"], + use_quantiles: bool = False, + ): + """ + Summarise posterior distribution and export results (cell abundance) to anndata object: + + 1. adata.obsm: Estimated cell abundance as pd.DataFrames for each posterior distribution summary `add_to_obsm`, + posterior mean, sd, 5% and 95% quantiles (['means', 'stds', 'q05', 'q95']). + If export to adata.obsm fails with error, results are saved to adata.obs instead. + 2. adata.uns: Posterior of all parameters, model name, date, + cell type names ('factor_names'), obs and var names. + + Parameters + ---------- + adata + anndata object where results should be saved + sample_kwargs + arguments for self.sample_posterior (generating and summarising posterior samples), namely: + num_samples - number of samples to use (Default = 1000). + batch_size - data batch size (keep low enough to fit on GPU, default 2048). + use_gpu - use gpu for generating samples? + export_slot + adata.uns slot where to export results + add_to_obsm + posterior distribution summary to export in adata.obsm (['means', 'stds', 'q05', 'q95']). + use_quantiles + compute quantiles directly (True, more memory efficient) or use samples (False, default). + If True, means and stds cannot be computed so are not exported and returned. + Returns + ------- + + """ + + sample_kwargs = sample_kwargs if isinstance(sample_kwargs, dict) else dict() + + # get posterior distribution summary + if use_quantiles: + add_to_obsm = [i for i in add_to_obsm if (i not in ["means", "stds"]) and ("q" in i)] + if len(add_to_obsm) == 0: + raise ValueError("No quantiles to export - please add add_to_obsm=['q05', 'q50', 'q95'].") + self.samples = dict() + for i in add_to_obsm: + q = float(f"0.{i[1:]}") + self.samples[f"post_sample_{i}"] = self.posterior_quantile(q=q, **sample_kwargs) + else: + # generate samples from posterior distributions for all parameters + # and compute mean, 5%/95% quantiles and standard deviation + self.samples = self.sample_posterior(**sample_kwargs) + + # export posterior distribution summary for all parameters and + # annotation (model, date, var, obs and cell type names) to anndata object + adata.uns[export_slot] = self._export2adata(self.samples) + + # add estimated cell abundance as dataframe to obsm in anndata + # first convert np.arrays to pd.DataFrames with cell type and observation names + # data frames contain mean, 5%/95% quantiles and standard deviation, denoted by a prefix + for k in add_to_obsm: + sample_df = self.sample2df_obs( + self.samples, + site_name="w_sf_cell_comm", + summary_name=k, + name_prefix="predicted_cell_abundance", + ) + try: + adata.obsm[f"{k}_cell_abundance_w_sf"] = sample_df.loc[adata.obs.index, :] + except ValueError: + # Catching weird error with obsm: `ValueError: value.index does not match parent’s axis 1 names` + adata.obs[sample_df.columns] = sample_df.loc[adata.obs.index, :] + + return adata diff --git a/cell2location/models/_cellcomm_module.py b/cell2location/models/_cellcomm_module.py new file mode 100755 index 00000000..ab1b0f82 --- /dev/null +++ b/cell2location/models/_cellcomm_module.py @@ -0,0 +1,497 @@ +from typing import Optional + +import numpy as np +import pyro +import pyro.distributions as dist +import torch +from einops import rearrange +from pyro.infer.autoguide.utils import deep_getattr, deep_setattr +from pyro.nn import PyroModule +from scipy.sparse import coo_matrix +from scvi import REGISTRY_KEYS +from scvi.nn import one_hot + +from cell2location.nn.CellCommunicationToEffectNN import CellCommunicationToTfActivityNN + + +class CellCommModule(PyroModule): + r""" + Cell2location models the elements of :math:`D` as Negative Binomial distributed, + given an unobserved gene expression level (rate) :math:`mu` and a gene- and batch-specific + over-dispersion parameter :math:`\alpha_{e,g}` which accounts for unexplained variance: + + .. math:: + D_{s,g} \sim \mathtt{NB}(\mu_{s,g}, \alpha_{e,g}) + + Here, :math:`w_{s,f}` denotes regression weight of each reference signature :math:`f` at location :math:`s`, which can be interpreted as the expected number of cells at location :math:`s` that express reference signature :math:`f`; + :math:`g_{f,g}` denotes the reference signatures of cell types :math:`f` of each gene :math:`g`, `cell_state_df` input ; + """ + use_pathway_interaction_effect = True + dropout_rate = 0.0 + min_distance = 25.0 + r_l_affinity_alpha_prior = 10.0 + record_sr_occupancy = False + use_spatial_receptor_info_remove_sp_signal = True + + def __init__( + self, + n_obs, + n_vars, + n_factors, + n_batch, + detection_mean=1 / 2, + detection_alpha=20.0, + m_g_gene_level_prior={"mean": 1, "mean_var_ratio": 1.0, "alpha_mean": 3.0}, + alpha_g_phi_hyp_prior={"alpha": 9.0, "beta": 3.0}, + detection_hyp_prior={"mean_alpha": 10.0}, + w_sf_mean_var_ratio=5.0, + init_vals: Optional[dict] = None, + init_alpha: float = 20.0, + dropout_p: float = 0.0, + receptor_abundance: Optional[np.ndarray] = None, + use_spatial_receptor_info: bool = False, + per_cell_type_normalisation: Optional[np.ndarray] = None, + signal_receptor_mask: Optional[np.ndarray] = None, + receptor_tf_mask: Optional[np.ndarray] = None, + use_learnable_mean_var_ratio: bool = False, + average_distance_prior: float = 50.0, + distances: Optional[coo_matrix] = None, + n_hidden: int = 256, + use_cell_abundance_normalisation: bool = True, + use_alpha_likelihood: bool = True, + use_normal_likelihood: bool = False, + fixed_w_sf_mean_var_ratio: Optional[float] = None, + use_non_negative_weights: bool = False, + n_pathways: int = 50, + use_diffusion_domain: bool = False, + use_global_cell_abundance_model: bool = False, + use_max_distance_threshold: bool = False, + ): + super().__init__() + + self.n_obs = n_obs + self.n_vars = n_vars + self.n_factors = n_factors + self.n_batch = n_batch + self.n_hidden = n_hidden + self.n_pathways = n_pathways + + self.m_g_gene_level_prior = m_g_gene_level_prior + + self.alpha_g_phi_hyp_prior = alpha_g_phi_hyp_prior + self.w_sf_mean_var_ratio = w_sf_mean_var_ratio + detection_hyp_prior["mean"] = detection_mean + detection_hyp_prior["alpha"] = detection_alpha + self.detection_hyp_prior = detection_hyp_prior + + self.dropout_p = dropout_p + if self.dropout_p is not None: + self.dropout = torch.nn.Dropout(p=self.dropout_p) + + if receptor_abundance is not None: + self.register_buffer("receptor_abundance", torch.tensor(receptor_abundance.astype("float32"))) + self.use_spatial_receptor_info = use_spatial_receptor_info + if per_cell_type_normalisation is not None: + self.register_buffer( + "per_cell_type_normalisation", torch.tensor(per_cell_type_normalisation.astype("float32")) + ) + self.signal_receptor_mask = signal_receptor_mask + self.receptor_tf_mask = receptor_tf_mask + + self.use_learnable_mean_var_ratio = use_learnable_mean_var_ratio + self.average_distance_prior = average_distance_prior + if distances is not None: + distances = coo_matrix(distances).astype("float32") + self.distances_scipy = distances + self.register_buffer( + "distances", + torch.sparse_coo_tensor( + torch.tensor(np.array([distances.row, distances.col])), + torch.tensor(distances.data.astype("float32")), + distances.shape, + ), + ) + self.use_cell_abundance_normalisation = use_cell_abundance_normalisation + self.use_alpha_likelihood = use_alpha_likelihood + self.use_normal_likelihood = use_normal_likelihood + self.fixed_w_sf_mean_var_ratio = fixed_w_sf_mean_var_ratio + self.use_non_negative_weights = use_non_negative_weights + self.use_diffusion_domain = use_diffusion_domain + self.use_global_cell_abundance_model = use_global_cell_abundance_model + self.use_max_distance_threshold = use_max_distance_threshold + + self.weights = PyroModule() + + if (init_vals is not None) & (type(init_vals) is dict): + self.np_init_vals = init_vals + for k in init_vals.keys(): + self.register_buffer(f"init_val_{k}", torch.tensor(init_vals[k])) + self.init_alpha = init_alpha + self.register_buffer("init_alpha_tt", torch.tensor(self.init_alpha)) + + self.register_buffer( + "detection_hyp_prior_alpha", + torch.tensor(self.detection_hyp_prior["alpha"]), + ) + self.register_buffer( + "detection_mean_hyp_prior_alpha", + torch.tensor(self.detection_hyp_prior["mean_alpha"]), + ) + self.register_buffer( + "detection_mean_hyp_prior_beta", + torch.tensor(self.detection_hyp_prior["mean_alpha"] / self.detection_hyp_prior["mean"]), + ) + + # compute hyperparameters from mean and sd + self.register_buffer("m_g_mu_hyp", torch.tensor(self.m_g_gene_level_prior["mean"])) + self.register_buffer( + "m_g_mu_mean_var_ratio_hyp", + torch.tensor(self.m_g_gene_level_prior["mean_var_ratio"]), + ) + + self.register_buffer("m_g_alpha_hyp_mean", torch.tensor(self.m_g_gene_level_prior["alpha_mean"])) + + self.register_buffer( + "alpha_g_phi_hyp_prior_alpha", + torch.tensor(self.alpha_g_phi_hyp_prior["alpha"]), + ) + self.register_buffer( + "alpha_g_phi_hyp_prior_beta", + torch.tensor(self.alpha_g_phi_hyp_prior["beta"]), + ) + + self.register_buffer("w_sf_mean_var_ratio_tensor", torch.tensor(self.w_sf_mean_var_ratio)) + + self.register_buffer("ones", torch.ones((1, 1))) + self.register_buffer("ones_1d", torch.ones(1)) + self.register_buffer("zeros", torch.zeros((1, 1))) + self.register_buffer("ones_1_n_factors", torch.ones((1, self.n_factors))) + self.register_buffer("ones_n_batch_1", torch.ones((self.n_batch, 1))) + self.register_buffer("eps", torch.tensor(1e-8)) + + @staticmethod + def _get_fn_args_from_batch(tensor_dict): + signal_abundance = tensor_dict["signal_abundance"] + w_sf = tensor_dict["w_sf"] + ind_x = tensor_dict["ind_x"].long().squeeze() + batch_index = tensor_dict[REGISTRY_KEYS.BATCH_KEY] + kwargs = {} + if "positions" in tensor_dict.keys(): + kwargs["positions"] = tensor_dict["positions"] + if "tiles" in tensor_dict.keys(): + kwargs["tiles"] = tensor_dict["tiles"] + if "tiles_unexpanded" in tensor_dict.keys(): + kwargs["tiles_unexpanded"] = tensor_dict["tiles_unexpanded"] + if "in_tissue" in tensor_dict.keys(): + kwargs["in_tissue"] = tensor_dict["in_tissue"].bool() + if "w_sf_lvl2" in tensor_dict.keys(): + kwargs["w_sf_lvl2"] = tensor_dict["w_sf_lvl2"] + return (signal_abundance, w_sf, ind_x, batch_index), kwargs + + def create_plates( + self, + signal_abundance, + w_sf, + idx, + batch_index, + tiles: torch.Tensor = None, + tiles_unexpanded: torch.Tensor = None, + positions: torch.Tensor = None, + in_tissue: torch.Tensor = None, + w_sf_lvl2: torch.Tensor = None, + ): + if tiles_unexpanded is not None: + tiles_in_use = (tiles.mean(0) > torch.tensor(0.99, device=tiles.device)).bool() + obs_in_use = (tiles_unexpanded[:, tiles_in_use].sum(1) > torch.tensor(0.0, device=tiles.device)).bool() + idx = idx[obs_in_use] + return pyro.plate("obs_plate", size=self.n_obs, dim=-2, subsample=idx) + + def list_obs_plate_vars(self): + """ + Create a dictionary with: + + 1. "name" - the name of observation/minibatch plate; + 2. "input" - indexes of model args to provide to encoder network when using amortised inference; + 3. "sites" - dictionary with + + * keys - names of variables that belong to the observation plate (used to recognise + and merge posterior samples for minibatch variables) + * values - the dimensions in non-plate axis of each variable (used to construct output + layer of encoder network when using amortised inference) + """ + input_transform = torch.log1p + n_in = self.n_vars + + return { + "name": "obs_plate", + "input": [0, 2], # expression data + (optional) batch index + "n_in": n_in, + "input_transform": [ + input_transform, + lambda x: x, + ], # how to transform input data before passing to NN + "input_normalisation": [ + False, + False, + ], # whether to normalise input data before passing to NN + "sites": {}, + } + + def get_cell_communication_module( + self, + name, + output_transform="softplus", + n_out: int = 1, + average_distance_prior: float = 50.0, + ): + # create module if it doesn't exist + if getattr(self.weights, name, None) is None: + deep_setattr( + self.weights, + name, + CellCommunicationToTfActivityNN( + name=name, + mode="signal_receptor_tf_effect_spatial", + output_transform=output_transform, + n_tfs=self.n_factors, + n_signals=self.signal_receptor_mask.shape[0], + n_receptors=self.signal_receptor_mask.shape[1], + n_out=n_out, + n_pathways=self.n_pathways, + signal_receptor_mask=self.signal_receptor_mask, # tells which receptors can bind which ligands + receptor_tf_mask=self.receptor_tf_mask, # tells which receptors can influence which TF (eg nuclear receptor = TF) + dropout_rate=self.dropout_rate, + use_horseshoe_prior=True, + use_gamma_horseshoe_prior=False, + weights_prior_tau=1.0, + use_pathway_interaction_effect=self.use_pathway_interaction_effect, + average_distance_prior=average_distance_prior, + use_non_negative_weights=self.use_non_negative_weights, + r_l_affinity_alpha_prior=self.r_l_affinity_alpha_prior, + use_global_cell_abundance_model=self.use_global_cell_abundance_model, + ), + ) + # get module + return deep_getattr(self.weights, name) + + def cell_comm_effect( + self, + signal_abundance, + receptor_abundance, + distances, + tiles, + obs_plate, + average_distance_prior=50.0, + obs_in_use=None, + w_sf=None, + use_diffusion_domain=False, + ): + # get module + module = self.get_cell_communication_module( + name="lr2abundance", + output_transform="softplus", + n_out=1, + average_distance_prior=average_distance_prior, + ) + # compute LR occupancy + if self.use_max_distance_threshold: + max_distance_threshold = average_distance_prior * 10.0 + else: + max_distance_threshold = None + bound_receptor_abundance_src = module.signal_receptor_occupancy_spatial( + signal_abundance, + receptor_abundance, + distances, + tiles, + obs_plate, + obs_in_use=obs_in_use, + w_sf=w_sf, + use_diffusion_domain=use_diffusion_domain, + max_distance_threshold=max_distance_threshold, + ) + # compute cell abundance prediction + w_sf_mu = module.signal_receptor_tf_effect_spatial( + bound_receptor_abundance_src, + ) + return w_sf_mu, bound_receptor_abundance_src + + def forward( + self, + signal_abundance, + w_sf, + idx, + batch_index, + tiles: torch.Tensor = None, + tiles_unexpanded: torch.Tensor = None, + positions: torch.Tensor = None, + in_tissue: torch.Tensor = None, + w_sf_lvl2: torch.Tensor = None, + ): + obs_plate = self.create_plates( + signal_abundance=signal_abundance, + w_sf=w_sf, + idx=idx, + batch_index=batch_index, + tiles=tiles, + tiles_unexpanded=tiles_unexpanded, + positions=positions, + in_tissue=in_tissue, + ) + obs_in_use = None + if tiles_unexpanded is not None: + tiles_in_use = (tiles.mean(0) > torch.tensor(0.99, device=tiles.device)).bool() + obs_in_use = (tiles_unexpanded[:, tiles_in_use].sum(1) > torch.tensor(0.0, device=tiles.device)).bool() + batch_index = batch_index[obs_in_use] + obs2sample = one_hot(batch_index, self.n_batch).float() + + if getattr(self, "distances", None) is not None: + distances = self.distances + elif positions is not None: + # compute distance using positions [observations, 2] + distances = ( + (positions.unsqueeze(1) - positions.unsqueeze(0)) # [observations, 1, 2] # [1, observations, 2] + .pow(2) + .sum(-1) + .sqrt() + ) + torch.tensor(self.min_distance, device=positions.device) + + # =====================Cell abundances w_sf======================= # + if self.use_spatial_receptor_info: + if self.use_spatial_receptor_info_remove_sp_signal: + receptor_abundance_ = torch.einsum( + "fr,cf,f -> cr", + self.receptor_abundance.T, + w_sf, + self.per_cell_type_normalisation, + ) + receptor_abundance_norm = torch.einsum( + "fr,r -> fr", + self.receptor_abundance.T, + self.receptor_abundance.sum(-1), + ) + receptor_abundance = torch.einsum( + "fr,cr -> cfr", + receptor_abundance_norm, + receptor_abundance_, + ) + else: + receptor_abundance = torch.einsum( + "fr,cf,f -> cfr", + self.receptor_abundance.T, + w_sf, + self.per_cell_type_normalisation, + ) + else: + receptor_abundance = self.receptor_abundance.T + w_sf_mu_cell_comm, bound_receptor_abundance_src = self.cell_comm_effect( + signal_abundance=signal_abundance, + receptor_abundance=receptor_abundance, + distances=distances, + tiles=tiles, + average_distance_prior=self.average_distance_prior, + obs_plate=obs_plate, + obs_in_use=obs_in_use, + w_sf=w_sf, + use_diffusion_domain=self.use_diffusion_domain, + ) + if not self.training and self.record_sr_occupancy: + with obs_plate: + # {sr pair, location * cell type} -> {sr pair, location, cell type} + bound_receptor_abundance_src = rearrange( + bound_receptor_abundance_src, + "r (c f) -> f c r", + f=self.receptor_abundance.shape[-1], + ).sum(-3) + if obs_in_use is not None: + bound_receptor_abundance_src = bound_receptor_abundance_src[obs_in_use, :] + pyro.deterministic( + "bound_receptor_abundance_sr_c", + bound_receptor_abundance_src, + ) + if self.fixed_w_sf_mean_var_ratio is not None: + w_sf_mean_var_ratio = torch.tensor(self.fixed_w_sf_mean_var_ratio, device=w_sf_mu_cell_comm.device) + else: + w_sf_mean_var_ratio_hyp = pyro.sample( + "w_sf_mean_var_ratio_hyp_lik", + dist.Gamma(self.w_sf_mean_var_ratio_tensor, self.ones).expand([1, 1]).to_event(2), + ) # prior mean 5.0 + w_sf_mean_var_ratio = pyro.sample( + "w_sf_mean_var_ratio_lik", + dist.Exponential(w_sf_mean_var_ratio_hyp).expand([1, self.n_factors]).to_event(2), + ) # (self.n_batch, self.n_vars) prior mean 0.2 + if self.use_normal_likelihood: + w_sf_mean_var_ratio = w_sf_mean_var_ratio + torch.tensor(1.0 / 50.0, device=w_sf_mean_var_ratio.device) + else: + w_sf_mean_var_ratio = self.ones / ( + w_sf_mean_var_ratio + torch.tensor(1.0 / 50.0, device=w_sf_mean_var_ratio.device) + ) + torch.tensor(5.0, device=w_sf_mean_var_ratio.device) + if tiles_unexpanded is not None: + w_sf_mu_cell_comm = w_sf_mu_cell_comm[obs_in_use] + w_sf = w_sf[obs_in_use] + if w_sf_lvl2 is not None: + w_sf_lvl2 = w_sf_lvl2[obs_in_use] + if self.use_cell_abundance_normalisation: + # =====================Location-specific detection efficiency ======================= # + # y_s with hierarchical mean prior + detection_mean_y_e = pyro.sample( + "detection_mean_y_e", + dist.Gamma( + self.ones * self.detection_mean_hyp_prior_alpha, + self.ones * self.detection_mean_hyp_prior_beta, + ) + .expand([self.n_batch, 1]) + .to_event(2), + ) + detection_hyp_prior_alpha = pyro.deterministic( + "detection_hyp_prior_alpha", + self.ones_n_batch_1 * self.detection_hyp_prior_alpha, + ) + + beta = (obs2sample @ detection_hyp_prior_alpha) / (obs2sample @ detection_mean_y_e) + with obs_plate: + k = "detection_y_s" + detection_y_s = pyro.sample( + k, + dist.Gamma(obs2sample @ detection_hyp_prior_alpha, beta), + ) # (self.n_obs, 1) + w_sf_mu_cell_comm = w_sf_mu_cell_comm * detection_y_s + + if w_sf_lvl2 is not None: + with obs_plate: + pyro.deterministic("w_sf_wo_lvl2_cell_comm", w_sf_mu_cell_comm) # (self.n_obs, self.n_factors) + w_sf_mu_cell_comm = w_sf_mu_cell_comm * w_sf_lvl2 + + with obs_plate: + k = "w_sf" + pyro.deterministic(f"{k}_cell_comm", w_sf_mu_cell_comm) # (self.n_obs, self.n_factors) + + if self.use_alpha_likelihood: + with obs_plate: + pyro.sample( + f"{k}_obs", + dist.Gamma( + w_sf_mean_var_ratio, + w_sf_mean_var_ratio / w_sf_mu_cell_comm, + ), + obs=w_sf, + ) # (self.n_obs, self.n_factors) + elif self.use_normal_likelihood: + with obs_plate: + pyro.sample( + f"{k}_obs", + dist.Normal( + w_sf_mu_cell_comm, + w_sf_mean_var_ratio, + ), + obs=w_sf, + ) # (self.n_obs, self.n_factors) + else: + with obs_plate: + pyro.sample( + f"{k}_obs", + dist.Gamma( + w_sf_mu_cell_comm * w_sf_mean_var_ratio, + w_sf_mean_var_ratio, + ), + obs=w_sf, + ) # (self.n_obs, self.n_factors) diff --git a/cell2location/models/base/_pyro_base_loc_module.py b/cell2location/models/base/_pyro_base_loc_module.py index e03b6bcb..923c644a 100755 --- a/cell2location/models/base/_pyro_base_loc_module.py +++ b/cell2location/models/base/_pyro_base_loc_module.py @@ -29,11 +29,11 @@ def __init__( amortised: bool = False, encoder_mode: Literal["single", "multiple", "single-multiple"] = "single", encoder_kwargs: Optional[dict] = None, - data_transform="log1p", create_autoguide_kwargs: Optional[dict] = None, + on_load_kwargs: Optional[dict] = None, **kwargs, ): - super().__init__() + super().__init__(on_load_kwargs=on_load_kwargs) self.hist = [] self._model = model(**kwargs) @@ -45,7 +45,6 @@ def __init__( model=self.model, amortised=self.is_amortised, encoder_kwargs=encoder_kwargs, - data_transform=data_transform, encoder_mode=encoder_mode, init_loc_fn=self.init_to_value, n_cat_list=[kwargs["n_batch"]], diff --git a/cell2location/models/base/_pyro_base_reference_module.py b/cell2location/models/base/_pyro_base_reference_module.py index 22413947..bd67b11a 100755 --- a/cell2location/models/base/_pyro_base_reference_module.py +++ b/cell2location/models/base/_pyro_base_reference_module.py @@ -1,4 +1,4 @@ -from typing import Literal +from typing import Literal, Optional from scvi.module.base import PyroBaseModuleClass @@ -12,7 +12,7 @@ def __init__( amortised: bool = False, encoder_mode: Literal["single", "multiple", "single-multiple"] = "single", encoder_kwargs=None, - data_transform="log1p", + create_autoguide_kwargs: Optional[dict] = None, **kwargs, ): """ @@ -36,14 +36,17 @@ def __init__( self._model = model(**kwargs) self._amortised = amortised + if create_autoguide_kwargs is None: + create_autoguide_kwargs = dict() + self._guide = self._create_autoguide( model=self.model, amortised=self.is_amortised, encoder_kwargs=encoder_kwargs, - data_transform=data_transform, encoder_mode=encoder_mode, init_loc_fn=self.init_to_value, n_cat_list=[kwargs["n_batch"]], + **create_autoguide_kwargs, ) self._get_fn_args_from_batch = self._model._get_fn_args_from_batch diff --git a/cell2location/models/base/_pyro_mixin.py b/cell2location/models/base/_pyro_mixin.py index 8eb7cd59..902e217a 100755 --- a/cell2location/models/base/_pyro_mixin.py +++ b/cell2location/models/base/_pyro_mixin.py @@ -1,4 +1,5 @@ import gc +import inspect import logging from datetime import date from functools import partial @@ -20,14 +21,28 @@ from scvi.model._utils import parse_device_args from scvi.module.base import PyroBaseModuleClass from scvi.train import PyroTrainingPlan as PyroTrainingPlan_scvi +from scvi.utils import track from ...distributions.AutoAmortisedNormalMessenger import ( AutoAmortisedHierarchicalNormalMessenger, + AutoNormalMessenger, ) logger = logging.getLogger(__name__) +def setup_pyro_model(dataloader, pl_module): + """Way to warmup Pyro Model and Guide in an automated way. + + Setup occurs before any device movement, so params are iniitalized on CPU. + """ + for tensors in dataloader: + tens = {k: t.to(pl_module.device) for k, t in tensors.items()} + args, kwargs = pl_module.module._get_fn_args_from_batch(tens) + pl_module.module.guide(*args, **kwargs) + break + + def init_to_value(site=None, values={}, init_fn=init_to_mean): if site is None: return partial(init_to_value, values=values) @@ -37,11 +52,99 @@ def init_to_value(site=None, values={}, init_fn=init_to_mean): return init_fn(site) +def expand_zeros_along_dim(tensor, size, dim): + shape = np.array(tensor.shape) + shape[dim] = size + return np.zeros(shape) + + +def complete_tensor_along_dim(tensor, indices, dim, value, mode="put"): + shape = value.shape + shape = np.ones(len(shape)) + shape[dim] = len(indices) + shape = shape.astype(int) + indices = indices.reshape(shape) + if mode == "take": + return np.take_along_axis(arr=tensor, indices=indices, axis=dim) + np.put_along_axis(arr=tensor, indices=indices, values=value, axis=dim) + return tensor + + +def _complete_full_tensors_using_plates( + means_global, + means, + plate_dict, + obs_plate_sites, + plate_indices, + plate_dim, + named_dims, +): + # complete full sized tensors with minibatch values given minibatch indices + for k in means_global.keys(): + # find which and how many plates contain this tensor + plates = [plate for plate in plate_dict.keys() if k in obs_plate_sites[plate].keys()] + if len(plates) == 1: + # if only one plate contains this tensor, complete it using the plate indices + if k in named_dims.keys(): + dim = named_dims[k] + else: + dim = plate_dim[plates[0]] + means_global[k] = complete_tensor_along_dim( + means_global[k], + plate_indices[plates[0]], + dim, + means[k], + ) + elif len(plates) == 2: + # subset data to index for plate 0 and fill index for plate 1 + if k in named_dims.keys() and (k in obs_plate_sites[list(plate_dict.keys())[0]].keys()): + dim0 = named_dims[k] + else: + dim0 = plate_dim[plates[0]] + means_global_k = complete_tensor_along_dim( + means_global[k], + plate_indices[plates[0]], + dim0, + means[k], + mode="take", + ) + if k in named_dims.keys() and (k in obs_plate_sites[list(plate_dict.keys())[1]].keys()): + dim1 = named_dims[k] + else: + dim1 = plate_dim[plates[1]] + means_global_k = complete_tensor_along_dim( + means_global_k, + plate_indices[plates[1]], + dim1, + means[k], + ) + # fill index for plate 0 in the full data + means_global[k] = complete_tensor_along_dim( + means_global[k], + plate_indices[plates[0]], + dim0, + means_global_k, + ) + # TODO add a test - observed variables should be identical if this code works correctly + # This code works correctly but the test needs to be added eventually + # np.allclose( + # samples['data_chromatin'].squeeze(-1).T, + # mod_reg.adata_manager.get_from_registry('X')[ + # :, ~mod_reg.adata_manager.get_from_registry('gene_bool').ravel() + # ].toarray() + # ) + else: + NotImplementedError( + f"Posterior sampling/mean/median/quantile not supported for variables with > 2 plates: {k} has {len(plates)}" + ) + return means_global + + class AutoGuideMixinModule: """ This mixin class provides methods for: - - initialising standard AutoNormal guides + - initialising standard AutoNormalMessenger guides - initialising amortised guides (AutoNormalEncoder) - initialising amortised guides with special additional inputs @@ -52,12 +155,11 @@ def _create_autoguide( model, amortised, encoder_kwargs, - data_transform, encoder_mode, init_loc_fn=init_to_mean(fallback=init_to_feasible), n_cat_list: list = [], encoder_instance=None, - guide_class=AutoNormal, + guide_class=AutoNormalMessenger, guide_kwargs: Optional[dict] = None, ): if guide_kwargs is None: @@ -83,44 +185,10 @@ def _create_autoguide( else: encoder_kwargs = encoder_kwargs if isinstance(encoder_kwargs, dict) else dict() n_hidden = encoder_kwargs["n_hidden"] if "n_hidden" in encoder_kwargs.keys() else 200 - if data_transform is None: - pass - elif isinstance(data_transform, np.ndarray): - # add extra info about gene clusters as input to NN - self.register_buffer("gene_clusters", torch.tensor(data_transform.astype("float32"))) - n_in = model.n_vars + data_transform.shape[1] - data_transform = self._data_transform_clusters() - elif data_transform == "log1p": - # use simple log1p transform - data_transform = torch.log1p - n_in = self.model.n_vars - elif ( - isinstance(data_transform, dict) - and "var_std" in list(data_transform.keys()) - and "var_mean" in list(data_transform.keys()) - ): - # use data transform by scaling - n_in = model.n_vars - self.register_buffer( - "var_mean", - torch.tensor(data_transform["var_mean"].astype("float32").reshape((1, n_in))), - ) - self.register_buffer( - "var_std", - torch.tensor(data_transform["var_std"].astype("float32").reshape((1, n_in))), - ) - data_transform = self._data_transform_scale() - else: - # use custom data transform - data_transform = data_transform - n_in = model.n_vars amortised_vars = model.list_obs_plate_vars() if len(amortised_vars["input"]) >= 2: encoder_kwargs["n_cat_list"] = n_cat_list - if data_transform is not None: - amortised_vars["input_transform"][0] = data_transform - if "n_in" in amortised_vars.keys(): - n_in = amortised_vars["n_in"] + n_in = amortised_vars["n_in"] if getattr(model, "discrete_variables", None) is not None: model = poutine.block(model, hide=model.discrete_variables) _guide = AutoAmortisedHierarchicalNormalMessenger( @@ -136,19 +204,6 @@ def _create_autoguide( ) return _guide - def _data_transform_clusters(self): - def _data_transform(x): - return torch.log1p(torch.cat([x, x @ self.gene_clusters], dim=1)) - - return _data_transform - - def _data_transform_scale(self): - def _data_transform(x): - # return (x - self.var_mean) / self.var_std - return x / self.var_std - - return _data_transform - class QuantileMixin: """ @@ -184,7 +239,92 @@ def optim_param(module_name, param_name): return optim_param - @torch.no_grad() + def _get_obs_plate_sites_v2( + self, + args: list, + kwargs: dict, + plate_name: str = None, + return_observed: bool = False, + return_deterministic: bool = True, + ): + """ + Automatically guess which model sites belong to observation/minibatch plate. + This function requires minibatch plate name specified in `self.module.list_obs_plate_vars["name"]`. + Parameters + ---------- + args + Arguments to the model. + kwargs + Keyword arguments to the model. + return_observed + Record samples of observed variables. + Returns + ------- + Dictionary with keys corresponding to site names and values to plate dimension. + """ + if plate_name is None: + plate_name = self.module.list_obs_plate_vars["name"] + + def try_trace(args, kwargs): + try: + trace_ = poutine.trace(self.module.guide).get_trace(*args, **kwargs) + trace_ = poutine.trace(poutine.replay(self.module.model, trace_)).get_trace(*args, **kwargs) + except ValueError: + # if sample is unsuccessful try again + trace_ = try_trace(args, kwargs) + return trace_ + + trace = try_trace(args, kwargs) + + # find plate dimension + obs_plate = { + name: { + fun.name: fun + for fun in site["cond_indep_stack"] + if (fun.name in plate_name) or (fun.name == plate_name) + } + for name, site in trace.nodes.items() + if ( + (site["type"] == "sample") # sample statement + and ( + ((not site.get("is_observed", True)) or return_observed) # don't save observed unless requested + or (site.get("infer", False).get("_deterministic", False) and return_deterministic) + ) # unless it is deterministic + and not isinstance(site.get("fn", None), poutine.subsample_messenger._Subsample) # don't save plates + ) + if any(f.name == plate_name for f in site["cond_indep_stack"]) + } + + return obs_plate + + def _get_dataloader( + self, + batch_size, + data_loader_indices, + dl_kwargs={}, + ): + if dl_kwargs is None: + dl_kwargs = dict() + signature_keys = list(inspect.signature(self._data_splitter_cls).parameters.keys()) + if "drop_last" in signature_keys: + dl_kwargs["drop_last"] = False + if "shuffle" in signature_keys: + dl_kwargs["shuffle_training"] = False + if "shuffle_set_split" in signature_keys: + dl_kwargs["shuffle_set_split"] = False + if "indices" in signature_keys: + dl_kwargs["indices"] = data_loader_indices + train_dl = self._data_splitter_cls( + self.adata_manager, + batch_size=batch_size, + train_size=1.0, + **dl_kwargs, + ) + train_dl.setup() + train_dl = train_dl.train_dataloader() + return train_dl + + @torch.inference_mode() def _posterior_quantile_minibatch( self, q: float = 0.5, @@ -192,8 +332,11 @@ def _posterior_quantile_minibatch( accelerator: str = "auto", device: Union[int, str] = "auto", use_median: bool = True, + return_observed: bool = False, exclude_vars: list = None, data_loader_indices=None, + show_progress: bool = True, + dl_kwargs: Optional[dict] = None, ): """ Compute median of the posterior distribution of each parameter, separating local (minibatch) variable @@ -228,136 +371,171 @@ def _posterior_quantile_minibatch( self.module.eval() - train_dl = AnnDataLoader(self.adata_manager, shuffle=False, batch_size=batch_size, indices=data_loader_indices) + if batch_size == self.adata_manager.adata.n_obs: + raise NotImplementedError("Please use batch_size < self.adata_manager.adata.n_obs") + + train_dl = self._get_dataloader( + batch_size=batch_size, + data_loader_indices=data_loader_indices, + dl_kwargs=dl_kwargs, + ) - # sample local parameters i = 0 - for tensor_dict in train_dl: + for tensor_dict in track( + train_dl, + style="tqdm", + description=f"Computing posterior quantile {q}, data batch: ", + disable=not show_progress, + ): args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) args = [a.to(device) for a in args] kwargs = {k: v.to(device) for k, v in kwargs.items()} self.to_device(device) if i == 0: + minibatch_plate_names = self.module.list_obs_plate_vars["name"] + plates = self.module.model.create_plates(*args, **kwargs) + if not isinstance(plates, list): + plates = [plates] + # find plate indices & dim + plate_dict = { + plate.name: plate + for plate in plates + if ((plate.name in minibatch_plate_names) or (plate.name == minibatch_plate_names)) + } + plate_size = {name: plate.size for name, plate in plate_dict.items()} + if data_loader_indices is not None: + # set total plate size to the number of indices in DL not total number of observations + # `data_loader_indices=not` None option is not really used + plate_size = { + name: len(train_dl.indices) + for name, plate in plate_dict.items() + if plate.name == minibatch_plate_names + } + plate_dim = {name: plate.dim for name, plate in plate_dict.items()} + plate_indices = {name: plate.indices.detach().cpu().numpy() for name, plate in plate_dict.items()} # find plate sites - obs_plate_sites = self._get_obs_plate_sites(args, kwargs, return_observed=True) - if len(obs_plate_sites) == 0: - # if no local variables - don't sample - break - # find plate dimension - obs_plate_dim = list(obs_plate_sites.values())[0] + obs_plate_sites = { + plate: self._get_obs_plate_sites_v2(args, kwargs, plate_name=plate, return_observed=return_observed) + for plate in plate_dict.keys() + } if use_median and q == 0.5: - means = self.module.guide.median(*args, **kwargs) + # use median rather than quantile method + def try_median(args, kwargs): + try: + means_ = self.module.guide.median(*args, **kwargs) + except ValueError: + # if sample is unsuccessful try again + means_ = try_median(args, kwargs) + return means_ + + means = try_median(args, kwargs) else: - means = self.module.guide.quantiles([q], *args, **kwargs) + + def try_quantiles(args, kwargs): + try: + means_ = self.module.guide.quantiles([q], *args, **kwargs) + except ValueError: + # if sample is unsuccessful try again + means_ = try_quantiles(args, kwargs) + return means_ + + means = try_quantiles(args, kwargs) + valid_sites = self._get_valid_sites(args, kwargs, return_observed=return_observed) means = { - k: means[k].cpu().numpy() + k: means[k].detach().cpu().numpy() for k in means.keys() - if (k in obs_plate_sites) and (k not in exclude_vars) + if (k not in exclude_vars) and (k in valid_sites) } - + means_global = means.copy() + for plate in plate_dict.keys(): + # create full sized tensors according to plate size + means_global = { + k: ( + expand_zeros_along_dim( + means_global[k], + plate_size[plate], + self.module.model.named_dims[k] + if (k in getattr(self.module.model, "named_dims", dict()).keys()) + else plate_dim[plate], + ) + if k in obs_plate_sites[plate].keys() + else means_global[k] + ) + for k in means_global.keys() + } + # complete full sized tensors with minibatch values given minibatch indices + means_global = _complete_full_tensors_using_plates( + means_global=means_global, + means=means, + plate_dict=plate_dict, + obs_plate_sites=obs_plate_sites, + plate_indices=plate_indices, + plate_dim=plate_dim, + named_dims=getattr(self.module.model, "named_dims", dict()), + ) + if np.all([len(v) == 0 for v in obs_plate_sites.values()]): + # if no local variables - don't sample further - return results now + break else: if use_median and q == 0.5: - means_ = self.module.guide.median(*args, **kwargs) - else: - means_ = self.module.guide.quantiles([q], *args, **kwargs) - means_ = { - k: means_[k].cpu().numpy() - for k in means_.keys() - if (k in obs_plate_sites) and (k not in exclude_vars) - } - means = {k: np.concatenate([means[k], means_[k]], axis=obs_plate_dim) for k in means.keys()} - i += 1 - # sample global parameters - tensor_dict = next(iter(train_dl)) - args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) - args = [a.to(device) for a in args] - kwargs = {k: v.to(device) for k, v in kwargs.items()} - self.to_device(device) + def try_median(args, kwargs): + try: + means_ = self.module.guide.median(*args, **kwargs) + except ValueError: + # if sample is unsuccessful try again + means_ = try_median(args, kwargs) + return means_ - if use_median and q == 0.5: - global_means = self.module.guide.median(*args, **kwargs) - else: - global_means = self.module.guide.quantiles([q], *args, **kwargs) - global_means = { - k: global_means[k].cpu().numpy() - for k in global_means.keys() - if (k not in obs_plate_sites) and (k not in exclude_vars) - } + means = try_median(args, kwargs) + else: - for k in global_means.keys(): - means[k] = global_means[k] + def try_quantiles(args, kwargs): + try: + means_ = self.module.guide.quantiles([q], *args, **kwargs) + except ValueError: + # if sample is unsuccessful try again + means_ = try_quantiles(args, kwargs) + return means_ - # quantile returns tensors with 0th dimension = 1 - if not (use_median and q == 0.5) and ( - not isinstance(self.module.guide, AutoAmortisedHierarchicalNormalMessenger) - ): - means = {k: means[k].squeeze(0) for k in means.keys()} + means = try_quantiles(args, kwargs) + valid_sites = self._get_valid_sites(args, kwargs, return_observed=return_observed) + means = { + k: means[k].detach().cpu().numpy() + for k in means.keys() + if (k not in exclude_vars) and (k in valid_sites) + } + # find plate indices & dim + plates = self.module.model.create_plates(*args, **kwargs) + if not isinstance(plates, list): + plates = [plates] + plate_dict = { + plate.name: plate + for plate in plates + if ((plate.name in minibatch_plate_names) or (plate.name == minibatch_plate_names)) + } + plate_indices = {name: plate.indices.detach().cpu().numpy() for name, plate in plate_dict.items()} + # TODO - is this correct to call this function again? find plate sites + obs_plate_sites = { + plate: self._get_obs_plate_sites_v2(args, kwargs, plate_name=plate, return_observed=return_observed) + for plate in plate_dict.keys() + } + # complete full sized tensors with minibatch values given minibatch indices + means_global = _complete_full_tensors_using_plates( + means_global=means_global, + means=means, + plate_dict=plate_dict, + obs_plate_sites=obs_plate_sites, + plate_indices=plate_indices, + plate_dim=plate_dim, + named_dims=getattr(self.module.model, "named_dims", dict()), + ) + i += 1 self.module.to(device) - return means - - @torch.no_grad() - def _posterior_quantile( - self, - q: float = 0.5, - batch_size: int = None, - accelerator: str = "auto", - device: Union[int, str] = "auto", - use_median: bool = True, - exclude_vars: list = None, - data_loader_indices=None, - ): - """ - Compute median of the posterior distribution of each parameter pyro models trained without amortised inference. - - Parameters - ---------- - q - Quantile to compute - use_gpu - Bool, use gpu? - use_median - Bool, when q=0.5 use median rather than quantile method of the guide - - Returns - ------- - dictionary {variable_name: posterior quantile} - - """ - - self.module.eval() - _, _, device = parse_device_args( - accelerator=accelerator, - devices=device, - return_device="torch", - validate_single_device=True, - ) - if batch_size is None: - batch_size = self.adata_manager.adata.n_obs - train_dl = AnnDataLoader(self.adata_manager, shuffle=False, batch_size=batch_size, indices=data_loader_indices) - # sample global parameters - tensor_dict = next(iter(train_dl)) - args, kwargs = self.module._get_fn_args_from_batch(tensor_dict) - args = [a.to(device) for a in args] - kwargs = {k: v.to(device) for k, v in kwargs.items()} - self.to_device(device) - - if use_median and q == 0.5: - means = self.module.guide.median(*args, **kwargs) - else: - means = self.module.guide.quantiles([q], *args, **kwargs) - means = {k: means[k].cpu().detach().numpy() for k in means.keys() if k not in exclude_vars} - - # quantile returns tensors with 0th dimension = 1 - if not (use_median and q == 0.5) and ( - not isinstance(self.module.guide, AutoAmortisedHierarchicalNormalMessenger) - ): - means = {k: means[k].squeeze(0) for k in means.keys()} - - return means + return means_global def posterior_quantile(self, exclude_vars: list = None, batch_size: int = None, **kwargs): """ @@ -385,10 +563,11 @@ def posterior_quantile(self, exclude_vars: list = None, batch_size: int = None, # median/quantiles in AutoNormal does not require minibatches batch_size = None - if batch_size is not None: - return self._posterior_quantile_minibatch(exclude_vars=exclude_vars, batch_size=batch_size, **kwargs) - else: - return self._posterior_quantile(exclude_vars=exclude_vars, batch_size=batch_size, **kwargs) + if batch_size is None: + from scvi import settings + + batch_size = settings.batch_size + return self._posterior_quantile_minibatch(exclude_vars=exclude_vars, batch_size=batch_size, **kwargs) class PltExportMixin: diff --git a/cell2location/models/reference/_reference_model.py b/cell2location/models/reference/_reference_model.py index 641416b3..5c2ab50d 100755 --- a/cell2location/models/reference/_reference_model.py +++ b/cell2location/models/reference/_reference_model.py @@ -33,8 +33,6 @@ class RegressionModel(QuantileMixin, PyroSampleMixin, PyroSviTrainMixin, PltExpo ---------- adata single-cell AnnData object that has been registered via :func:`~scvi.data.setup_anndata`. - use_gpu - Use the GPU? **model_kwargs Keyword args for :class:`~scvi.external.LocationModelLinearDependentWMultiExperimentModel` diff --git a/cell2location/models/reference/_reference_module.py b/cell2location/models/reference/_reference_module.py index a3af7984..61cadac5 100755 --- a/cell2location/models/reference/_reference_module.py +++ b/cell2location/models/reference/_reference_module.py @@ -176,8 +176,8 @@ def list_obs_plate_vars(self): } def forward(self, x_data, idx, batch_index, label_index, extra_categoricals): - obs2sample = one_hot(batch_index, self.n_batch) - obs2label = one_hot(label_index, self.n_factors) + obs2sample = one_hot(batch_index, self.n_batch).float() + obs2label = one_hot(label_index, self.n_factors).float() if self.n_extra_categoricals is not None: obs2extra_categoricals = torch.cat( [ @@ -188,7 +188,7 @@ def forward(self, x_data, idx, batch_index, label_index, extra_categoricals): for i, n_cat in enumerate(self.n_extra_categoricals) ], dim=1, - ) + ).float() obs_plate = self.create_plates(x_data, idx, batch_index, label_index, extra_categoricals) diff --git a/cell2location/nn/CellCommunicationToEffectNN.py b/cell2location/nn/CellCommunicationToEffectNN.py new file mode 100755 index 00000000..b0fb95a3 --- /dev/null +++ b/cell2location/nn/CellCommunicationToEffectNN.py @@ -0,0 +1,1271 @@ +import numpy as np +import pyro +import torch +from einops import rearrange +from pyro.infer.autoguide.utils import deep_getattr, deep_setattr +from pyro.nn import PyroModule +from scvi.nn import one_hot +from torch import nn as nn + +from ._mixins import CreateParameterMixin + + +class CellCommunicationToTfActivityNN( + PyroModule, + CreateParameterMixin, +): + """ + Defining a function that maps signal abundance in the microenvironment + to TF activity or cell abundance in target cells via receptors expressed by target cells. + + Parameters + ---------- + n_tfs + The number of TFs + n_signals + The number of signals + n_receptors + The number of receptors (both single- and multi-subunit). + n_hidden + The number of nodes per hidden layer + dropout_rate + Dropout rate to apply to each of the hidden layers + use_layer_norm + Whether to have `LayerNorm` layers or not + use_activation + Whether to have layer activation or not + use_activation + Whether to have layer activation at last layer or not + bias + Whether to learn bias in linear layers or not + activation_fn + Which activation function to use + """ + + return_composite_matches = False + mask_binding_cooperativity = False + mask_regulatory_cooperativity = False + tf_features_reparameterised_regularisation = True + normalise_distance_weights = False # True False + use_bayesian_protein_effects = True + bayesian = False + use_sqrt_normalisation = True + + promoter_distance_prior = 50 + use_footprinting_masked_tn5_index = -1 + n_spatial_domains = 5 + + use_cached_effects = False + cached_effects = dict() + + def __init__( + self, + name: str, + n_tfs: int, + n_signals: int, + n_receptors: int, + n_out: int = 1, + n_pathways: int = 1, + mode: str = "signal_receptor_tf_effect", + signal_receptor_mask: np.ndarray = None, # tells which receptors can bind which ligands + receptor_tf_mask: np.ndarray = None, # tells which receptors can influence which TF (eg nuclear receptor = TF) + dropout_rate: float = 0.0, + activation_fn: nn.Module = nn.Softplus, + weights_prior={"shape": 1.0, "scale": 1.0, "informed_scale": 0.2}, + bias_prior={"mean": 0.0, "sigma": 1.0}, + mode_suffix: str = "_free", + use_horseshoe_prior: bool = True, + use_gamma_horseshoe_prior: bool = False, + weights_prior_tau: float = 1, + output_transform: str = "proportion", + use_unbound_concentration: bool = False, + use_pathway_interaction_effect: bool = True, + average_distance_prior: float = 50.0, + use_non_negative_weights: bool = False, + use_global_cell_abundance_model: bool = False, + r_l_affinity_alpha_prior: float = 10.0, + ): + super().__init__() + + self.name = name + self.name_prefix = name + self.mode = mode + self.n_tfs = n_tfs + self.n_signals = n_signals + self.n_receptors = n_receptors + self.n_out = n_out + self.n_pathways = n_pathways + self.dropout_rate = dropout_rate + + self.activation_fn = activation_fn + self.weights_prior = weights_prior + self.bias_prior = bias_prior + self.mode_suffix = mode_suffix + self.use_horseshoe_prior = use_horseshoe_prior + self.use_gamma_horseshoe_prior = use_gamma_horseshoe_prior + + self.output_transform = output_transform + self.use_unbound_concentration = use_unbound_concentration + + self.use_pathway_interaction_effect = use_pathway_interaction_effect + + self.average_distance_prior = average_distance_prior + + self.use_non_negative_weights = use_non_negative_weights + self.use_global_cell_abundance_model = use_global_cell_abundance_model + + self.r_l_affinity_alpha_prior = r_l_affinity_alpha_prior + + self.weights = PyroModule() + + if signal_receptor_mask is None: + signal_receptor_mask = np.ones((self.n_signals, self.n_receptors)) + assert signal_receptor_mask.shape == (self.n_signals, self.n_receptors), ( + f"signal_receptor_mask shape {signal_receptor_mask.shape} " + f"does not match n_signals {self.n_signals} and n_receptors {self.n_receptors}" + ) + from scipy.sparse import coo_matrix + + signal_receptor_mask = coo_matrix(signal_receptor_mask) + self.signal_receptor_mask_scipy = signal_receptor_mask + if receptor_tf_mask is not None: + self.register_buffer( + "receptor_tf_mask", + torch.tensor(np.asarray(receptor_tf_mask).astype("float32")), + ) + else: + self.receptor_tf_mask = None + + self.register_buffer("n_tfs_tensor", torch.tensor(float(n_tfs))) + self.register_buffer("n_signals_tensor", torch.tensor(float(n_signals))) + self.register_buffer("n_receptors_tensor", torch.tensor(float(n_receptors))) + self.register_buffer("ones", torch.ones(1)) + self.register_buffer("zeros", torch.zeros(1)) + self.register_buffer("ten", torch.tensor(10.0)) + self.register_buffer("weights_prior_shape", torch.tensor(float(self.weights_prior["shape"]))) + self.register_buffer("weights_prior_scale", torch.tensor(float(self.weights_prior["scale"]))) + self.register_buffer( + "weights_prior_informed_scale", + torch.tensor(float(self.weights_prior["informed_scale"])), + ) + self.register_buffer("bias_mean_prior", torch.tensor(float(self.bias_prior["mean"]))) + self.register_buffer("bias_sigma_prior", torch.tensor(float(self.bias_prior["sigma"]))) + self.register_buffer("weights_prior_tau", torch.tensor(float(weights_prior_tau))) + + def get_tf_effect( + self, + x, + name, + layer, + weights_shape, + remove_diagonal=None, + weights_prior_tau=None, + use_horseshoe_prior=None, + non_negative: bool = False, + upper_triangle: bool = False, + ): + if remove_diagonal is None: + remove_diagonal = self.remove_diagonal + + weights_name = f"{self.name}_{name}_layer_{layer}_protein2effect" + + zero_diag = self.ones.expand(weights_shape) + if weights_prior_tau is None: + weights_prior_tau = self.ones if not hasattr(self, "weights_prior_tau") else self.weights_prior_tau + if use_horseshoe_prior is None: + use_horseshoe_prior = False if not hasattr(self, "use_horseshoe_prior") else self.use_horseshoe_prior + weights = self.get_param( + x=x, + name=name, + layer=layer, + weights_shape=weights_shape, + bias_shape=[1], + random_init_scale=1.0, + bayesian=True, + use_non_negative_weights=non_negative, + weights_prior_tau=weights_prior_tau, + use_horseshoe_prior=use_horseshoe_prior and not non_negative, + ) + if upper_triangle: + if len(weights.shape) > 2: + weights = torch.triu(torch.ones((weights.shape[0], weights.shape[1]))).unsqueeze(-1) * weights + else: + weights = torch.triu(weights) + + if len(weights_shape) == 2: + # [n_in, n_out] + if remove_diagonal: + zero_diag = torch.ones(weights_shape[-2], weights_shape[-2], device=weights.device) + zero_diag = zero_diag.fill_diagonal_(0.0) + elif len(weights_shape) == 3 and (weights_shape[0] == weights_shape[1]): + if remove_diagonal: + zero_diag = torch.ones(weights_shape[-3], weights_shape[-2], device=weights.device) + zero_diag = zero_diag.fill_diagonal_(0.0) + zero_diag = zero_diag.unsqueeze(-1) + if not self.training: + pyro.deterministic(f"{weights_name}_total_effect", weights * zero_diag) + return weights * zero_diag + + def get_signal_distance_effect( + self, + x, + layer, + name, + weights_shape, + ): + # [n_out, n_in] + weights_name = f"{self.name}_{name}_layer_{layer}_protein2effect" + + sig_distance_effect = self.get_param( + x=x, + name=name, + layer=layer, + weights_shape=weights_shape, + bias_shape=[1], + random_init_scale=1 / np.sqrt(weights_shape[-1]), + bayesian=True, + use_non_negative_weights=False, + ) + + if not self.training: + pyro.deterministic(f"{weights_name}_total_effect", sig_distance_effect) + + return sig_distance_effect + + def get_dist_prior( + self, + layer, + name, + weights_shape, + prior_alpha=None, + prior_beta=None, + prior_fun=pyro.distributions.Gamma, + ): + # [n_out, n_in] + weights_name = f"{self.name}_{name}_layer_{layer}_protein2effect" + + if prior_alpha is None: + prior_alpha = 1.0 + if getattr(self, f"{name}prior_alpha", None) is None: + self.register_buffer(f"{name}prior_alpha", torch.tensor(float(prior_alpha))) + if prior_beta is None: + prior_beta = 1.0 + if getattr(self, f"{name}prior_beta", None) is None: + self.register_buffer(f"{name}prior_beta", torch.tensor(float(prior_beta))) + # Weights + if getattr(self.weights, weights_name, None) is None: + deep_setattr( + self.weights, + weights_name, + pyro.nn.PyroSample( + lambda prior: prior_fun( + getattr(self, f"{name}prior_alpha"), + getattr(self, f"{name}prior_beta"), + ) + .expand(weights_shape) + .to_event(len(weights_shape)), + ), + ) + + sig_distance_effect = deep_getattr(self.weights, weights_name) + if not self.training: + pyro.deterministic(f"{weights_name}_total_effect", sig_distance_effect) + + return sig_distance_effect + + def get_signal_receptor_effect( + self, + x, + layer, + weights_shape, + ): + # [n_out, n_in] + name = "signal_receptor_effect" + weights_name = f"{self.name}_{name}_layer_{layer}_protein2effect" + + if weights_shape is None: + weights_shape = [len(self.signal_receptor_mask_scipy.data)] + + rec_sig_effect = self.get_param( + x=x, + name=name, + layer=layer, + weights_shape=weights_shape, + bias_shape=[1], + # random_init_scale=1 / np.sqrt(self.n_signals), + bayesian=True, + weights_prior_shape=torch.tensor(self.r_l_affinity_alpha_prior, device=x.device), + weights_prior_rate=torch.tensor(self.r_l_affinity_alpha_prior, device=x.device), + # sample positive weights + use_non_negative_weights=True, + ) + + if not self.training: + pyro.deterministic(f"{weights_name}_total_effect", rec_sig_effect) + + return rec_sig_effect + + def get_signal_receptor_tf_effect( + self, + x, + layer, + weights_shape, + name="signal_receptor_tf_effect", + use_non_negative_weights=None, + ): + # [n_tf, n_signals, n_receptors] + weights_name = f"{self.name}_{name}_layer_{layer}_protein2effect" + + if use_non_negative_weights is None: + use_non_negative_weights = self.use_non_negative_weights + + tf_sig_rec_tf_effect = self.get_param( + x=x, + name=name, + layer=layer, + weights_shape=weights_shape, + bias_shape=[1], + random_init_scale=1.0, + bayesian=True, + # sample positive weights + use_non_negative_weights=use_non_negative_weights, + use_horseshoe_prior=not use_non_negative_weights, + ) + + if self.receptor_tf_mask is not None: + if self.n_out == 1: + tf_sig_rec_tf_effect = ( + tf_sig_rec_tf_effect * self.receptor_tf_mask.T[:, self.signal_receptor_mask_scipy.row] + ) + else: + tf_sig_rec_tf_effect = torch.einsum( + "hrf,rh->hrf", + tf_sig_rec_tf_effect, + self.receptor_tf_mask[self.signal_receptor_mask_scipy.row, :], + ) + + if not self.training: + pyro.deterministic(f"{weights_name}_total_effect", tf_sig_rec_tf_effect) + + return tf_sig_rec_tf_effect + + def inverse_sigmoid_lm(self, x, weight, bias, scaling): + # expand shapes correctly + if x.dim() == 2: + weight = weight.unsqueeze(-1) + bias = bias.unsqueeze(-1) + if scaling is not None: + scaling = scaling.unsqueeze(-1) + elif x.dim() == 3: + weight = weight.unsqueeze(-1).unsqueeze(-1) + bias = bias.unsqueeze(-1).unsqueeze(-1) + if scaling is not None: + scaling = scaling.unsqueeze(-1).unsqueeze(-1) + if scaling is None: + # compute sigmoid function + return self.ones - torch.sigmoid(x * weight + bias) + else: + # compute sigmoid function + return (self.ones - torch.sigmoid(x * weight + bias)) * scaling + + def gamma_pdf(self, x, concentration, rate, scaling=None): + # expand shapes correctly + if x.dim() == 2: + concentration = concentration.unsqueeze(-1) + rate = rate.unsqueeze(-1) + if scaling is not None: + scaling = scaling.unsqueeze(-1) + elif x.dim() == 3: + concentration = concentration.unsqueeze(-1).unsqueeze(-1) + rate = rate.unsqueeze(-1).unsqueeze(-1) + if scaling is not None: + scaling = scaling.unsqueeze(-1).unsqueeze(-1) + # compute gamma function + if scaling is None: + return ( + pyro.distributions.Gamma( + concentration=concentration, + rate=rate, + ) + .log_prob(x) + .exp() + ) + return ( + pyro.distributions.Gamma( + concentration=concentration, + rate=rate, + ) + .log_prob(x) + .exp() + * scaling + ) + + def inverse_sigmoid_distance_function_protein_features( + self, + distances, + layer, + weights_shape, + name, + mode, + average_distance_prior=None, + ): + if average_distance_prior is None: + average_distance_prior = self.average_distance_prior + + name_ = f"{name}DistanceFunctionScaling" + scaling = self.get_signal_distance_effect( + x=distances, + layer=layer, + name=name_, + weights_shape=weights_shape, + ) + scaling = torch.sigmoid(scaling / torch.tensor(5.0, device=distances.device)) + + # sigmoid function ================= + name_ = f"{name}DistanceWeights" # strictly positive + weight = self.get_signal_distance_effect( + x=distances, + layer=layer, + name=name_, + weights_shape=weights_shape, + ) + # strictly positive + weight = ( + nn.functional.softplus(weight) + # prior of ~1/50 (= 1 / (softplus(0) / 35)) + / torch.tensor(0.7, device=distances.device) + ) / torch.tensor(average_distance_prior, device=distances.device) + name_ = f"{name}DistanceBias" + bias = self.get_signal_distance_effect( + x=distances, + layer=layer, + name=name_, + weights_shape=weights_shape, + ) - ( + self.ones + self.ones + ) # prior of -2 + sigmoid_distance_function = self.inverse_sigmoid_lm(distances, weight, bias, scaling) + + # gamma function ================= + name_ = f"{name}DistanceGammaConcentration" # strictly positive + gamma_concentration = self.get_signal_distance_effect( + x=distances, + layer=layer, + name=name_, + weights_shape=weights_shape, + ) + # strictly positive + gamma_concentration = ( + nn.functional.softplus(gamma_concentration) + # 1 * 5 = (softplus(0) / 0.7) * 5 = 5 + / torch.tensor(0.7, device=distances.device) + ) * torch.tensor(5.0, device=distances.device) + name_ = f"{name}DistanceGammaDistance" # strictly positive + gamma_distance = self.get_signal_distance_effect( + x=distances, + layer=layer, + name=name_, + weights_shape=weights_shape, + ) + # strictly positive + gamma_distance = ( + nn.functional.softplus(gamma_distance) + # 1 * average_distance_prior = (softplus(0) / 0.7) * average_distance_prior = average_distance_prior + / torch.tensor(0.7, device=distances.device) + ) + gamma_distance = gamma_distance * torch.tensor(average_distance_prior, device=distances.device) + gamma_distance_function = self.gamma_pdf( + distances, + concentration=gamma_concentration, + rate=gamma_concentration / gamma_distance, + scaling=torch.ones(1, device=distances.device) - scaling, + ) + + return sigmoid_distance_function + gamma_distance_function + + def inverse_sigmoid_distance_function( + self, + distances, + layer, + weights_shape, + name, + mode, + average_distance_prior=None, + ): + if average_distance_prior is None: + average_distance_prior = self.average_distance_prior + + name_ = f"{name}DistanceFunctionScaling" + scaling = self.get_dist_prior( + layer=layer, + name=name_, + weights_shape=weights_shape, + prior_alpha=1.0, + prior_beta=1.0, + prior_fun=pyro.distributions.Beta, + ) + + # sigmoid function ================= + name_ = f"{name}DistanceWeights" # strictly positive + weight = self.get_dist_prior( + layer=layer, + name=name_, + weights_shape=weights_shape, + prior_alpha=2.0, + prior_beta=2.0 / average_distance_prior, + prior_fun=pyro.distributions.Gamma, + ) + name_ = f"{name}DistanceBias" + bias = self.get_dist_prior( + layer=layer, + name=name_, + weights_shape=weights_shape, + prior_alpha=-2.0, + prior_beta=1.0, + prior_fun=pyro.distributions.Normal, + ) + sigmoid_distance_function = self.inverse_sigmoid_lm(distances, weight, bias, scaling) + + # gamma function ================= + name_ = f"{name}DistanceGammaConcentration" # strictly positive + gamma_concentration = self.get_dist_prior( + layer=layer, + name=name_, + weights_shape=weights_shape, + prior_alpha=2.0, + prior_beta=2.0 / 1.0, + prior_fun=pyro.distributions.Gamma, + ) + name_ = f"{name}DistanceGammaDistance" # strictly positive + gamma_distance = self.get_dist_prior( + layer=layer, + name=name_, + weights_shape=weights_shape, + prior_alpha=2.0, + prior_beta=2.0 / average_distance_prior, + prior_fun=pyro.distributions.Gamma, + ) + gamma_distance_function = self.gamma_pdf( + distances, + concentration=gamma_concentration, + rate=gamma_concentration / gamma_distance, + scaling=torch.tensor(1.0, device=distances.device) - scaling, + ) + + return sigmoid_distance_function + gamma_distance_function + + def inverse_sigmoid_signal_distance_function( + self, + distances, + layer, + weights_shape=None, + name="signal_distance_", + mode="independent_effect", + average_distance_prior=None, + ) -> torch.Tensor: + if weights_shape is None: + weights_shape = [self.n_signals] + if average_distance_prior is None: + average_distance_prior = self.average_distance_prior + # returns weights for [n_signals, n_distance_bins] + if distances.dim() == 1: + distances = distances.unsqueeze(-2) + elif distances.dim() == 2: + distances = distances.unsqueeze(-3) + return self.inverse_sigmoid_distance_function( + distances, + layer, + weights_shape=weights_shape, + name=name, + mode=mode, + average_distance_prior=average_distance_prior, + ) + + def forward(self, *args, **kwargs): + return getattr(self, self.mode)( + *args, + **kwargs, + ) + + def signal_receptor_tf_effect( + self, + bound_receptor_abundance_src: torch.Tensor, + use_cell_abundance_model: bool = False, + ): + layer = 0 + # optionally apply dropout ========== + if self.dropout_rate > 0.0: + if getattr(self.weights, f"{self.name}_layer_{layer}_dropout", None) is None: + deep_setattr( + self.weights, + f"{self.name}_layer_{layer}_dropout", + nn.Dropout(p=self.dropout_rate), + ) + dropout = deep_getattr(self.weights, f"{self.name}_layer_{layer}_dropout") + bound_receptor_abundance_src = dropout(bound_receptor_abundance_src) + + # 3. Computing effects a_{r,s,h} of ligand-receptor complexes x_{c,r,s} on active TF concentration. How? ========== + # Basal TF active concentration ========== + name = "basal_TF_weights" + basal_tf_weights = self.get_tf_effect( + x=bound_receptor_abundance_src, + name=name, + layer=layer, + weights_shape=[self.n_tfs] if self.n_out == 1 else [self.n_tfs, self.n_out], + remove_diagonal=False, + ) + # Signal-receptor complex effect on TF active concentration ========== + if not (use_cell_abundance_model and self.use_global_cell_abundance_model): + tf_sig_rec_effect_hsr = self.get_signal_receptor_tf_effect( + x=bound_receptor_abundance_src, + layer=layer, + weights_shape=[self.n_tfs, len(self.signal_receptor_mask_scipy.data)] + if self.n_out == 1 + else [self.n_tfs, len(self.signal_receptor_mask_scipy.data), self.n_out], + name="signal_receptor_tf_effect", + ) + else: + tf_sig_rec_effect_hsr = self.get_signal_receptor_tf_effect( + x=bound_receptor_abundance_src, + layer=layer, + weights_shape=[len(self.signal_receptor_mask_scipy.data)] + if self.n_out == 1 + else [len(self.signal_receptor_mask_scipy.data), self.n_out], + name="signal_receptor_tf_effect", + ) + # normalise by sqrt of the number of predictors + tf_sig_rec_effect_hsr = tf_sig_rec_effect_hsr / torch.sqrt( + torch.tensor( + float(len(self.signal_receptor_mask_scipy.data)), + device=tf_sig_rec_effect_hsr.device, + ) + ) + + # print("tf_sig_rec_effect_hsr mean", tf_sig_rec_effect_hsr.mean()) + # print("tf_sig_rec_effect_hsr min", tf_sig_rec_effect_hsr.min()) + # print("tf_sig_rec_effect_hsr max", tf_sig_rec_effect_hsr.max()) + # print("bound_receptor_abundance_crs mean", bound_receptor_abundance_crs.mean()) + # print("bound_receptor_abundance_crs min", bound_receptor_abundance_crs.min()) + # print("bound_receptor_abundance_crs max", bound_receptor_abundance_crs.max()) + if self.n_out == 1: + if use_cell_abundance_model: + bound_rs = rearrange( + bound_receptor_abundance_src, + "r (c h) -> c h r", + h=self.n_tfs, + ) + if self.use_global_cell_abundance_model: + effect = torch.einsum("chr,r->ch", bound_rs, tf_sig_rec_effect_hsr) + else: + effect = torch.einsum("chr,hr->ch", bound_rs, tf_sig_rec_effect_hsr) + else: + effect = torch.einsum( + "cr,hr->ch", + bound_receptor_abundance_src.coalesce().values(), + tf_sig_rec_effect_hsr, + ) + effect_on_tf_abundance = ( + # independent term for basal TF active concentration + basal_tf_weights.unsqueeze(-2) + # Communication-dependent effect on TF active concentration + + effect + ) + else: + if use_cell_abundance_model: + bound_rs = rearrange( + bound_receptor_abundance_src, + "r (c h) -> c h r", + h=self.n_tfs, + ) + if self.use_global_cell_abundance_model: + effect = torch.einsum("chr,rf->fch", bound_rs, tf_sig_rec_effect_hsr) + else: + effect = torch.einsum("chr,hrf->fch", bound_rs, tf_sig_rec_effect_hsr) + else: + effect = torch.einsum( + "cr,hrf->fch", + bound_receptor_abundance_src.coalesce().values(), + tf_sig_rec_effect_hsr, + ) + effect_on_tf_abundance = ( + # independent term for basal TF active concentration + basal_tf_weights.T.unsqueeze(-2) + # Communication-dependent effect on TF active concentration + + effect + ) + + if self.n_pathways > 1: + effect_on_tf_abundance = effect_on_tf_abundance + self.signal_receptor_pathway_tf_effect( + bound_receptor_abundance_src=bound_receptor_abundance_src, + use_cell_abundance_model=use_cell_abundance_model, + ) + + # print("self.output_transform", self.output_transform) + # print("effect_on_tf_abundance mean", effect_on_tf_abundance.mean()) + # print("effect_on_tf_abundance min", effect_on_tf_abundance.min()) + # print("effect_on_tf_abundance max", effect_on_tf_abundance.max()) + if use_cell_abundance_model: + effect_on_tf_abundance = effect_on_tf_abundance / torch.tensor(10.0, device=effect_on_tf_abundance.device) + if self.output_transform == "softplus": + # apply softplus to ensure positive values + effect_on_tf_abundance = nn.functional.softplus( + effect_on_tf_abundance / torch.tensor(1.0, device=effect_on_tf_abundance.device) + ) / torch.tensor(0.7, device=effect_on_tf_abundance.device) + elif self.output_transform == "proportion1": + # The proportion of TF in the active form + # TODO # how do basal_tf_weights behave? average 2 before transform and the other effect is normalised to be small? + effect_on_tf_abundance = torch.sigmoid( + effect_on_tf_abundance / torch.tensor(1.0, device=effect_on_tf_abundance.device) + + torch.tensor(1.0, device=effect_on_tf_abundance.device) + ) + elif self.output_transform == "proportion0": + # The proportion of TF in the active form + # TODO # how do basal_tf_weights behave? average 2 before transform and the other effect is normalised to be small? + effect_on_tf_abundance = torch.sigmoid( + effect_on_tf_abundance / torch.tensor(1.0, device=effect_on_tf_abundance.device) + - torch.tensor(1.0, device=effect_on_tf_abundance.device) + ) + elif self.output_transform == "activity": + # NOTE that this does not use TF abundance info (similarly to independent TF activity term) + # This represents effect direction but not magnitude + effect_on_tf_abundance = torch.sigmoid( + effect_on_tf_abundance / torch.tensor(1.0, device=effect_on_tf_abundance.device) + ) * torch.tensor(2.0, device=effect_on_tf_abundance.device) - torch.tensor( + 1.0, device=effect_on_tf_abundance.device + ) + + # print(f"cell_comm_effect {self.name} mean ", tf_abundance.mean()) + # print(f"cell_comm_effect {self.name} min", tf_abundance.min()) + # print(f"cell_comm_effect {self.name} max", tf_abundance.max()) + + return effect_on_tf_abundance + + def apply_upper_triangle_diag(self, weights, name, n_out=1): + if (len(weights.shape) == 3) and (n_out == 1): + shape = (weights.shape[1], weights.shape[2]) + triu = torch.triu(torch.ones(shape, device=weights.device)).unsqueeze(-3) + weights = weights * triu + zero_diag = torch.ones(shape, device=weights.device) + zero_diag = zero_diag.fill_diagonal_(0.0) + zero_diag = zero_diag.unsqueeze(-3) + weights = weights * zero_diag + elif (len(weights.shape) == 4) and (n_out > 1): + raise NotImplementedError + elif (len(weights.shape) == 3) and (n_out > 1): + raise NotImplementedError + elif (len(weights.shape) == 2) and (n_out == 1): + shape = (weights.shape[0], weights.shape[1]) + triu = torch.triu(torch.ones(shape, device=weights.device)) + weights = weights * triu + zero_diag = torch.ones(shape, device=weights.device) + zero_diag = zero_diag.fill_diagonal_(0.0) + weights = weights * zero_diag + if not self.training: + pyro.deterministic(f"{name}_2_total_effect", weights) + return weights, zero_diag, triu + + def signal_receptor_pathway_tf_effect( + self, + bound_receptor_abundance_src: torch.Tensor, + use_cell_abundance_model: bool = False, + ): + layer = 0 + + # Basal pathway activity ========== + name = "basal_pathway_weights" + basal_pathway_weights = self.get_tf_effect( + x=bound_receptor_abundance_src, + name=name, + layer=layer, + weights_shape=[self.n_pathways] if self.n_out == 1 else [self.n_pathways, self.n_out], + remove_diagonal=False, + ) + # Signal-receptor complex effect on pathway activity ========== + pathway_sig_rec_effect_hsr = self.get_signal_receptor_tf_effect( + x=bound_receptor_abundance_src, + layer=layer, + weights_shape=[self.n_pathways, len(self.signal_receptor_mask_scipy.data)] + if self.n_out == 1 + else [ + self.n_pathways, + len(self.signal_receptor_mask_scipy.data), + self.n_out, + ], + name="signal_receptor_pathway_effect", + ) + # normalise by sqrt of the number of predictors + pathway_sig_rec_effect_hsr = pathway_sig_rec_effect_hsr / torch.sqrt( + torch.tensor( + float(len(self.signal_receptor_mask_scipy.data)), + device=pathway_sig_rec_effect_hsr.device, + ) + ) + + # print("pathway_sig_rec_effect_hsr mean", pathway_sig_rec_effect_hsr.mean()) + # print("pathway_sig_rec_effect_hsr min", pathway_sig_rec_effect_hsr.min()) + # print("pathway_sig_rec_effect_hsr max", pathway_sig_rec_effect_hsr.max()) + # print("bound_receptor_abundance_crs mean", bound_receptor_abundance_crs.mean()) + # print("bound_receptor_abundance_crs min", bound_receptor_abundance_crs.min()) + # print("bound_receptor_abundance_crs max", bound_receptor_abundance_crs.max()) + if self.n_out == 1: + effect_on_pathway_activity = ( + # independent term for basal pathway activity + basal_pathway_weights.unsqueeze(-2) + # Communication-dependent effect on pathway activity + + torch.einsum( + "rc,hr->ch", + bound_receptor_abundance_src, + pathway_sig_rec_effect_hsr, + ) + ) + else: + effect_on_pathway_activity = ( + # independent term for basal pathway activity + basal_pathway_weights.T.unsqueeze(-2) + # Communication-dependent effect on pathway activity + + torch.einsum( + "rc,hrf->fch", + bound_receptor_abundance_src, + pathway_sig_rec_effect_hsr, + ) + ) + # print("self.output_transform", self.output_transform) + # print("effect_on_tf_abundance mean", effect_on_tf_abundance.mean()) + # print("effect_on_tf_abundance min", effect_on_tf_abundance.min()) + # print("effect_on_tf_abundance max", effect_on_tf_abundance.max()) + + # apply softplus to ensure positive values + effect_on_pathway_activity = nn.functional.softplus( + effect_on_pathway_activity / torch.tensor(1.0, device=effect_on_pathway_activity.device) + ) / torch.tensor(0.7, device=effect_on_pathway_activity.device) + + # compute pathway effects on TFs + name = "pathway_tf_weights" + if not (use_cell_abundance_model and self.use_global_cell_abundance_model): + pathway_tf_weights = self.get_tf_effect( + x=bound_receptor_abundance_src, + name=name, + layer=layer, + weights_shape=[self.n_tfs, self.n_pathways] + if self.n_out == 1 + else [self.n_tfs, self.n_pathways * self.n_out], + remove_diagonal=False, + non_negative=self.use_non_negative_weights, + use_horseshoe_prior=True, + ) + else: + pathway_tf_weights = self.get_tf_effect( + x=bound_receptor_abundance_src, + name=name, + layer=layer, + weights_shape=[self.n_pathways] if self.n_out == 1 else [self.n_pathways * self.n_out], + remove_diagonal=False, + non_negative=self.use_non_negative_weights, + use_horseshoe_prior=True, + ) + # normalise by sqrt of the number of predictors + pathway_tf_weights = pathway_tf_weights / torch.sqrt( + torch.tensor( + float(self.n_pathways), + device=pathway_tf_weights.device, + ) + ) + if use_cell_abundance_model: + effect_on_pathway_activity = rearrange( + effect_on_pathway_activity, + "(c h) p -> c h p", + h=self.n_tfs, + ) + if self.n_out == 1: + if use_cell_abundance_model: + if self.use_global_cell_abundance_model: + effect_on_tf_activity = torch.einsum("chp,p->ch", effect_on_pathway_activity, pathway_tf_weights) + else: + effect_on_tf_activity = torch.einsum("chp,hp->ch", effect_on_pathway_activity, pathway_tf_weights) + else: + effect_on_tf_activity = torch.einsum("cp,hp->ch", effect_on_pathway_activity, pathway_tf_weights) + else: + if self.use_global_cell_abundance_model: + pathway_tf_weights = rearrange(pathway_tf_weights, "(p f) -> p f", p=self.n_pathways, f=self.n_out) + else: + pathway_tf_weights = rearrange(pathway_tf_weights, "h (p f) -> h p f", p=self.n_pathways, f=self.n_out) + if use_cell_abundance_model: + if self.use_global_cell_abundance_model: + effect_on_tf_activity = torch.einsum("chp,pf->fch", effect_on_pathway_activity, pathway_tf_weights) + else: + effect_on_tf_activity = torch.einsum("chp,hpf->fch", effect_on_pathway_activity, pathway_tf_weights) + else: + effect_on_tf_activity = torch.einsum("cp,hpf->fch", effect_on_pathway_activity, pathway_tf_weights) + # including pathway interactions + if self.use_pathway_interaction_effect: + name = "pathway_interaction_tf_weights" + if not (use_cell_abundance_model and self.use_global_cell_abundance_model): + pathway_tf_weights = self.get_tf_effect( + x=bound_receptor_abundance_src, + name=name, + layer=layer, + weights_shape=[self.n_tfs, self.n_pathways * self.n_pathways] + if self.n_out == 1 + else [self.n_tfs, self.n_pathways * self.n_pathways * self.n_out], + remove_diagonal=False, + non_negative=self.use_non_negative_weights, + use_horseshoe_prior=True, + ) + else: + pathway_tf_weights = self.get_tf_effect( + x=bound_receptor_abundance_src, + name=name, + layer=layer, + weights_shape=[self.n_pathways * self.n_pathways] + if self.n_out == 1 + else [self.n_pathways * self.n_pathways * self.n_out], + remove_diagonal=False, + non_negative=self.use_non_negative_weights, + use_horseshoe_prior=True, + ) + if self.n_out == 1: + if self.use_global_cell_abundance_model: + pathway_tf_weights = rearrange( + pathway_tf_weights, + "(o p) -> o p", + p=self.n_pathways, + o=self.n_pathways, + ) + pathway_tf_weights, zero_diag, triu = self.apply_upper_triangle_diag(pathway_tf_weights, name) + else: + pathway_tf_weights = rearrange( + pathway_tf_weights, + "h (o p) -> h o p", + p=self.n_pathways, + o=self.n_pathways, + ) + pathway_tf_weights, zero_diag, triu = self.apply_upper_triangle_diag(pathway_tf_weights, name) + # normalise by sqrt of the number of predictors + pathway_tf_weights = pathway_tf_weights / torch.sqrt((zero_diag * triu).sum()) + if use_cell_abundance_model: + if self.use_global_cell_abundance_model: + effect_on_tf_activity = effect_on_tf_activity + torch.einsum( + "chp,op,cho->ch", + effect_on_pathway_activity, + pathway_tf_weights, + effect_on_pathway_activity, + ) + else: + effect_on_tf_activity = effect_on_tf_activity + torch.einsum( + "chp,hop,cho->ch", + effect_on_pathway_activity, + pathway_tf_weights, + effect_on_pathway_activity, + ) + else: + effect_on_tf_activity = effect_on_tf_activity + torch.einsum( + "cp,hop,co->ch", + effect_on_pathway_activity, + pathway_tf_weights, + effect_on_pathway_activity, + ) + else: + if self.use_global_cell_abundance_model: + pathway_tf_weights = rearrange( + pathway_tf_weights, + "(o p f) -> o p f", + p=self.n_pathways, + o=self.n_pathways, + f=self.n_out, + ) + pathway_tf_weights, zero_diag, triu = self.apply_upper_triangle_diag( + pathway_tf_weights, + name, + n_out=self.n_out, + ) + else: + pathway_tf_weights = rearrange( + pathway_tf_weights, + "h (o p f) -> h o p f", + p=self.n_pathways, + o=self.n_pathways, + f=self.n_out, + ) + pathway_tf_weights, zero_diag, triu = self.apply_upper_triangle_diag( + pathway_tf_weights, + name, + n_out=self.n_out, + ) + # normalise by sqrt of the number of predictors + pathway_tf_weights = pathway_tf_weights / torch.sqrt((zero_diag * triu).sum()) + if use_cell_abundance_model: + if self.use_global_cell_abundance_model: + effect_on_tf_activity = torch.einsum( + "chp,opf,cho->fch", + effect_on_pathway_activity, + pathway_tf_weights, + effect_on_pathway_activity, + ) + else: + effect_on_tf_activity = torch.einsum( + "chp,hopf,cho->fch", + effect_on_pathway_activity, + pathway_tf_weights, + effect_on_pathway_activity, + ) + else: + effect_on_tf_activity = torch.einsum( + "cp,hopf,co->fch", + effect_on_pathway_activity, + pathway_tf_weights, + effect_on_pathway_activity, + ) + + # print(f"effect_on_tf_activity pathway {self.name} mean ", effect_on_tf_activity.mean()) + # print(f"effect_on_tf_activity pathway {self.name} min", effect_on_tf_activity.min()) + # print(f"effect_on_tf_activity pathway {self.name} max", effect_on_tf_activity.max()) + + return effect_on_tf_activity + + def signal_receptor_occupancy( + self, + signal_abundance: torch.Tensor, + receptor_abundance: torch.Tensor, + distances: torch.Tensor = None, + skip_distance_effect: bool = False, + ): + layer = 0 + # optionally apply dropout ========== + if self.dropout_rate > 0.0: + if getattr(self.weights, f"{self.name}_layer_{layer}_dropout", None) is None: + deep_setattr( + self.weights, + f"{self.name}_layer_{layer}_dropout", + nn.Dropout(p=self.dropout_rate), + ) + dropout = deep_getattr(self.weights, f"{self.name}_layer_{layer}_dropout") + signal_abundance = dropout(signal_abundance) + receptor_abundance = dropout(receptor_abundance) + + # 1. Signal RNA -> signal protein conversion using distance function ============ + # a_{s, b} = f(signal_protein_features, distance_between_bins) + # w_{c,s} = sum_b w_{c,s,b} * a_{s, b} + if not skip_distance_effect: + signal_distance_effect_sb = self.inverse_sigmoid_signal_distance_function( + distances, + layer=layer, + ) + signal_abundance = torch.einsum("csb,sb->cs", signal_abundance, signal_distance_effect_sb) + + # 2. Computing bound receptor concentrations using learnable a_{r,s} affinity ============ + # a_{r,s} = f(receptor_features, signal_features) + sig_rec_affinity_rs = self.get_signal_receptor_effect( + x=signal_abundance, + layer=layer, + weights_shape=[len(self.signal_receptor_mask_scipy.data)], + ) + + # x_{c,r,s} = w_{c,s} * a_{r,s} + pair2signal = one_hot( + torch.tensor(self.signal_receptor_mask_scipy.col, device=signal_abundance.device).long().unsqueeze(-1), + self.signal_receptor_mask_scipy.shape[1], + ) + affinity_to_receptor_src = signal_abundance.T[ + self.signal_receptor_mask_scipy.row, : + ] * sig_rec_affinity_rs.unsqueeze(-1) + affinity_to_receptor_rc_sum = torch.mm(pair2signal.T, affinity_to_receptor_src) # ps,pc -> sc + # affinity_to_receptor_crs = torch.einsum( + # "cs,rs->crs", signal_abundance, sig_rec_affinity_rs + # ) + # Compute bound receptor abundance + # bound = total * proportion_of_signal_with_affinity + # TODO - which unbound term to use? + # x_{c,r,s} = w_{c,r} * (x_{c,r,s} / (sum_s x_{c,r,s} + unbound_r)) + # x_{c,r,s} = w_{c,r} * (x_{c,r,s} / (sum_s x_{c,r,s} + unbound_r * w_{c,r})) + unbound_r = self.get_signal_distance_effect( + x=signal_abundance, + layer=layer, + name="unbound_r", + weights_shape=[self.n_receptors], + ) + unbound_r = nn.functional.softplus( + unbound_r / torch.tensor(5.0, device=unbound_r.device) - torch.tensor(2.0, device=unbound_r.device) + ) + if not self.use_unbound_concentration: + proportion_of_signal_with_affinity_src = affinity_to_receptor_src / ( + affinity_to_receptor_rc_sum[self.signal_receptor_mask_scipy.col, :] + # + unbound_r[self.signal_receptor_mask_scipy.col].unsqueeze(-1) + + torch.tensor(1.0, device=signal_abundance.device) + ) + else: + proportion_of_signal_with_affinity_src = affinity_to_receptor_src / ( + affinity_to_receptor_rc_sum[self.signal_receptor_mask_scipy.col, :] + + unbound_r[self.signal_receptor_mask_scipy.col].unsqueeze(-1) + * receptor_abundance.T[self.signal_receptor_mask_scipy.col, :] + ) + bound_receptor_abundance_src = ( + proportion_of_signal_with_affinity_src * receptor_abundance.T[self.signal_receptor_mask_scipy.col, :] + ) + # bound_receptor_abundance_crs = torch.einsum( + # "cr,crs->crs", receptor_abundance, proportion_of_signal_with_affinity_src + # ) + # optionally apply dropout + # if self.dropout_rate > 0: + # bound_receptor_abundance_crs = dropout(bound_receptor_abundance_crs) + + return bound_receptor_abundance_src + + def diffusion_domain_function( + self, + signal_abundance: torch.Tensor, + w_sf: torch.Tensor, + ): + # Low dimensional diffusion limiter - for every signal limit where it can diffuse. + # Maybe this can lead to more reasonable distributions + # without requiring suppressing effects to get rid of the signal. + n_signals = signal_abundance.shape[-1] + n_cell_types = w_sf.shape[-1] + + name = "diffusion_domain_function" + # x_cs = sum_q y_cq * y_qs + # y_qs ~ Beta(100, 1) + # y_cq = y_cq / sum_q y_cq + # y_cq = sum_f w_cf * y_fq # maybe w_cf is lvl3 but better lvl5 + # y_fq ~ Gamma(1, 1) + y_fq = self.get_dist_prior( + layer="", + name=f"{name}_y_fq", + weights_shape=[n_cell_types, self.n_spatial_domains], + prior_alpha=1.0, + prior_beta=1.0, + prior_fun=pyro.distributions.Gamma, + ) + y_cq = torch.einsum("cf,fq->cq", w_sf, y_fq) + y_cq = y_cq / y_cq.sum(dim=-1, keepdim=True) + y_qs = self.get_dist_prior( + layer="", + name=f"{name}_y_qs", + weights_shape=[self.n_spatial_domains, n_signals], + prior_alpha=100.0, + prior_beta=1.0, + prior_fun=pyro.distributions.Beta, + ) + x_cs = torch.einsum("cq,qs->cs", y_cq, y_qs) + signal_abundance = signal_abundance * x_cs + return signal_abundance + + def signal_receptor_occupancy_spatial( + self, + signal_abundance: torch.Tensor, + receptor_abundance: torch.Tensor, + distances: torch.Tensor = None, + tiles: torch.Tensor = None, + obs_plate=None, + obs_in_use=None, + w_sf: torch.Tensor = None, + use_diffusion_domain: bool = False, + max_distance_threshold: float = None, + ): + n_locations = signal_abundance.shape[-2] + n_signals = signal_abundance.shape[-1] + n_receptors = receptor_abundance.shape[-1] + n_cell_types = receptor_abundance.shape[-2] + + if distances.is_sparse: + if tiles is not None: + raise ValueError("tiles should be None when using sparse distances") + # with obs_plate as ind: + # pass + # indices0 = distances.coalesce().indices()[0, :] + if not self.training: + # make sure that the indices are correct + with obs_plate as ind: + assert torch.allclose( + ind, torch.arange(n_locations, device=ind.device) + ), "indices in obs_plate do not match the unshuffled order of locations" + # assert torch.allclose(distances.coalesce()[ind, ind], distances), \ + # 'indices in obs_plate do not match the unshuffled order of locations' + indices1 = distances.coalesce().indices()[1, :] + distances_ = distances.coalesce().values().float() + # indices = torch.logical_or(torch.isin(indices0, ind), torch.isin(indices1, ind)) + # indices0 = indices0[indices] + # use 1d tensor with propper indices mapping here to make sure that the indices are correct + # indices1 = indices1[indices] + # distances_ = distances_[indices] + + # 1. Signal RNA -> signal protein conversion using distance function ============ + signal_distance_effect_ss_b = self.inverse_sigmoid_signal_distance_function( + distances_, + layer="0", + name="signal_distance_spatial_", + ).T + target2row = one_hot( + torch.as_tensor(indices1, device=signal_abundance.device).long().unsqueeze(-1), + distances.shape[1], + ).T.float() + signal_abundance = torch.mm( + target2row, + signal_distance_effect_ss_b * signal_abundance[indices1, :], # target s to row # row to signal + ) + else: + # 1. Signal RNA -> signal protein conversion using distance function ============ + signal_distance_effect_ss_b = self.inverse_sigmoid_signal_distance_function( + distances, + layer="0", + name="signal_distance_spatial_", + ) + if tiles is not None: + tiles_mask = tiles @ tiles.T + signal_distance_effect_ss_b = torch.einsum("sop,op->sop", signal_distance_effect_ss_b, tiles_mask) + if max_distance_threshold is not None: + signal_distance_effect_ss_b = torch.einsum( + "sop,op->sop", + signal_distance_effect_ss_b, + (distances < torch.tensor(max_distance_threshold, device=distances.device)).float(), + ) + signal_abundance = torch.einsum( + "ps,sop->os", + signal_abundance, + signal_distance_effect_ss_b, + ) + + if use_diffusion_domain: + signal_abundance = self.diffusion_domain_function( + signal_abundance=signal_abundance, + w_sf=w_sf, + ) + + if not self.training: + with obs_plate: + if obs_in_use is not None: + pyro.deterministic( + "signal_abundance_local", + signal_abundance[obs_in_use, :], + ) + else: + pyro.deterministic( + "signal_abundance_local", + signal_abundance, + ) + + # 2. Computing bound receptor concentrations using learnable a_{r,s} affinity ============ + # first reshape inputs to be locations * cell type specific + # d_{c,s} -> d_{c,f,s} + signal_abundance = signal_abundance.unsqueeze(-2).expand([n_locations, n_cell_types, n_signals]) + signal_abundance = rearrange(signal_abundance, "c f s -> (c f) s", f=n_cell_types) + # g_{f,r} -> g_{c,f,r} + if receptor_abundance.dim() == 2: + receptor_abundance = receptor_abundance.unsqueeze(-3).expand([n_locations, n_cell_types, n_receptors]) + receptor_abundance = rearrange(receptor_abundance, "c f r -> (c f) r", f=n_cell_types) + + bound_receptor_abundance_src = self.signal_receptor_occupancy( + signal_abundance=signal_abundance, + receptor_abundance=receptor_abundance, + distances=distances, + skip_distance_effect=True, + ) + return bound_receptor_abundance_src + + def signal_receptor_tf_effect_spatial( + self, + bound_receptor_abundance_src: torch.Tensor, + ): + return self.signal_receptor_tf_effect( + bound_receptor_abundance_src=bound_receptor_abundance_src, + use_cell_abundance_model=True, + ) diff --git a/cell2location/nn/_mixins.py b/cell2location/nn/_mixins.py new file mode 100755 index 00000000..96153d6d --- /dev/null +++ b/cell2location/nn/_mixins.py @@ -0,0 +1,351 @@ +import pyro +import pyro.distributions as dist +import torch +from pyro.infer.autoguide.utils import deep_getattr, deep_setattr +from pyro.nn import PyroParam, PyroSample +from torch import nn as nn +from torch.distributions import constraints + + +class CreateParameterMixin: + def create_horseshoe_prior( + self, + name, + weights_shape, + weights_prior_scale=None, + weights_prior_tau=None, + scale_distribution=dist.HalfNormal, # TODO figure out which distribution to use HalfCauchy has mean=Inf so can't use it + ): + # Create scalar tau (like sd for horseshoe prior) ===================== + tau_name = f"{name}tau" + if getattr(self.weights, tau_name, None) is None: + if weights_prior_tau is None: + weights_prior_tau = self.weights_prior_tau + if getattr(self, f"{tau_name}_scale", None) is None: + self.register_buffer(f"{tau_name}_scale", weights_prior_tau) + deep_setattr( + self.weights, + tau_name, + PyroSample( + lambda prior: scale_distribution( + getattr(self, f"{tau_name}_scale"), + ) + .expand([1]) + .to_event(1), + ), + ) + tau = deep_getattr(self.weights, tau_name) + + # Create weights (like mean for horseshoe prior) ===================== + weights_name = f"{name}weights" + if getattr(self.weights, weights_name, None) is None: + deep_setattr( + self.weights, + weights_name, + PyroSample( + lambda prior: dist.Normal( + self.zeros, + self.ones, + ) + .expand(weights_shape) + .to_event(len(weights_shape)), + ), + ) + unscaled_weights = deep_getattr(self.weights, weights_name) + + if getattr(self, "use_gamma_horseshoe_prior", False): + # Create elementwise lambdas using Gamma distribution (like sd for horseshoe prior) ===================== + lambdas_name = f"{name}lambdas" + if getattr(self.weights, lambdas_name, None) is None: + if weights_prior_scale is None: + weights_prior_scale = self.weights_prior_scale + if getattr(self, f"{lambdas_name}_scale", None) is None: + self.register_buffer(f"{lambdas_name}_scale", weights_prior_scale) + deep_setattr( + self.weights, + lambdas_name, + PyroSample( + lambda prior: dist.Gamma( + tau, + getattr(self, f"{lambdas_name}_scale"), + ) + .expand(weights_shape) + .to_event(len(weights_shape)), + ), + ) + lambdas = deep_getattr(self.weights, lambdas_name) + else: + # Create elementwise lambdas (like sd for horseshoe prior) ===================== + lambdas_name = f"{name}lambdas" + if getattr(self.weights, lambdas_name, None) is None: + if weights_prior_scale is None: + weights_prior_scale = self.weights_prior_scale + if getattr(self, f"{lambdas_name}_scale", None) is None: + self.register_buffer(f"{lambdas_name}_scale", weights_prior_scale) + deep_setattr( + self.weights, + lambdas_name, + PyroSample( + lambda prior: scale_distribution( + getattr(self, f"{lambdas_name}_scale"), + ) + .expand(weights_shape) + .to_event(len(weights_shape)), + ), + ) + lambdas = deep_getattr(self.weights, lambdas_name) + lambdas = tau * lambdas + + weights = lambdas * unscaled_weights + if not self.training: + pyro.deterministic(f"{self.name_prefix}{name}", weights) + return weights + + def get_param( + self, + x: torch.Tensor, + name: str, + layer, + weights_shape: list, + random_init_scale: float = 1.0, + bayesian: bool = True, + use_non_negative_weights: bool = False, + bias_shape: list = [1], + skip_name: bool = False, + weights_prior_mean: torch.Tensor = None, + weights_prior_scale: torch.Tensor = None, + weights_prior_shape: torch.Tensor = None, + weights_prior_rate: torch.Tensor = None, + weights_prior_tau: torch.Tensor = None, + return_bias: bool = False, + use_horseshoe_prior: bool = False, + ): + # generate parameter names ========== + if skip_name: + weights_name = f"{name}_layer_{layer}_weights" + bias_name = f"{name}_layer_{layer}_bias" + else: + weights_name = f"{self.name}_{name}_layer_{layer}_weights" + bias_name = f"{self.name}_{name}_layer_{layer}_bias" + + # create parameters ========== + if not use_horseshoe_prior: + # register priors ========== + if weights_prior_mean is None: + weights_prior_mean = self.zeros + if getattr(self, f"{weights_name}_mean", None) is None: + self.register_buffer(f"{weights_name}_mean", weights_prior_mean) + if weights_prior_scale is None: + weights_prior_scale = self.weights_prior_scale + if getattr(self, f"{weights_name}_scale", None) is None: + self.register_buffer(f"{weights_name}_scale", weights_prior_scale) + if weights_prior_shape is None: + weights_prior_shape = self.weights_prior_shape + if getattr(self, f"{weights_name}_shape", None) is None: + self.register_buffer(f"{weights_name}_shape", weights_prior_shape) + if weights_prior_rate is None: + weights_prior_rate = self.ones + if getattr(self, f"{weights_name}_rate", None) is None: + self.register_buffer(f"{weights_name}_rate", weights_prior_rate) + # create parameters ========== + if getattr(self.weights, weights_name, None) is None: + if bayesian: + # generate bayesian variables + if use_non_negative_weights: + # define Gamma distributed weights and Normal bias + # positive effect of input on output + deep_setattr( + self.weights, + weights_name, + PyroSample( + lambda prior: dist.Gamma( + getattr(self, f"{weights_name}_shape"), + getattr(self, f"{weights_name}_rate"), + ) + .expand(weights_shape) + .to_event(len(weights_shape)) + ), + ) + else: + deep_setattr( + self.weights, + weights_name, + PyroSample( + lambda prior: dist.SoftLaplace( + getattr(self, f"{weights_name}_mean"), + getattr(self, f"{weights_name}_scale"), + ) + .expand(weights_shape) + .to_event(len(weights_shape)), + ), + ) + if return_bias: + # bias allows requiring signal from more than one input + deep_setattr( + self.weights, + bias_name, + PyroSample( + lambda prior: dist.Normal( + self.bias_mean_prior, + self.ones * self.bias_sigma_prior, + ) + .expand(bias_shape) + .to_event(len(bias_shape)), + ), + ) + else: + if use_non_negative_weights: + # initialise weights + init_param = torch.normal( + torch.full( + size=weights_shape, + fill_value=0.0, + device=x.device, + ), + torch.full( + size=weights_shape, + fill_value=random_init_scale, + device=x.device, + ), + ).abs() + deep_setattr( + self.weights, + weights_name, + PyroParam( + init_param.clone().detach().requires_grad_(True), + constraint=constraints.positive, + ), + ) + else: + # initialise weights + init_param = torch.normal( + torch.full( + size=weights_shape, + fill_value=0.0, + device=x.device, + ), + torch.full( + size=weights_shape, + fill_value=random_init_scale, + device=x.device, + ), + ) + deep_setattr( + self.weights, + weights_name, + PyroParam(init_param.clone().detach().requires_grad_(True)), + ) + if return_bias: + init_param = torch.normal( + torch.full( + size=bias_shape, + fill_value=0.0, + device=x.device, + ), + torch.full( + size=bias_shape, + fill_value=random_init_scale, + device=x.device, + ), + ) + deep_setattr( + self.weights, + bias_name, + PyroParam(init_param.clone().detach().requires_grad_(True)), + ) + # extract parameters ========== + weights = deep_getattr(self.weights, weights_name) + if return_bias: + bias = deep_getattr(self.weights, bias_name) + return weights, bias + return weights + else: + # create and extract parameters ========== + return self.create_horseshoe_prior( + name=weights_name, + weights_shape=weights_shape, + weights_prior_scale=weights_prior_scale, + weights_prior_tau=weights_prior_tau, + ) + + def get_layernorm(self, name, layer, norm_shape): + if getattr(self.weights, f"{self.name}_{name}_layer_{layer}_layer_norm", None) is None: + deep_setattr( + self.weights, + f"{self.name}_{name}_layer_{layer}_layer_norm", + nn.LayerNorm(norm_shape, elementwise_affine=False), + ) + layer_norm = deep_getattr(self.weights, f"{self.name}_{name}_layer_{layer}_layer_norm") + return layer_norm + + def get_activation(self, name, layer): + if getattr(self.weights, f"{self.name}_{name}_layer_{layer}_activation_fn", None) is None: + deep_setattr( + self.weights, + f"{self.name}_{name}_layer_{layer}_activation_fn", + self.activation_fn(), + ) + activation_fn = deep_getattr(self.weights, f"{self.name}_{name}_layer_{layer}_activation_fn") + return activation_fn + + def get_pool(self, name, layer, kernel_size, pool_class=torch.nn.MaxPool2d): + if getattr(self.weights, f"{self.name}_{name}_layer_{layer}_Pool", None) is None: + deep_setattr( + self.weights, + f"{self.name}_{name}_layer_{layer}_Pool", + pool_class(kernel_size), + ) + max_pool = deep_getattr(self.weights, f"{self.name}_{name}_layer_{layer}_Pool") + return max_pool + + def get_nn_weight(self, weights_name, weights_shape): + if not hasattr(self.weights, weights_name): + deep_setattr( + self.weights, + weights_name, + pyro.nn.PyroSample( + lambda prior: dist.SoftLaplace( + self.zeros, + self.ones, + ) + .expand(weights_shape) + .to_event(len(weights_shape)), + ), + ) + return deep_getattr(self.weights, weights_name) + + def get_nn_bias(self, bias_name, bias_shape): + if not hasattr(self.weights, bias_name): + deep_setattr( + self.weights, + bias_name, + pyro.nn.PyroSample( + lambda prior: dist.SoftLaplace( + self.zeros, + self.ones, + ) + .expand(bias_shape) + .to_event(len(bias_shape)), + ), + ) + return deep_getattr(self.weights, bias_name) + + def get_nn_layernorm(self, name, layer, norm_shape): + if getattr(self.weights, f"{name}_layer_{layer}_layer_norm", None) is None: + deep_setattr( + self.weights, + f"{name}_layer_{layer}_layer_norm", + torch.nn.LayerNorm(norm_shape, elementwise_affine=False), + ) + layer_norm = deep_getattr(self.weights, f"{name}_layer_{layer}_layer_norm") + return layer_norm + + def get_nn_activation(self, name, layer): + if getattr(self.weights, f"{name}_layer_{layer}_activation_fn", None) is None: + deep_setattr( + self.weights, + f"{name}_layer_{layer}_activation_fn", + torch.nn.Softplus(), + ) + activation_fn = deep_getattr(self.weights, f"{name}_layer_{layer}_activation_fn") + return activation_fn diff --git a/cell2location/nn/fclayers.py b/cell2location/nn/fclayers.py index 0c2e6290..ef3f51c4 100755 --- a/cell2location/nn/fclayers.py +++ b/cell2location/nn/fclayers.py @@ -2,7 +2,7 @@ from typing import Iterable import torch -from scvi.nn._utils import one_hot +from scvi.nn import one_hot from torch import nn as nn diff --git a/cell2location/nn/fclayers_context.py b/cell2location/nn/fclayers_context.py index 11e4771c..23dc3ab5 100755 --- a/cell2location/nn/fclayers_context.py +++ b/cell2location/nn/fclayers_context.py @@ -2,7 +2,7 @@ from typing import Iterable import torch -from scvi.nn._utils import one_hot +from scvi.nn import one_hot from torch import nn as nn diff --git a/cell2location/plt/plot_heatmap.py b/cell2location/plt/plot_heatmap.py index a632f33f..3dd9c698 100644 --- a/cell2location/plt/plot_heatmap.py +++ b/cell2location/plt/plot_heatmap.py @@ -19,6 +19,7 @@ def heatmap( title="", vmin=None, vmax=None, + vcenter=None, ): r"""Plot heatmap with row and column labels using plt.imshow @@ -38,9 +39,13 @@ def heatmap( array = np.array(array) if log: - plt.imshow(array, interpolation="nearest", cmap=cmap, norm=matplotlib.colors.LogNorm(vmin=vmin, vmax=vmax)) + norm = matplotlib.colors.LogNorm(vmin=vmin, vmax=vmax) else: - plt.imshow(array, interpolation="nearest", cmap=cmap) + if vcenter is None: + norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax) + else: + norm = matplotlib.colors.TwoSlopeNorm(vmin=vmin, vcenter=vcenter, vmax=vmax) + plt.imshow(array, interpolation="nearest", cmap=cmap, norm=norm) if cbar is True: plt.colorbar() @@ -163,6 +168,7 @@ def clustermap( array_size=None, vmin=None, vmax=None, + vcenter=None, ): r"""Plot heatmap with hierarchically clustered rows and columns using `cell2location.plt.plot_heatmap.heatmap()` and `cell2location.plt.plot_heatmap.dotplot()`. @@ -216,6 +222,7 @@ def clustermap( title=title, vmin=vmin, vmax=vmax, + vcenter=vcenter, ) elif fun_type == "dotplot": # plot dotplot diff --git a/cell2location/utils/__init__.py b/cell2location/utils/__init__.py index 7c7bccd8..9df3f330 100755 --- a/cell2location/utils/__init__.py +++ b/cell2location/utils/__init__.py @@ -2,7 +2,12 @@ import numpy as np -from ._spatial_knn import spatial_neighbours, sum_neighbours +from ._spatial_knn import ( + from_c2l_get_lr_abundance, + make_spatial_neighbours, + spatial_neighbours, + sum_neighbours, +) from .filtering import filter_genes @@ -14,11 +19,15 @@ def select_slide(adata, s, batch_key="sample"): :param batch_key: column in adata.obs listing experiment name for each location """ - slide = adata[adata.obs[batch_key].isin([s]), :].copy() + slide = adata[adata.obs[batch_key].isin([s]), :] s_keys = list(slide.uns["spatial"].keys()) s_spatial = np.array(s_keys)[[s in k for k in s_keys]][0] - slide.uns["spatial"] = {s_spatial: slide.uns["spatial"][s_spatial]} + spatial = {s_spatial: slide.uns["spatial"][s_spatial]} + del slide.uns["spatial"] + + slide = slide.copy() + slide.uns["spatial"] = spatial return slide @@ -46,4 +55,6 @@ def list_imported_modules(): "spatial_neighbours", "sum_neighbours", "list_imported_modules", + "make_spatial_neighbours", + "from_c2l_get_lr_abundance", ] diff --git a/cell2location/utils/_spatial_knn.py b/cell2location/utils/_spatial_knn.py index 7b95f955..082f06f9 100644 --- a/cell2location/utils/_spatial_knn.py +++ b/cell2location/utils/_spatial_knn.py @@ -1,10 +1,116 @@ import numpy as np +import pandas as pd +import scanpy as sc from scipy.sparse import coo_matrix from scipy.spatial import cKDTree from sklearn.neighbors import KDTree from umap.umap_ import fuzzy_simplicial_set +def from_c2l_get_lr_abundance( + adata_vis, + cell_state, + signal_bool, + receptor_bool, + receptor_bool_b, + signal_receptor_mask, + top_n: int = 20, + scale_receptor_abundance_by_m_g: bool = False, + post_sample_name: str = "post_sample_q05", + use_normalisation_by_y_s: bool = False, + use_normalisation_by_total: bool = False, + use_normalisation_per_signal: bool = False, + use_normalisation_per_receptor: bool = False, +): + if np.all(adata_vis.obs_names == adata_vis.uns["mod"]["obs_names"]): + obs_bool = np.ones_like(adata_vis.obs_names, dtype=bool) + obs_names = adata_vis.obs_names + else: + obs_names = np.intersect1d(adata_vis.obs_names, adata_vis.uns["mod"]["obs_names"]) + obs_bool = np.isin(adata_vis.uns["mod"]["obs_names"], obs_names) + assert np.all(adata_vis.uns["mod"]["obs_names"][obs_bool] == obs_names) + assert np.all(adata_vis.var_names.values.astype("str") == adata_vis.uns["mod"]["var_names"].astype("str")) + var_names = adata_vis.uns["mod"]["var_names"] + adata_vis = adata_vis[obs_names] + adata_vis = adata_vis[:, var_names] + cell_state = cell_state.loc[var_names, :] + adata_vis.obsm["w_sf"] = adata_vis.uns["mod"][post_sample_name]["w_sf"][obs_bool, :] + # normalisation + if use_normalisation_by_y_s: + normalisation = adata_vis.uns["mod"][post_sample_name]["detection_y_s"][obs_bool, :] + elif use_normalisation_by_total: + normalisation = np.asarray(adata_vis.X.sum(1)) / 10000.0 + else: + normalisation = 1.0 + adata_vis.obsm["signal_abundance"] = adata_vis.X[:, signal_bool].toarray() / normalisation + # normalise signal abundance + if use_normalisation_per_signal: + for i in range(adata_vis.obsm["signal_abundance"].shape[1]): + top_n_vals = np.sort(adata_vis.obsm["signal_abundance"][:, i])[::-1][:top_n] + adata_vis.obsm["signal_abundance"][:, i] = adata_vis.obsm["signal_abundance"][:, i] / top_n_vals.mean() + adata_vis.obsm["signal_abundance"][np.isnan(adata_vis.obsm["signal_abundance"])] = 0.0 + adata_vis.obsm["signal_abundance"] = pd.DataFrame( + adata_vis.obsm["signal_abundance"], + index=adata_vis.obs_names, + columns=signal_receptor_mask.index, + ) + if scale_receptor_abundance_by_m_g: + m_g = adata_vis.uns["mod"][post_sample_name]["m_g"].T + cell_state = cell_state * m_g + # get minimum of the two receptor subunits + receptor_abundance = np.minimum( + (cell_state).iloc[receptor_bool, :].values, (cell_state).iloc[receptor_bool_b, :].values + ) + receptor_abundance = pd.DataFrame( + receptor_abundance, + index=signal_receptor_mask.columns, + columns=adata_vis.uns["mod"]["factor_names"], + ) + # normalise receptor abundance + if not scale_receptor_abundance_by_m_g and use_normalisation_per_receptor: + receptor_abundance = (receptor_abundance.T / receptor_abundance.max(1)).T + + per_cell_type_normalisation = ( + 1.0 + / np.array( + [np.sort(adata_vis.obsm["w_sf"][:, i])[::-1][:top_n].mean() for i in range(adata_vis.obsm["w_sf"].shape[1])] + ) + ).astype("float32") + + return adata_vis, receptor_abundance, per_cell_type_normalisation + + +def get_lr_abundance(cell_state, d_sg, m_g, y_s, signal_bool, receptor_bool, receptor_bool_b): + # get lr abundance + signal_abundance = d_sg[:, signal_bool] / y_s + receptor_abundance = np.minimum((cell_state * m_g)[:, receptor_bool], (cell_state * m_g)[:, receptor_bool_b]) + + return signal_abundance, receptor_abundance + + +def make_spatial_neighbours( + adata_vis, + batch_key: str = "sample", + spatial_key: str = "spatial", + n_neighbors: int = 200, +): + # compute KNN using the coordinates stored in adata.obsm + sc.pp.neighbors(adata_vis, use_rep=spatial_key, metric="euclidean", n_neighbors=n_neighbors) + + from scipy.sparse import csr_matrix + + batch_id = csr_matrix(pd.get_dummies(adata_vis.obs[batch_key])) + batch_id = batch_id @ batch_id.T + + adata_vis.obsp["distances"] = csr_matrix( + adata_vis.obsp["distances"].astype("float32").multiply(batch_id.astype("float32")) + ) + adata_vis.obsp["connectivities"] = csr_matrix( + adata_vis.obsp["connectivities"].astype("float32").multiply(batch_id.astype("float32")) + ) + return adata_vis + + def get_sparse_matrix_from_indices_distances_umap(knn_indices, knn_dists, n_obs, n_neighbors): """ Copied out of scanpy.neighbors diff --git a/setup.cfg b/setup.cfg index 699326b7..5d7d6c19 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = cell2location -version = 0.1.4 +version = 0.1.5 description = cell2location: High-throughput spatial mapping of cell types long_description = file: README.md long_description_content_type = text/markdown @@ -22,6 +22,7 @@ install_requires = pandas scanpy opencv-python + einops [options.extras_require] dev = diff --git a/tests/test_cell2location.py b/tests/test_cell2location.py index f762c3f7..af70b1dd 100644 --- a/tests/test_cell2location.py +++ b/tests/test_cell2location.py @@ -1,4 +1,5 @@ import numpy as np +import pytest import torch from pyro.infer.autoguide import AutoHierarchicalNormalMessenger from scvi.data import synthetic_iid @@ -18,10 +19,20 @@ def export_posterior(model, dataset): dataset = model.export_posterior(dataset, use_quantiles=True, add_to_obsm=["q50", "q001"]) # quantile 0.50 dataset = model.export_posterior( - dataset, use_quantiles=True, add_to_obsm=["q50"], sample_kwargs={"batch_size": 10} + dataset, + use_quantiles=True, + add_to_obsm=["q50"], + sample_kwargs={"batch_size": 10}, + ) # quantile 0.50 + dataset = model.export_posterior( + dataset, + use_quantiles=True, + add_to_obsm=["q50"], + sample_kwargs={"batch_size": 10, "use_median": True}, ) # quantile 0.50 dataset = model.export_posterior(dataset, use_quantiles=True) # default dataset = model.export_posterior(dataset, use_quantiles=True, sample_kwargs={"batch_size": 10}) + return dataset def export_posterior_sc(model, dataset): @@ -29,8 +40,12 @@ def export_posterior_sc(model, dataset): dataset = model.export_posterior( dataset, use_quantiles=True, add_to_varm=["q50"], sample_kwargs={"batch_size": 10} ) # quantile 0.50 + dataset = model.export_posterior( + dataset, use_quantiles=True, add_to_varm=["q50"], sample_kwargs={"batch_size": 10, "use_median": True} + ) # quantile 0.50 dataset = model.export_posterior(dataset, use_quantiles=True) # default dataset = model.export_posterior(dataset, use_quantiles=True, sample_kwargs={"batch_size": 10}) + return dataset def test_cell2location(): @@ -79,9 +94,68 @@ def test_cell2location(): # export the estimated cell abundance (summary of the posterior distribution) # full data dataset = st_model.export_posterior(dataset, sample_kwargs={"num_samples": 10, "batch_size": st_model.adata.n_obs}) + assert "data_target" not in dataset.uns["mod"]["post_sample_means"].keys() + assert "u_sf_mRNA_factors" in dataset.uns["mod"]["post_sample_means"].keys() + assert dataset.uns["mod"]["post_sample_means"]["w_sf"].shape == (dataset.n_obs, dataset.obs["labels"].nunique()) # test quantile export - export_posterior(st_model, dataset) + dataset = export_posterior(st_model, dataset) + dataset = st_model.export_posterior( + dataset, + use_quantiles=True, + add_to_obsm=["q50", "q05", "q001"], + ) + assert "data_target" not in dataset.uns["mod"]["post_sample_q50"].keys() + assert "u_sf_mRNA_factors" in dataset.uns["mod"]["post_sample_q50"].keys() + assert "u_sf_mRNA_factors" in dataset.uns["mod"]["post_sample_q001"].keys() + assert dataset.uns["mod"]["post_sample_q50"]["w_sf"].shape == (dataset.n_obs, dataset.obs["labels"].nunique()) st_model.plot_QC(summary_name="q05") + # test correct indexing + dataset = st_model.export_posterior( + dataset, + sample_kwargs={ + "num_samples": 10, # "batch_size": st_model.adata.n_obs, + "return_observed": True, + }, + ) + assert np.allclose(dataset.X.astype("float32"), dataset.uns["mod"]["post_sample_means"]["data_target"]) + dataset = st_model.export_posterior( + dataset, + use_quantiles=True, + add_to_obsm=["q50", "q05", "q001"], + sample_kwargs={ + # "batch_size": st_model.adata.n_obs, + "return_observed": True, + }, + ) + # u_sf_mRNA_factors_full = dataset.uns["mod"]["post_sample_q50"]["u_sf_mRNA_factors"] + # u_sf_mRNA_factors_full_q05 = dataset.uns["mod"]["post_sample_q05"]["u_sf_mRNA_factors"] + assert np.allclose(dataset.X.astype("float32"), dataset.uns["mod"]["post_sample_q50"]["data_target"]) + assert np.allclose(dataset.X.astype("float32"), dataset.uns["mod"]["post_sample_q05"]["data_target"]) + dataset = st_model.export_posterior( + dataset, + sample_kwargs={ + "num_samples": 10, + "batch_size": 50, + "return_observed": True, + }, + ) + assert np.allclose(dataset.X.astype("float32"), dataset.uns["mod"]["post_sample_means"]["data_target"]) + dataset = st_model.export_posterior( + dataset, + use_quantiles=True, + add_to_obsm=["q50", "q05", "q001"], + sample_kwargs={ + "batch_size": 50, + "return_observed": True, + }, + ) + # u_sf_mRNA_factors_batch = dataset.uns["mod"]["post_sample_q50"]["u_sf_mRNA_factors"] + # u_sf_mRNA_factors_batch_q05 = dataset.uns["mod"]["post_sample_q05"]["u_sf_mRNA_factors"] + assert np.allclose(dataset.X.astype("float32"), dataset.uns["mod"]["post_sample_q50"]["data_target"]) + assert np.allclose(dataset.X.astype("float32"), dataset.uns["mod"]["post_sample_q05"]["data_target"]) + # TODO uncomment the test after fixing "batch_size": st_model.adata.n_obs bug + # assert np.allclose(u_sf_mRNA_factors_batch, u_sf_mRNA_factors_full) + # assert np.allclose(u_sf_mRNA_factors_batch_q05, u_sf_mRNA_factors_full_q05) ## minibatches of locations ## Cell2location.setup_anndata(dataset, batch_key="batch") st_model = Cell2location(dataset, cell_state_df=inf_aver, N_cells_per_location=30, detection_alpha=200) @@ -166,7 +240,6 @@ def test_cell2location(): batch_size=20, plan_kwargs={"n_aggressive_epochs": 1, "n_aggressive_steps": 5}, accelerator=accelerator, - use_gpu=use_gpu, ) # test hiding variables on the list var_list = ["locs.s_g_gene_add_alpha_e_inv"] @@ -190,7 +263,6 @@ def test_cell2location(): batch_size=20, plan_kwargs={"n_aggressive_epochs": 1, "n_aggressive_steps": 5}, accelerator=accelerator, - use_gpu=use_gpu, ) for k, v in st_model.module.guide.named_parameters(): k_in_vars = np.any([i in k for i in var_list]) @@ -319,3 +391,260 @@ def test_cell2location(): sample_key="batch", ) melt_signal_target_data_frame(weighted_avg_dict, distance_bins) + + +@pytest.mark.parametrize("sliding_window_size", [0, 4]) +@pytest.mark.parametrize("use_aggregated_w_sf", [False, True]) +@pytest.mark.parametrize("amortised", [False, True]) +@pytest.mark.parametrize("amortised_sliding_window_size", [0, 4]) +@pytest.mark.parametrize("n_tiles", [1, 2]) +@pytest.mark.parametrize("sliding_window_size_list", [None, [0, 4, 8]]) +@pytest.mark.parametrize("use_weigted_cnn_weights", [False]) +def test_cell2location_with_aggregation( + sliding_window_size, + use_aggregated_w_sf, + amortised, + amortised_sliding_window_size, + n_tiles, + sliding_window_size_list, + use_weigted_cnn_weights, +): + save_path = "./cell2location_model_test" + if torch.cuda.is_available(): + accelerator = "gpu" + else: + accelerator = "cpu" + data_size = 200 + dataset = synthetic_iid(batch_size=data_size * n_tiles, n_labels=5) + dataset.obsm["X_spatial"] = np.random.normal(0, 1, [dataset.n_obs, 2]) + RegressionModel.setup_anndata(dataset, labels_key="labels", batch_key="batch") + + # train regression model to get signatures of cell types + sc_model = RegressionModel(dataset) + # test minibatch training + sc_model.train(max_epochs=1, batch_size=100, accelerator=accelerator) + # export the estimated cell abundance (summary of the posterior distribution) + dataset = sc_model.export_posterior(dataset, sample_kwargs={"num_samples": 10}) + # test quantile export + export_posterior_sc(sc_model, dataset) + sc_model.plot_QC(summary_name="q05") + # export estimated expression in each cluster + if "means_per_cluster_mu_fg" in dataset.varm.keys(): + inf_aver = dataset.varm["means_per_cluster_mu_fg"][ + [f"means_per_cluster_mu_fg_{i}" for i in dataset.uns["mod"]["factor_names"]] + ].copy() + else: + inf_aver = dataset.var[[f"means_per_cluster_mu_fg_{i}" for i in dataset.uns["mod"]["factor_names"]]].copy() + inf_aver.columns = dataset.uns["mod"]["factor_names"] + ### test cell2location model with convolutions ### + use_tiles = (sliding_window_size > 0) or (amortised_sliding_window_size > 0) + tiles = [] + for i in range(n_tiles): + tiles = tiles + [f"tile{i}" for _ in range(data_size * 2)] + dataset.obs["tiles"] = tiles + Cell2location.setup_anndata( + dataset, + batch_key="batch", + position_key=None, + tiles_key="tiles" if use_tiles else None, + ) + ## full data ## + st_model = Cell2location( + dataset, + cell_state_df=inf_aver, + N_cells_per_location=30, + detection_alpha=200, + average_distance_prior=5.0, + sliding_window_size=sliding_window_size, + amortised_sliding_window_size=amortised_sliding_window_size, + sliding_window_size_list=sliding_window_size_list + if (sliding_window_size > 0) or (amortised_sliding_window_size > 0) + else None, + image_size=[20, 20], + use_aggregated_w_sf=use_aggregated_w_sf, + use_weigted_cnn_weights=use_weigted_cnn_weights, + use_independent_prior_on_w_sf=True, + amortised=amortised, + encoder_mode="multiple", + encoder_kwargs={ + "dropout_rate": 0.1, + "n_hidden": { + "multiple": 256, + "single": 256, + "n_s_cells_per_location": 10, + "b_s_groups_per_location": 10, + "a_s_factors_per_location": 10, + "z_sr_groups_factors": 64, + "w_sf": 256, + "prior_w_sf": 256, + "detection_y_s": 10, + }, + "use_batch_norm": False, + "use_layer_norm": True, + "n_layers": 1, + "activation_fn": torch.nn.ELU, + }, + ) + shuffle = False if (sliding_window_size > 0) or (amortised_sliding_window_size > 0) else True + batch_size = n_tiles if (sliding_window_size > 0) or (amortised_sliding_window_size > 0) else None + # test full data training + st_model.train( + max_epochs=1, + accelerator=accelerator, + shuffle_set_split=shuffle, + batch_size=batch_size, + # datasplitter_kwargs={"shuffle": shuffle, "shuffle_set_split": shuffle}, + ) + if (sliding_window_size > 0) or (amortised_sliding_window_size > 0): + st_model.module.model.n_tiles = 1 + # test save/load + st_model.save(save_path, overwrite=True, save_anndata=True) + st_model = Cell2location.load(save_path) + # export the estimated cell abundance (summary of the posterior distribution) + # full data + if (sliding_window_size > 0) or (amortised_sliding_window_size > 0): + dataset = st_model.export_posterior( + dataset, + sample_kwargs={ + "batch_size": 1, + "use_median": True, + }, + add_to_obsm=["q50"], + use_quantiles=True, + ) + else: + dataset = st_model.export_posterior( + dataset, + sample_kwargs={ + "num_samples": 10, + "batch_size": 100, + }, + ) + + +@pytest.mark.parametrize("use_cell_comm_prior_on_w_sf", [False, True]) +@pytest.mark.parametrize("use_cell_comm_likelihood_w_sf", [False, True]) +@pytest.mark.parametrize("amortised", [False, True]) +def test_cell2location_with_aggregation_cell_comm( + use_cell_comm_prior_on_w_sf, + use_cell_comm_likelihood_w_sf, + amortised, +): + if (use_cell_comm_prior_on_w_sf and use_cell_comm_likelihood_w_sf) or ( + not use_cell_comm_prior_on_w_sf and not use_cell_comm_likelihood_w_sf + ): + return None + save_path = "./cell2location_model_test" + if torch.cuda.is_available(): + accelerator = "gpu" + else: + accelerator = "cpu" + data_size = 200 + dataset = synthetic_iid(batch_size=data_size, n_labels=5) + dataset.obsm["X_spatial"] = np.random.normal(0, 1, [dataset.n_obs, 2]) + RegressionModel.setup_anndata(dataset, labels_key="labels", batch_key="batch") + + # train regression model to get signatures of cell types + sc_model = RegressionModel(dataset) + # test minibatch training + sc_model.train(max_epochs=1, batch_size=100, accelerator=accelerator) + # export the estimated cell abundance (summary of the posterior distribution) + dataset = sc_model.export_posterior(dataset, sample_kwargs={"num_samples": 10}) + # test quantile export + export_posterior_sc(sc_model, dataset) + sc_model.plot_QC(summary_name="q05") + # export estimated expression in each cluster + if "means_per_cluster_mu_fg" in dataset.varm.keys(): + inf_aver = dataset.varm["means_per_cluster_mu_fg"][ + [f"means_per_cluster_mu_fg_{i}" for i in dataset.uns["mod"]["factor_names"]] + ].copy() + else: + inf_aver = dataset.var[[f"means_per_cluster_mu_fg_{i}" for i in dataset.uns["mod"]["factor_names"]]].copy() + inf_aver.columns = dataset.uns["mod"]["factor_names"] + ### test cell2location model with cell comm terms ### + use_distance_fun = use_cell_comm_prior_on_w_sf | use_cell_comm_likelihood_w_sf + Cell2location.setup_anndata( + dataset, + batch_key="batch", + position_key=None if not use_distance_fun else "X_spatial", + ) + + signal_bool = np.random.choice([True, False], dataset.n_vars) + receptor_bool = np.random.choice([True, False], dataset.n_vars) + signal_receptor_mask = np.random.choice([True, False], [signal_bool.sum(), receptor_bool.sum()]) + receptor_tf_mask = None + distances = np.random.uniform(0, 100, [dataset.n_obs, dataset.n_obs]) * np.random.choice( + [True, False], [dataset.n_obs, dataset.n_obs] + ) + from scipy.sparse import coo_matrix + + distances = coo_matrix(distances) + + ## full data ## + st_model = Cell2location( + dataset, + cell_state_df=inf_aver, + N_cells_per_location=30, + detection_alpha=200, + signal_bool=signal_bool, + receptor_bool=receptor_bool, + signal_receptor_mask=signal_receptor_mask, + receptor_tf_mask=receptor_tf_mask, + distances=distances, + average_distance_prior=5.0, + use_cell_comm_prior_on_w_sf=use_cell_comm_prior_on_w_sf, + use_cell_comm_likelihood_w_sf=use_cell_comm_likelihood_w_sf, + use_independent_prior_on_w_sf=True, + amortised=amortised, + encoder_mode="multiple", + encoder_kwargs={ + "dropout_rate": 0.1, + "n_hidden": { + "multiple": 256, + "single": 256, + "n_s_cells_per_location": 10, + "b_s_groups_per_location": 10, + "a_s_factors_per_location": 10, + "z_sr_groups_factors": 64, + "w_sf": 256, + "prior_w_sf": 256, + "detection_y_s": 10, + }, + "use_batch_norm": False, + "use_layer_norm": True, + "n_layers": 1, + "activation_fn": torch.nn.ELU, + }, + ) + shuffle = True + batch_size = None + # test full data training + st_model.train( + max_epochs=1, + accelerator=accelerator, + shuffle_set_split=shuffle, + batch_size=batch_size, + # datasplitter_kwargs={"shuffle": shuffle, "shuffle_set_split": shuffle}, + ) + # test save/load + st_model.save(save_path, overwrite=True, save_anndata=True) + st_model = Cell2location.load(save_path) + # export the estimated cell abundance (summary of the posterior distribution) + # full data + batch_size = dataset.n_obs + dataset = st_model.export_posterior( + dataset, + sample_kwargs={ + "batch_size": batch_size, + "use_median": True, + }, + add_to_obsm=["q50"], + use_quantiles=True, + ) + dataset = st_model.export_posterior( + dataset, + sample_kwargs={ + "num_samples": 10, + "batch_size": batch_size, + }, + )