-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference.py
More file actions
40 lines (31 loc) · 1.42 KB
/
inference.py
File metadata and controls
40 lines (31 loc) · 1.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import pandas as pd
import torch
from tqdm import tqdm
from utils import LABEL_TO_STR, load_images, load_model_and_preprocessing
def apply_model(model_name: str, data_path: str):
"""Uses the model to infer on all images in the data_path.
Args:
model_name: Most recent with model_name or WandB id of the model.
data_path: Path to the folder containing the images.
Each image path should contain the label.
Returns:
model_id: The id of the model.
results_df: A dataframe containing the results.
"""
model_id, model, preprocessing = load_model_and_preprocessing(model_name)
model.eval()
return use_model(model_id, model, data_path, preprocessing=preprocessing)
def use_model(model_id, model: torch.nn.Module, data_path: str, preprocessing=None):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
columns = ["file"] + [LABEL_TO_STR[i] for i in range(len(LABEL_TO_STR))]
results_df = pd.DataFrame(columns=columns)
path_tensor_pairs = load_images([data_path])
for path, image in tqdm(path_tensor_pairs, desc=f"Inference of {data_path}"):
image = image.to(device)
if preprocessing:
image = preprocessing(image)
image = image.unsqueeze(0)
output = model(image).tolist()[0]
results_df.loc[len(results_df)] = [path] + output
return model_id, results_df