diff --git a/chebai/result/analyse_sem.py b/chebai/result/analyse_sem.py index 19276a66..0682ab4e 100644 --- a/chebai/result/analyse_sem.py +++ b/chebai/result/analyse_sem.py @@ -217,10 +217,16 @@ def set_label_names(self, label_names): self.label_successors = self.label_successors.unsqueeze(0) def __call__(self, preds): + if preds.shape[1] == 0: + # no labels predicted + return preds + # preds shape: (n_samples, n_labels) preds_sum_orig = torch.sum(preds) # step 1: apply implications: for each class, set prediction to max of itself and all successors preds = preds.unsqueeze(1) preds_masked_succ = torch.where(self.label_successors, preds, 0) + # preds_masked_succ shape: (n_samples, n_labels, n_labels) + preds = preds_masked_succ.max(dim=2).values if torch.sum(preds) != preds_sum_orig: print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}") diff --git a/pyproject.toml b/pyproject.toml index e82d062f..a2aa136a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "chebai" -version = "1.0.2" +version = "1.0.3" description = "ChEBai is a deep learning library designed for the integration of deep learning methods with chemical ontologies, particularly ChEBI." authors = [ { name = "MGlauer", email = "martin.glauer@ovgu.de" }