diff --git a/cell2location/models/_cell2location_model.py b/cell2location/models/_cell2location_model.py index 68320d38..07308b78 100755 --- a/cell2location/models/_cell2location_model.py +++ b/cell2location/models/_cell2location_model.py @@ -82,6 +82,11 @@ def __init__( self.cell_state_df_ = cell_state_df self.n_factors_ = cell_state_df.shape[1] self.factor_names_ = cell_state_df.columns.values + # annotations for extra categorical covariates + if "extra_categoricals" in self.adata.uns["_scvi"].keys(): + self.extra_categoricals_ = self.adata.uns["_scvi"]["extra_categoricals"] + self.n_extra_categoricals_ = self.adata.uns["_scvi"]["extra_categoricals"]["n_cats_per_key"] + model_kwargs["n_extra_categoricals"] = self.n_extra_categoricals_ if not detection_mean_per_sample: # compute expected change in sensitivity (m_g in V1 or y_s in V2) diff --git a/cell2location/models/_cell2location_module.py b/cell2location/models/_cell2location_module.py old mode 100755 new mode 100644 index 5c22da79..d772dd12 --- a/cell2location/models/_cell2location_module.py +++ b/cell2location/models/_cell2location_module.py @@ -27,13 +27,14 @@ class LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGen as a linear function of expression signatures of reference cell types :math:`g_{f,g}`: .. math:: - \mu_{s,g} = (m_{g} \left (\sum_{f} {w_{s,f} \: g_{f,g}} \right) + s_{e,g}) y_{s} + \mu_{s,g} = (m_{g} \left (\sum_{f} {w_{s,f} \: g_{f,g}} \right) + s_{e,g}) y_{s} y_{t,g} Here, :math:`w_{s,f}` denotes regression weight of each reference signature :math:`f` at location :math:`s`, which can be interpreted as the expected number of cells at location :math:`s` that express reference signature :math:`f`; :math:`g_{f,g}` denotes the reference signatures of cell types :math:`f` of each gene :math:`g`, `cell_state_df` input ; :math:`m_{g}` denotes a gene-specific scaling parameter which adjusts for global differences in sensitivity between technologies (platform effect); :math:`y_{s}` denotes a location/observation-specific scaling parameter which adjusts for differences in sensitivity between observations and batches; :math:`s_{e,g}` is additive component that account for gene- and location-specific shift, such as due to contaminating or free-floating RNA. + :math:`y_{t,g}` denotes per gene :math:`g` multiplicative detection efficiency normalisation for each covariate :math:`t` To account for the similarity of location patterns across cell types, :math:`w_{s,f}` is modelled using another layer of decomposition (factorization) using :math:`r={1, .., R}` groups of cell types, @@ -79,6 +80,7 @@ def __init__( n_factors, n_batch, cell_state_mat, + n_extra_categoricals=None, n_groups: int = 50, detection_mean=1 / 2, detection_alpha=20.0, @@ -94,6 +96,7 @@ def __init__( "beta": 100.0, }, detection_hyp_prior={"mean_alpha": 10.0}, + gene_tech_prior={"mean": 1, "alpha": 200}, w_sf_mean_var_ratio=5.0, init_vals: Optional[dict] = None, init_alpha=20.0, @@ -107,6 +110,7 @@ def __init__( self.n_factors = n_factors self.n_batch = n_batch self.n_groups = n_groups + self.n_extra_categoricals = n_extra_categoricals self.m_g_gene_level_prior = m_g_gene_level_prior @@ -117,6 +121,7 @@ def __init__( detection_hyp_prior["mean"] = detection_mean detection_hyp_prior["alpha"] = detection_alpha self.detection_hyp_prior = detection_hyp_prior + self.gene_tech_prior = gene_tech_prior self.dropout_p = dropout_p if self.dropout_p is not None: @@ -131,6 +136,7 @@ def __init__( factors_per_groups = A_factors_per_location / B_groups_per_location + # normalisation priors self.register_buffer( "detection_hyp_prior_alpha", torch.tensor(self.detection_hyp_prior["alpha"]), @@ -143,6 +149,14 @@ def __init__( "detection_mean_hyp_prior_beta", torch.tensor(self.detection_hyp_prior["mean_alpha"] / self.detection_hyp_prior["mean"]), ) + self.register_buffer( + "gene_tech_prior_alpha", + torch.tensor(self.gene_tech_prior["alpha"]), + ) + self.register_buffer( + "gene_tech_prior_beta", + torch.tensor(self.gene_tech_prior["alpha"] / self.gene_tech_prior["mean"]), + ) # compute hyperparameters from mean and sd self.register_buffer("m_g_mu_hyp", torch.tensor(self.m_g_gene_level_prior["mean"])) @@ -197,13 +211,28 @@ def __init__( self.register_buffer("eps", torch.tensor(1e-8)) @staticmethod - def _get_fn_args_from_batch(tensor_dict): - x_data = tensor_dict[REGISTRY_KEYS.X_KEY] + def _get_fn_args_from_batch_no_cat(tensor_dict): + x_data = tensor_dict[_CONSTANTS.X_KEY] ind_x = tensor_dict["ind_x"].long().squeeze() - batch_index = tensor_dict[REGISTRY_KEYS.BATCH_KEY] - return (x_data, ind_x, batch_index), {} + batch_index = tensor_dict[_CONSTANTS.BATCH_KEY] + return (x_data, ind_x, batch_index, batch_index), {} + + @staticmethod + def _get_fn_args_from_batch_cat(tensor_dict): + x_data = tensor_dict[_CONSTANTS.X_KEY] + ind_x = tensor_dict["ind_x"].long().squeeze() + batch_index = tensor_dict[_CONSTANTS.BATCH_KEY] + extra_categoricals = tensor_dict[_CONSTANTS.CAT_COVS_KEY] + return (x_data, ind_x, batch_index, extra_categoricals), {} + + @property + def _get_fn_args_from_batch(self): + if self.n_extra_categoricals is not None: + return self._get_fn_args_from_batch_cat + else: + return self._get_fn_args_from_batch_no_cat - def create_plates(self, x_data, idx, batch_index): + def create_plates(self, x_data, idx, batch_index, extra_categoricals): return pyro.plate("obs_plate", size=self.n_obs, dim=-2, subsample=idx) def list_obs_plate_vars(self): @@ -240,11 +269,22 @@ def list_obs_plate_vars(self): }, } - def forward(self, x_data, idx, batch_index): + def forward(self, x_data, idx, batch_index, extra_categoricals): obs2sample = one_hot(batch_index, self.n_batch) + if self.n_extra_categoricals is not None: + obs2extra_categoricals = torch.cat( + [ + one_hot( + extra_categoricals[:, i].view((extra_categoricals.shape[0], 1)), + n_cat, + ) + for i, n_cat in enumerate(self.n_extra_categoricals) + ], + dim=1, + ) - obs_plate = self.create_plates(x_data, idx, batch_index) + obs_plate = self.create_plates(x_data, idx, batch_index, extra_categoricals) # =====================Gene expression level scaling m_g======================= # # Explains difference in sensitivity for each gene between single cell and spatial technology @@ -269,6 +309,20 @@ def forward(self, x_data, idx, batch_index): dist.Gamma(m_g_alpha_e, m_g_alpha_e / m_g_mean).expand([1, self.n_vars]).to_event(2), # self.m_g_mu_hyp) ) # (1, n_vars) + # =====================Gene-specific multiplicative component ======================= # + # `y_{t, g}` per gene multiplicative effect that explains the difference + # in sensitivity between genes in each technology or covariate effect + if self.n_extra_categoricals is not None: + detection_tech_gene_tg = pyro.sample( + "detection_tech_gene_tg", + dist.Gamma( + self.ones * self.gene_tech_prior_alpha, + self.ones * self.gene_tech_prior_beta, + ) + .expand([np.sum(self.n_extra_categoricals), self.n_vars]) + .to_event(2), + ) + # =====================Cell abundances w_sf======================= # # factorisation prior on w_sf models similarity in locations # across cell types f and reflects the absolute scale of w_sf @@ -461,21 +515,20 @@ def forward(self, x_data, idx, batch_index): if not self.training_wo_observed: # expected expression mu = ((w_sf @ self.cell_state) * m_g + (obs2sample @ s_g_gene_add)) * detection_y_s + if self.n_extra_categoricals is not None: + # gene-specific normalisation for covatiates + mu = mu * (obs2extra_categoricals @ detection_tech_gene_tg) alpha = obs2sample @ (self.ones / alpha_g_inverse.pow(2)) - # convert mean and overdispersion to total count and logits - # total_count, logits = _convert_mean_disp_to_counts_logits( - # mu, alpha, eps=self.eps - # ) # =====================DATA likelihood ======================= # # Likelihood (sampling distribution) of data_target & add overdispersion via NegativeBinomial if self.dropout_p != 0: x_data = self.dropout(x_data) + with obs_plate: pyro.sample( "data_target", dist.GammaPoisson(concentration=alpha, rate=alpha / mu), - # dist.NegativeBinomial(total_count=total_count, logits=logits), obs=x_data, ) @@ -494,10 +547,21 @@ def compute_expected(self, samples, adata_manager, ind_x=None): ind_x = ind_x.astype(int) obs2sample = adata_manager.get_from_registry(REGISTRY_KEYS.BATCH_KEY) obs2sample = pd.get_dummies(obs2sample.flatten()).values[ind_x, :] + if self.n_extra_categoricals is not None: + extra_categoricals = get_from_registry(adata, _CONSTANTS.CAT_COVS_KEY) + obs2extra_categoricals = np.concatenate( + [ + pd.get_dummies(extra_categoricals.iloc[ind_x, i]) + for i, n_cat in enumerate(self.n_extra_categoricals) + ], + axis=1, + ) mu = ( np.dot(samples["w_sf"][ind_x, :], self.cell_state_mat.T) * samples["m_g"] + np.dot(obs2sample, samples["s_g_gene_add"]) ) * samples["detection_y_s"][ind_x, :] + if self.n_extra_categoricals is not None: + mu = mu * np.dot(obs2extra_categoricals, samples["detection_tech_gene_tg"]) alpha = np.dot(obs2sample, 1 / np.power(samples["alpha_g_inverse"], 2)) return {"mu": mu, "alpha": alpha, "ind_x": ind_x} diff --git a/tests/test_cell2location.py b/tests/test_cell2location.py index 0c53476f..7e6b45c2 100644 --- a/tests/test_cell2location.py +++ b/tests/test_cell2location.py @@ -162,3 +162,20 @@ def test_cell2location(): # export the estimated cell abundance (summary of the posterior distribution) # full data st_model.export_posterior(dataset, sample_kwargs={"num_samples": 10, "batch_size": st_model.adata.n_obs}) + + ## Model with extra categorical covariates ## + dataset_sp = dataset.copy() + Cell2location.setup_anndata( + dataset_sp, labels_key="labels", batch_key="batch", categorical_covariate_keys=["labels"] + ) + st_model = Cell2location( + dataset_sp, + cell_state_df=inf_aver, + N_cells_per_location=30, + detection_alpha=200, + ) + # test full data training + st_model.train(max_epochs=1) + # export the estimated cell abundance (summary of the posterior distribution) + # full data + st_model.export_posterior(dataset_sp, sample_kwargs={"num_samples": 10, "batch_size": st_model.adata.n_obs})