diff --git a/src/probabilistic_dag_model/probabilistic_dag_autoencoder.py b/src/probabilistic_dag_model/probabilistic_dag_autoencoder.py index 6de4fbf..fcb63ba 100644 --- a/src/probabilistic_dag_model/probabilistic_dag_autoencoder.py +++ b/src/probabilistic_dag_model/probabilistic_dag_autoencoder.py @@ -93,7 +93,10 @@ def ELBO_loss(self, X_pred, X): if self.regr > 0: kl_loss = torch.nn.KLDivLoss(reduction='mean') - regularizer = kl_loss(self.probabilistic_dag.edge_log_params, self.prior_p * torch.ones_like(self.probabilistic_dag.edge_log_params)) + logsigmoid = torch.nn.LogSigmoid() + ones = torch.ones_like(self.probabilistic_dag.edge_log_params) + regularizer = kl_loss(logsigmoid(self.probabilistic_dag.edge_log_params, self.prior_p * ones)) #KLDivLoss expects logits as first argument + regularizer += kl_loss(torch.log(ones - torch.sigmoid(self.probabilistic_dag.edge_log_params)), (1 - self.prior_p) * ones) ELBO_loss = ELBO_loss + self.regr * regularizer return ELBO_loss