From a88e1e37a7c646be986ceb55ffb58f05331ada87 Mon Sep 17 00:00:00 2001 From: Jean El Khoury Date: Mon, 11 Jul 2022 16:16:07 +0200 Subject: [PATCH] fix kl div --- src/probabilistic_dag_model/probabilistic_dag_autoencoder.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/probabilistic_dag_model/probabilistic_dag_autoencoder.py b/src/probabilistic_dag_model/probabilistic_dag_autoencoder.py index 6de4fbf..fcb63ba 100644 --- a/src/probabilistic_dag_model/probabilistic_dag_autoencoder.py +++ b/src/probabilistic_dag_model/probabilistic_dag_autoencoder.py @@ -93,7 +93,10 @@ def ELBO_loss(self, X_pred, X): if self.regr > 0: kl_loss = torch.nn.KLDivLoss(reduction='mean') - regularizer = kl_loss(self.probabilistic_dag.edge_log_params, self.prior_p * torch.ones_like(self.probabilistic_dag.edge_log_params)) + logsigmoid = torch.nn.LogSigmoid() + ones = torch.ones_like(self.probabilistic_dag.edge_log_params) + regularizer = kl_loss(logsigmoid(self.probabilistic_dag.edge_log_params, self.prior_p * ones)) #KLDivLoss expects logits as first argument + regularizer += kl_loss(torch.log(ones - torch.sigmoid(self.probabilistic_dag.edge_log_params)), (1 - self.prior_p) * ones) ELBO_loss = ELBO_loss + self.regr * regularizer return ELBO_loss