From 0ffcec77c8130dd693342ef4aaeddda4978cb29a Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Mon, 11 Jul 2022 15:42:16 +0100 Subject: [PATCH 1/2] option to train the model with residual factors --- cell2location/models/_cell2location_module.py | 84 +++++++++++++++++-- tests/test_cell2location.py | 4 + 2 files changed, 80 insertions(+), 8 deletions(-) diff --git a/cell2location/models/_cell2location_module.py b/cell2location/models/_cell2location_module.py index 5c22da79..f4b78a79 100755 --- a/cell2location/models/_cell2location_module.py +++ b/cell2location/models/_cell2location_module.py @@ -71,6 +71,7 @@ class LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGen # training mode without observed data (just using priors) training_wo_observed = False training_wo_initial = False + use_residual_factors = False def __init__( self, @@ -80,6 +81,7 @@ def __init__( n_batch, cell_state_mat, n_groups: int = 50, + n_residual_factors: int = 10, detection_mean=1 / 2, detection_alpha=20.0, m_g_gene_level_prior={"mean": 1, "mean_var_ratio": 1.0, "alpha_mean": 3.0}, @@ -87,6 +89,7 @@ def __init__( A_factors_per_location=7.0, B_groups_per_location=7.0, N_cells_mean_var_ratio=1.0, + A_residual_factors_per_location=3.0, alpha_g_phi_hyp_prior={"alpha": 9.0, "beta": 3.0}, gene_add_alpha_hyp_prior={"alpha": 9.0, "beta": 3.0}, gene_add_mean_hyp_prior={ @@ -94,6 +97,11 @@ def __init__( "beta": 100.0, }, detection_hyp_prior={"mean_alpha": 10.0}, + factor_prior={ + "rate": 1.0, + "alpha": 1.0, + "states_per_gene": 10.0, + }, w_sf_mean_var_ratio=5.0, init_vals: Optional[dict] = None, init_alpha=20.0, @@ -107,6 +115,7 @@ def __init__( self.n_factors = n_factors self.n_batch = n_batch self.n_groups = n_groups + self.n_residual_factors = n_residual_factors self.m_g_gene_level_prior = m_g_gene_level_prior @@ -117,6 +126,7 @@ def __init__( detection_hyp_prior["mean"] = detection_mean detection_hyp_prior["alpha"] = detection_alpha self.detection_hyp_prior = detection_hyp_prior + self.factor_prior = factor_prior self.dropout_p = dropout_p if self.dropout_p is not None: @@ -159,8 +169,22 @@ def __init__( self.register_buffer("N_cells_per_location", torch.tensor(N_cells_per_location)) self.register_buffer("factors_per_groups", torch.tensor(factors_per_groups)) self.register_buffer("B_groups_per_location", torch.tensor(B_groups_per_location)) + self.register_buffer("A_residual_factors_per_location", torch.tensor(A_residual_factors_per_location)) self.register_buffer("N_cells_mean_var_ratio", torch.tensor(N_cells_mean_var_ratio)) + self.register_buffer( + "factor_states_per_gene", + torch.tensor(self.factor_prior["states_per_gene"]), + ) + self.register_buffer( + "factor_prior_alpha", + torch.tensor(self.factor_prior["alpha"]), + ) + self.register_buffer( + "factor_prior_beta", + torch.tensor(self.factor_prior["alpha"] / self.factor_prior["rate"]), + ) + self.register_buffer( "alpha_g_phi_hyp_prior_alpha", torch.tensor(self.alpha_g_phi_hyp_prior["alpha"]), @@ -190,9 +214,11 @@ def __init__( self.register_buffer("n_factors_tensor", torch.tensor(self.n_factors)) self.register_buffer("n_groups_tensor", torch.tensor(self.n_groups)) + self.register_buffer("n_residual_factors_tensor", torch.tensor(self.n_residual_factors)) self.register_buffer("ones", torch.ones((1, 1))) self.register_buffer("ones_1_n_groups", torch.ones((1, self.n_groups))) + self.register_buffer("ones_1_n_residual_factors", torch.ones((1, self.n_residual_factors))) self.register_buffer("ones_n_batch_1", torch.ones((self.n_batch, 1))) self.register_buffer("eps", torch.tensor(1e-8)) @@ -237,6 +263,8 @@ def list_obs_plate_vars(self): "z_sr_groups_factors": self.n_groups, "w_sf": self.n_factors, "detection_y_s": 1, + "b_s_residual_factors_per_location": 1, + "w_sf_residual_factors": self.n_residual_factors, }, } @@ -457,10 +485,47 @@ def forward(self, x_data, idx, batch_index): dist.Exponential(alpha_g_phi_hyp).expand([self.n_batch, self.n_vars]).to_event(2), ) # (self.n_batch, self.n_vars) + # =======Residual cell abundances w_sf and loadings g_fg ======= # + if self.use_residual_factors: + with obs_plate as ind: + k = "b_s_residual_factors_per_location" + b_s_residual_factors_per_location = pyro.sample( + k, + dist.Gamma(self.A_residual_factors_per_location, self.ones), + ) + # location loadings + shape = self.ones_1_n_residual_factors * b_s_residual_factors_per_location / self.n_residual_factors_tensor + rate = self.ones_1_n_residual_factors / (n_s_cells_per_location / b_s_residual_factors_per_location) + with obs_plate as ind: + k = "w_sf_residual_factors" + w_sf_residual_factors = pyro.sample( + k, + dist.Gamma(shape, rate), + ) # (n_obs, n_groups) + # g_{f,g} + residual_factor_level_g = pyro.sample( + "residual_factor_level_g", + dist.Gamma(self.factor_prior_alpha, self.factor_prior_beta).expand([1, self.n_vars]).to_event(2), + obs=getattr(self, "fixed_val_residual_factor_level_g", None), + ) + g_fg_residual_factors = pyro.sample( + "g_fg_residual_factors", + dist.Gamma( + self.factor_states_per_gene / self.n_residual_factors_tensor, + self.ones / residual_factor_level_g, + ) + .expand([self.n_residual_factors, self.n_vars]) + .to_event(2), + obs=getattr(self, "fixed_val_g_fg_residual_factors", None), + ) + # =====================Expected expression ======================= # if not self.training_wo_observed: # expected expression - mu = ((w_sf @ self.cell_state) * m_g + (obs2sample @ s_g_gene_add)) * detection_y_s + mu_biol = w_sf @ self.cell_state + if self.use_residual_factors: + mu_biol = mu_biol + w_sf_residual_factors @ g_fg_residual_factors + mu = (mu_biol * m_g + (obs2sample @ s_g_gene_add)) * detection_y_s alpha = obs2sample @ (self.ones / alpha_g_inverse.pow(2)) # convert mean and overdispersion to total count and logits # total_count, logits = _convert_mean_disp_to_counts_logits( @@ -494,10 +559,12 @@ def compute_expected(self, samples, adata_manager, ind_x=None): ind_x = ind_x.astype(int) obs2sample = adata_manager.get_from_registry(REGISTRY_KEYS.BATCH_KEY) obs2sample = pd.get_dummies(obs2sample.flatten()).values[ind_x, :] - mu = ( - np.dot(samples["w_sf"][ind_x, :], self.cell_state_mat.T) * samples["m_g"] - + np.dot(obs2sample, samples["s_g_gene_add"]) - ) * samples["detection_y_s"][ind_x, :] + mu_biol = np.dot(samples["w_sf"][ind_x, :], self.cell_state_mat.T) + if self.use_residual_factors: + mu_biol = mu_biol + np.dot(samples["w_sf_residual_factors"][ind_x, :], samples["g_fg_residual_factors"]) + mu = (mu_biol * samples["m_g"] + np.dot(obs2sample, samples["s_g_gene_add"])) * samples["detection_y_s"][ + ind_x, : + ] alpha = np.dot(obs2sample, 1 / np.power(samples["alpha_g_inverse"], 2)) return {"mu": mu, "alpha": alpha, "ind_x": ind_x} @@ -536,9 +603,10 @@ def compute_expected_per_cell_type(self, samples, adata_manager, ind_x=None): # compute total expected expression obs2sample = adata_manager.get_from_registry(REGISTRY_KEYS.BATCH_KEY) obs2sample = pd.get_dummies(obs2sample.flatten()).values[ind_x, :] - mu = np.dot(samples["w_sf"][ind_x, :], self.cell_state_mat.T) * samples["m_g"] + np.dot( - obs2sample, samples["s_g_gene_add"] - ) + mu_biol = np.dot(samples["w_sf"][ind_x, :], self.cell_state_mat.T) + if self.use_residual_factors: + mu_biol = mu_biol + np.dot(samples["w_sf_residual_factors"][ind_x, :], samples["g_fg_residual_factors"]) + mu = mu_biol * samples["m_g"] + np.dot(obs2sample, samples["s_g_gene_add"]) # compute conditional expected expression per cell type mu_ct = [ diff --git a/tests/test_cell2location.py b/tests/test_cell2location.py index 0c53476f..b8b96b55 100644 --- a/tests/test_cell2location.py +++ b/tests/test_cell2location.py @@ -51,6 +51,10 @@ def test_cell2location(): st_model = Cell2location(dataset, cell_state_df=inf_aver, N_cells_per_location=30, detection_alpha=200) # test full data training st_model.train(max_epochs=1, use_gpu=use_gpu) + # test full data training with residual factors + st_model.module.model.use_residual_factors = True + st_model.train(max_epochs=1, use_gpu=use_gpu) + st_model.module.model.use_residual_factors = False # export the estimated cell abundance (summary of the posterior distribution) # full data dataset = st_model.export_posterior(dataset, sample_kwargs={"num_samples": 10, "batch_size": st_model.adata.n_obs}) From 5de57167fecf5a0a4d04c41cc8bac23804db0094 Mon Sep 17 00:00:00 2001 From: Vitalii Kleshchevnikov Date: Thu, 14 Jul 2022 14:28:27 +0100 Subject: [PATCH 2/2] train overdispersion again when using residual factors --- cell2location/models/_cell2location_module.py | 27 +++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/cell2location/models/_cell2location_module.py b/cell2location/models/_cell2location_module.py index f4b78a79..caad1921 100755 --- a/cell2location/models/_cell2location_module.py +++ b/cell2location/models/_cell2location_module.py @@ -476,14 +476,15 @@ def forward(self, x_data, idx, batch_index): ) # (self.n_batch, n_vars) # =====================Gene-specific overdispersion ======================= # - alpha_g_phi_hyp = pyro.sample( - "alpha_g_phi_hyp", - dist.Gamma(self.ones * self.alpha_g_phi_hyp_prior_alpha, self.ones * self.alpha_g_phi_hyp_prior_beta), - ) - alpha_g_inverse = pyro.sample( - "alpha_g_inverse", - dist.Exponential(alpha_g_phi_hyp).expand([self.n_batch, self.n_vars]).to_event(2), - ) # (self.n_batch, self.n_vars) + if not self.use_residual_factors: + alpha_g_phi_hyp = pyro.sample( + "alpha_g_phi_hyp", + dist.Gamma(self.ones * self.alpha_g_phi_hyp_prior_alpha, self.ones * self.alpha_g_phi_hyp_prior_beta), + ) + alpha_g_inverse = pyro.sample( + "alpha_g_inverse", + dist.Exponential(alpha_g_phi_hyp).expand([self.n_batch, self.n_vars]).to_event(2), + ) # (self.n_batch, self.n_vars) # =======Residual cell abundances w_sf and loadings g_fg ======= # if self.use_residual_factors: @@ -519,6 +520,16 @@ def forward(self, x_data, idx, batch_index): obs=getattr(self, "fixed_val_g_fg_residual_factors", None), ) + # Gene-specific overdispersion + alpha_g_phi_hyp = pyro.sample( + "alpha_residual_g_phi_hyp", + dist.Gamma(self.ones * self.alpha_g_phi_hyp_prior_alpha, self.ones * self.alpha_g_phi_hyp_prior_beta), + ) + alpha_g_inverse = pyro.sample( + "alpha_residual_g_inverse", + dist.Exponential(alpha_g_phi_hyp).expand([self.n_batch, self.n_vars]).to_event(2), + ) # (self.n_batch, self.n_vars) + # =====================Expected expression ======================= # if not self.training_wo_observed: # expected expression