Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 95 additions & 16 deletions cell2location/models/_cell2location_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ class LocationModelLinearDependentWMultiExperimentLocationBackgroundNormLevelGen
# training mode without observed data (just using priors)
training_wo_observed = False
training_wo_initial = False
use_residual_factors = False

def __init__(
self,
Expand All @@ -80,20 +81,27 @@ def __init__(
n_batch,
cell_state_mat,
n_groups: int = 50,
n_residual_factors: int = 10,
detection_mean=1 / 2,
detection_alpha=20.0,
m_g_gene_level_prior={"mean": 1, "mean_var_ratio": 1.0, "alpha_mean": 3.0},
N_cells_per_location=8.0,
A_factors_per_location=7.0,
B_groups_per_location=7.0,
N_cells_mean_var_ratio=1.0,
A_residual_factors_per_location=3.0,
alpha_g_phi_hyp_prior={"alpha": 9.0, "beta": 3.0},
gene_add_alpha_hyp_prior={"alpha": 9.0, "beta": 3.0},
gene_add_mean_hyp_prior={
"alpha": 1.0,
"beta": 100.0,
},
detection_hyp_prior={"mean_alpha": 10.0},
factor_prior={
"rate": 1.0,
"alpha": 1.0,
"states_per_gene": 10.0,
},
w_sf_mean_var_ratio=5.0,
init_vals: Optional[dict] = None,
init_alpha=20.0,
Expand All @@ -107,6 +115,7 @@ def __init__(
self.n_factors = n_factors
self.n_batch = n_batch
self.n_groups = n_groups
self.n_residual_factors = n_residual_factors

self.m_g_gene_level_prior = m_g_gene_level_prior

Expand All @@ -117,6 +126,7 @@ def __init__(
detection_hyp_prior["mean"] = detection_mean
detection_hyp_prior["alpha"] = detection_alpha
self.detection_hyp_prior = detection_hyp_prior
self.factor_prior = factor_prior

self.dropout_p = dropout_p
if self.dropout_p is not None:
Expand Down Expand Up @@ -159,8 +169,22 @@ def __init__(
self.register_buffer("N_cells_per_location", torch.tensor(N_cells_per_location))
self.register_buffer("factors_per_groups", torch.tensor(factors_per_groups))
self.register_buffer("B_groups_per_location", torch.tensor(B_groups_per_location))
self.register_buffer("A_residual_factors_per_location", torch.tensor(A_residual_factors_per_location))
self.register_buffer("N_cells_mean_var_ratio", torch.tensor(N_cells_mean_var_ratio))

self.register_buffer(
"factor_states_per_gene",
torch.tensor(self.factor_prior["states_per_gene"]),
)
self.register_buffer(
"factor_prior_alpha",
torch.tensor(self.factor_prior["alpha"]),
)
self.register_buffer(
"factor_prior_beta",
torch.tensor(self.factor_prior["alpha"] / self.factor_prior["rate"]),
)

self.register_buffer(
"alpha_g_phi_hyp_prior_alpha",
torch.tensor(self.alpha_g_phi_hyp_prior["alpha"]),
Expand Down Expand Up @@ -190,9 +214,11 @@ def __init__(

self.register_buffer("n_factors_tensor", torch.tensor(self.n_factors))
self.register_buffer("n_groups_tensor", torch.tensor(self.n_groups))
self.register_buffer("n_residual_factors_tensor", torch.tensor(self.n_residual_factors))

self.register_buffer("ones", torch.ones((1, 1)))
self.register_buffer("ones_1_n_groups", torch.ones((1, self.n_groups)))
self.register_buffer("ones_1_n_residual_factors", torch.ones((1, self.n_residual_factors)))
self.register_buffer("ones_n_batch_1", torch.ones((self.n_batch, 1)))
self.register_buffer("eps", torch.tensor(1e-8))

Expand Down Expand Up @@ -237,6 +263,8 @@ def list_obs_plate_vars(self):
"z_sr_groups_factors": self.n_groups,
"w_sf": self.n_factors,
"detection_y_s": 1,
"b_s_residual_factors_per_location": 1,
"w_sf_residual_factors": self.n_residual_factors,
},
}

Expand Down Expand Up @@ -448,19 +476,67 @@ def forward(self, x_data, idx, batch_index):
) # (self.n_batch, n_vars)

# =====================Gene-specific overdispersion ======================= #
alpha_g_phi_hyp = pyro.sample(
"alpha_g_phi_hyp",
dist.Gamma(self.ones * self.alpha_g_phi_hyp_prior_alpha, self.ones * self.alpha_g_phi_hyp_prior_beta),
)
alpha_g_inverse = pyro.sample(
"alpha_g_inverse",
dist.Exponential(alpha_g_phi_hyp).expand([self.n_batch, self.n_vars]).to_event(2),
) # (self.n_batch, self.n_vars)
if not self.use_residual_factors:
alpha_g_phi_hyp = pyro.sample(
"alpha_g_phi_hyp",
dist.Gamma(self.ones * self.alpha_g_phi_hyp_prior_alpha, self.ones * self.alpha_g_phi_hyp_prior_beta),
)
alpha_g_inverse = pyro.sample(
"alpha_g_inverse",
dist.Exponential(alpha_g_phi_hyp).expand([self.n_batch, self.n_vars]).to_event(2),
) # (self.n_batch, self.n_vars)

# =======Residual cell abundances w_sf and loadings g_fg ======= #
if self.use_residual_factors:
with obs_plate as ind:
k = "b_s_residual_factors_per_location"
b_s_residual_factors_per_location = pyro.sample(
k,
dist.Gamma(self.A_residual_factors_per_location, self.ones),
)
# location loadings
shape = self.ones_1_n_residual_factors * b_s_residual_factors_per_location / self.n_residual_factors_tensor
rate = self.ones_1_n_residual_factors / (n_s_cells_per_location / b_s_residual_factors_per_location)
with obs_plate as ind:
k = "w_sf_residual_factors"
w_sf_residual_factors = pyro.sample(
k,
dist.Gamma(shape, rate),
) # (n_obs, n_groups)
# g_{f,g}
residual_factor_level_g = pyro.sample(
"residual_factor_level_g",
dist.Gamma(self.factor_prior_alpha, self.factor_prior_beta).expand([1, self.n_vars]).to_event(2),
obs=getattr(self, "fixed_val_residual_factor_level_g", None),
)
g_fg_residual_factors = pyro.sample(
"g_fg_residual_factors",
dist.Gamma(
self.factor_states_per_gene / self.n_residual_factors_tensor,
self.ones / residual_factor_level_g,
)
.expand([self.n_residual_factors, self.n_vars])
.to_event(2),
obs=getattr(self, "fixed_val_g_fg_residual_factors", None),
)

# Gene-specific overdispersion
alpha_g_phi_hyp = pyro.sample(
"alpha_residual_g_phi_hyp",
dist.Gamma(self.ones * self.alpha_g_phi_hyp_prior_alpha, self.ones * self.alpha_g_phi_hyp_prior_beta),
)
alpha_g_inverse = pyro.sample(
"alpha_residual_g_inverse",
dist.Exponential(alpha_g_phi_hyp).expand([self.n_batch, self.n_vars]).to_event(2),
) # (self.n_batch, self.n_vars)

# =====================Expected expression ======================= #
if not self.training_wo_observed:
# expected expression
mu = ((w_sf @ self.cell_state) * m_g + (obs2sample @ s_g_gene_add)) * detection_y_s
mu_biol = w_sf @ self.cell_state
if self.use_residual_factors:
mu_biol = mu_biol + w_sf_residual_factors @ g_fg_residual_factors
mu = (mu_biol * m_g + (obs2sample @ s_g_gene_add)) * detection_y_s
alpha = obs2sample @ (self.ones / alpha_g_inverse.pow(2))
# convert mean and overdispersion to total count and logits
# total_count, logits = _convert_mean_disp_to_counts_logits(
Expand Down Expand Up @@ -494,10 +570,12 @@ def compute_expected(self, samples, adata_manager, ind_x=None):
ind_x = ind_x.astype(int)
obs2sample = adata_manager.get_from_registry(REGISTRY_KEYS.BATCH_KEY)
obs2sample = pd.get_dummies(obs2sample.flatten()).values[ind_x, :]
mu = (
np.dot(samples["w_sf"][ind_x, :], self.cell_state_mat.T) * samples["m_g"]
+ np.dot(obs2sample, samples["s_g_gene_add"])
) * samples["detection_y_s"][ind_x, :]
mu_biol = np.dot(samples["w_sf"][ind_x, :], self.cell_state_mat.T)
if self.use_residual_factors:
mu_biol = mu_biol + np.dot(samples["w_sf_residual_factors"][ind_x, :], samples["g_fg_residual_factors"])
mu = (mu_biol * samples["m_g"] + np.dot(obs2sample, samples["s_g_gene_add"])) * samples["detection_y_s"][
ind_x, :
]
alpha = np.dot(obs2sample, 1 / np.power(samples["alpha_g_inverse"], 2))

return {"mu": mu, "alpha": alpha, "ind_x": ind_x}
Expand Down Expand Up @@ -536,9 +614,10 @@ def compute_expected_per_cell_type(self, samples, adata_manager, ind_x=None):
# compute total expected expression
obs2sample = adata_manager.get_from_registry(REGISTRY_KEYS.BATCH_KEY)
obs2sample = pd.get_dummies(obs2sample.flatten()).values[ind_x, :]
mu = np.dot(samples["w_sf"][ind_x, :], self.cell_state_mat.T) * samples["m_g"] + np.dot(
obs2sample, samples["s_g_gene_add"]
)
mu_biol = np.dot(samples["w_sf"][ind_x, :], self.cell_state_mat.T)
if self.use_residual_factors:
mu_biol = mu_biol + np.dot(samples["w_sf_residual_factors"][ind_x, :], samples["g_fg_residual_factors"])
mu = mu_biol * samples["m_g"] + np.dot(obs2sample, samples["s_g_gene_add"])

# compute conditional expected expression per cell type
mu_ct = [
Expand Down
4 changes: 4 additions & 0 deletions tests/test_cell2location.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def test_cell2location():
st_model = Cell2location(dataset, cell_state_df=inf_aver, N_cells_per_location=30, detection_alpha=200)
# test full data training
st_model.train(max_epochs=1, use_gpu=use_gpu)
# test full data training with residual factors
st_model.module.model.use_residual_factors = True
st_model.train(max_epochs=1, use_gpu=use_gpu)
st_model.module.model.use_residual_factors = False
# export the estimated cell abundance (summary of the posterior distribution)
# full data
dataset = st_model.export_posterior(dataset, sample_kwargs={"num_samples": 10, "batch_size": st_model.adata.n_obs})
Expand Down