-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_model.py
More file actions
147 lines (101 loc) · 5.61 KB
/
train_model.py
File metadata and controls
147 lines (101 loc) · 5.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
from pytorch_lightning import Trainer
import models.CollateFn
from helpers.dataset_helpers import create_image_processor
from models import DeformableDetr
from coco_cached import CachedDETRDataset
from coco_detection import CocoDetection
from pytorch_lightning.loggers import TensorBoardLogger
import torch
from torch.utils.data import DataLoader
import helpers.const
from pytorch_lightning.callbacks import EarlyStopping
import torchvision.transforms.v2 as T
torch.set_float32_matmul_precision('high')
def check_batch_and_model(model, train_dataloader, train_dataset):
batch = next(iter(train_dataloader))
print(batch.keys)
pixel_values, target = train_dataset[0]
print(pixel_values.shape)
print(target)
outputs = model(pixel_values=batch['pixel_values'], pixel_mask=batch['pixel_mask'])
print(outputs.logits.shape)
# %%
def get_dataloader(caching, model,dataset_root, batch_size, small_dataset=False, train_transforms=None, val_transforms=None):
"""
:param caching: Set True if cached files should be used
:param model: If not None a test on the model is performed to make sure everything worked as expected
:return: DataLoaders for train and val, and the image_processor
"""
# image_processor = AutoImageProcessor.from_pretrained("SenseTime/deformable-detr", use_fast=True)
image_processor = create_image_processor()
if caching:
train_dataset = CachedDETRDataset(os.path.join(dataset_root, "cached_data", "train"))
val_dataset = CachedDETRDataset(os.path.join(dataset_root, "cached_data", "val"))
col_for_loader = models.CollateFn.collate_fn_cached
else:
train_dataset_name = "coco_train_small.json"if small_dataset else "coco_train.json"
val_dataset_name = "coco_val_small.json" if small_dataset else "coco_val.json"
train_dataset = CocoDetection(img_folder=os.path.join(dataset_root, "training", "extracted_frames", "train"), coco_json_filename=train_dataset_name,
processor=image_processor, augmentations=train_transforms, train=True)
val_dataset = CocoDetection(img_folder=os.path.join(dataset_root, 'training', 'extracted_frames', 'val'), processor=image_processor,
coco_json_filename=val_dataset_name, train=False)
col_for_loader = models.CollateFn.collate_fn
print("Number of training examples:", len(train_dataset))
print("Number of validation examples:", len(val_dataset))
train_dataloader = DataLoader(train_dataset, collate_fn=col_for_loader, batch_size=batch_size, shuffle=True,
num_workers=helpers.const.num_workers,
persistent_workers=True,
pin_memory=True )
val_dataloader = DataLoader(val_dataset, collate_fn=col_for_loader, batch_size=batch_size, num_workers=helpers.const.num_workers,
persistent_workers=True,
pin_memory=True )
if model is not None:
check_batch_and_model(model, train_dataloader, train_dataset),
return train_dataloader, val_dataloader, image_processor
def run_training(training_name, caching, small_dataset, early_stopping_patience, dataset_root):
"""
Run Training.
Logs can be accessed using TensorBoard
:param training_name:
:param caching:
:param small_dataset:
:param early_stopping_patience:
:return:
"""
logging_dir = "tb_logs"
early_stop_callback = EarlyStopping(
monitor='val_loss',
min_delta=1e-3,
patience=early_stopping_patience,
verbose=False,
mode='min'
)
train_transforms = T.Compose([
T.RandomHorizontalFlip(p=0.2),
T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.02),
T.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 1.5)),
])
model = DeformableDetr.DeformableDetr(lr=5e-5, lr_backbone=5e-6, weight_decay=1e-4, id2label=helpers.const.id2label)
#num_classes = 1
# num_classes_with_bg = num_classes + 1 # "+1" für "no object" Klasse
#for i in range(len(model.model.class_embed)):
# in_features = model.model.class_embed[i].in_features
# model.model.class_embed[i] = nn.Linear(in_features, num_classes_with_bg)
train_dataloader, val_dataloader, image_processor = get_dataloader(caching=caching, model=None,
small_dataset=small_dataset, dataset_root=dataset_root, batch_size=helpers.const.batch_size, train_transforms=train_transforms, val_transforms=None)
trainer = Trainer(max_epochs=helpers.const.max_epochs, callbacks=[early_stop_callback], precision="16-mixed",
accelerator="gpu", devices=helpers.const.nr_devices, strategy=helpers.const.strategy, logger=TensorBoardLogger(logging_dir, name=training_name or "default"), accumulate_grad_batches=4
)
print("Baseline memory consumption")
print(torch.cuda.memory_summary())
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)
print("Baseline memory consumption")
print(torch.cuda.memory_summary())
return trainer
if __name__ == "__main__":
import torch.multiprocessing
tmpdir = os.environ.get("TMPDIR", "") # Use the TMPDIR environment variable if set, otherwise use default
torch.multiprocessing.set_start_method("spawn", force=True)
trainer = run_training("small_smmall_boxes_small_image", False, helpers.const.use_small, helpers.const.patience, tmpdir )
trainer.save_checkpoint(helpers.const.save_checkpoint_name)