Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 50 additions & 25 deletions run.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
from pathlib import Path
import os

import fire

import os, fire
import unisal


def train(eval_sources=('DHF1K', 'SALICON', 'UCFSports', 'Hollywood'),
**kwargs):
def train(eval_sources=("DHF1K", "SALICON", "UCFSports", "Hollywood"), **kwargs):
"""Run training and evaluation."""
trainer = unisal.train.Trainer(**kwargs)
trainer.fit()
Expand All @@ -20,17 +16,17 @@ def train(eval_sources=('DHF1K', 'SALICON', 'UCFSports', 'Hollywood'),
def load_trainer(train_id=None):
"""Instantiate Trainer class from saved kwargs."""
if train_id is None:
train_id = 'pretrained_unisal'
train_id = "pretrained_unisal"
print(f"Train ID: {train_id}")
train_dir = Path(os.environ["TRAIN_DIR"])
train_dir = train_dir / train_id
print(f"initalizing trainer from {train_dir}...")
return unisal.train.Trainer.init_from_cfg_dir(train_dir)


def score_model(
train_id=None,
sources=('DHF1K', 'SALICON', 'UCFSports', 'Hollywood'),
**kwargs):
train_id=None, sources=("DHF1K", "SALICON", "UCFSports", "Hollywood"), **kwargs
):
"""Compute the scores for a trained model."""

trainer = load_trainer(train_id)
Expand All @@ -39,26 +35,27 @@ def score_model(


def generate_predictions(
train_id=None,
sources=('DHF1K', 'SALICON', 'UCFSports', 'Hollywood',
'MIT1003', 'MIT300'),
**kwargs):
train_id=None,
sources=("DHF1K", "SALICON", "UCFSports", "Hollywood", "MIT1003", "MIT300"),
**kwargs,
):
"""Generate predictions with a trained model."""

trainer = load_trainer(train_id)
for source in sources:

# Load fine-tuned weights for MIT datasets
if source in ('MIT1003', 'MIT300'):
if source in ("MIT1003", "MIT300"):
trainer.model.load_weights(trainer.train_dir, "ft_mit1003")
trainer.salicon_cfg['x_val_step'] = 0
kwargs.update({'model_domain': 'SALICON', 'load_weights': False})
trainer.salicon_cfg["x_val_step"] = 0
kwargs.update({"model_domain": "SALICON", "load_weights": False})

trainer.generate_predictions(source=source, **kwargs)


def predictions_from_folder(
folder_path, is_video, source=None, train_id=None, model_domain=None):
folder_path, is_video, source=None, train_id=None, model_domain=None
):
"""Generate predictions of files in a folder with a trained model."""

# Allows us to call this function directly from command-line
Expand All @@ -67,7 +64,8 @@ def predictions_from_folder(

trainer = load_trainer(train_id)
trainer.generate_predictions_from_path(
folder_path, is_video, source=source, model_domain=model_domain)
folder_path, is_video, source=source, model_domain=model_domain
)


def predict_examples(train_id=None):
Expand All @@ -76,21 +74,48 @@ def predict_examples(train_id=None):
continue

source = example_folder.name
is_video = source not in ('SALICON', 'MIT1003')
is_video = source not in ("SALICON", "MIT1003")

print(f"\nGenerating predictions for {'video' if is_video else 'image'} "
f"folder\n{str(source)}")
print(
f"\nGenerating predictions for {'video' if is_video else 'image'} "
f"folder\n{str(source)}"
)

if is_video:
if not example_folder.is_dir():
continue
for video_folder in example_folder.glob('[!.]*'): # ignore hidden files
for video_folder in example_folder.glob("[!.]*"): # ignore hidden files
predictions_from_folder(
video_folder, is_video, train_id=train_id, source=source)
video_folder, is_video, train_id=train_id, source=source
)

else:
predictions_from_folder(
example_folder, is_video, train_id=train_id, source=source)
example_folder, is_video, train_id=train_id, source=source
)


def predict_image(
image_path: str,
out_image_path: str = None,
):
"""a minimal working example of running inference on a single image"""
from PIL import Image
import numpy as np

assert os.path.isfile(
image_path
), f"provide image_path ({image_path}) is not a valid file"
im = Image.open(image_path)
smap = unisal.demo.predict_image(img_rgb=np.array(im))

if out_image_path:
output_dir = os.path.dirname(out_image_path)
assert os.path.isdir(
output_dir
), f"output directory {output_dir} does not exist"
Image.fromarray(smap).save(out_image_path)
return smap


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion unisal/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import train, data, model, models, utils
from . import train, data, model, models, utils, demo
Loading