diff --git a/README.md b/README.md index 979711c..03dfe3a 100644 --- a/README.md +++ b/README.md @@ -105,7 +105,16 @@ python run.py train By default, this function computes the scores of the DHF1K and SALICON validation sets and the Hollywood-2 and UCF Sports test sets after the training is finished. The training data and scores are saved in the `training_runs` folder. -Alternatively, the training path can be overwritten with the environment variable `TRAIN_DIR`. +Alternatively, the training path can be overwritten with the environment variable `TRAIN_DIR`. + + +### Finetuning + +To finetune the model with the MIT1003 dataset for the MIT300 benchmark +```bash +python run.py train_finetune_mit +``` + ### Scoring Any trained model can be scored with: diff --git a/run.py b/run.py index 731f804..d3419e5 100644 --- a/run.py +++ b/run.py @@ -1,3 +1,20 @@ +""" +UNISAL Training and Evaluation Scripts + +WandB Integration: + To enable WandB logging, use the following parameters: + - use_wandb=True: Enable WandB logging + - wandb_project="your_project_name": Set WandB project name (default: "unisal") + - wandb_entity="your_entity": Set WandB entity/username (optional) + +Examples: + # Regular training with WandB + python run.py train --use_wandb=True --wandb_project="unisal_experiment" + + # Fine-tuning with WandB + python run.py train_finetune_mit --use_wandb=True --wandb_project="unisal_finetune" +""" + from pathlib import Path import os @@ -16,6 +33,15 @@ def train(eval_sources=('DHF1K', 'SALICON', 'UCFSports', 'Hollywood'), trainer.export_scalars() trainer.writer.close() +def train_finetune_mit(eval_sources=('MIT300',), + **kwargs): + """Run training and evaluation.""" + trainer = unisal.train.Trainer(**kwargs) + trainer.fine_tune_mit() + for source in eval_sources: + trainer.score_model(source=source) + trainer.export_scalars() + trainer.writer.close() def load_trainer(train_id=None): """Instantiate Trainer class from saved kwargs.""" diff --git a/unisal/data.py b/unisal/data.py index d4366f3..d881b53 100644 --- a/unisal/data.py +++ b/unisal/data.py @@ -88,7 +88,7 @@ def get_map(self, img_nr): return map def get_img(self, img_nr): - img_file = self.dir / 'images' / ( + img_file = self.dir / 'images' / self.phase_str / ( self.file_stem + self.file_nr.format(img_nr) + '.jpg') img = cv2.imread(str(img_file)) assert(img is not None) @@ -125,7 +125,7 @@ def dir(self): def prepare_samples(self): samples = [] - for file in (self.dir / 'images').glob(self.file_stem + '*.jpg'): + for file in (self.dir / 'images' / self.phase_str).glob(self.file_stem + '*.jpg'): samples.append(int(file.stem[-12:])) return sorted(samples) @@ -249,6 +249,9 @@ def __init__(self, phase='test'): 'rgb_std': (0.229, 0.224, 0.225), } self.samples, self.target_size_dict = self.load_data() + # For compatibility with video datasets + self.n_images_dict = {img_idx: 1 for img_idx in range(len(self.samples))} + self.n_samples = len(self.samples) def load_data(self): samples = [] @@ -363,6 +366,14 @@ def __init__(self, phase='train', subset=None, verbose=1, self.samples = samples self.all_image_files, self.size_dict = self.load_data() + + # Adjust samples to match actual number of available images + actual_n_images = len(self.all_image_files) + if actual_n_images < self.n_train_val_images: + print(f"Warning: Expected {self.n_train_val_images} images but found {actual_n_images}") + # Filter samples to only include valid indices + self.samples = [s for s in self.samples if s < actual_n_images] + if self.subset is not None: self.samples = self.samples[:int(len(self.samples) * subset)] # For compatibility with video datasets @@ -397,11 +408,11 @@ def dir(self): @property def fix_dir(self): - return self.dir / 'ALLFIXATIONMAPS' / 'ALLFIXATIONMAPS' + return self.dir / 'ALLFIXATIONMAPS' @property def img_dir(self): - return self.dir / 'ALLSTIMULI' / 'ALLSTIMULI' + return self.dir / 'ALLSTIMULI' def get_out_size_eval(self, img_size): ar = img_size[0] / img_size[1] diff --git a/unisal/train.py b/unisal/train.py index 506ec6a..ff49ccb 100644 --- a/unisal/train.py +++ b/unisal/train.py @@ -129,6 +129,9 @@ def __init__( tboard=True, debug=False, new_instance=True, + use_wandb=False, + wandb_project="unisal", + wandb_entity=None, ): # Save training parameters self.num_epochs = num_epochs @@ -173,9 +176,13 @@ def __init__( self.chkpnt_warmup = chkpnt_warmup self.chkpnt_epochs = chkpnt_epochs device = "cuda:0" if torch.cuda.is_available() else "cpu" + # device = "cuda:1" if torch.cuda.is_available() else "cpu" self.device = torch.device(device) self.tboard = tboard self.debug = debug + self.use_wandb = use_wandb + self.wandb_project = wandb_project + self.wandb_entity = wandb_entity if debug: self.num_workers = 0 @@ -225,6 +232,9 @@ def fit(self): """ Train the model """ + + # Initialize WandB logging + self.init_wandb() # Print information about the trainer class to the terminal # pprint.pprint(self.asdict(), width=1) @@ -246,6 +256,9 @@ def fit(self): # Save the training data (losses, etc.) self.export_scalars() + + # Finish WandB logging + self.finish_wandb() return self.best_val_score @@ -350,6 +363,16 @@ def fit_phase(self): self.add_scalar(f"{key}/loss/{self.phase}", phase_loss, self.epoch) for idx, loss_ in enumerate(phase_loss_summands): self.add_scalar(f"{key}/loss_{idx}/{self.phase}", loss_, self.epoch) + + # Log to WandB + wandb_metrics = { + f"{key}/loss/{self.phase}": phase_loss, + f"epoch": self.epoch, + f"lr": self.optimizer.param_groups[0]["lr"] if hasattr(self, '_optimizer') else None + } + for idx, loss_ in enumerate(phase_loss_summands): + wandb_metrics[f"{key}/loss_{idx}/{self.phase}"] = loss_ + self.log_wandb(wandb_metrics) if ( src == "DHF1K" @@ -506,7 +529,7 @@ def run_inference( pred_seq = torch.full(results_size, 0, dtype=torch.float) if metrics is not None: sal_seq = torch.full(results_size, 0, dtype=torch.float) - fix_seq = torch.full(results_size, 0, dtype=torch.uint8) + fix_seq = torch.full(results_size, 0, dtype=torch.bool) else: sal_seq, fix_seq = None, None @@ -649,7 +672,7 @@ def other_maps(): """Sample reference maps for s-AUC""" while True: this_map = np.zeros(results_size[-2:]) - video_nrs = random.sample(dataset.n_images_dict.keys(), n_aucs_maps) + video_nrs = random.sample(list(dataset.n_images_dict.keys()), n_aucs_maps) for map_idx, vid_nr in enumerate(video_nrs): frame_nr = random.randint(1, dataset.n_images_dict[vid_nr]) if static_data: @@ -659,7 +682,7 @@ def other_maps(): vid_nr, [frame_nr], "fix" ).numpy()[0, 0, ...] this_this_map = cv2.resize( - this_this_map, tuple(target_size[::-1]), cv2.INTER_NEAREST + this_this_map.astype(np.float32), tuple(target_size[::-1]), cv2.INTER_NEAREST ) this_map += this_this_map @@ -764,20 +787,39 @@ def score_model( """ if load_weights: - # Load the best weights, if available, otherwise the weights of - # the last epoch - try: - self.model.load_best_weights(self.train_dir) - print("Best weights loaded") - except FileNotFoundError: - print("No best weights found") - self.model.load_last_chkpnt(self.train_dir) - print("Last checkpoint loaded") + # Load weights based on training mode + if hasattr(self, 'mit1003_finetuned') and self.mit1003_finetuned: + # Load fine-tuned MIT1003 weights if available + try: + self.model.load_weights(self.train_dir, "ft_mit1003") + print("MIT1003 fine-tuned weights loaded") + except FileNotFoundError: + print("No fine-tuned weights found, loading best weights") + try: + self.model.load_best_weights(self.train_dir) + print("Best weights loaded") + except FileNotFoundError: + print("No best weights found") + self.model.load_last_chkpnt(self.train_dir) + print("Last checkpoint loaded") + else: + # Load the best weights for regular training + try: + self.model.load_best_weights(self.train_dir) + print("Best weights loaded") + except FileNotFoundError: + print("No best weights found") + self.model.load_last_chkpnt(self.train_dir) + print("Last checkpoint loaded") # Select the appropriate phase (see docstring) and get the dataset if phase is None: phase = "eval" if source in ("DHF1K", "SALICON", "MIT1003") else "test" dataset = self.get_dataset(phase, source) + + # MIT300 doesn't have ground truth labels, so disable metrics + if source == "MIT300": + metrics = None if vid_nr_array is None: # Get list of sample numbers @@ -802,38 +844,58 @@ def score_model( ) scores.append(this_scores) if vid_idx == 0: + if metrics is not None: + print( + f" Nr. ( .../{len(vid_nr_array):4d}), " + + ", ".join(f"{metric:5s}" for metric in metrics) + ) + else: + print(f" Nr. ( .../{len(vid_nr_array):4d}), Inference only") + + if metrics is not None: print( - f" Nr. ( .../{len(vid_nr_array):4d}), " - + ", ".join(f"{metric:5s}" for metric in metrics) + f"{vid_nr:6d} " + + f"({vid_idx + 1:4d}/{len(vid_nr_array):4d}), " + + ", ".join(f"{score:.3f}" for score in this_scores) ) - print( - f"{vid_nr:6d} " - + f"({vid_idx + 1:4d}/{len(vid_nr_array):4d}), " - + ", ".join(f"{score:.3f}" for score in this_scores) - ) + else: + print(f"{vid_nr:6d} ({vid_idx + 1:4d}/{len(vid_nr_array):4d}), Inference completed") # Compute the average video scores tmr.finish() scores = np.array(scores) - mean_scores = scores.mean(0) - - # In previous literature, scores were computed across all video frames, - # which means that each videos contribution to the overall score is - # weighted by its number of frames. The equivalent scores are denoted - # below as weighted mean - num_frames_array = [dataset.n_images_dict[vid_nr] for vid_nr in vid_nr_array] - weighted_mean_scores = np.average(scores, 0, num_frames_array) - - # Print and save the scores - print() - print("Macro average (average of video averages) scores:") - print(", ".join(f"{metric:5s}" for metric in metrics)) - print(", ".join(f"{score:.3f}" for score in mean_scores)) - print() - print("Weighted average (per-frame average) scores:") - print(", ".join(f"{metric:5s}" for metric in metrics)) - print(", ".join(f"{score:.3f}" for score in weighted_mean_scores)) - if subset == 1: + + if metrics is not None: + if len(scores) > 0: + mean_scores = scores.mean(0) + else: + mean_scores = np.array([np.nan] * len(metrics)) + + # In previous literature, scores were computed across all video frames, + # which means that each videos contribution to the overall score is + # weighted by its number of frames. The equivalent scores are denoted + # below as weighted mean + num_frames_array = [dataset.n_images_dict[vid_nr] for vid_nr in vid_nr_array] + if len(scores) > 0 and num_frames_array and sum(num_frames_array) > 0: + weighted_mean_scores = np.average(scores, 0, num_frames_array) + else: + weighted_mean_scores = np.array([np.nan] * len(metrics)) + + # Print and save the scores + print() + print("Macro average (average of video averages) scores:") + print(", ".join(f"{metric:5s}" for metric in metrics)) + print(", ".join(f"{score:.3f}" for score in mean_scores)) + print() + print("Weighted average (per-frame average) scores:") + print(", ".join(f"{metric:5s}" for metric in metrics)) + print(", ".join(f"{score:.3f}" for score in weighted_mean_scores)) + else: + print() + print("Inference completed without evaluation metrics (no ground truth available)") + mean_scores = None + weighted_mean_scores = None + if subset == 1 and metrics is not None: dest_dir = self.mit1003_dir if source == "MIT1003" else self.train_dir if source in ("Hollywood", "UCFSports"): source += "_resized" @@ -1017,16 +1079,37 @@ def fine_tune_mit( self.num_workers = 4 - # Load the best weights, if available, otherwise the weights of - # the last epoch + # Copy pretrained weights to current train directory for evaluation + pretrained_dir = Path(os.environ["TRAIN_DIR"]) / "pretrained_unisal" + pretrained_weights = pretrained_dir / "weights_best.pth" + current_weights = self.train_dir / "weights_best.pth" + + if pretrained_weights.exists() and not current_weights.exists(): + import shutil + shutil.copy2(pretrained_weights, current_weights) + print(f"Copied pretrained weights from {pretrained_weights} to {current_weights}") + + # Load the pretrained weights from pretrained_unisal directory try: - self.model.load_best_weights(self.train_dir) - print("Best weights loaded") + self.model.load_best_weights(pretrained_dir) + print("Pretrained best weights loaded") except FileNotFoundError: - print("No best weights found") - self.model.load_last_chkpnt(self.train_dir) + try: + self.model.load_last_chkpnt(pretrained_dir) + print("Pretrained last checkpoint loaded") + except FileNotFoundError: + print("No pretrained weights found, trying current train_dir") + try: + self.model.load_best_weights(self.train_dir) + print("Best weights loaded from current train_dir") + except FileNotFoundError: + print("No best weights found") + self.model.load_last_chkpnt(self.train_dir) # Run the fine tuning + # Initialize WandB logging for fine-tuning + self.init_wandb() + # pprint.pprint(self.asdict(), width=1) best_epoch = None best_val = None @@ -1047,14 +1130,24 @@ def fine_tune_mit( val_score = -val_loss if self.best_val_score is None: self.best_val_score = val_score + self.model.save_weights(self.train_dir, "ft_mit1003") + print(f"Initial MIT1003 fine-tuned weights saved at epoch {self.epoch}") + best_epoch = self.epoch + best_val = val_loss elif val_score > self.best_val_score: self.best_val_score = val_score + self.model.save_weights(self.train_dir, "ft_mit1003") + print(f"New best MIT1003 fine-tuned weights saved at epoch {self.epoch}") best_epoch = self.epoch best_val = val_loss self.epoch += 1 self.export_scalars() + + # Finish WandB logging + self.finish_wandb() + return best_val, best_epoch def get_dataset(self, phase, source="DHF1K"): @@ -1399,3 +1492,82 @@ def get_configs(self): @property def train_id(self): return "/".join(self.train_dir.parts[-2:]) + + def init_wandb(self): + """Initialize Weights & Biases logging""" + if not self.use_wandb: + return + + try: + import wandb + + # Create run config from training parameters + config = { + "num_epochs": self.num_epochs, + "lr": self.lr, + "batch_size": self.batch_size, + "optimizer": self.optim_algo, + "momentum": self.momentum, + "weight_decay": self.weight_decay, + "lr_scheduler": self.lr_scheduler, + "lr_gamma": self.lr_gamma, + "data_sources": self.data_sources, + "loss_metrics": self.loss_metrics, + "loss_weights": self.loss_weights, + "model_cfg": self.model_cfg, + } + + # Add fine-tuning specific config if applicable + if hasattr(self, 'mit1003_finetuned') and self.mit1003_finetuned: + config["fine_tuning"] = "MIT1003" + config["train_cnn_after"] = getattr(self, 'train_cnn_after', None) + + # Clean train_id for WandB (remove forbidden characters) + clean_id = self.train_id + forbidden_chars = ":;,#?/'" + for char in forbidden_chars: + clean_id = clean_id.replace(char, "_") + + # Initialize wandb + wandb.init( + project=self.wandb_project, + entity=self.wandb_entity, + config=config, + name=self.train_id, + resume="allow", + id=clean_id + ) + + print(f"WandB initialized for project: {self.wandb_project}") + + except ImportError: + print("WandB not installed. Install with: pip install wandb") + self.use_wandb = False + except Exception as e: + print(f"WandB initialization failed: {e}") + self.use_wandb = False + + def log_wandb(self, metrics_dict, step=None): + """Log metrics to WandB""" + if not self.use_wandb: + return + + try: + import wandb + if step is not None: + wandb.log(metrics_dict, step=step) + else: + wandb.log(metrics_dict) + except Exception as e: + print(f"WandB logging failed: {e}") + + def finish_wandb(self): + """Finish WandB run""" + if not self.use_wandb: + return + + try: + import wandb + wandb.finish() + except Exception as e: + print(f"WandB finish failed: {e}")