Skip to content

Commit 416a395

Browse files
committed
adapt aug props for ablation data classes
1 parent 7560fec commit 416a395

File tree

4 files changed

+66
-48
lines changed

4 files changed

+66
-48
lines changed

chebai_graph/preprocessing/properties/augmented_properties.py

Lines changed: 57 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
from abc import ABC
23
from typing import Dict, List, Optional
34

@@ -13,6 +14,15 @@
1314
from . import properties as pr
1415
from .base import AtomProperty, BondProperty, FrozenPropertyAlias
1516

17+
# For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order
18+
# https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights
19+
# https://mail.python.org/pipermail/python-dev/2017-December/151283.html
20+
assert sys.version_info >= (
21+
3,
22+
7,
23+
), "This code requires Python 3.7 or higher."
24+
# Order preservation is necessary to to create `prop_list`
25+
1626

1727
# --------------------- Atom Properties -----------------------------
1828
class AugmentedAtomProperty(AtomProperty, ABC):
@@ -24,9 +34,7 @@ def get_property_value(self, augmented_mol: Dict):
2434
f"Key `{self.MAIN_KEY}` should be present in augmented molecule dict"
2535
)
2636

27-
missing_keys = {"atom_nodes", "fg_nodes", "graph_node"} - augmented_mol[
28-
self.MAIN_KEY
29-
].keys()
37+
missing_keys = {"atom_nodes"} - augmented_mol[self.MAIN_KEY].keys()
3038
if missing_keys:
3139
raise KeyError(f"Missing keys {missing_keys} in augmented molecule nodes")
3240

@@ -35,26 +43,29 @@ def get_property_value(self, augmented_mol: Dict):
3543
raise TypeError(
3644
f'augmented_mol["{self.MAIN_KEY}"]["atom_nodes"] must be an instance of rdkit.Chem.Mol'
3745
)
38-
3946
prop_list = [self.get_atom_value(atom) for atom in atom_molecule.GetAtoms()]
4047

41-
fg_nodes = augmented_mol[self.MAIN_KEY]["fg_nodes"]
42-
graph_node = augmented_mol[self.MAIN_KEY]["graph_node"]
43-
if not isinstance(fg_nodes, dict) or not isinstance(graph_node, dict):
44-
raise TypeError(
45-
f'augmented_mol["{self.MAIN_KEY}"](["fg_nodes"]/["graph_node"]) must be an instance of dict '
46-
f"containing its properties"
47-
)
48+
if "fg_nodes" in augmented_mol[self.MAIN_KEY]:
49+
fg_nodes = augmented_mol[self.MAIN_KEY]["fg_nodes"]
50+
if not isinstance(fg_nodes, dict):
51+
raise TypeError(
52+
f'augmented_mol["{self.MAIN_KEY}"](["fg_nodes"]) must be an instance of dict '
53+
f"containing its properties"
54+
)
55+
prop_list.extend([self.get_atom_value(atom) for atom in fg_nodes.values()])
56+
57+
if "graph_node" in augmented_mol[self.MAIN_KEY]:
58+
graph_node = augmented_mol[self.MAIN_KEY]["graph_node"]
59+
if not isinstance(graph_node, dict):
60+
raise TypeError(
61+
f'augmented_mol["{self.MAIN_KEY}"](["graph_node"]) must be an instance of dict '
62+
f"containing its properties"
63+
)
64+
prop_list.append(self.get_atom_value(graph_node))
4865

49-
# For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order
50-
# https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights
51-
# https://mail.python.org/pipermail/python-dev/2017-December/151283.html
52-
prop_list.extend([self.get_atom_value(atom) for atom in fg_nodes.values()])
53-
prop_list.append(self.get_atom_value(graph_node))
5466
assert (
5567
len(prop_list) == augmented_mol[self.MAIN_KEY]["num_nodes"]
5668
), "Number of property values should be equal to number of nodes"
57-
5869
return prop_list
5970

6071
def _check_modify_atom_prop_value(self, atom: Chem.rdchem.Atom | Dict, prop: str):
@@ -228,7 +239,7 @@ def get_property_value(self, augmented_mol: Dict) -> List:
228239
f"Key `{self.MAIN_KEY}` should be present in augmented molecule dict"
229240
)
230241

