|
22 | 22 | MolecularProperty, |
23 | 23 | ) |
24 | 24 | from chebai_graph.preprocessing.reader import ( |
25 | | - GraphFGAugmentorReader, |
| 25 | + AtomFGReader_NoFGEdges_WithGraphNode, |
| 26 | + AtomFGReader_WithFGEdges_NoGraphNode, |
| 27 | + AtomFGReader_WithFGEdges_WithGraphNode, |
| 28 | + AtomReader_WithGraphNodeOnly, |
| 29 | + AtomsFGReader_NoFGEdges_NoGraphNode, |
26 | 30 | GraphPropertyReader, |
27 | 31 | GraphReader, |
28 | 32 | ) |
@@ -178,20 +182,11 @@ def _merge_props_into_base(self, row): |
178 | 182 | ) |
179 | 183 | else: |
180 | 184 | molecule_attr = torch.cat([molecule_attr, property_values], dim=1) |
181 | | - |
182 | | - is_atom_node = ( |
183 | | - geom_data.is_atom_node if hasattr(geom_data, "is_atom_node") else None |
184 | | - ) |
185 | | - is_graph_node = ( |
186 | | - geom_data.is_graph_node if hasattr(geom_data, "is_graph_node") else None |
187 | | - ) |
188 | 185 | return GeomData( |
189 | 186 | x=x, |
190 | 187 | edge_index=geom_data.edge_index, |
191 | 188 | edge_attr=edge_attr, |
192 | 189 | molecule_attr=molecule_attr, |
193 | | - is_atom_node=is_atom_node, |
194 | | - is_graph_node=is_graph_node, |
195 | 190 | ) |
196 | 191 |
|
197 | 192 | def load_processed_data_from_file(self, filename): |
@@ -249,5 +244,55 @@ class ChEBI50GraphPropertiesPartial(ChEBI50GraphProperties, ChEBIOverXPartial): |
249 | 244 | pass |
250 | 245 |
|
251 | 246 |
|
252 | | -class ChEBI50GraphFGAugmentorReader(GraphPropertiesMixIn, ChEBIOver50): |
253 | | - READER = GraphFGAugmentorReader |
| 247 | +class AugGraphPropMixIn_NoGraphNode(GraphPropertiesMixIn, ABC): |
| 248 | + READER = None |
| 249 | + |
| 250 | + def _merge_props_into_base(self, row): |
| 251 | + data = super()._merge_props_into_base(row) |
| 252 | + geom_data = row["features"] |
| 253 | + assert isinstance(geom_data, GeomData) and isinstance(data, GeomData) |
| 254 | + |
| 255 | + is_atom_node = geom_data.is_atom_node |
| 256 | + assert is_atom_node is not None, "is_atom_node must be set in the geom_data" |
| 257 | + data.is_atom_node = is_atom_node |
| 258 | + return data |
| 259 | + |
| 260 | + |
| 261 | +class AugGraphPropMixIn_WithGraphNode(AugGraphPropMixIn_NoGraphNode, ABC): |
| 262 | + READER = None |
| 263 | + |
| 264 | + def _merge_props_into_base(self, row): |
| 265 | + data = super()._merge_props_into_base(row) |
| 266 | + return self._add_graph_node_mask(data, row) |
| 267 | + |
| 268 | + def _add_graph_node_mask(self, data: GeomData, row) -> GeomData: |
| 269 | + """ |
| 270 | + Add a mask for graph nodes to the data. |
| 271 | + This is used to distinguish between atom nodes and graph nodes. |
| 272 | + """ |
| 273 | + geom_data = row["features"] |
| 274 | + assert isinstance(geom_data, GeomData) and isinstance(data, GeomData) |
| 275 | + is_graph_node = geom_data.is_graph_node |
| 276 | + assert is_graph_node is not None, "is_graph_node must be set in the geom_data" |
| 277 | + data.is_graph_node = is_graph_node |
| 278 | + return data |
| 279 | + |
| 280 | + |
| 281 | +class ChEBI50_WFGE_WGN_GraphProp(AugGraphPropMixIn_WithGraphNode, ChEBIOver50): |
| 282 | + READER = AtomFGReader_WithFGEdges_WithGraphNode |
| 283 | + |
| 284 | + |
| 285 | +class ChEBI50_NFGE_WGN_GraphProp(AugGraphPropMixIn_WithGraphNode, ChEBIOver50): |
| 286 | + READER = AtomFGReader_NoFGEdges_WithGraphNode |
| 287 | + |
| 288 | + |
| 289 | +class ChEBI50_WFGE_NGN_GraphProp(AugGraphPropMixIn_NoGraphNode, ChEBIOver50): |
| 290 | + READER = AtomFGReader_WithFGEdges_NoGraphNode |
| 291 | + |
| 292 | + |
| 293 | +class ChEBI50_NFGE_NGN_GraphProp(AugGraphPropMixIn_NoGraphNode, ChEBIOver50): |
| 294 | + READER = AtomsFGReader_NoFGEdges_NoGraphNode |
| 295 | + |
| 296 | + |
| 297 | +class ChEBI50_Atom_WGNOnly_GraphProp(AugGraphPropMixIn_WithGraphNode, ChEBIOver50): |
| 298 | + READER = AtomReader_WithGraphNodeOnly |
0 commit comments