-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun.py
More file actions
111 lines (92 loc) · 4.22 KB
/
run.py
File metadata and controls
111 lines (92 loc) · 4.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import torch
import torch.optim as optim
from utils.config import load_config
from utils.transforms import *
from pathlib import Path
from factory.factory import ModelFactory, TrainerFactory, DataLoaderFactory
import sys
import json
def main():
config_debug = Path(sys.argv[1])
config = load_config(config_debug)
#SETUP
DATA_ROOT = config['data_root']
n_cpus = 4#int(os.cpu_count())
task_id = 0
start_epoch = 1
torch.backends.cudnn.benchmark = True
checkpoint_dir = os.path.join(config["checkpoint_root"],config["dataset"],config["target_attr"],config["trainer"],config["desc"],str(task_id))
#create directory if not exists
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
config["checkpoint_dir"] = checkpoint_dir
# ERM
if config["trainer"] == 'erm':
config["arch"] = 'ResNet18'
#DATASET
loaders = DataLoaderFactory.create(config["dataset"],
root= DATA_ROOT,
batch_size=config["batch_size"],
num_workers=n_cpus, configs=config)
num_classes = 2
#MODEL ERM
model_args = {
"name": config["arch"],
"feature_dim": config["feature_dim"],
"num_classes": num_classes,
}
# Model setup and optimizer config
model = ModelFactory.create(**model_args).cuda()
if torch.cuda.get_device_capability()[0] >= 7:
model = torch.compile(model)
optimizer = optim.Adam(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config["scheduler_T_max"])
trainer = TrainerFactory.create('erm', config, model, loaders, optimizer, num_classes)
trainer.set_checkpoint_dir(checkpoint_dir)
#ERM TRAINING
for epoch in range(start_epoch, config["max_epoch"]+1):
trainer.train(epoch=epoch)
trainer.test(epoch=epoch)
scheduler.step()
with open(os.path.join(os.path.join(checkpoint_dir,'results.txt'), 'w'), 'w') as f:
f.write('best_accuracy:'+ str(trainer.test_best_acc)+ "\n")
f.write('best_avg_accuracy:'+ str(trainer.test_avg_acc)+ "\n")
f.write('best_worst_accuracy:'+ str(trainer.test_worst_acc)+ "\n")
# CFIX CODE
if config["trainer"] == 'cfix':
config["arch"] = 'ResNet18Cfix'
#DATASET BASE
loaders = DataLoaderFactory.create(config["dataset"],
root= DATA_ROOT,
batch_size=config["batch_size"],
num_workers=n_cpus, configs=config)
num_classes = 2
model_args = {
"name": config['arch'],
"feature_dim": config['feature_dim'],
"num_classes": num_classes,
"pseudo_dim": config['k'],
"self_supervised": config['self_supervised'],
"config": config,
"eval" : False
}
#Model setup and optimizer config
model = ModelFactory.create(**model_args).cuda()
if torch.cuda.get_device_capability()[0] >= 7:
model = torch.compile(model)
optimizer = optim.Adam(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config["scheduler_T_max"])
trainer = TrainerFactory.create('cfix', config, model, loaders, optimizer, num_classes, config['t'], config['beta'])
trainer.set_checkpoint_dir(checkpoint_dir)
for epoch in range(start_epoch, config['max_epoch']+1):
if epoch==1:
trainer.clustering(epoch=epoch)
trainer.train(epoch=epoch)
trainer.test(epoch=epoch)
scheduler.step()
with open(os.path.join(checkpoint_dir,'results.txt'), 'w') as f:
f.write('best_avg_accuracy:'+ str(trainer.test_best_avg_acc)+ "\n")
f.write('best_worst_accuracy:'+ str(trainer.test_best_worst)+ "\n")
if __name__ == "__main__":
main()