forked from juntaoJianggavin/RWKV-UNet
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
90 lines (84 loc) · 3.84 KB
/
train.py
File metadata and controls
90 lines (84 loc) · 3.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import argparse
import logging
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from rwkv_unet import RWKV_UNet
from trainer import trainer_synapse, trainer_acdc
parser = argparse.ArgumentParser()
parser.add_argument('--root_path', type=str,
default='./data/Synapse/train_npz', help='root dir for data')
parser.add_argument('--dataset', type=str,
default='Synapse', help='experiment_name')
parser.add_argument('--list_dir', type=str,
default='./lists/lists_Synapse', help='list dir')
parser.add_argument('--num_classes', type=int,
default=9, help='output channel of network')
parser.add_argument('--max_iterations', type=int,
default=30000, help='maximum epoch number to train')
parser.add_argument('--max_epochs', type=int,
default=30, help='maximum epoch number to train')
parser.add_argument('--batch_size', type=int,
default=24, help='batch_size per gpu')
parser.add_argument('--n_gpu', type=int, default=1, help='total gpu')
parser.add_argument('--deterministic', type=int, default=1,
help='whether use deterministic training')
parser.add_argument('--base_lr', type=float, default=0.001,
help='segmentation network learning rate')
parser.add_argument('--img_size', type=int,
default=224, help='input patch size of network input')
parser.add_argument('--seed', type=int,
default=1234, help='random seed')
parser.add_argument('--pretrained_path', type=str,
default='net_B.pth')
args = parser.parse_args()
if __name__ == "__main__":
if not args.deterministic:
cudnn.benchmark = True
cudnn.deterministic = False
else:
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
dataset_name = args.dataset
dataset_config = {
'ACDC': {
#'Dataset': ACDC_dataset, # datasets.dataset_acdc.BaseDataSets,
'root_path': './data/ACDC',
'list_dir': None,
'num_classes': 4,
},
'Synapse': {
'root_path': './data/Synapse/train_npz',
'list_dir': './lists/lists_Synapse',
'num_classes': 9,
},
}
args.num_classes = dataset_config[dataset_name]['num_classes']
args.root_path = dataset_config[dataset_name]['root_path']
args.list_dir = dataset_config[dataset_name]['list_dir']
args.is_pretrain = True
args.exp = 'rwkv' + dataset_name + str(args.img_size)
snapshot_path = "./outputs/{}/{}".format(args.exp, 'exp1')
snapshot_path = snapshot_path + '_pretrain' if args.is_pretrain else snapshot_path
snapshot_path = snapshot_path+'_'+str(args.max_iterations)[0:2]+'k' if args.max_iterations != 30000 else snapshot_path
snapshot_path = snapshot_path + '_epo' +str(args.max_epochs) if args.max_epochs != 30 else snapshot_path
snapshot_path = snapshot_path+'_bs'+str(args.batch_size)
snapshot_path = snapshot_path + '_lr' + str(args.base_lr) if args.base_lr != 0.01 else snapshot_path
snapshot_path = snapshot_path + '_'+str(args.img_size)
snapshot_path = snapshot_path + '_s'+str(args.seed) if args.seed!=1234 else snapshot_path
if not os.path.exists(snapshot_path):
os.makedirs(snapshot_path)
net = RWKV_UNet(in_channels=1, img_size=args.img_size, num_classes=args.num_classes,pretrained_path=pretrained_path)
if torch.cuda.device_count() > 1:
print(f"Using {torch.cuda.device_count()} GPUs")
net = nn.DataParallel(net)
net = net.to('cuda')
trainer = {'Synapse': trainer_synapse,'ACDC': trainer_acdc,}
trainer[dataset_name](args, net, snapshot_path)