Skip to content

Commit 7560fec

Browse files
committed
add data classes for ablation readers
1 parent f0bb097 commit 7560fec

File tree

3 files changed

+68
-14
lines changed

3 files changed

+68
-14
lines changed

chebai_graph/preprocessing/datasets/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from .chebi import (
2+
ChEBI50_Atom_WGNOnly_GraphProp,
3+
ChEBI50_NFGE_NGN_GraphProp,
4+
ChEBI50_NFGE_WGN_GraphProp,
5+
ChEBI50_WFGE_NGN_GraphProp,
6+
ChEBI50_WFGE_WGN_GraphProp,
27
ChEBI50GraphData,
3-
ChEBI50GraphFGAugmentorReader,
48
ChEBI50GraphProperties,
59
)
610
from .pubchem import PubChemGraphProperties
@@ -10,4 +14,9 @@
1014
"ChEBI50GraphProperties",
1115
"ChEBI50GraphData",
1216
"PubChemGraphProperties",
17+
"ChEBI50_Atom_WGNOnly_GraphProp",
18+
"ChEBI50_NFGE_NGN_GraphProp",
19+
"ChEBI50_NFGE_WGN_GraphProp",
20+
"ChEBI50_WFGE_NGN_GraphProp",
21+
"ChEBI50_WFGE_WGN_GraphProp",
1322
]

chebai_graph/preprocessing/datasets/chebi.py

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@
2222
MolecularProperty,
2323
)
2424
from chebai_graph.preprocessing.reader import (
25-
GraphFGAugmentorReader,
25+
AtomFGReader_NoFGEdges_WithGraphNode,
26+
AtomFGReader_WithFGEdges_NoGraphNode,
27+
AtomFGReader_WithFGEdges_WithGraphNode,
28+
AtomReader_WithGraphNodeOnly,
29+
AtomsFGReader_NoFGEdges_NoGraphNode,
2630
GraphPropertyReader,
2731
GraphReader,
2832
)
@@ -178,20 +182,11 @@ def _merge_props_into_base(self, row):
178182
)
179183
else:
180184
molecule_attr = torch.cat([molecule_attr, property_values], dim=1)
181-
182-
is_atom_node = (
183-
geom_data.is_atom_node if hasattr(geom_data, "is_atom_node") else None
184-
)
185-
is_graph_node = (
186-
geom_data.is_graph_node if hasattr(geom_data, "is_graph_node") else None
187-
)
188185
return GeomData(
189186
x=x,
190187
edge_index=geom_data.edge_index,
191188
edge_attr=edge_attr,
192189
molecule_attr=molecule_attr,
193-
is_atom_node=is_atom_node,
194-
is_graph_node=is_graph_node,
195190
)
196191

197192
def load_processed_data_from_file(self, filename):
@@ -249,5 +244,55 @@ class ChEBI50GraphPropertiesPartial(ChEBI50GraphProperties, ChEBIOverXPartial):
249244
pass
250245

251246

252-
class ChEBI50GraphFGAugmentorReader(GraphPropertiesMixIn, ChEBIOver50):
253-
READER = GraphFGAugmentorReader
247+
class AugGraphPropMixIn_NoGraphNode(GraphPropertiesMixIn, ABC):
248+
READER = None
249+
250+
def _merge_props_into_base(self, row):
251+
data = super()._merge_props_into_base(row)
252+
geom_data = row["features"]
253+
assert isinstance(geom_data, GeomData) and isinstance(data, GeomData)
254+
255+
is_atom_node = geom_data.is_atom_node
256+
assert is_atom_node is not None, "is_atom_node must be set in the geom_data"
257+
data.is_atom_node = is_atom_node
258+
return data
259+
260+
261+
class AugGraphPropMixIn_WithGraphNode(AugGraphPropMixIn_NoGraphNode, ABC):
262+
READER = None
263+
264+
def _merge_props_into_base(self, row):
265+
data = super()._merge_props_into_base(row)
266+
return self._add_graph_node_mask(data, row)
267+
268+
def _add_graph_node_mask(self, data: GeomData, row) -> GeomData:
269+
"""
270+
Add a mask for graph nodes to the data.
271+
This is used to distinguish between atom nodes and graph nodes.
272+
"""
273+
geom_data = row["features"]
274+
assert isinstance(geom_data, GeomData) and isinstance(data, GeomData)
275+
is_graph_node = geom_data.is_graph_node
276+
assert is_graph_node is not None, "is_graph_node must be set in the geom_data"
277+
data.is_graph_node = is_graph_node
278+
return data
279+
280+
281+
class ChEBI50_WFGE_WGN_GraphProp(AugGraphPropMixIn_WithGraphNode, ChEBIOver50):
282+
READER = AtomFGReader_WithFGEdges_WithGraphNode
283+
284+
285+
class ChEBI50_NFGE_WGN_GraphProp(AugGraphPropMixIn_WithGraphNode, ChEBIOver50):
286+
READER = AtomFGReader_NoFGEdges_WithGraphNode
287+
288+
289+
class ChEBI50_WFGE_NGN_GraphProp(AugGraphPropMixIn_NoGraphNode, ChEBIOver50):
290+
READER = AtomFGReader_WithFGEdges_NoGraphNode
291+
292+
293+
class ChEBI50_NFGE_NGN_GraphProp(AugGraphPropMixIn_NoGraphNode, ChEBIOver50):
294+
READER = AtomsFGReader_NoFGEdges_NoGraphNode
295+
296+
297+
class ChEBI50_Atom_WGNOnly_GraphProp(AugGraphPropMixIn_WithGraphNode, ChEBIOver50):
298+
READER = AtomReader_WithGraphNodeOnly

configs/data/chebi50_augmented_baseline.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
class_path: chebai_graph.preprocessing.datasets.ChEBI50GraphFGAugmentorReader
1+
class_path: chebai_graph.preprocessing.datasets.ChEBI50_WFGE_WGN_GraphProp
22
init_args:
33
properties:
44
- chebai_graph.preprocessing.properties.AugAtomType

0 commit comments

Comments
 (0)