-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
42 lines (31 loc) · 1.24 KB
/
train.py
File metadata and controls
42 lines (31 loc) · 1.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from torch import optim, nn
from torch.cuda import amp
from config import *
from train_utils import *
from unet import UNet
from data import CarvanaDataset
def train():
train_loader, val_loader = CarvanaDataset.get_dataloaders(
IMAGE_HEIGHT, IMAGE_WIDTH, TRAIN_IMG_DIR, TRAIN_MASK_DIR, VAL_IMG_DIR, VAL_MASK_DIR,
BATCH_SIZE, PIN_MEMORY, NUM_WORKERS
)
model = UNet(in_ch=3, out_ch=1)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scaler = amp.GradScaler()
if LOAD_MODEL:
model = load_checkpoint(model, filepath=MODEL_RESTORE_PATH)
check_accuracy(val_loader, model, device=DEVICE)
for epoch in range(NUM_EPOCHS):
print(f'{epoch+1}/{NUM_EPOCHS}:')
# one training pass
train_fn(train_loader, model, optimizer, loss_fn, scaler, device=DEVICE)
# save model
state = {'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}
save_checkpoint(state, filepath=MODEL_SAVE_PATH)
# check accuracy
check_accuracy(val_loader, model, device=DEVICE)
# save preds
save_preds_as_imgs(val_loader, model, device=DEVICE, folder=PRED_EX_SAVE_PATH)
if __name__ == '__main__':
train()