|
10 | 10 | from rdkit.Chem.Draw import rdMolDraw2D |
11 | 11 | from torch import Tensor |
12 | 12 |
|
13 | | -from chebai_graph.preprocessing.properties.constants import * |
| 13 | +from chebai_graph.preprocessing.properties import constants as k |
14 | 14 | from chebai_graph.preprocessing.reader import GraphFGAugmentorReader |
15 | 15 |
|
16 | 16 | matplotlib.use("TkAgg") |
17 | 17 |
|
18 | 18 | EDGE_COLOR_MAP = { |
19 | | - WITHIN_ATOMS_EDGE: "#1f77b4", |
20 | | - ATOM_FG_EDGE: "#9467bd", |
21 | | - WITHIN_FG_EDGE: "#ff7f0e", |
22 | | - FG_GRAPHNODE_EDGE: "#2ca02c", |
| 19 | + k.WITHIN_ATOMS_EDGE: "#1f77b4", |
| 20 | + k.ATOM_FG_EDGE: "#9467bd", |
| 21 | + k.WITHIN_FG_EDGE: "#ff7f0e", |
| 22 | + k.FG_GRAPHNODE_EDGE: "#2ca02c", |
23 | 23 | } |
24 | 24 |
|
25 | 25 | NODE_COLOR_MAP = { |
@@ -90,22 +90,22 @@ def _create_graph( |
90 | 90 | src_nodes, tgt_nodes = edge_index.tolist() |
91 | 91 | with_atom_edges = { |
92 | 92 | f"{bond.GetBeginAtomIdx()}_{bond.GetEndAtomIdx()}" |
93 | | - for bond in augmented_graph_edges[WITHIN_ATOMS_EDGE].GetBonds() |
| 93 | + for bond in augmented_graph_edges[k.WITHIN_ATOMS_EDGE].GetBonds() |
94 | 94 | } |
95 | | - atom_fg_edges = set(augmented_graph_edges[ATOM_FG_EDGE]) |
96 | | - within_fg_edges = set(augmented_graph_edges[WITHIN_FG_EDGE]) |
97 | | - fg_graph_edges = set(augmented_graph_edges[FG_GRAPHNODE_EDGE]) |
| 95 | + atom_fg_edges = set(augmented_graph_edges[k.ATOM_FG_EDGE]) |
| 96 | + within_fg_edges = set(augmented_graph_edges[k.WITHIN_FG_EDGE]) |
| 97 | + fg_graph_edges = set(augmented_graph_edges[k.FG_GRAPHNODE_EDGE]) |
98 | 98 |
|
99 | 99 | for src, tgt in zip(src_nodes, tgt_nodes): |
100 | 100 | undirected_edge_set = {f"{src}_{tgt}", f"{tgt}_{src}"} |
101 | 101 | if undirected_edge_set & with_atom_edges: |
102 | | - edge_type = WITHIN_ATOMS_EDGE |
| 102 | + edge_type = k.WITHIN_ATOMS_EDGE |
103 | 103 | elif undirected_edge_set & atom_fg_edges: |
104 | | - edge_type = ATOM_FG_EDGE |
| 104 | + edge_type = k.ATOM_FG_EDGE |
105 | 105 | elif undirected_edge_set & within_fg_edges: |
106 | | - edge_type = WITHIN_FG_EDGE |
| 106 | + edge_type = k.WITHIN_FG_EDGE |
107 | 107 | elif undirected_edge_set & fg_graph_edges: |
108 | | - edge_type = FG_GRAPHNODE_EDGE |
| 108 | + edge_type = k.FG_GRAPHNODE_EDGE |
109 | 109 | else: |
110 | 110 | raise ValueError("Unexpected edge type") |
111 | 111 | G.add_edge(src, tgt, edge_type=edge_type, edge_color=EDGE_COLOR_MAP[edge_type]) |
@@ -266,10 +266,10 @@ def _draw_3d(G: nx.Graph, mol: Mol) -> None: |
266 | 266 |
|
267 | 267 | # Collect edges by type |
268 | 268 | edge_type_to_edges = { |
269 | | - WITHIN_ATOMS_EDGE: [], |
270 | | - ATOM_FG_EDGE: [], |
271 | | - WITHIN_FG_EDGE: [], |
272 | | - FG_GRAPHNODE_EDGE: [], |
| 269 | + k.WITHIN_ATOMS_EDGE: [], |
| 270 | + k.ATOM_FG_EDGE: [], |
| 271 | + k.WITHIN_FG_EDGE: [], |
| 272 | + k.FG_GRAPHNODE_EDGE: [], |
273 | 273 | } |
274 | 274 | for src, tgt, data in G.edges(data=True): |
275 | 275 | edge_type_to_edges[data["edge_type"]].append((src, tgt)) |
|
0 commit comments