231-
missing_keys = k.EDGE_LEVELS - augmented_mol[self.MAIN_KEY].keys()
242+
missing_keys = {k.WITHIN_ATOMS_EDGE} - augmented_mol[self.MAIN_KEY].keys()
232243
if missing_keys:
233244
raise KeyError(f"Missing keys {missing_keys} in augmented molecule nodes")
234245

@@ -237,31 +248,38 @@ def get_property_value(self, augmented_mol: Dict) -> List:
237248
raise TypeError(
238249
f'augmented_mol["{self.MAIN_KEY}"]["{k.WITHIN_ATOMS_EDGE}"] must be an instance of rdkit.Chem.Mol'
239250
)
240-
241251
prop_list = [self.get_bond_value(bond) for bond in atom_molecule.GetBonds()]
242252

243-
fg_atom_edges = augmented_mol[self.MAIN_KEY][k.ATOM_FG_EDGE]
244-
fg_edges = augmented_mol[self.MAIN_KEY][k.WITHIN_FG_EDGE]
245-
fg_graph_node_edges = augmented_mol[self.MAIN_KEY][k.FG_GRAPHNODE_EDGE]
246-
247-
if (
248-
not isinstance(fg_atom_edges, dict)
249-
or not isinstance(fg_edges, dict)
250-
or not isinstance(fg_graph_node_edges, dict)
251-
):
252-
raise TypeError(
253-
f'augmented_mol["{self.MAIN_KEY}"](["{k.ATOM_FG_EDGE}"]/["{k.WITHIN_FG_EDGE}"]/["{k.FG_GRAPHNODE_EDGE}"]) '
254-
f"must be an instance of dict containing its properties"
253+
if k.ATOM_FG_EDGE in augmented_mol[self.MAIN_KEY]:
254+
fg_atom_edges = augmented_mol[self.MAIN_KEY][k.ATOM_FG_EDGE]
255+
if not isinstance(fg_atom_edges, dict):
256+
raise TypeError(
257+
f"augmented_mol['{self.MAIN_KEY}'](['{k.ATOM_FG_EDGE}'])"
258+
f"must be an instance of dict containing its properties"
259+
)
260+
prop_list.extend(
261+
[self.get_bond_value(bond) for bond in fg_atom_edges.values()]
255262
)
256263

257-
# For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order
258-
# https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights
259-
# https://mail.python.org/pipermail/python-dev/2017-December/151283.html
260-
prop_list.extend([self.get_bond_value(bond) for bond in fg_atom_edges.values()])
261-
prop_list.extend([self.get_bond_value(bond) for bond in fg_edges.values()])
262-
prop_list.extend(
263-
[self.get_bond_value(bond) for bond in fg_graph_node_edges.values()]
264-
)
264+
if k.WITHIN_FG_EDGE in augmented_mol[self.MAIN_KEY]:
265+
fg_edges = augmented_mol[self.MAIN_KEY][k.WITHIN_FG_EDGE]
266+
if not isinstance(fg_edges, dict):
267+
raise TypeError(
268+
f"augmented_mol['{self.MAIN_KEY}'](['{k.WITHIN_FG_EDGE}'])"
269+
f"must be an instance of dict containing its properties"
270+
)
271+
prop_list.extend([self.get_bond_value(bond) for bond in fg_edges.values()])
272+
273+
if k.TO_GRAPHNODE_EDGE in augmented_mol[self.MAIN_KEY]:
274+
fg_graph_node_edges = augmented_mol[self.MAIN_KEY][k.TO_GRAPHNODE_EDGE]
275+
if not isinstance(fg_graph_node_edges, dict):
276+
raise TypeError(
277+
f"augmented_mol['{self.MAIN_KEY}'](['{k.TO_GRAPHNODE_EDGE}'])"
278+
f"must be an instance of dict containing its properties"
279+
)
280+
prop_list.extend(
281+
[self.get_bond_value(bond) for bond in fg_graph_node_edges.values()]
282+
)
265283

266284
num_directed_edges = augmented_mol[self.MAIN_KEY][k.NUM_EDGES] // 2
267285
assert (

chebai_graph/preprocessing/properties/constants.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@
88
WITHIN_ATOMS_EDGE = "within_atoms_lvl"
99
WITHIN_FG_EDGE = "within_fg_lvl"
1010
ATOM_FG_EDGE = "atom_fg_lvl"
11-
FG_GRAPHNODE_EDGE = "fg_graphNode_lvl"
12-
EDGE_LEVELS = {WITHIN_ATOMS_EDGE, WITHIN_FG_EDGE, ATOM_FG_EDGE, FG_GRAPHNODE_EDGE}
11+
TO_GRAPHNODE_EDGE = "fg_graphNode_lvl"
12+
EDGE_LEVELS = {WITHIN_ATOMS_EDGE, WITHIN_FG_EDGE, ATOM_FG_EDGE, TO_GRAPHNODE_EDGE}
1313
NUM_EDGES = "num_undirected_edges"

chebai_graph/preprocessing/reader/augmented_reader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ def _add_graph_node_and_edges_to_nodes(
598598
self._construct_nodes_to_graph_node_structure(nodes_ids)
599599
)
600600

601-
augmented_struct["edge_info"][k.FG_GRAPHNODE_EDGE] = nodes_to_graph_edges
601+
augmented_struct["edge_info"][k.TO_GRAPHNODE_EDGE] = nodes_to_graph_edges
602602
augmented_struct["edge_info"][k.NUM_EDGES] += len(nodes_to_graph_edges)
603603
assert (
604604
self._idx_of_edge == augmented_struct["edge_info"][k.NUM_EDGES]
@@ -643,7 +643,7 @@ def _construct_nodes_to_graph_node_structure(
643643
graph_edge_index[0].append(self._idx_of_node)
644644
graph_edge_index[1].append(fg_id)
645645
graph_to_nodes_edges[f"{self._idx_of_node}_{fg_id}"] = {
646-
k.EDGE_LEVEL: k.FG_GRAPHNODE_EDGE
646+
k.EDGE_LEVEL: k.TO_GRAPHNODE_EDGE
647647
}
648648
self._idx_of_edge += 1
649649
self._idx_of_node += 1

chebai_graph/preprocessing/utils/visualize_augmented_molecule.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
k.WITHIN_ATOMS_EDGE: "#1f77b4",
2626
k.ATOM_FG_EDGE: "#9467bd",
2727
k.WITHIN_FG_EDGE: "#ff7f0e",
28-
k.FG_GRAPHNODE_EDGE: "#2ca02c",
28+
k.TO_GRAPHNODE_EDGE: "#2ca02c",
2929
}
3030

3131
NODE_COLOR_MAP = {
@@ -119,8 +119,8 @@ def _create_graph(
119119
else set()
120120
)
121121
fg_graph_edges = (
122-
set(augmented_graph_edges[k.FG_GRAPHNODE_EDGE])
123-
if k.FG_GRAPHNODE_EDGE in augmented_graph_edges
122+
set(augmented_graph_edges[k.TO_GRAPHNODE_EDGE])
123+
if k.TO_GRAPHNODE_EDGE in augmented_graph_edges
124124
else set()
125125
)
126126

@@ -133,7 +133,7 @@ def _create_graph(
133133
elif undirected_edge_set & within_fg_edges:
134134
edge_type = k.WITHIN_FG_EDGE
135135
elif undirected_edge_set & fg_graph_edges:
136-
edge_type = k.FG_GRAPHNODE_EDGE
136+
edge_type = k.TO_GRAPHNODE_EDGE
137137
else:
138138
raise ValueError("Unexpected edge type")
139139
G.add_edge(src, tgt, edge_type=edge_type, edge_color=EDGE_COLOR_MAP[edge_type])
@@ -318,7 +318,7 @@ def _draw_3d(G: nx.Graph, mol: Mol) -> None:
318318
k.WITHIN_ATOMS_EDGE: [],
319319
k.ATOM_FG_EDGE: [],
320320
k.WITHIN_FG_EDGE: [],
321-
k.FG_GRAPHNODE_EDGE: [],
321+
k.TO_GRAPHNODE_EDGE: [],
322322
}
323323
for src, tgt, data in G.edges(data=True):
324324
edge_type_to_edges[data["edge_type"]].append((src, tgt))

0 commit comments

Comments
 (0)