forked from CRIPAC-DIG/GRACE
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
140 lines (113 loc) · 4.41 KB
/
train.py
File metadata and controls
140 lines (113 loc) · 4.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import argparse
import os.path as osp
import random
from time import perf_counter as t
import yaml
from yaml import SafeLoader
import torch
import torch_geometric.transforms as T
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.datasets import Planetoid, CitationFull
from torch_geometric.utils import dropout_adj
from torch_geometric.nn import GCNConv
from model import Encoder, Model, drop_feature
from eval import label_classification
def train(model: Model, x, edge_index):
model.train()
optimizer.zero_grad()
edge_index_1 = dropout_adj(edge_index, p=drop_edge_rate_1)[0]
edge_index_2 = dropout_adj(edge_index, p=drop_edge_rate_2)[0]
x_1 = drop_feature(x, drop_feature_rate_1)
x_2 = drop_feature(x, drop_feature_rate_2)
z1 = model(x_1, edge_index_1)
z2 = model(x_2, edge_index_2)
loss = model.loss(z1, z2, batch_size=0)
loss.backward()
optimizer.step()
return loss.item()
def test(model: Model, x, edge_index, y, final=False):
model.eval()
z = model(x, edge_index)
label_classification(z, y, ratio=0.1)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset", type=str, default="DBLP")
parser.add_argument("--gpu_id", type=int, default=0)
parser.add_argument("--config", type=str, default="config.yaml")
parser.add_argument(
"--data_path", type=str, default=None, help="Path to custom .pt graph"
)
parser.add_argument(
"--save_path",
type=str,
default="grace_model.pt",
help="Where to save the checkpoint",
)
args = parser.parse_args()
assert args.gpu_id in range(0, 8)
torch.cuda.set_device(args.gpu_id)
config = yaml.load(open(args.config), Loader=SafeLoader)[args.dataset]
torch.manual_seed(config["seed"])
random.seed(12345)
learning_rate = config["learning_rate"]
num_hidden = config["num_hidden"]
num_proj_hidden = config["num_proj_hidden"]
activation = ({"relu": F.relu, "prelu": nn.PReLU()})[config["activation"]]
base_model = ({"GCNConv": GCNConv})[config["base_model"]]
num_layers = config["num_layers"]
drop_edge_rate_1 = config["drop_edge_rate_1"]
drop_edge_rate_2 = config["drop_edge_rate_2"]
drop_feature_rate_1 = config["drop_feature_rate_1"]
drop_feature_rate_2 = config["drop_feature_rate_2"]
tau = config["tau"]
num_epochs = config["num_epochs"]
weight_decay = config["weight_decay"]
def get_dataset(path, name, custom_path=None):
# Logic for custom dataset loading
if custom_path:
print(f"Loading custom graph from {custom_path}")
data_dict = torch.load(custom_path)
from torch_geometric.data import Data
# Note: Ensure keys match what you saved in Step 1 (x, edge_index, y)
data = Data(
x=data_dict["x"], edge_index=data_dict["edge_index"], y=data_dict["y"]
)
return [data]
# [cite_start]Original logic for standard datasets [cite: 108]
assert name in ["Cora", "CiteSeer", "PubMed", "DBLP"]
name = "dblp" if name == "DBLP" else name
return (CitationFull if name == "dblp" else Planetoid)(
path, name, T.NormalizeFeatures()
)
path = osp.join(osp.expanduser("~"), "datasets", args.dataset)
dataset = get_dataset(path, args.dataset, custom_path=args.data_path)
data = dataset[0]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = data.to(device)
num_features = dataset[0].x.shape[1] # Get features directly from the data tensor
encoder = Encoder(
num_features,
num_hidden,
activation,
base_model=base_model,
k=num_layers,
).to(device)
model = Model(encoder, num_hidden, num_proj_hidden, tau).to(device)
optimizer = torch.optim.Adam(
model.parameters(), lr=learning_rate, weight_decay=weight_decay
)
start = t()
prev = start
for epoch in range(1, num_epochs + 1):
loss = train(model, data.x, data.edge_index)
now = t()
print(
f"(T) | Epoch={epoch:03d}, loss={loss:.4f}, "
f"this epoch {now - prev:.4f}, total {now - start:.4f}"
)
prev = now
print("=== Final ===")
test(model, data.x, data.edge_index, data.y, final=True)
torch.save(model.state_dict(), args.save_path)
print(f"Model saved to {args.save_path}")