-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
102 lines (83 loc) · 3.93 KB
/
train.py
File metadata and controls
102 lines (83 loc) · 3.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import yaml
import argparse
from torch import optim
from torch.utils.data import DataLoader
from data.datasets import E2EDataset, collate_fn
from utils.data.label_encoder import LabelEncoder
from utils.data.transforms import ImageTransform
from utils.trainers import Trainer
from models import LPICR
# ---------- Parse the arguments ---------- #
parser = argparse.ArgumentParser()
parser.add_argument('-n', '--run_name', type=str, required=True, help="Run name; Trainer will create a directory with this name and save the final model and checkpoints in it.")
parser.add_argument('-c', '--config', type=str, required=True, help="Path to config file;")
parser.add_argument('-r', '--resume', type=str, default=None, help="Path to pre-trained checkpoints. Training will be resumed;")
parser.add_argument('--wandb', help="Log progress to WandB", action='store_true')
args = parser.parse_args()
# ---------- Load the config file ---------- #
with open(args.config, 'r') as f:
config = yaml.safe_load(f)
# ---------- Initialize LabelEncoder & Model ---------- #
train_denoiser = config['train']['train_denoiser'] # Check if training denoiser or modelT
if args.resume:
label_encoder = LabelEncoder.load(args.resume)
model = LPICR.load(args.resume)
image_transform = ImageTransform.load(args.resume)
else:
image_transform = ImageTransform(**config['image_transform_config'])
label_encoder = LabelEncoder(**config['label_encoder'])
model = LPICR(
vocab_size=label_encoder.vocab_size,
image_shape=(4, image_transform.image_height, image_transform.image_width),
**config['model_config']
)
optimizer = optim.Adam(model.parameters(),
lr=config['train']['learning_rate'],
weight_decay=0.0001
)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
factor=0.1,
patience=2,
verbose=True)
# ---------- Datasets ---------- #
train_dataset = E2EDataset(
data_dir=config['datasets']['train_datasets'],
label_encoder=label_encoder if not train_denoiser else None,
make_2_row=config['datasets']['make_2_row']
)
test_dataset = E2EDataset(
data_dir=config['datasets']['test_datasets'],
label_encoder=label_encoder if not train_denoiser else None,
make_2_row=config['datasets']['make_2_row']
)
print(f"Train dataset size: {len(train_dataset)} | Test dataset size: {len(test_dataset)}")
# Initialize image_transform with dataset's mean & std values
mean_std = train_dataset.get_mean_std()
image_transform.init_normalizer(mean_std)
# ---------- DataLoaders ---------- #
_collate_fn = collate_fn(transform=image_transform,
downsample=config['datasets']['downsample'],
return_labels=train_denoiser is False)
train_dataloader = DataLoader(train_dataset,
shuffle=True,
batch_size=config['train']['batch_size'],
collate_fn=_collate_fn)
test_dataloader = DataLoader(test_dataset,
shuffle=False,
batch_size=config['train']['batch_size'],
collate_fn=_collate_fn)
# ---------- Trainer ---------- #
trainer = Trainer(run_name=args.run_name,
model=model,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
label_encoder=label_encoder if not train_denoiser else None,
image_transform=image_transform,
device='cuda',
training_denoiser=train_denoiser,
log_to_wandb=args.wandb,
)
trainer.train(config['train']['epochs'])
trainer.save("model.pth")