-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtab_train.py
More file actions
121 lines (90 loc) · 3.46 KB
/
tab_train.py
File metadata and controls
121 lines (90 loc) · 3.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import math
import numpy as np
import toml
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from utils.dataset import CensusDataset
from utils.models import TabTransformer
def cus_scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
scale = 1.0 / math.sqrt(query.size(-1))
attn_weights = torch.matmul(query, key.transpose(-2, -1)) * scale
if attn_mask is not None:
attn_weights += attn_mask
attn_weights = torch.softmax(attn_weights, dim=-1)
if dropout_p > 0.0:
attn_weights = torch.nn.functional.dropout(
attn_weights, p=dropout_p)
output = torch.matmul(attn_weights, value)
return output
from contextlib import contextmanager
@contextmanager
def use_cus_scaled_dot_product_attention():
original_sdp_attention = torch.nn.functional.scaled_dot_product_attention
try:
torch.nn.functional.scaled_dot_product_attention = cus_scaled_dot_product_attention
yield
finally:
torch.nn.functional.scaled_dot_product_attention = original_sdp_attention
def census_data():
X = []
Y = []
i = 0
with open("./dataset/census", "r") as ins:
for line in ins:
line = line.strip()
line1 = line.split(',')
if (i == 0):
i += 1
continue
L = [int(i) for i in line1[:-1]]
X.append(L)
if int(line1[-1]) == 0:
Y.append(0)
else:
Y.append(1)
X = np.array(X, dtype=float)
Y = np.array(Y, dtype=float)
input_shape = (None, 13)
nb_classes = 2
return X, Y, input_shape, nb_classes
config = toml.load("./config/census.toml")
X, Y, input_shape, nb_classes = census_data()
X_train = X[:int(len(X) * 0.8)]
Y_train = Y[:int(len(Y) * 0.8)]
X_test = X[int(len(X) * 0.8):]
Y_test = Y[int(len(Y) * 0.8):]
train_dataset = CensusDataset(X_train, Y_train)
test_dataset = CensusDataset(X_test, Y_test)
train_loader = DataLoader(train_dataset, batch_size=config["train"]["batch_size"], shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=config["train"]["batch_size"], shuffle=False)
device = "cuda"
model = TabTransformer(**config["model"]).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=config["train"]["lr"], weight_decay=config["train"]["weight_decay"])
for epoch in range(config["train"]["num_epochs"]):
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print(f"Epoch {epoch} loss: {loss.item()}")
model.eval()
test_loss = 0
correct = 0
total = 0
with use_cus_scaled_dot_product_attention():
for batch_idx, (data, target) in enumerate(test_loader):
data, target = data.to(device), target.to(device)
output = model(data)
loss = criterion(output, target)
test_loss += loss.item()
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
test_loss /= len(test_loader.dataset)
test_accuracy = 100. * correct / total
print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")