-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
executable file
·136 lines (103 loc) · 5.01 KB
/
main.py
File metadata and controls
executable file
·136 lines (103 loc) · 5.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import argparse
import torch
import numpy as np
import random
import time
from torchinfo import summary
from src.data import load_data
from src.methods.deep_network import MLP, CNN, Trainer
from src.methods.dummy_methods import DummyClassifier
from src.utils import normalize_fn, append_bias_term, accuracy_fn, macrof1_fn, get_n_classes
def set_seed(seed):
"""
Sets the seed for random number generators in PyTorch, NumPy and random.
This ensures reproducibility of results.
Arguments:
seed (int): The seed value to use.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def main(args):
"""
The main function of the script. Do not hesitate to play with it
and add your own code, visualization, prints, etc!
Arguments:
args (Namespace): arguments that were parsed from the command line (see at the end
of this file). Their value can be accessed as "args.argument".
"""
## 1. First, we load our data and flatten the images into vectors
start_timer = time.time()
set_seed(2025)
xtrain, xtest, ytrain, ytest = load_data()
## 2. Then we must prepare it. This is were you can create a validation set,
# normalize, add bias, etc.
means = np.mean(xtrain)
stds = np.std(xtrain) + 1e-8
xtrain = normalize_fn(xtrain, means, stds)
xtest = normalize_fn(xtest, means, stds)
# Make a validation set
if not args.test:
n_train = xtrain.shape[0]
n_val = (n_train // 5)
indices = np.random.permutation(n_train)
train_idx, val_idx = indices[n_val:], indices[:n_val]
xtest, ytest = xtrain[val_idx], ytrain[val_idx]
xtrain, ytrain = xtrain[train_idx], ytrain[train_idx]
if args.nn_type == "mlp":
xtrain = xtrain.reshape(xtrain.shape[0], -1)
xtest = xtest.reshape(xtest.shape[0], -1)
elif args.nn_type == "cnn":
xtrain = xtrain.transpose(0, 3, 1, 2)
xtest = xtest.transpose(0, 3, 1, 2)
### WRITE YOUR CODE HERE to do any other data processing
# Neural Networks (MS2)
# Prepare the model (and data) for Pytorch
# Note: you might need to reshape the data depending on the network you use!
n_classes = get_n_classes(ytrain)
if args.nn_type == "mlp":
model = MLP(input_size=28*28*3, n_classes=n_classes)
elif args.nn_type == "cnn":
model = CNN(input_channels=3, n_classes=n_classes)
summary(model)
# Trainer object
method_obj = Trainer(model, lr=args.lr, epochs=args.max_iters, batch_size=args.nn_batch_size)
## 4. Train and evaluate the method
# Fit (:=train) the method on the training data
preds_train = method_obj.fit(xtrain, ytrain)
# Predict on unseen data
preds = method_obj.predict(xtest)
## Report results: performance on train and valid/test sets
acc = accuracy_fn(preds_train, ytrain)
macrof1 = macrof1_fn(preds_train, ytrain)
print(f"\nTrain set: accuracy = {acc:.3f}% - F1-score = {macrof1:.6f}")
## As there are no test dataset labels, check your model accuracy on validation dataset.
# You can check your model performance on test set by submitting your test set predictions on the AIcrowd competition.
acc = accuracy_fn(preds, ytest)
macrof1 = macrof1_fn(preds, ytest)
print(f"Validation set: accuracy = {acc:.3f}% - F1-score = {macrof1:.6f}")
### WRITE YOUR CODE HERE if you want to add other outputs, visualization, etc.
print(f"Time taken in s: {int(time.time() - start_timer)}")
if __name__ == '__main__':
# Definition of the arguments that can be given through the command line (terminal).
# If an argument is not given, it will take its default value as defined below.
parser = argparse.ArgumentParser()
# Feel free to add more arguments here if you need!
# MS2 arguments
parser.add_argument('--data', default="dataset", type=str, help="path to your dataset")
parser.add_argument('--nn_type', default="mlp",
help="which network architecture to use, it can be 'mlp' | 'transformer' | 'cnn'")
parser.add_argument('--nn_batch_size', type=int, default=64, help="batch size for NN training")
parser.add_argument('--device', type=str, default="mps",
help="Device to use for the training, it can be 'cpu' | 'cuda' | 'mps'")
parser.add_argument('--lr', type=float, default=2e-4, help="learning rate for methods with learning rate")
parser.add_argument('--max_iters', type=int, default=50, help="max iters for methods which are iterative")
parser.add_argument('--test', action="store_true",
help="train on whole training data and evaluate on the test data, otherwise use a validation set")
# "args" will keep in memory the arguments and their values,
# which can be accessed as "args.data", for example.
args = parser.parse_args()
main(args)