-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate_model_performance.py
More file actions
39 lines (28 loc) · 1.88 KB
/
evaluate_model_performance.py
File metadata and controls
39 lines (28 loc) · 1.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import os
from transformers import DeformableDetrForObjectDetection
import helpers.const
from coco_detection import CocoDetection
from helpers.evaluation_functions import evaluate_result
from train_model import get_dataloader
from models import DeformableDetr
if __name__ == "__main__":
import torch.multiprocessing
tmpdir = os.environ.get("TMPDIR", "") # Use the TMPDIR environment variable if set, otherwise use default
torch.multiprocessing.set_start_method("spawn", force=True)
lm = DeformableDetr.DeformableDetr.load_from_checkpoint(checkpoint_path='prod/after_training_augtest.ckpt',
lr=1e-4, lr_backbone=1e-5, weight_decay=1e-4, id2label=helpers.const.id2label, assign=True)
lm.model.save_pretrained('finetuned_detr')
lm.model.config.save_pretrained('finetuned_detr')
model = DeformableDetrForObjectDetection.from_pretrained('finetuned_detr')
model.eval()
# Get Dataset
small_dataset = True
train_dataset_name = "coco_train_small.json" if small_dataset else "coco_train.json"
val_dataset_name = "coco_val_small.json" if small_dataset else "coco_val.json"
train_dataloader, _, image_processor = get_dataloader(caching=False, model=model,
small_dataset=small_dataset,dataset_root=tmpdir, batch_size=helpers.const.batch_size)
#os.path.join(dataset_root, "training", "extracted_frames", "train")
val_dataset = CocoDetection(img_folder=os.path.join(tmpdir, "training", "extracted_frames", "val"), processor=image_processor,
coco_json_filename=val_dataset_name, train=False)
train_dataset = CocoDetection(img_folder=os.path.join(tmpdir, "training", "extracted_frames", "train"), processor=image_processor, coco_json_filename=train_dataset_name)
evaluate_result(model, val_dataset, image_processor)