-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutil.py
More file actions
56 lines (46 loc) · 1.42 KB
/
util.py
File metadata and controls
56 lines (46 loc) · 1.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""\
This contains some utility functions, mainly for handling the model
"""
import math
from model import GNN_FULL_CLASS
import torch
def chunk_into_n(lst, n):
size = math.ceil(len(lst) / n)
return list(
map(lambda x: lst[x * size:x * size + size],
list(range(n)))
)
def init_model(NO_MP, lr, wd):
# Model
NO_MP = NO_MP
model = GNN_FULL_CLASS(NO_MP)
# Optimizer
LEARNING_RATE = lr
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=wd)
# Criterion
#criterion = torch.nn.MSELoss()
criterion = torch.nn.L1Loss()
return model, optimizer, criterion
def train(model, criterion, optimizer, loader):
loss_sum = 0
for batch in loader:
# Forward pass and gradient descent
labels = batch.y
predictions = model(batch)
loss = criterion(predictions[torch.isfinite(labels)], labels[torch.isfinite(labels)])
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_sum += loss.item()
return loss_sum/len(loader)
def evaluate(model, criterion, loader):
loss_sum = 0
with torch.no_grad():
for batch in loader:
# Forward pass
labels = batch.y
predictions = model(batch)
loss = criterion(predictions[torch.isfinite(labels)], labels[torch.isfinite(labels)])
loss_sum += loss.item()
return loss_sum/len(loader)