-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
109 lines (88 loc) · 4.14 KB
/
main.py
File metadata and controls
109 lines (88 loc) · 4.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import fire
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import numpy as np
import torchvision
import torchvision.transforms as transforms
#from torch.utils import set_seeds
import model.resnet as resnet
import train
def preprocess():
labels = pd.read_csv("stage_2_train_labels.csv")
labels = labels.drop_duplicates("patientId")
ROOT_PATH = Path("stage_2_train_images")
SAVE_PATH = Path('output')
for c, patientid in enumerate(tqdm(labels.patientId)):
patient_id=labels.patientId.iloc[c]
dcm_path=ROOT_PATH/patient_id
dcm_path=dcm_path.with_suffix(".dcm")
dcm=pydicom.read_file(dcm_path).pixel_array/255
dcm_array = cv2.resize(dcm, (224,224)).astype(np.float16)
label = labels.Target.iloc[c]
train_or_val = "train" if c < 24000 else "val"
current_save_path = SAVE_PATH/train_or_val/str(label)
current_save_path.mkdir(parents=True, exist_ok=True)
np.save(current_save_path/patient_id, dcm_array)
def load_file(path):
return np.load(path).astype(np.float32)
def get_model(model_name, device):
if model_name=='Resnet':
return resnet.ResNet(resnet.ResidualBlock, [3, 4, 6, 3]).to(device)
def main(task='Hemorrhage',
train_cfg='config/train_Hemorrhage.json',
model_name='Resnet',
train_file="datasets/Hemorrhage/train/",
test_file="datasets/Hemorrhage/val/",
save_dir='save_model/resnet',
mode='train'):
cfg=train.Config.from_json(train_cfg)
#model_cfg=resnet.Config.from_json(model_cfg)
preprocess()
#set_seeds(cfg.seed)
train_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(0.49, 0.248),
transforms.RandomAffine(degrees=(-5, 5), translate=(0, 0.05), scale=(0.9, 1.1)),
transforms.RandomResizedCrop((224, 224), scale=(0.35, 1))
])
val_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(0.49, 0.248)
])
train_dataset = torchvision.datasets.DatasetFolder(train_file, loader=load_file, extensions="npy", transform=train_transforms)
val_dataset = torchvision.datasets.DatasetFolder(test_file, loader=load_file, extensions="npy", transform=val_transforms)
train_loader=torch.utils.data.DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True)
val_loader=torch.utils.data.DataLoader(val_dataset, batch_size=cfg.batch_size, shuffle=False)
criterion=nn.CrossEntropyLoss()
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model=get_model(model_name, device)
optimizer=torch.optim.SGD(model.parameters(), lr=cfg.lr, momentum=cfg.momentum, weight_decay=cfg.weight_decay)
schedule=torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.6, last_epoch=-1)
trainer = train.Trainer(cfg,
model,
train_loader,
val_loader,
optimizer,
schedule,
save_dir, device)
if mode=='train':
def get_loss(model, batch, global_step):
images, labels=batch
output=model(images)
loss=criterion(output, labels).to(device)
return loss
trainer.train(get_loss)
def evalute(model, batch):
images, labels=batch
outputs=model(images)
_, prediction=torch.max(outputs.data, 1)
result=(prediction == labels).float()
accuracy=result.mean()
return accuracy, result
results=trainer.eval(evaluate)
total_accuracy=torch.cat(results).mean().item()
print("Accuracy:", total_accuracy)
if __name__=='__main__':
fire.Fire(main)