Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,16 @@ python run.py train
By default, this function computes the scores of the DHF1K and SALICON validation sets
and the Hollywood-2 and UCF Sports test sets after the training is finished.
The training data and scores are saved in the `training_runs` folder.
Alternatively, the training path can be overwritten with the environment variable `TRAIN_DIR`.
Alternatively, the training path can be overwritten with the environment variable `TRAIN_DIR`.


### Finetuning

To finetune the model with the MIT1003 dataset for the MIT300 benchmark
```bash
python run.py train_finetune_mit
```


### Scoring
Any trained model can be scored with:
Expand Down
26 changes: 26 additions & 0 deletions run.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
"""
UNISAL Training and Evaluation Scripts

WandB Integration:
To enable WandB logging, use the following parameters:
- use_wandb=True: Enable WandB logging
- wandb_project="your_project_name": Set WandB project name (default: "unisal")
- wandb_entity="your_entity": Set WandB entity/username (optional)

Examples:
# Regular training with WandB
python run.py train --use_wandb=True --wandb_project="unisal_experiment"

# Fine-tuning with WandB
python run.py train_finetune_mit --use_wandb=True --wandb_project="unisal_finetune"
"""

from pathlib import Path
import os

Expand All @@ -16,6 +33,15 @@ def train(eval_sources=('DHF1K', 'SALICON', 'UCFSports', 'Hollywood'),
trainer.export_scalars()
trainer.writer.close()

def train_finetune_mit(eval_sources=('MIT300',),
**kwargs):
"""Run training and evaluation."""
trainer = unisal.train.Trainer(**kwargs)
trainer.fine_tune_mit()
for source in eval_sources:
trainer.score_model(source=source)
trainer.export_scalars()
trainer.writer.close()

def load_trainer(train_id=None):
"""Instantiate Trainer class from saved kwargs."""
Expand Down
19 changes: 15 additions & 4 deletions unisal/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def get_map(self, img_nr):
return map

def get_img(self, img_nr):
img_file = self.dir / 'images' / (
img_file = self.dir / 'images' / self.phase_str / (
self.file_stem + self.file_nr.format(img_nr) + '.jpg')
img = cv2.imread(str(img_file))
assert(img is not None)
Expand Down Expand Up @@ -125,7 +125,7 @@ def dir(self):

def prepare_samples(self):
samples = []
for file in (self.dir / 'images').glob(self.file_stem + '*.jpg'):
for file in (self.dir / 'images' / self.phase_str).glob(self.file_stem + '*.jpg'):
samples.append(int(file.stem[-12:]))
return sorted(samples)

Expand Down Expand Up @@ -249,6 +249,9 @@ def __init__(self, phase='test'):
'rgb_std': (0.229, 0.224, 0.225),
}
self.samples, self.target_size_dict = self.load_data()
# For compatibility with video datasets
self.n_images_dict = {img_idx: 1 for img_idx in range(len(self.samples))}
self.n_samples = len(self.samples)

def load_data(self):
samples = []
Expand Down Expand Up @@ -363,6 +366,14 @@ def __init__(self, phase='train', subset=None, verbose=1,
self.samples = samples

self.all_image_files, self.size_dict = self.load_data()

# Adjust samples to match actual number of available images
actual_n_images = len(self.all_image_files)
if actual_n_images < self.n_train_val_images:
print(f"Warning: Expected {self.n_train_val_images} images but found {actual_n_images}")
# Filter samples to only include valid indices
self.samples = [s for s in self.samples if s < actual_n_images]

if self.subset is not None:
self.samples = self.samples[:int(len(self.samples) * subset)]
# For compatibility with video datasets
Expand Down Expand Up @@ -397,11 +408,11 @@ def dir(self):

@property
def fix_dir(self):
return self.dir / 'ALLFIXATIONMAPS' / 'ALLFIXATIONMAPS'
return self.dir / 'ALLFIXATIONMAPS'

@property
def img_dir(self):
return self.dir / 'ALLSTIMULI' / 'ALLSTIMULI'
return self.dir / 'ALLSTIMULI'

def get_out_size_eval(self, img_size):
ar = img_size[0] / img_size[1]
Expand Down
Loading