From c660906897ed662ee571702f16e2a3140e7749e3 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Tue, 16 Nov 2021 23:47:51 +0000 Subject: [PATCH 1/3] support for optional extra covariates in the spatial model --- cell2location/models/_cell2location_model.py | 5 ++ cell2location/models/_cell2location_module.py | 80 ++++++++++++++++--- tests/test_cell2location.py | 17 ++++ 3 files changed, 93 insertions(+), 9 deletions(-) diff --git a/cell2location/models/_cell2location_model.py b/cell2location/models/_cell2location_model.py index 21c70437..70531112 100755 --- a/cell2location/models/_cell2location_model.py +++ b/cell2location/models/_cell2location_model.py @@ -74,6 +74,11 @@ def __init__( self.cell_state_df_ = cell_state_df self.n_factors_ = cell_state_df.shape[1] self.factor_names_ = cell_state_df.columns.values + # annotations for extra categorical covariates + if "extra_categoricals" in self.adata.uns["_scvi"].keys(): + self.extra_categoricals_ = self.adata.uns["_scvi"]["extra_categoricals"] + self.n_extra_categoricals_ = self.adata.uns["_scvi"]["extra_categoricals"]["n_cats_per_key"] + model_kwargs["n_extra_categoricals"] = self.n_extra_categoricals_ if not detection_mean_per_sample: # compute expected change in sensitivity (m_g in V1 or y_s in V2) diff --git a/cell2location/models/_cell2location_module.py b/cell2location/models/_cell2location_module.py index ab705ee8..68ebc2b4 100755 --- a/cell2location/models/_cell2location_module.py +++ b/cell2location/models/_cell2location_module.py @@ -80,6 +80,7 @@ def __init__( n_factors, n_batch, cell_state_mat, + n_extra_categoricals=None, n_groups: int = 50, detection_mean=1 / 2, detection_alpha=200.0, @@ -95,6 +96,7 @@ def __init__( "beta": 100.0, }, detection_hyp_prior={"mean_alpha": 10.0}, + gene_tech_prior={"mean": 1, "alpha": 200}, w_sf_mean_var_ratio=5.0, init_vals: Optional[dict] = None, init_alpha=3.0, @@ -107,6 +109,7 @@ def __init__( self.n_factors = n_factors self.n_batch = n_batch self.n_groups = n_groups + self.n_extra_categoricals = n_extra_categoricals self.m_g_gene_level_prior = m_g_gene_level_prior @@ -117,6 +120,7 @@ def __init__( detection_hyp_prior["mean"] = detection_mean detection_hyp_prior["alpha"] = detection_alpha self.detection_hyp_prior = detection_hyp_prior + self.gene_tech_prior = gene_tech_prior if (init_vals is not None) & (type(init_vals) is dict): self.np_init_vals = init_vals @@ -127,6 +131,7 @@ def __init__( factors_per_groups = A_factors_per_location / B_groups_per_location + # normalisation priors self.register_buffer( "detection_hyp_prior_alpha", torch.tensor(self.detection_hyp_prior["alpha"]), @@ -139,6 +144,14 @@ def __init__( "detection_mean_hyp_prior_beta", torch.tensor(self.detection_hyp_prior["mean_alpha"] / self.detection_hyp_prior["mean"]), ) + self.register_buffer( + "gene_tech_prior_alpha", + torch.tensor(self.gene_tech_prior["alpha"]), + ) + self.register_buffer( + "gene_tech_prior_beta", + torch.tensor(self.gene_tech_prior["alpha"] / self.gene_tech_prior["mean"]), + ) # compute hyperparameters from mean and sd self.register_buffer("m_g_mu_hyp", torch.tensor(self.m_g_gene_level_prior["mean"])) @@ -193,11 +206,26 @@ def __init__( self.register_buffer("eps", torch.tensor(1e-8)) @staticmethod - def _get_fn_args_from_batch(tensor_dict): + def _get_fn_args_from_batch_no_cat(tensor_dict): + x_data = tensor_dict[_CONSTANTS.X_KEY] + ind_x = tensor_dict["ind_x"].long().squeeze() + batch_index = tensor_dict[_CONSTANTS.BATCH_KEY] + return (x_data, ind_x, batch_index, batch_index), {} + + @staticmethod + def _get_fn_args_from_batch_cat(tensor_dict): x_data = tensor_dict[_CONSTANTS.X_KEY] ind_x = tensor_dict["ind_x"].long().squeeze() batch_index = tensor_dict[_CONSTANTS.BATCH_KEY] - return (x_data, ind_x, batch_index), {} + extra_categoricals = tensor_dict[_CONSTANTS.CAT_COVS_KEY] + return (x_data, ind_x, batch_index, extra_categoricals), {} + + @property + def _get_fn_args_from_batch(self): + if self.n_extra_categoricals is not None: + return self._get_fn_args_from_batch_cat + else: + return self._get_fn_args_from_batch_no_cat def create_plates(self, x_data, idx, batch_index): return pyro.plate("obs_plate", size=self.n_obs, dim=-2, subsample=idx) @@ -229,9 +257,20 @@ def list_obs_plate_vars(self): }, } - def forward(self, x_data, idx, batch_index): + def forward(self, x_data, idx, batch_index, extra_categoricals): obs2sample = one_hot(batch_index, self.n_batch) + if self.n_extra_categoricals is not None: + obs2extra_categoricals = torch.cat( + [ + one_hot( + extra_categoricals[:, i].view((extra_categoricals.shape[0], 1)), + n_cat, + ) + for i, n_cat in enumerate(self.n_extra_categoricals) + ], + dim=1, + ) obs_plate = self.create_plates(x_data, idx, batch_index) @@ -258,6 +297,20 @@ def forward(self, x_data, idx, batch_index): dist.Gamma(m_g_alpha_e, m_g_alpha_e / m_g_mean).expand([1, self.n_vars]).to_event(2), # self.m_g_mu_hyp) ) # (1, n_vars) + # =====================Gene-specific multiplicative component ======================= # + # `y_{t, g}` per gene multiplicative effect that explains the difference + # in sensitivity between genes in each technology or covariate effect + if self.n_extra_categoricals is not None: + detection_tech_gene_tg = pyro.sample( + "detection_tech_gene_tg", + dist.Gamma( + self.ones * self.gene_tech_prior_alpha, + self.ones * self.gene_tech_prior_beta, + ) + .expand([np.sum(self.n_extra_categoricals), self.n_vars]) + .to_event(2), + ) + # =====================Cell abundances w_sf======================= # # factorisation prior on w_sf models similarity in locations # across cell types f and reflects the absolute scale of w_sf @@ -388,19 +441,17 @@ def forward(self, x_data, idx, batch_index): if not self.training_wo_observed: # expected expression mu = ((w_sf @ self.cell_state) * m_g + (obs2sample @ s_g_gene_add)) * detection_y_s + if self.n_extra_categoricals is not None: + # gene-specific normalisation for covatiates + mu = mu * (obs2extra_categoricals @ detection_tech_gene_tg) alpha = obs2sample @ (self.ones / alpha_g_inverse.pow(2)) - # convert mean and overdispersion to total count and logits - # total_count, logits = _convert_mean_disp_to_counts_logits( - # mu, alpha, eps=self.eps - # ) # =====================DATA likelihood ======================= # - # Likelihood (sampling distribution) of data_target & add overdispersion via NegativeBinomial + # Likelihood (sampling distribution) of observed RNA counts with obs_plate: pyro.sample( "data_target", dist.GammaPoisson(concentration=alpha, rate=alpha / mu), - # dist.NegativeBinomial(total_count=total_count, logits=logits), obs=x_data, ) @@ -419,10 +470,21 @@ def compute_expected(self, samples, adata, ind_x=None): ind_x = ind_x.astype(int) obs2sample = get_from_registry(adata, _CONSTANTS.BATCH_KEY) obs2sample = pd.get_dummies(obs2sample.flatten()).values[ind_x, :] + if self.n_extra_categoricals is not None: + extra_categoricals = get_from_registry(adata, _CONSTANTS.CAT_COVS_KEY) + obs2extra_categoricals = np.concatenate( + [ + pd.get_dummies(extra_categoricals.iloc[ind_x, i]) + for i, n_cat in enumerate(self.n_extra_categoricals) + ], + axis=1, + ) mu = ( np.dot(samples["w_sf"][ind_x, :], self.cell_state_mat.T) * samples["m_g"] + np.dot(obs2sample, samples["s_g_gene_add"]) ) * samples["detection_y_s"][ind_x, :] + if self.n_extra_categoricals is not None: + mu = mu * np.dot(obs2extra_categoricals, samples["detection_tech_gene_tg"]) alpha = np.dot(obs2sample, 1 / np.power(samples["alpha_g_inverse"], 2)) return {"mu": mu, "alpha": alpha, "ind_x": ind_x} diff --git a/tests/test_cell2location.py b/tests/test_cell2location.py index 02e4f587..9d33bbd1 100644 --- a/tests/test_cell2location.py +++ b/tests/test_cell2location.py @@ -120,3 +120,20 @@ def test_cell2location(): # export the estimated cell abundance (summary of the posterior distribution) # full data st_model.export_posterior(dataset, sample_kwargs={"num_samples": 10, "batch_size": st_model.adata.n_obs}) + + ## Model with extra categorical covariates ## + dataset_sp = dataset.copy() + Cell2location.setup_anndata( + dataset_sp, labels_key="labels", batch_key="batch", categorical_covariate_keys=["labels"] + ) + st_model = Cell2location( + dataset_sp, + cell_state_df=inf_aver, + N_cells_per_location=30, + detection_alpha=200, + ) + # test full data training + st_model.train(max_epochs=1) + # export the estimated cell abundance (summary of the posterior distribution) + # full data + st_model.export_posterior(dataset_sp, sample_kwargs={"num_samples": 10, "batch_size": st_model.adata.n_obs}) From df1d508763dbd0f7024fcf280a244a3fce86a0b2 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Tue, 16 Nov 2021 23:53:07 +0000 Subject: [PATCH 2/3] updated math description --- cell2location/models/_cell2location_module.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cell2location/models/_cell2location_module.py b/cell2location/models/_cell2location_module.py index 68ebc2b4..7f38ca31 100755 --- a/cell2location/models/_cell2location_module.py +++ b/cell2location/models/_cell2location_module.py @@ -28,13 +28,14 @@ class LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGen as a linear function of expression signatures of reference cell types :math:`g_{f,g}`: .. math:: - \mu_{s,g} = (m_{g} \left (\sum_{f} {w_{s,f} \: g_{f,g}} \right) + s_{e,g}) y_{s} + \mu_{s,g} = (m_{g} \left (\sum_{f} {w_{s,f} \: g_{f,g}} \right) + s_{e,g}) y_{s} y_{t,g} Here, :math:`w_{s,f}` denotes regression weight of each reference signature :math:`f` at location :math:`s`, which can be interpreted as the expected number of cells at location :math:`s` that express reference signature :math:`f`; :math:`g_{f,g}` denotes the reference signatures of cell types :math:`f` of each gene :math:`g`, `cell_state_df` input ; :math:`m_{g}` denotes a gene-specific scaling parameter which adjusts for global differences in sensitivity between technologies (platform effect); :math:`y_{s}` denotes a location/observation-specific scaling parameter which adjusts for differences in sensitivity between observations and batches; :math:`s_{e,g}` is additive component that account for gene- and location-specific shift, such as due to contaminating or free-floating RNA. + :math:`y_{t,g}` denotes per gene :math:`g` multiplicative detection efficiency normalisation for each covariate :math:`t` To account for the similarity of location patterns across cell types, :math:`w_{s,f}` is modelled using another layer of decomposition (factorization) using :math:`r={1, .., R}` groups of cell types, From b99040c5fa210d465ab13742d4a559562a97cc96 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Tue, 16 Nov 2021 23:58:29 +0000 Subject: [PATCH 3/3] updated create_plates --- cell2location/models/_cell2location_module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cell2location/models/_cell2location_module.py b/cell2location/models/_cell2location_module.py index 7f38ca31..9ba39261 100755 --- a/cell2location/models/_cell2location_module.py +++ b/cell2location/models/_cell2location_module.py @@ -228,7 +228,7 @@ def _get_fn_args_from_batch(self): else: return self._get_fn_args_from_batch_no_cat - def create_plates(self, x_data, idx, batch_index): + def create_plates(self, x_data, idx, batch_index, extra_categoricals): return pyro.plate("obs_plate", size=self.n_obs, dim=-2, subsample=idx) def list_obs_plate_vars(self): @@ -273,7 +273,7 @@ def forward(self, x_data, idx, batch_index, extra_categoricals): dim=1, ) - obs_plate = self.create_plates(x_data, idx, batch_index) + obs_plate = self.create_plates(x_data, idx, batch_index, extra_categoricals) # =====================Gene expression level scaling m_g======================= # # Explains difference in sensitivity for each gene between single cell and spatial technology