forked from DCBIA-OrthoLab/ShapeAXI
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsaxi_train_v2.py
More file actions
151 lines (110 loc) · 6.23 KB
/
saxi_train_v2.py
File metadata and controls
151 lines (110 loc) · 6.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import argparse
import math
import os
import pandas as pd
import numpy as np
import torch
from shapeaxi.saxi_dataset import SaxiDataModule, SaxiDataModuleVF
from shapeaxi.saxi_transforms import TrainTransform, EvalTransform
from shapeaxi import saxi_nets
from shapeaxi import saxi_logger
import lightning as L
from lightning import Trainer
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.strategies import DDPStrategy
from lightning.pytorch.loggers import NeptuneLogger
def main(args):
if(os.path.splitext(args.csv_train)[1] == ".csv"):
df_train = pd.read_csv(args.csv_train)
df_val = pd.read_csv(args.csv_valid)
else:
df_train = pd.read_parquet(args.csv_train)
df_val = pd.read_parquet(args.csv_valid)
NN = getattr(saxi_nets, args.nn)
model = NN(**vars(args))
train_transform = TrainTransform(scale_factor=args.scale_factor)
valid_transform = EvalTransform(scale_factor=args.scale_factor)
lotus_data = SaxiDataModule(df_train, df_val, df_val, mount_point=args.mount_point, batch_size=args.batch_size, num_workers=4, surf_column=args.surf_column, class_column=args.class_column, scalar_column=args.scalar_column, train_transform=train_transform, valid_transform=valid_transform, drop_last=False)
# lotus_data = SaxiDataModuleVF(df_train, df_val, df_val, mount_point=args.mount_point, batch_size=args.batch_size, num_workers=4, surf_column=args.surf_column, class_column=args.class_column, train_transform=train_transform, valid_transform=valid_transform, drop_last=False)
# lotus_data.setup()
# dl = lotus_data.train_dataloader()
# for batch in dl:
# V, F = batch
# print(V.shape, F.shape)
callbacks = []
checkpoint_callback = ModelCheckpoint(
dirpath=args.out,
filename='{epoch}-{val_loss:.2f}',
save_top_k=2,
monitor='val_loss',
save_last=True,
)
callbacks.append(checkpoint_callback)
if args.monitor:
checkpoint_callback_acc = ModelCheckpoint(
dirpath=args.out,
filename='{epoch}-{' + args.monitor + '}:.2f}',
save_top_k=2,
monitor=args.monitor,
save_last=True,
mode='max'
)
callbacks.append(checkpoint_callback_acc)
early_stop_callback = EarlyStopping(monitor='val_loss', min_delta=0.00, patience=args.patience, verbose=True, mode="min")
callbacks.append(early_stop_callback)
logger_neptune = None
if args.neptune_tags:
logger_neptune = NeptuneLogger(
project='ImageMindAnalytics/saxi',
tags=args.neptune_tags,
api_key=os.environ['NEPTUNE_API_TOKEN'],
log_model_checkpoints=False
)
LOGGER = getattr(saxi_logger, args.logger)
image_logger = LOGGER(log_steps=args.log_steps)
callbacks.append(image_logger)
trainer = Trainer(
logger=logger_neptune,
log_every_n_steps=args.log_steps,
max_epochs=args.epochs,
max_steps=args.steps,
callbacks=callbacks,
accelerator='gpu',
devices=torch.cuda.device_count(),
# strategy=DDPStrategy(),
strategy='ddp'
)
trainer.fit(model, datamodule=lotus_data, ckpt_path=args.model)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Shape Analysis Explainaiblity and Interpretability train', conflict_handler='resolve')
hparams_group = parser.add_argument_group('Hyperparameters')
hparams_group.add_argument('--epochs', help='Max number of epochs', type=int, default=200)
hparams_group.add_argument('--patience', help='Max number of patience for early stopping', type=int, default=30)
hparams_group.add_argument('--steps', help='Max number of steps per epoch', type=int, default=-1)
hparams_group.add_argument('--batch_size', help='Batch size', type=int, default=2)
hparams_group.add_argument('--monitor', help='which other variable to monitor to save checkpoints', type=str, default=None)
input_group = parser.add_argument_group('Input')
input_group.add_argument('--nn', help='Type of neural network', type=str, default="USAEReconstruction")
input_group.add_argument('--model', help='Model to continue training', type=str, default= None)
input_group.add_argument('--mount_point', help='Dataset mount directory', type=str, default="./")
input_group.add_argument('--num_workers', help='Number of workers for loading', type=int, default=4)
input_group.add_argument('--csv_train', required=True, type=str, help='Train CSV')
input_group.add_argument('--csv_valid', required=True, type=str, help='Valid CSV')
input_group.add_argument('--surf_column', type=str, default='surf_path', help='Column name for the surface data')
input_group.add_argument('--class_column', type=str, default=None, help='Column name for the class column')
input_group.add_argument('--scalar_column', type=str, default=None, help='Column name for the scalar column')
input_group.add_argument('--scale_factor', type=float, default=None, help='Use a common scale factor')
output_group = parser.add_argument_group('Output')
output_group.add_argument('--out', help='Output directory', type=str, default="./")
output_group.add_argument('--use_early_stopping', help='Use early stopping criteria', type=int, default=0)
output_group.add_argument('--monitor', help='Additional metric to monitor to save checkpoints', type=str, default=None)
log_group = parser.add_argument_group('Logging')
log_group.add_argument('--neptune_tags', help='Neptune tags', type=str, nargs="+", default=None)
log_group.add_argument('--logger', help='Neptune tags', type=str, default="USAEReconstructionNeptuneLogger")
log_group.add_argument('--log_steps', help='Log every N steps', type=int, default=20)
args, unknownargs = parser.parse_known_args()
NN = getattr(saxi_nets, args.nn)
NN.add_model_specific_args(parser)
args = parser.parse_args()
main(args)