Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ A web application for the ensemble is available at https://chebifier.hastingslab

## Installation

Note: `chebai-graph` and its dependencies cannot be installed automatically. To install it, follow
the instructions in the [chebai-graph repository](https://github.com/ChEB-AI/python-chebai-graph). Other dependencies are installed automatically.

You can get the package from PyPI:
```bash
pip install chebifier
Expand All @@ -21,9 +24,6 @@ cd python-chebifier
pip install -e .
```

`chebai-graph` and its dependencies cannot be installed automatically. If you want to use Graph Neural Networks, follow
the instructions in the [chebai-graph repository](https://github.com/ChEB-AI/python-chebai-graph).

## Usage

### Command Line Interface
Expand Down
43 changes: 10 additions & 33 deletions chebifier/ensemble/base_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import torch
import tqdm
from chebai.preprocessing.datasets.chebi import ChEBIOver50
from chebai.result.analyse_sem import PredictionSmoother, get_chebi_graph
from chebifier.inconsistency_resolution import PredictionSmoother
from chebifier.utils import load_chebi_graph, get_disjoint_files

from chebifier.check_env import check_package_installed
from chebifier.prediction_models.base_predictor import BasePredictor
Expand All @@ -21,32 +21,8 @@ def __init__(
# Deferred Import: To avoid circular import error
from chebifier.model_registry import MODEL_TYPES

self.chebi_dataset = ChEBIOver50(chebi_version=chebi_version)
self.chebi_dataset._download_required_data() # download chebi if not already downloaded
self.chebi_graph = get_chebi_graph(self.chebi_dataset, None)
local_disjoint_files = [
os.path.join("data", "disjoint_chebi.csv"),
os.path.join("data", "disjoint_additional.csv"),
]
self.disjoint_files = []
for file in local_disjoint_files:
if os.path.isfile(file):
self.disjoint_files.append(file)
else:
print(
f"Disjoint axiom file {file} not found. Loading from huggingface instead..."
)
from chebifier.hugging_face import download_model_files

self.disjoint_files.append(
download_model_files(
{
"repo_id": "chebai/chebifier",
"repo_type": "dataset",
"files": {"disjoint_file": os.path.basename(file)},
}
)["disjoint_file"]
)
self.chebi_graph = load_chebi_graph()
self.disjoint_files = get_disjoint_files()

self.models = []
self.positive_prediction_threshold = 0.5
Expand All @@ -72,7 +48,7 @@ def __init__(

if resolve_inconsistencies:
self.smoother = PredictionSmoother(
self.chebi_dataset,
self.chebi_graph,
label_names=None,
disjoint_files=self.disjoint_files,
)
Expand Down Expand Up @@ -203,10 +179,11 @@ def predict_smiles_list(
"Warning: No classes have been predicted for the given SMILES list."
)
# save predictions
torch.save(ordered_predictions, preds_file)
with open(predicted_classes_file, "w") as f:
for cls in predicted_classes:
f.write(f"{cls}\n")
if load_preds_if_possible:
torch.save(ordered_predictions, preds_file)
with open(predicted_classes_file, "w") as f:
for cls in predicted_classes:
f.write(f"{cls}\n")
predicted_classes = {cls: i for i, cls in enumerate(predicted_classes)}
else:
print(
Expand Down
124 changes: 124 additions & 0 deletions chebifier/inconsistency_resolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import csv
import os
import torch
from pathlib import Path


def get_disjoint_groups(disjoint_files):
if disjoint_files is None:
disjoint_files = os.path.join("data", "chebi-disjoints.owl")
disjoint_pairs, disjoint_groups = [], []
for file in disjoint_files:
if isinstance(file, Path):
file = str(file)
if file.endswith(".csv"):
with open(file, "r") as f:
reader = csv.reader(f)
disjoint_pairs += [line for line in reader]
elif file.endswith(".owl"):
with open(file, "r") as f:
plaintext = f.read()
segments = plaintext.split("<")
disjoint_pairs = []
left = None
for seg in segments:
if seg.startswith("rdf:Description ") or seg.startswith(
"owl:Class"
):
left = int(seg.split('rdf:about="&obo;CHEBI_')[1].split('"')[0])
elif seg.startswith("owl:disjointWith"):
right = int(
seg.split('rdf:resource="&obo;CHEBI_')[1].split('"')[0]
)
disjoint_pairs.append([left, right])

disjoint_groups = []
for seg in plaintext.split("<rdf:Description>"):
if "owl;AllDisjointClasses" in seg:
classes = seg.split('rdf:about="&obo;CHEBI_')[1:]
classes = [int(c.split('"')[0]) for c in classes]
disjoint_groups.append(classes)
else:
raise NotImplementedError(
"Unsupported disjoint file format: " + file.split(".")[-1]
)

disjoint_all = disjoint_pairs + disjoint_groups
# one disjointness is commented out in the owl-file
# (the correct way would be to parse the owl file and notice the comment symbols, but for this case, it should work)
if [22729, 51880] in disjoint_all:
disjoint_all.remove([22729, 51880])
# print(f"Found {len(disjoint_all)} disjoint groups")
return disjoint_all


class PredictionSmoother:
"""Removes implication and disjointness violations from predictions"""

def __init__(self, chebi_graph, label_names=None, disjoint_files=None):
self.chebi_graph = chebi_graph
self.set_label_names(label_names)
self.disjoint_groups = get_disjoint_groups(disjoint_files)

def set_label_names(self, label_names):
if label_names is not None:
self.label_names = label_names
chebi_subgraph = self.chebi_graph.subgraph(self.label_names)
self.label_successors = torch.zeros(
(len(self.label_names), len(self.label_names)), dtype=torch.bool
)
for i, label in enumerate(self.label_names):
self.label_successors[i, i] = 1
for p in chebi_subgraph.successors(label):
if p in self.label_names:
self.label_successors[i, self.label_names.index(p)] = 1
self.label_successors = self.label_successors.unsqueeze(0)

def __call__(self, preds):
if preds.shape[1] == 0:
# no labels predicted
return preds
# preds shape: (n_samples, n_labels)
preds_sum_orig = torch.sum(preds)
# step 1: apply implications: for each class, set prediction to max of itself and all successors
preds = preds.unsqueeze(1)
preds_masked_succ = torch.where(self.label_successors, preds, 0)
# preds_masked_succ shape: (n_samples, n_labels, n_labels)

preds = preds_masked_succ.max(dim=2).values
if torch.sum(preds) != preds_sum_orig:
print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}")
preds_sum_orig = torch.sum(preds)
# step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower)
preds_bounded = torch.min(preds, torch.ones_like(preds) * 0.49)
for disj_group in self.disjoint_groups:
disj_group = [
self.label_names.index(g) for g in disj_group if g in self.label_names
]
if len(disj_group) > 1:
old_preds = preds[:, disj_group]
disj_max = torch.max(preds[:, disj_group], dim=1)
for i, row in enumerate(preds):
for l_ in range(len(preds[i])):
if l_ in disj_group and l_ != disj_group[disj_max.indices[i]]:
preds[i, l_] = preds_bounded[i, l_]
samples_changed = 0
for i, row in enumerate(preds[:, disj_group]):
if any(r != o for r, o in zip(row, old_preds[i])):
samples_changed += 1
if samples_changed != 0:
print(
f"disjointness group {[self.label_names[d] for d in disj_group]} changed {samples_changed} samples"
)
if torch.sum(preds) != preds_sum_orig:
print(f"Preds change (step 2): {torch.sum(preds) - preds_sum_orig}")
preds_sum_orig = torch.sum(preds)
# step 3: disjointness violation removal may have caused new implication inconsistencies -> set each prediction to min of predecessors
preds = preds.unsqueeze(1)
preds_masked_predec = torch.where(
torch.transpose(self.label_successors, 1, 2), preds, 1
)
preds = preds_masked_predec.min(dim=2).values
if torch.sum(preds) != preds_sum_orig:
print(f"Preds change (step 3): {torch.sum(preds) - preds_sum_orig}")
return preds
2 changes: 1 addition & 1 deletion chebifier/prediction_models/c3p_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
chebi_id
] = result.is_match
if result.is_match and self.chebi_graph is not None:
for parent in list(self.chebi_graph.predecessors(int(chebi_id))):
for parent in list(self.chebi_graph.predecessors(chebi_id)):
result_reformatted[smiles_list.index(result.input_smiles)][
str(parent)
] = 1
Expand Down
4 changes: 2 additions & 2 deletions chebifier/prediction_models/chemlog_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
sample_additions = dict()
for cls in sample:
if sample[cls] == 1:
successors = list(self.chebi_graph.predecessors(int(cls)))
successors = list(self.chebi_graph.predecessors(cls))
if successors:
for succ in successors:
sample_additions[str(succ)] = 1
Expand Down Expand Up @@ -114,7 +114,7 @@ def predict_smiles(self, smiles: str) -> Optional[dict]:
indirect_pos_labels = [
str(pr)
for label in pos_labels
for pr in self.chebi_graph.predecessors(int(label))
for pr in self.chebi_graph.predecessors(label)
]
pos_labels = list(set(pos_labels + indirect_pos_labels))
return {
Expand Down
131 changes: 131 additions & 0 deletions chebifier/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import os

import networkx as nx
import requests
import fastobo
from chebifier.hugging_face import download_model_files
import pickle


def load_chebi_graph(filename=None):
"""Load ChEBI graph from Hugging Face (if filename is None) or local file"""
if filename is None:
print("Loading ChEBI graph from Hugging Face...")
file = download_model_files(
{
"repo_id": "chebai/chebifier",
"repo_type": "dataset",
"files": {"f": "chebi_graph.pkl"},
}
)["f"]
else:
print(f"Loading ChEBI graph from local {filename}...")
file = filename
return pickle.load(open(file, "rb"))


def term_callback(doc):
"""Similar to the chebai function, but reduced to the necessary fields. Also, ChEBI IDs are strings"""
parents = []
name = None
smiles = None
for clause in doc:
if isinstance(clause, fastobo.term.PropertyValueClause):
t = clause.property_value
if str(t.relation) == "http://purl.obolibrary.org/obo/chebi/smiles":
assert smiles is None
smiles = t.value
# in older chebi versions, smiles strings are synonyms
# e.g. synonym: "[F-].[Na+]" RELATED SMILES [ChEBI]
elif isinstance(clause, fastobo.term.SynonymClause):
if "SMILES" in clause.raw_value():
assert smiles is None
smiles = clause.raw_value().split('"')[1]
elif isinstance(clause, fastobo.term.IsAClause):
chebi_id = str(clause.term)
chebi_id = chebi_id[chebi_id.index(":") + 1 :]
parents.append(chebi_id)
elif isinstance(clause, fastobo.term.NameClause):
name = str(clause.name)

if isinstance(clause, fastobo.term.IsObsoleteClause):
if clause.obsolete:
# if the term document contains clause as obsolete as true, skips this document.
return False
chebi_id = str(doc.id)
chebi_id = chebi_id[chebi_id.index(":") + 1 :]
return {
"id": chebi_id,
"parents": parents,
"name": name,
"smiles": smiles,
}


def build_chebi_graph(chebi_version=241):
"""Creates a networkx graph for the ChEBI hierarchy. Usually, you don't want to call this function directly, but rather use the `load_chebi_graph` function."""
chebi_path = os.path.join("data", f"chebi_v{chebi_version}", "chebi.obo")
os.makedirs(os.path.join("data", f"chebi_v{chebi_version}"), exist_ok=True)
if not os.path.exists(chebi_path):
url = f"http://purl.obolibrary.org/obo/chebi/{chebi_version}/chebi.obo"
r = requests.get(url, allow_redirects=True)
open(chebi_path, "wb").write(r.content)
with open(chebi_path, encoding="utf-8") as chebi:
chebi = "\n".join(line for line in chebi if not line.startswith("xref:"))

elements = []
for term_doc in fastobo.loads(chebi):
if (
term_doc
and isinstance(term_doc.id, fastobo.id.PrefixedIdent)
and term_doc.id.prefix == "CHEBI"
):
term_dict = term_callback(term_doc)
if term_dict:
elements.append(term_dict)

g = nx.DiGraph()
for n in elements:
g.add_node(n["id"], **n)

# Only take the edges which connect the existing nodes, to avoid internal creation of obsolete nodes
# https://github.com/ChEB-AI/python-chebai/pull/55#issuecomment-2386654142
g.add_edges_from(
[(p, q["id"]) for q in elements for p in q["parents"] if g.has_node(p)]
)
return nx.transitive_closure_dag(g)


def get_disjoint_files():
"""Gets local disjointness files if they are present in the right location, otherwise downloads them from Hugging Face."""
local_disjoint_files = [
os.path.join("data", "disjoint_chebi.csv"),
os.path.join("data", "disjoint_additional.csv"),
]
disjoint_files = []
for file in local_disjoint_files:
if os.path.isfile(file):
disjoint_files.append(file)
else:
print(
f"Disjoint axiom file {file} not found. Loading from huggingface instead..."
)

disjoint_files.append(
download_model_files(
{
"repo_id": "chebai/chebifier",
"repo_type": "dataset",
"files": {"disjoint_file": os.path.basename(file)},
}
)["disjoint_file"]
)
return disjoint_files


if __name__ == "__main__":
# chebi_graph = build_chebi_graph(chebi_version=241)
# save the graph to a file
# pickle.dump(chebi_graph, open("chebi_graph.pkl", "wb"))
chebi_graph = load_chebi_graph()
print(chebi_graph)
Loading