forked from Yasoz/PLGF
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
126 lines (106 loc) · 4.37 KB
/
main.py
File metadata and controls
126 lines (106 loc) · 4.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import torch.nn.functional as F
import time, math, os
from util.PLGF import PLGF
from util.metrics import get_MSE, get_MAE, get_MAPE
from util.dataset import get_dataloader
from util.focalloss import FocalL1Loss
def set_seed(seed=42):
torch.manual_seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def train_model(model, train_loader, val_loader, num_epochs, lr, device):
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
criterion = FocalL1Loss()
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=8)
best_val_loss = float('inf')
patience_counter = 0
patience = 12
train_losses = []
val_losses = []
for epoch in range(num_epochs):
model.train()
train_loss = 0
for X, Y, ext in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
X, Y, ext = X.to(device), Y.to(device), ext.to(device)
optimizer.zero_grad()
output = model(X, ext)
loss = criterion(output, Y)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
train_loss += loss.item() * X.size(0)
train_loss /= len(train_loader.dataset)
train_losses.append(train_loss)
val_loss = 0
if val_loader is not None:
model.eval()
with torch.no_grad():
for X, Y, ext in val_loader:
X, Y, ext = X.to(device), Y.to(device), ext.to(device)
output = model(X, ext)
total_loss = criterion(output, Y)
val_loss += total_loss.item() * X.size(0)
val_loss /= len(val_loader.dataset)
val_losses.append(val_loss)
scheduler.step(val_loss)
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
torch.save(model.state_dict(), 'best_model.pth')
else:
patience_counter += 1
print(f"Epoch {epoch+1}, Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}")
if patience_counter >= patience:
print(f"Early stopping at epoch {epoch+1}")
break
else:
print(f"Epoch {epoch+1}, Train Loss: {train_loss:.6f}")
if val_loader is not None:
model.load_state_dict(torch.load('best_model.pth'))
print("Loaded best model based on validation loss")
return train_losses, val_losses
def test_model(model, test_loader, device='cuda', desc="Test"):
model.eval()
preds = []
reals = []
with torch.no_grad():
for X, Y, ext in test_loader:
X, ext = X.to(device), ext.to(device)
output = model(X, ext).cpu().numpy()
preds.append(output)
reals.append(Y.numpy())
preds = np.concatenate(preds, axis=0)
reals = np.concatenate(reals, axis=0)
mse = get_MSE(preds, reals)
mae = get_MAE(preds, reals)
mape = get_MAPE(preds, reals)
print(f"{desc}: MSE={mse:.6f}, MAE={mae:.6f}, MAPE={mape:.6f}")
return mse, mae, mape
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_epochs = 200
batch_size = 16
lr = 3e-4
datapath = "xxxxxxx/Urbanflow" # change to your data path
dataset_name = "TaxiBJ"
task_id = 1
train_loader = get_dataloader(datapath, dataset=dataset_name, batch_size=batch_size, mode='train', task_id=task_id)
val_loader = get_dataloader(datapath, dataset=dataset_name, batch_size=batch_size, mode='valid', task_id=task_id)
test_loader = get_dataloader(datapath, dataset=dataset_name, batch_size=batch_size, mode='test', task_id=task_id)
model = PLGF(base_channels=128, num_layers=4).to(device)
train_model(model, train_loader, val_loader, num_epochs, lr, device)
print("Final Test:")
test_model(model, test_loader, device=device, desc="Test")
if __name__ == "__main__":
set_seed()
main()