diff --git a/run.py b/run.py index 731f804..5ef8cf0 100644 --- a/run.py +++ b/run.py @@ -1,13 +1,9 @@ from pathlib import Path -import os - -import fire - +import os, fire import unisal -def train(eval_sources=('DHF1K', 'SALICON', 'UCFSports', 'Hollywood'), - **kwargs): +def train(eval_sources=("DHF1K", "SALICON", "UCFSports", "Hollywood"), **kwargs): """Run training and evaluation.""" trainer = unisal.train.Trainer(**kwargs) trainer.fit() @@ -20,17 +16,17 @@ def train(eval_sources=('DHF1K', 'SALICON', 'UCFSports', 'Hollywood'), def load_trainer(train_id=None): """Instantiate Trainer class from saved kwargs.""" if train_id is None: - train_id = 'pretrained_unisal' + train_id = "pretrained_unisal" print(f"Train ID: {train_id}") train_dir = Path(os.environ["TRAIN_DIR"]) train_dir = train_dir / train_id + print(f"initalizing trainer from {train_dir}...") return unisal.train.Trainer.init_from_cfg_dir(train_dir) def score_model( - train_id=None, - sources=('DHF1K', 'SALICON', 'UCFSports', 'Hollywood'), - **kwargs): + train_id=None, sources=("DHF1K", "SALICON", "UCFSports", "Hollywood"), **kwargs +): """Compute the scores for a trained model.""" trainer = load_trainer(train_id) @@ -39,26 +35,27 @@ def score_model( def generate_predictions( - train_id=None, - sources=('DHF1K', 'SALICON', 'UCFSports', 'Hollywood', - 'MIT1003', 'MIT300'), - **kwargs): + train_id=None, + sources=("DHF1K", "SALICON", "UCFSports", "Hollywood", "MIT1003", "MIT300"), + **kwargs, +): """Generate predictions with a trained model.""" trainer = load_trainer(train_id) for source in sources: # Load fine-tuned weights for MIT datasets - if source in ('MIT1003', 'MIT300'): + if source in ("MIT1003", "MIT300"): trainer.model.load_weights(trainer.train_dir, "ft_mit1003") - trainer.salicon_cfg['x_val_step'] = 0 - kwargs.update({'model_domain': 'SALICON', 'load_weights': False}) + trainer.salicon_cfg["x_val_step"] = 0 + kwargs.update({"model_domain": "SALICON", "load_weights": False}) trainer.generate_predictions(source=source, **kwargs) def predictions_from_folder( - folder_path, is_video, source=None, train_id=None, model_domain=None): + folder_path, is_video, source=None, train_id=None, model_domain=None +): """Generate predictions of files in a folder with a trained model.""" # Allows us to call this function directly from command-line @@ -67,7 +64,8 @@ def predictions_from_folder( trainer = load_trainer(train_id) trainer.generate_predictions_from_path( - folder_path, is_video, source=source, model_domain=model_domain) + folder_path, is_video, source=source, model_domain=model_domain + ) def predict_examples(train_id=None): @@ -76,21 +74,48 @@ def predict_examples(train_id=None): continue source = example_folder.name - is_video = source not in ('SALICON', 'MIT1003') + is_video = source not in ("SALICON", "MIT1003") - print(f"\nGenerating predictions for {'video' if is_video else 'image'} " - f"folder\n{str(source)}") + print( + f"\nGenerating predictions for {'video' if is_video else 'image'} " + f"folder\n{str(source)}" + ) if is_video: if not example_folder.is_dir(): continue - for video_folder in example_folder.glob('[!.]*'): # ignore hidden files + for video_folder in example_folder.glob("[!.]*"): # ignore hidden files predictions_from_folder( - video_folder, is_video, train_id=train_id, source=source) + video_folder, is_video, train_id=train_id, source=source + ) else: predictions_from_folder( - example_folder, is_video, train_id=train_id, source=source) + example_folder, is_video, train_id=train_id, source=source + ) + + +def predict_image( + image_path: str, + out_image_path: str = None, +): + """a minimal working example of running inference on a single image""" + from PIL import Image + import numpy as np + + assert os.path.isfile( + image_path + ), f"provide image_path ({image_path}) is not a valid file" + im = Image.open(image_path) + smap = unisal.demo.predict_image(img_rgb=np.array(im)) + + if out_image_path: + output_dir = os.path.dirname(out_image_path) + assert os.path.isdir( + output_dir + ), f"output directory {output_dir} does not exist" + Image.fromarray(smap).save(out_image_path) + return smap if __name__ == "__main__": diff --git a/unisal/__init__.py b/unisal/__init__.py index 45d1627..0bdf348 100644 --- a/unisal/__init__.py +++ b/unisal/__init__.py @@ -1 +1 @@ -from . import train, data, model, models, utils +from . import train, data, model, models, utils, demo diff --git a/unisal/data.py b/unisal/data.py index d4366f3..f91af36 100644 --- a/unisal/data.py +++ b/unisal/data.py @@ -1,4 +1,3 @@ - from pathlib import Path import os import random @@ -7,8 +6,13 @@ import copy import torch -from torch.utils.data import Dataset, DataLoader, BatchSampler, RandomSampler, \ - SequentialSampler +from torch.utils.data import ( + Dataset, + DataLoader, + BatchSampler, + RandomSampler, + SequentialSampler, +) from torchvision import transforms import numpy as np import cv2 @@ -25,8 +29,7 @@ if "SALICON_DATA_DIR" not in os.environ: os.environ["SALICON_DATA_DIR"] = str(default_data_dir / "SALICON") if "HOLLYWOOD_DATA_DIR" not in os.environ: - os.environ["HOLLYWOOD_DATA_DIR"] = str( - default_data_dir / "Hollywood2_actions") + os.environ["HOLLYWOOD_DATA_DIR"] = str(default_data_dir / "Hollywood2_actions") if "UCFSPORTS_DATA_DIR" not in os.environ: os.environ["UCFSPORTS_DATA_DIR"] = str(default_data_dir / "ucf-002") if "MIT300_DATA_DIR" not in os.environ: @@ -40,77 +43,96 @@ def get_dataset(): return DHF1KDataset -def get_dataloader(src='DHF1K'): - if src in ('MIT1003',): +def get_dataloader(src="DHF1K"): + if src in ("MIT1003",): return ImgSizeDataLoader return DataLoader class SALICONDataset(Dataset, utils.KwConfigClass): - source = 'SALICON' + source = "SALICON" dynamic = False - def __init__(self, phase='train', subset=None, verbose=1, - out_size=(288, 384), target_size=(480, 640), - preproc_cfg=None): + def __init__( + self, + phase="train", + subset=None, + verbose=1, + out_size=(288, 384), + target_size=(480, 640), + preproc_cfg=None, + ): self.phase = phase - self.train = phase == 'train' + self.train = phase == "train" self.subset = subset self.verbose = verbose self.out_size = out_size self.target_size = target_size self.preproc_cfg = { - 'rgb_mean': (0.485, 0.456, 0.406), - 'rgb_std': (0.229, 0.224, 0.225), + "rgb_mean": (0.485, 0.456, 0.406), + "rgb_std": (0.229, 0.224, 0.225), } if preproc_cfg is not None: self.preproc_cfg.update(preproc_cfg) - self.phase_str = 'val' if phase in ('valid', 'eval') else phase + self.phase_str = "val" if phase in ("valid", "eval") else phase self.file_stem = f"COCO_{self.phase_str}2014_" self.file_nr = "{:012d}" self.samples = self.prepare_samples() if self.subset is not None: - self.samples = self.samples[:int(len(self.samples) * subset)] + self.samples = self.samples[: int(len(self.samples) * subset)] # For compatibility with video datasets self.n_images_dict = {img_nr: 1 for img_nr in self.samples} - self.target_size_dict = { - img_nr: self.target_size for img_nr in self.samples} + self.target_size_dict = {img_nr: self.target_size for img_nr in self.samples} self.n_samples = len(self.samples) self.frame_modulo = 1 def get_map(self, img_nr): - map_file = self.dir / 'maps' / self.phase_str / ( - self.file_stem + self.file_nr.format(img_nr) + '.png') + map_file = ( + self.dir + / "maps" + / self.phase_str + / (self.file_stem + self.file_nr.format(img_nr) + ".png") + ) map = cv2.imread(str(map_file), cv2.IMREAD_GRAYSCALE) - assert(map is not None) + assert map is not None return map def get_img(self, img_nr): - img_file = self.dir / 'images' / ( - self.file_stem + self.file_nr.format(img_nr) + '.jpg') + img_file = ( + self.dir + / "images" + / (self.file_stem + self.file_nr.format(img_nr) + ".jpg") + ) img = cv2.imread(str(img_file)) - assert(img is not None) + assert img is not None return np.ascontiguousarray(img[:, :, ::-1]) def get_raw_fixations(self, img_nr): - raw_fix_file = self.dir / 'fixations' / self.phase_str / ( - self.file_stem + self.file_nr.format(img_nr) + '.mat') + raw_fix_file = ( + self.dir + / "fixations" + / self.phase_str + / (self.file_stem + self.file_nr.format(img_nr) + ".mat") + ) fix_data = scipy.io.loadmat(raw_fix_file) - fixations_array = [gaze[2] for gaze in fix_data['gaze'][:, 0]] - return fixations_array, fix_data['resolution'].tolist()[0] + fixations_array = [gaze[2] for gaze in fix_data["gaze"][:, 0]] + return fixations_array, fix_data["resolution"].tolist()[0] def process_raw_fixations(self, fixations_array, res): fix_map = np.zeros(res, dtype=np.uint8) for subject_fixations in fixations_array: - fix_map[subject_fixations[:, 1] - 1, subject_fixations[:, 0] - 1]\ - = 255 + fix_map[subject_fixations[:, 1] - 1, subject_fixations[:, 0] - 1] = 255 return fix_map def get_fixation_map(self, img_nr): - fix_map_file = self.dir / 'fixations' / self.phase_str / ( - self.file_stem + self.file_nr.format(img_nr) + '.png') + fix_map_file = ( + self.dir + / "fixations" + / self.phase_str + / (self.file_stem + self.file_nr.format(img_nr) + ".png") + ) if fix_map_file.exists(): fix_map = cv2.imread(str(fix_map_file), cv2.IMREAD_GRAYSCALE) else: @@ -125,30 +147,32 @@ def dir(self): def prepare_samples(self): samples = [] - for file in (self.dir / 'images').glob(self.file_stem + '*.jpg'): + for file in (self.dir / "images").glob(self.file_stem + "*.jpg"): samples.append(int(file.stem[-12:])) return sorted(samples) def __len__(self): return len(self.samples) - def preprocess(self, img, data='img'): + def preprocess(self, img, data="img"): transformations = [ transforms.ToPILImage(), ] - if data == 'img': - transformations.append(transforms.Resize( - self.out_size, interpolation=PIL.Image.LANCZOS)) + if data == "img": + transformations.append( + transforms.Resize(self.out_size, interpolation=PIL.Image.LANCZOS) + ) transformations.append(transforms.ToTensor()) - if data == 'img' and 'rgb_mean' in self.preproc_cfg: + if data == "img" and "rgb_mean" in self.preproc_cfg: transformations.append( transforms.Normalize( - self.preproc_cfg['rgb_mean'], self.preproc_cfg['rgb_std'])) - elif data == 'sal': + self.preproc_cfg["rgb_mean"], self.preproc_cfg["rgb_std"] + ) + ) + elif data == "sal": transformations.append(transforms.Lambda(utils.normalize_tensor)) - elif data == 'fix': - transformations.append( - transforms.Lambda(lambda fix: torch.gt(fix, 0.5))) + elif data == "fix": + transformations.append(transforms.Lambda(lambda fix: torch.gt(fix, 0.5))) processing = transforms.Compose(transformations) tensor = processing(img) @@ -156,14 +180,14 @@ def preprocess(self, img, data='img'): def get_data(self, img_nr): img = self.get_img(img_nr) - img = self.preprocess(img, data='img') - if self.phase == 'test': + img = self.preprocess(img, data="img") + if self.phase == "test": return [1], img, self.target_size sal = self.get_map(img_nr) - sal = self.preprocess(sal, data='sal') + sal = self.preprocess(sal, data="sal") fix = self.get_fixation_map(img_nr) - fix = self.preprocess(fix, data='fix') + fix = self.preprocess(fix, data="fix") return [1], img, sal, fix, self.target_size @@ -175,20 +199,20 @@ def __getitem__(self, item): class ImgSizeBatchSampler: def __init__(self, dataset, batch_size=1, shuffle=False, drop_last=False): - assert(isinstance(dataset, MIT1003Dataset)) + assert isinstance(dataset, MIT1003Dataset) self.batch_size = batch_size self.shuffle = shuffle self.drop_last = drop_last out_size_array = [ - dataset.size_dict[img_idx]['out_size'] - for img_idx in dataset.samples] + dataset.size_dict[img_idx]["out_size"] for img_idx in dataset.samples + ] self.out_size_set = sorted(list(set(out_size_array))) - self.sample_idx_dict = { - out_size: [] for out_size in self.out_size_set} + self.sample_idx_dict = {out_size: [] for out_size in self.out_size_set} for sample_idx, img_idx in enumerate(dataset.samples): - self.sample_idx_dict[dataset.size_dict[img_idx]['out_size']].append( - sample_idx) + self.sample_idx_dict[dataset.size_dict[img_idx]["out_size"]].append( + sample_idx + ) self.len = 0 self.n_batches_dict = {} @@ -198,9 +222,12 @@ def __init__(self, dataset, batch_size=1, shuffle=False, drop_last=False): self.n_batches_dict[out_size] = this_n_batches def __iter__(self): - batch_array = list(itertools.chain.from_iterable( - [out_size for _ in range(n_batches)] - for out_size, n_batches in self.n_batches_dict.items())) + batch_array = list( + itertools.chain.from_iterable( + [out_size for _ in range(n_batches)] + for out_size, n_batches in self.n_batches_dict.items() + ) + ) if not self.shuffle: random.seed(27) random.shuffle(batch_array) @@ -209,8 +236,8 @@ def __iter__(self): for sample_idx_array in this_sample_idx_dict.values(): random.shuffle(sample_idx_array) for out_size in batch_array: - this_indices = this_sample_idx_dict[out_size][:self.batch_size] - del this_sample_idx_dict[out_size][:self.batch_size] + this_indices = this_sample_idx_dict[out_size][: self.batch_size] + del this_sample_idx_dict[out_size][: self.batch_size] yield this_indices def __len__(self): @@ -219,8 +246,7 @@ def __len__(self): class ImgSizeDataLoader(DataLoader): - def __init__(self, dataset, batch_size=1, shuffle=False, drop_last=False, - **kwargs): + def __init__(self, dataset, batch_size=1, shuffle=False, drop_last=False, **kwargs): if batch_size == 1: if shuffle: sampler = RandomSampler(dataset) @@ -229,32 +255,34 @@ def __init__(self, dataset, batch_size=1, shuffle=False, drop_last=False, batch_sampler = BatchSampler(sampler, batch_size, drop_last) else: batch_sampler = ImgSizeBatchSampler( - dataset, batch_size=batch_size, shuffle=shuffle, - drop_last=drop_last) + dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last + ) super().__init__(dataset, batch_sampler=batch_sampler, **kwargs) class MIT300Dataset(Dataset, utils.KwConfigClass): - source = 'MIT300' + source = "MIT300" dynamic = False - def __init__(self, phase='test'): - assert(phase == 'test') + def __init__(self, phase="test"): + assert phase == "test" self.phase = phase self.train = False self.target_size = None self.preproc_cfg = { - 'rgb_mean': (0.485, 0.456, 0.406), - 'rgb_std': (0.229, 0.224, 0.225), + "rgb_mean": (0.485, 0.456, 0.406), + "rgb_std": (0.229, 0.224, 0.225), } self.samples, self.target_size_dict = self.load_data() def load_data(self): samples = [] target_size_dict = {} - file_list = list(self.dir.glob('*.jpg')) - file_list = sorted(file_list, key=lambda x: int(x.stem[1:min(4, len(x.stem))])) + file_list = list(self.dir.glob("*.jpg")) + file_list = sorted( + file_list, key=lambda x: int(x.stem[1 : min(4, len(x.stem))]) + ) for img_idx, file in enumerate(file_list): img = cv2.imread(str(file)) @@ -283,24 +311,25 @@ def load_data(self): @property def dir(self): - return Path(os.environ["MIT300_DATA_DIR"]) / 'BenchmarkIMAGES' + return Path(os.environ["MIT300_DATA_DIR"]) / "BenchmarkIMAGES" def __len__(self): return len(self.samples) - def preprocess(self, img, out_size, data='img'): - assert(data == 'img') + def preprocess(self, img, out_size, data="img"): + assert data == "img" transformations = [ transforms.ToPILImage(), - transforms.Resize( - out_size, interpolation=PIL.Image.LANCZOS), + transforms.Resize(out_size, interpolation=PIL.Image.LANCZOS), transforms.ToTensor(), ] - if 'rgb_mean' in self.preproc_cfg: + if "rgb_mean" in self.preproc_cfg: transformations.append( transforms.Normalize( - self.preproc_cfg['rgb_mean'], self.preproc_cfg['rgb_std'])) + self.preproc_cfg["rgb_mean"], self.preproc_cfg["rgb_std"] + ) + ) processing = transforms.Compose(transformations) tensor = processing(img) @@ -310,9 +339,9 @@ def get_data(self, item): img_name, out_size = self.samples[item] img_file = self.dir / img_name img = cv2.imread(str(img_file)) - assert(img is not None) + assert img is not None img = np.ascontiguousarray(img[:, :, ::-1]) - img = self.preprocess(img, out_size, data='img') + img = self.preprocess(img, out_size, data="img") return [1], img, self.target_size_dict[item] def __getitem__(self, item): @@ -321,19 +350,27 @@ def __getitem__(self, item): class MIT1003Dataset(Dataset, utils.KwConfigClass): - source = 'MIT1003' + source = "MIT1003" n_train_val_images = 1003 dynamic = False - def __init__(self, phase='train', subset=None, verbose=1, - preproc_cfg=None, n_x_val=10, x_val_step=0, x_val_seed=27): + def __init__( + self, + phase="train", + subset=None, + verbose=1, + preproc_cfg=None, + n_x_val=10, + x_val_step=0, + x_val_seed=27, + ): self.phase = phase - self.train = phase == 'train' + self.train = phase == "train" self.subset = subset self.verbose = verbose self.preproc_cfg = { - 'rgb_mean': (0.485, 0.456, 0.406), - 'rgb_std': (0.229, 0.224, 0.225), + "rgb_mean": (0.485, 0.456, 0.406), + "rgb_std": (0.229, 0.224, 0.225), } if preproc_cfg is not None: self.preproc_cfg.update(preproc_cfg) @@ -348,7 +385,7 @@ def __init__(self, phase='train', subset=None, verbose=1, self.samples = np.arange(0, n_images) else: print(f"X-Val step: {x_val_step}") - assert(self.x_val_step < self.n_x_val) + assert self.x_val_step < self.n_x_val samples = np.arange(0, n_images) if self.x_val_seed > 0: np.random.seed(self.x_val_seed) @@ -364,31 +401,31 @@ def __init__(self, phase='train', subset=None, verbose=1, self.all_image_files, self.size_dict = self.load_data() if self.subset is not None: - self.samples = self.samples[:int(len(self.samples) * subset)] + self.samples = self.samples[: int(len(self.samples) * subset)] # For compatibility with video datasets self.n_images_dict = {sample: 1 for sample in self.samples} self.target_size_dict = { - img_idx: self.size_dict[img_idx]['target_size'] - for img_idx in self.samples} + img_idx: self.size_dict[img_idx]["target_size"] for img_idx in self.samples + } self.n_samples = len(self.samples) self.frame_modulo = 1 def get_map(self, img_idx): - map_file = self.fix_dir / self.all_image_files[img_idx]['map'] + map_file = self.fix_dir / self.all_image_files[img_idx]["map"] map = cv2.imread(str(map_file), cv2.IMREAD_GRAYSCALE) - assert(map is not None) + assert map is not None return map def get_img(self, img_idx): - img_file = self.img_dir / self.all_image_files[img_idx]['img'] + img_file = self.img_dir / self.all_image_files[img_idx]["img"] img = cv2.imread(str(img_file)) - assert(img is not None) + assert img is not None return np.ascontiguousarray(img[:, :, ::-1]) def get_fixation_map(self, img_idx): - fix_map_file = self.fix_dir / self.all_image_files[img_idx]['pts'] + fix_map_file = self.fix_dir / self.all_image_files[img_idx]["pts"] fix_map = cv2.imread(str(fix_map_file), cv2.IMREAD_GRAYSCALE) - assert(fix_map is not None) + assert fix_map is not None return fix_map @property @@ -397,11 +434,11 @@ def dir(self): @property def fix_dir(self): - return self.dir / 'ALLFIXATIONMAPS' / 'ALLFIXATIONMAPS' + return self.dir / "ALLFIXATIONMAPS" / "ALLFIXATIONMAPS" @property def img_dir(self): - return self.dir / 'ALLSTIMULI' / 'ALLSTIMULI' + return self.dir / "ALLSTIMULI" / "ALLSTIMULI" def get_out_size_eval(self, img_size): ar = img_size[0] / img_size[1] @@ -445,69 +482,73 @@ def load_data(self): all_image_files = [] for img_file in sorted(self.img_dir.glob("*.jpeg")): - all_image_files.append({ - 'img': img_file.name, - 'map': img_file.stem + "_fixMap.jpg", - 'pts': img_file.stem + "_fixPts.jpg", - }) - assert((self.fix_dir / all_image_files[-1]['map']).exists()) - assert((self.fix_dir / all_image_files[-1]['pts']).exists()) + all_image_files.append( + { + "img": img_file.name, + "map": img_file.stem + "_fixMap.jpg", + "pts": img_file.stem + "_fixPts.jpg", + } + ) + assert (self.fix_dir / all_image_files[-1]["map"]).exists() + assert (self.fix_dir / all_image_files[-1]["pts"]).exists() size_dict_file = config_path / "img_size_dict.json" if size_dict_file.exists(): - with open(size_dict_file, 'r') as f: + with open(size_dict_file, "r") as f: size_dict = json.load(f) - size_dict = {int(img_idx): val for - img_idx, val in size_dict.items()} + size_dict = {int(img_idx): val for img_idx, val in size_dict.items()} else: size_dict = {} for img_idx in range(self.n_train_val_images): - img = cv2.imread( - str(self.img_dir / all_image_files[img_idx]['img'])) - size_dict[img_idx] = {'img_size': img.shape[:2]} - with open(size_dict_file, 'w') as f: + img = cv2.imread(str(self.img_dir / all_image_files[img_idx]["img"])) + size_dict[img_idx] = {"img_size": img.shape[:2]} + with open(size_dict_file, "w") as f: json.dump(size_dict, f) for img_idx in self.samples: - img_size = size_dict[img_idx]['img_size'] - if self.phase in ('train', 'valid'): + img_size = size_dict[img_idx]["img_size"] + if self.phase in ("train", "valid"): out_size = self.get_out_size_train(img_size) else: out_size = self.get_out_size_eval(img_size) - if self.phase in ('train', 'valid'): + if self.phase in ("train", "valid"): target_size = tuple(sz * 2 for sz in out_size) else: target_size = img_size - size_dict[img_idx].update({ - 'out_size': out_size, 'target_size': target_size}) + size_dict[img_idx].update( + {"out_size": out_size, "target_size": target_size} + ) return all_image_files, size_dict def __len__(self): return len(self.samples) - def preprocess(self, img, out_size=None, data='img'): + def preprocess(self, img, out_size=None, data="img"): transformations = [ transforms.ToPILImage(), ] - if data in ('img', 'sal'): - transformations.append(transforms.Resize( - out_size, interpolation=PIL.Image.LANCZOS)) + if data in ("img", "sal"): + transformations.append( + transforms.Resize(out_size, interpolation=PIL.Image.LANCZOS) + ) else: - transformations.append(transforms.Resize( - out_size, interpolation=PIL.Image.NEAREST)) + transformations.append( + transforms.Resize(out_size, interpolation=PIL.Image.NEAREST) + ) transformations.append(transforms.ToTensor()) - if data == 'img' and 'rgb_mean' in self.preproc_cfg: + if data == "img" and "rgb_mean" in self.preproc_cfg: transformations.append( transforms.Normalize( - self.preproc_cfg['rgb_mean'], self.preproc_cfg['rgb_std'])) - elif data == 'sal': + self.preproc_cfg["rgb_mean"], self.preproc_cfg["rgb_std"] + ) + ) + elif data == "sal": transformations.append(transforms.Lambda(utils.normalize_tensor)) - elif data == 'fix': - transformations.append( - transforms.Lambda(lambda fix: torch.gt(fix, 0.5))) + elif data == "fix": + transformations.append(transforms.Lambda(lambda fix: torch.gt(fix, 0.5))) processing = transforms.Compose(transformations) tensor = processing(img) @@ -515,16 +556,16 @@ def preprocess(self, img, out_size=None, data='img'): def get_data(self, img_idx): img = self.get_img(img_idx) - out_size = self.size_dict[img_idx]['out_size'] + out_size = self.size_dict[img_idx]["out_size"] target_size = self.target_size_dict[img_idx] - img = self.preprocess(img, out_size=out_size, data='img') - if self.phase == 'test': + img = self.preprocess(img, out_size=out_size, data="img") + if self.phase == "test": return [1], img, target_size sal = self.get_map(img_idx) - sal = self.preprocess(sal, target_size, data='sal') + sal = self.preprocess(sal, target_size, data="sal") fix = self.get_fixation_map(img_idx) - fix = self.preprocess(fix, target_size, data='fix') + fix = self.preprocess(fix, target_size, data="fix") return [1], img, sal, fix, target_size @@ -539,29 +580,42 @@ class DHF1KDataset(Dataset, utils.KwConfigClass): n_train_val_videos = 700 test_vid_nrs = (701, 1000) frame_rate = 30 - source = 'DHF1K' + source = "DHF1K" dynamic = True - def __init__(self, - seq_len=12, - frame_modulo=5, - max_seq_len=1e6, - preproc_cfg=None, - out_size=(224, 384), phase='train', target_size=(360, 640), - debug=False, val_size=100, n_x_val=3, x_val_step=2, - x_val_seed=0, seq_per_vid=1, subset=None, verbose=1, - n_images_file='dhf1k_n_images.dat', seq_per_vid_val=2, - sal_offset=None): + def __init__( + self, + seq_len=12, + frame_modulo=5, + max_seq_len=1e6, + preproc_cfg=None, + out_size=(224, 384), + phase="train", + target_size=(360, 640), + debug=False, + val_size=100, + n_x_val=3, + x_val_step=2, + x_val_seed=0, + seq_per_vid=1, + subset=None, + verbose=1, + n_images_file="dhf1k_n_images.dat", + seq_per_vid_val=2, + sal_offset=None, + ): self.phase = phase - self.train = phase == 'train' + self.train = phase == "train" if not self.train: preproc_cfg = {} elif preproc_cfg is None: preproc_cfg = {} - preproc_cfg.update({ - 'rgb_mean': (0.485, 0.456, 0.406), - 'rgb_std': (0.229, 0.224, 0.225), - }) + preproc_cfg.update( + { + "rgb_mean": (0.485, 0.456, 0.406), + "rgb_std": (0.229, 0.224, 0.225), + } + ) self.preproc_cfg = preproc_cfg self.out_size = out_size self.debug = debug @@ -585,46 +639,50 @@ def __init__(self, self.vid_nr_array = None # Evaluation - if phase in ('eval', 'test'): + if phase in ("eval", "test"): self.seq_len = int(1e6) - if self.phase in ('test',): - self.vid_nr_array = list(range( - self.test_vid_nrs[0], self.test_vid_nrs[1] + 1)) + if self.phase in ("test",): + self.vid_nr_array = list( + range(self.test_vid_nrs[0], self.test_vid_nrs[1] + 1) + ) self.samples, self.target_size_dict = self.prepare_samples() return # Cross-validation split n_videos = self.n_train_val_videos - assert(self.val_size <= n_videos // self.n_x_val) - assert(self.x_val_step < self.n_x_val) + assert self.val_size <= n_videos // self.n_x_val + assert self.x_val_step < self.n_x_val vid_nr_array = np.arange(1, n_videos + 1) if self.x_val_seed > 0: np.random.seed(self.x_val_seed) np.random.shuffle(vid_nr_array) - val_start = (len(vid_nr_array) - self.val_size) //\ - (self.n_x_val - 1) * self.x_val_step + val_start = ( + (len(vid_nr_array) - self.val_size) // (self.n_x_val - 1) * self.x_val_step + ) vid_nr_array = vid_nr_array.tolist() if not self.train: - self.vid_nr_array =\ - vid_nr_array[val_start:val_start + self.val_size] + self.vid_nr_array = vid_nr_array[val_start : val_start + self.val_size] else: - del vid_nr_array[val_start:val_start + self.val_size] + del vid_nr_array[val_start : val_start + self.val_size] self.vid_nr_array = vid_nr_array if self.subset is not None: - self.vid_nr_array =\ - self.vid_nr_array[:int(len(self.vid_nr_array) * self.subset)] + self.vid_nr_array = self.vid_nr_array[ + : int(len(self.vid_nr_array) * self.subset) + ] self.samples, self.target_size_dict = self.prepare_samples() @property def n_images_dict(self): if self._n_images_dict is None: - with open(config_path.parent / self.n_images_file, 'r') as f: + with open(config_path.parent / self.n_images_file, "r") as f: self._n_images_dict = { - idx + 1: int(line) for idx, line in enumerate(f) - if idx + 1 in self.vid_nr_array} + idx + 1: int(line) + for idx, line in enumerate(f) + if idx + 1 in self.vid_nr_array + } return self._n_images_dict @property @@ -645,9 +703,8 @@ def prepare_samples(self): too_short = 0 too_long = 0 for vid_nr, n_images in self.n_images_dict.items(): - if self.phase in ('eval', 'test'): - samples += [ - (vid_nr, offset + 1) for offset in range(self.frame_modulo)] + if self.phase in ("eval", "test"): + samples += [(vid_nr, offset + 1) for offset in range(self.frame_modulo)] continue if n_images < self.clip_len: too_short += 1 @@ -655,51 +712,59 @@ def prepare_samples(self): if n_images // self.frame_modulo > self.max_seq_len: too_long += 1 continue - if self.phase == 'train': + if self.phase == "train": samples += [(vid_nr, None)] * self.seq_per_vid continue - elif self.phase == 'valid': + elif self.phase == "valid": x = n_images // (self.seq_per_vid_val * 2) - self.clip_len // 2 start = max(1, x) end = min(n_images - self.clip_len, n_images - x) samples += [ - (vid_nr, int(start)) for start in - np.linspace(start, end, self.seq_per_vid_val)] + (vid_nr, int(start)) + for start in np.linspace(start, end, self.seq_per_vid_val) + ] continue - if self.phase not in ('eval', 'test') and self.n_images_dict: + if self.phase not in ("eval", "test") and self.n_images_dict: n_loaded = len(self.n_images_dict) - too_short - too_long - print(f"{n_loaded} videos loaded " - f"({n_loaded / len(self.n_images_dict) * 100:.1f}%)") - print(f"{too_short} videos are too short " - f"({too_short / len(self.n_images_dict) * 100:.1f}%)") - print(f"{too_long} videos are too long " - f"({too_long / len(self.n_images_dict) * 100:.1f}%)") + print( + f"{n_loaded} videos loaded " + f"({n_loaded / len(self.n_images_dict) * 100:.1f}%)" + ) + print( + f"{too_short} videos are too short " + f"({too_short / len(self.n_images_dict) * 100:.1f}%)" + ) + print( + f"{too_long} videos are too long " + f"({too_long / len(self.n_images_dict) * 100:.1f}%)" + ) target_size_dict = { - vid_nr: self.target_size for vid_nr in self.n_images_dict.keys()} + vid_nr: self.target_size for vid_nr in self.n_images_dict.keys() + } return samples, target_size_dict def get_frame_nrs(self, vid_nr, start): n_images = self.n_images_dict[vid_nr] - if self.phase in ('eval', 'test'): + if self.phase in ("eval", "test"): return list(range(start, n_images + 1, self.frame_modulo)) return list(range(start, start + self.clip_len, self.frame_modulo)) def get_annotation_dir(self, vid_nr): - return self.dir / 'annotation' / f'{vid_nr:04d}' + return self.dir / "annotation" / f"{vid_nr:04d}" def get_data_file(self, vid_nr, f_nr, dkey): - if dkey == 'frame': - folder = 'images' - elif dkey == 'sal': - folder = 'maps' - elif dkey == 'fix': - folder = 'fixation' + if dkey == "frame": + folder = "images" + elif dkey == "sal": + folder = "maps" + elif dkey == "fix": + folder = "fixation" else: - raise ValueError(f'Unknown data key {dkey}') - return self.get_annotation_dir(vid_nr) / folder / f'{f_nr:04d}.png' + raise ValueError(f"Unknown data key {dkey}") + return self.get_annotation_dir(vid_nr) / folder / f"{f_nr:04d}.png" def load_data(self, vid_nr, f_nr, dkey): - read_flag = None if dkey == 'frame' else cv2.IMREAD_GRAYSCALE + read_flag = None if dkey == "frame" else cv2.IMREAD_GRAYSCALE data_file = self.get_data_file(vid_nr, f_nr, dkey) if read_flag is not None: data = cv2.imread(str(data_file), read_flag) @@ -707,10 +772,10 @@ def load_data(self, vid_nr, f_nr, dkey): data = cv2.imread(str(data_file)) if data is None: raise FileNotFoundError(data_file) - if dkey == 'frame': + if dkey == "frame": data = np.ascontiguousarray(data[:, :, ::-1]) - if dkey == 'sal' and self.train and self.sal_offset is not None: + if dkey == "sal" and self.train and self.sal_offset is not None: data += self.sal_offset data[0, 0] = 0 @@ -718,20 +783,22 @@ def load_data(self, vid_nr, f_nr, dkey): def preprocess_sequence(self, frame_seq, dkey, vid_nr): transformations = [] - if dkey == 'frame': + if dkey == "frame": transformations.append(transforms.ToPILImage()) - transformations.append(transforms.Resize( - self.out_size, interpolation=PIL.Image.LANCZOS)) + transformations.append( + transforms.Resize(self.out_size, interpolation=PIL.Image.LANCZOS) + ) transformations.append(transforms.ToTensor()) - if dkey == 'frame' and 'rgb_mean' in self.preproc_cfg: + if dkey == "frame" and "rgb_mean" in self.preproc_cfg: transformations.append( transforms.Normalize( - self.preproc_cfg['rgb_mean'], self.preproc_cfg['rgb_std'])) - elif dkey == 'sal': + self.preproc_cfg["rgb_mean"], self.preproc_cfg["rgb_std"] + ) + ) + elif dkey == "sal": transformations.append(transforms.Lambda(utils.normalize_tensor)) - elif dkey == 'fix': - transformations.append( - transforms.Lambda(lambda fix: torch.gt(fix, 0.5))) + elif dkey == "fix": + transformations.append(transforms.Lambda(lambda fix: torch.gt(fix, 0.5))) processing = transforms.Compose(transformations) @@ -751,12 +818,12 @@ def get_data(self, vid_nr, start): else: start = np.random.randint(1, max_start) frame_nrs = self.get_frame_nrs(vid_nr, start) - frame_seq = self.get_seq(vid_nr, frame_nrs, 'frame') + frame_seq = self.get_seq(vid_nr, frame_nrs, "frame") target_size = self.target_size_dict[vid_nr] - if self.phase == 'test' and self.source in ('DHF1K',): + if self.phase == "test" and self.source in ("DHF1K",): return frame_nrs, frame_seq, target_size - sal_seq = self.get_seq(vid_nr, frame_nrs, 'sal') - fix_seq = self.get_seq(vid_nr, frame_nrs, 'fix') + sal_seq = self.get_seq(vid_nr, frame_nrs, "sal") + fix_seq = self.get_seq(vid_nr, frame_nrs, "fix") return frame_nrs, frame_seq, sal_seq, fix_seq, target_size def __getitem__(self, item): @@ -767,115 +834,124 @@ def __getitem__(self, item): class HollywoodDataset(DHF1KDataset): - source = 'Hollywood' + source = "Hollywood" dynamic = True img_channels = 1 - n_videos = { - 'train': 747, - 'test': 884 - } + n_videos = {"train": 747, "test": 884} test_vid_nrs = (1, 884) frame_rate = 24 - def __init__(self, out_size=(224, 416), val_size=75, n_images_file=None, - seq_per_vid_val=1, register_file='hollywood_register.json', - phase='train', - frame_modulo=4, - seq_len=12, - **kwargs): + def __init__( + self, + out_size=(224, 416), + val_size=75, + n_images_file=None, + seq_per_vid_val=1, + register_file="hollywood_register.json", + phase="train", + frame_modulo=4, + seq_len=12, + **kwargs, + ): self.register = None - self.phase_str = 'test' if phase in ('eval', 'test') else 'train' + self.phase_str = "test" if phase in ("eval", "test") else "train" self.register_file = self.phase_str + "_" + register_file - super().__init__(out_size=out_size, val_size=val_size, - n_images_file=n_images_file, - seq_per_vid_val=seq_per_vid_val, - x_val_seed=42, phase=phase, target_size=out_size, - frame_modulo=frame_modulo, - seq_len=seq_len, - **kwargs) - if phase in ('eval', 'test'): - self.target_size_dict = self.get_register()['vid_size_dict'] + super().__init__( + out_size=out_size, + val_size=val_size, + n_images_file=n_images_file, + seq_per_vid_val=seq_per_vid_val, + x_val_seed=42, + phase=phase, + target_size=out_size, + frame_modulo=frame_modulo, + seq_len=seq_len, + **kwargs, + ) + if phase in ("eval", "test"): + self.target_size_dict = self.get_register()["vid_size_dict"] @property def n_images_dict(self): if self._n_images_dict is None: - self._n_images_dict = self.get_register()['n_images_dict'] - self._n_images_dict = {vid_nr: ni for vid_nr, ni - in self._n_images_dict.items() - if vid_nr // 100 in self.vid_nr_array} + self._n_images_dict = self.get_register()["n_images_dict"] + self._n_images_dict = { + vid_nr: ni + for vid_nr, ni in self._n_images_dict.items() + if vid_nr // 100 in self.vid_nr_array + } return self._n_images_dict def get_register(self): if self.register is None: register_file = config_path / self.register_file if register_file.exists(): - with open(config_path / register_file, 'r') as f: + with open(config_path / register_file, "r") as f: self.register = json.load(f) - for reg_key in ('n_images_dict', 'start_image_dict', - 'vid_size_dict'): + for reg_key in ("n_images_dict", "start_image_dict", "vid_size_dict"): self.register[reg_key] = { - int(key): val for key, val in - self.register[reg_key].items()} + int(key): val for key, val in self.register[reg_key].items() + } else: self.register = self.generate_register() - with open(config_path / register_file, 'w') as f: + with open(config_path / register_file, "w") as f: json.dump(self.register, f, indent=2) return self.register def generate_register(self): - n_shots = { - vid_nr: 0 for vid_nr in range(1, self.n_videos[self.phase_str] + 1)} + n_shots = {vid_nr: 0 for vid_nr in range(1, self.n_videos[self.phase_str] + 1)} n_images_dict = {} start_image_dict = {} vid_size_dict = {} - for folder in sorted(self.dir.glob('actionclip*')): + for folder in sorted(self.dir.glob("actionclip*")): name = folder.stem vid_nr_start = 10 + len(self.phase_str) - vid_nr = int(name[vid_nr_start:vid_nr_start + 5]) + vid_nr = int(name[vid_nr_start : vid_nr_start + 5]) shot_nr = int(name[-2:].replace("_", "")) n_shots[vid_nr] += 1 vid_nr_shot_nr = 100 * vid_nr + shot_nr - image_files = sorted((folder / 'images').glob('actionclip*.png')) + image_files = sorted((folder / "images").glob("actionclip*.png")) n_images_dict[vid_nr_shot_nr] = len(image_files) start_image_dict[vid_nr_shot_nr] = int(image_files[0].stem[-5:]) img = cv2.imread(str(image_files[0])) vid_size_dict[vid_nr_shot_nr] = tuple(img.shape[:2]) return dict( - n_shots=n_shots, n_images_dict=n_images_dict, - start_image_dict=start_image_dict, vid_size_dict=vid_size_dict) + n_shots=n_shots, + n_images_dict=n_images_dict, + start_image_dict=start_image_dict, + vid_size_dict=vid_size_dict, + ) def preprocess_sequence(self, frame_seq, dkey, vid_nr): - transformations = [ - transforms.ToPILImage() - ] + transformations = [transforms.ToPILImage()] - vid_size = self.register['vid_size_dict'][vid_nr] + vid_size = self.register["vid_size_dict"][vid_nr] if vid_size[0] != self.out_size[0]: - interpolation = PIL.Image.LANCZOS if dkey in ('frame', 'sal')\ - else PIL.Image.NEAREST - size = (self.out_size[0], - int(vid_size[1] * self.out_size[0] / vid_size[0])) - transformations.append( - transforms.Resize(size, interpolation=interpolation)) + interpolation = ( + PIL.Image.LANCZOS if dkey in ("frame", "sal") else PIL.Image.NEAREST + ) + size = (self.out_size[0], int(vid_size[1] * self.out_size[0] / vid_size[0])) + transformations.append(transforms.Resize(size, interpolation=interpolation)) transformations += [ transforms.CenterCrop(self.out_size), transforms.ToTensor(), ] - if dkey == 'frame' and 'rgb_mean' in self.preproc_cfg: + if dkey == "frame" and "rgb_mean" in self.preproc_cfg: transformations.append( transforms.Normalize( - self.preproc_cfg['rgb_mean'], self.preproc_cfg['rgb_std'])) - elif dkey == 'sal': + self.preproc_cfg["rgb_mean"], self.preproc_cfg["rgb_std"] + ) + ) + elif dkey == "sal": transformations.append(transforms.Lambda(utils.normalize_tensor)) - elif dkey == 'fix': - transformations.append( - transforms.Lambda(lambda fix: torch.gt(fix, 0.5))) + elif dkey == "fix": + transformations.append(transforms.Lambda(lambda fix: torch.gt(fix, 0.5))) processing = transforms.Compose(transformations) @@ -886,22 +962,23 @@ def preprocess_sequence(self, frame_seq, dkey, vid_nr): def preprocess_sequence_eval(self, frame_seq, dkey, vid_nr): transformations = [] - if dkey == 'frame': + if dkey == "frame": transformations.append(transforms.ToPILImage()) transformations.append( - transforms.Resize( - self.out_size, interpolation=PIL.Image.LANCZOS)) + transforms.Resize(self.out_size, interpolation=PIL.Image.LANCZOS) + ) transformations.append(transforms.ToTensor()) - if dkey == 'frame' and 'rgb_mean' in self.preproc_cfg: + if dkey == "frame" and "rgb_mean" in self.preproc_cfg: transformations.append( transforms.Normalize( - self.preproc_cfg['rgb_mean'], self.preproc_cfg['rgb_std'])) - elif dkey == 'sal': + self.preproc_cfg["rgb_mean"], self.preproc_cfg["rgb_std"] + ) + ) + elif dkey == "sal": transformations.append(transforms.Lambda(utils.normalize_tensor)) - elif dkey == 'fix': - transformations.append( - transforms.Lambda(lambda fix: torch.gt(fix, 0.5))) + elif dkey == "fix": + transformations.append(transforms.Lambda(lambda fix: torch.gt(fix, 0.5))) processing = transforms.Compose(transformations) @@ -915,37 +992,43 @@ def get_annotation_dir(self, vid_nr_shot_nr): return self.dir / f"actionclip{self.phase_str}{vid_nr:05d}_{shot_nr:1d}" def get_data_file(self, vid_nr_shot_nr, f_nr, dkey): - if dkey == 'frame': - folder = 'images' - elif dkey == 'sal': - folder = 'maps' - elif dkey == 'fix': - folder = 'fixation' + if dkey == "frame": + folder = "images" + elif dkey == "sal": + folder = "maps" + elif dkey == "fix": + folder = "fixation" else: - raise ValueError(f'Unknown data key {dkey}') + raise ValueError(f"Unknown data key {dkey}") vid_nr = vid_nr_shot_nr // 100 - f_nr += self.register['start_image_dict'][vid_nr_shot_nr] - 1 - return self.get_annotation_dir(vid_nr_shot_nr) / folder /\ - f'actionclip{self.phase_str}{vid_nr:05d}_{f_nr:05d}.png' + f_nr += self.register["start_image_dict"][vid_nr_shot_nr] - 1 + return ( + self.get_annotation_dir(vid_nr_shot_nr) + / folder + / f"actionclip{self.phase_str}{vid_nr:05d}_{f_nr:05d}.png" + ) def get_seq(self, vid_nr, frame_nrs, dkey): data_seq = [self.load_data(vid_nr, f_nr, dkey) for f_nr in frame_nrs] - preproc_fun = self.preprocess_sequence if self.phase \ - in ('train', 'valid') else self.preprocess_sequence_eval + preproc_fun = ( + self.preprocess_sequence + if self.phase in ("train", "valid") + else self.preprocess_sequence_eval + ) return preproc_fun(data_seq, dkey, vid_nr) @property def dir(self): if self._dir is None: - self._dir = Path(os.environ["HOLLYWOOD_DATA_DIR"]) /\ - ('training' if self.phase in ('train', 'valid') - else 'testing') + self._dir = Path(os.environ["HOLLYWOOD_DATA_DIR"]) / ( + "training" if self.phase in ("train", "valid") else "testing" + ) return self._dir class UCFSportsDataset(DHF1KDataset): - source = 'UCFSports' + source = "UCFSports" dynamic = True img_channels = 1 @@ -953,48 +1036,60 @@ class UCFSportsDataset(DHF1KDataset): test_vid_nrs = (1, 47) frame_rate = 24 - def __init__(self, out_size=(256, 384), val_size=10, n_images_file=None, - seq_per_vid_val=1, register_file='ucfsports_register.json', - phase='train', - frame_modulo=4, - seq_len=12, - **kwargs): - self.phase_str = 'test' if phase in ('eval', 'test') else 'train' + def __init__( + self, + out_size=(256, 384), + val_size=10, + n_images_file=None, + seq_per_vid_val=1, + register_file="ucfsports_register.json", + phase="train", + frame_modulo=4, + seq_len=12, + **kwargs, + ): + self.phase_str = "test" if phase in ("eval", "test") else "train" self.register_file = self.phase_str + "_" + register_file self.register = None - super().__init__(out_size=out_size, val_size=val_size, - n_images_file=n_images_file, - seq_per_vid_val=seq_per_vid_val, - x_val_seed=27, target_size=out_size, - frame_modulo=frame_modulo, phase=phase, - seq_len=seq_len, - **kwargs) - if phase in ('eval', 'test'): - self.target_size_dict = self.get_register()['vid_size_dict'] + super().__init__( + out_size=out_size, + val_size=val_size, + n_images_file=n_images_file, + seq_per_vid_val=seq_per_vid_val, + x_val_seed=27, + target_size=out_size, + frame_modulo=frame_modulo, + phase=phase, + seq_len=seq_len, + **kwargs, + ) + if phase in ("eval", "test"): + self.target_size_dict = self.get_register()["vid_size_dict"] @property def n_images_dict(self): if self._n_images_dict is None: - self._n_images_dict = self.get_register()['n_images_dict'] - self._n_images_dict = {vid_nr: ni for vid_nr, ni - in self._n_images_dict.items() - if vid_nr in self.vid_nr_array} + self._n_images_dict = self.get_register()["n_images_dict"] + self._n_images_dict = { + vid_nr: ni + for vid_nr, ni in self._n_images_dict.items() + if vid_nr in self.vid_nr_array + } return self._n_images_dict def get_register(self): if self.register is None: register_file = config_path / self.register_file if register_file.exists(): - with open(config_path / register_file, 'r') as f: + with open(config_path / register_file, "r") as f: self.register = json.load(f) - for reg_key in ('n_images_dict', 'vid_name_dict', - 'vid_size_dict'): + for reg_key in ("n_images_dict", "vid_name_dict", "vid_size_dict"): self.register[reg_key] = { - int(key): val for key, val in - self.register[reg_key].items()} + int(key): val for key, val in self.register[reg_key].items() + } else: self.register = self.generate_register() - with open(config_path / register_file, 'w') as f: + with open(config_path / register_file, "w") as f: json.dump(self.register, f, indent=2) return self.register @@ -1003,49 +1098,50 @@ def generate_register(self): vid_name_dict = {} vid_size_dict = {} - for vid_idx, folder in enumerate(sorted(self.dir.glob('*-*'))): + for vid_idx, folder in enumerate(sorted(self.dir.glob("*-*"))): vid_nr = vid_idx + 1 vid_name_dict[vid_nr] = folder.stem - image_files = list((folder / 'images').glob('*.png')) + image_files = list((folder / "images").glob("*.png")) n_images_dict[vid_nr] = len(image_files) img = cv2.imread(str(image_files[0])) vid_size_dict[vid_nr] = tuple(img.shape[:2]) return dict( - vid_name_dict=vid_name_dict, n_images_dict=n_images_dict, - vid_size_dict=vid_size_dict) + vid_name_dict=vid_name_dict, + n_images_dict=n_images_dict, + vid_size_dict=vid_size_dict, + ) def preprocess_sequence(self, frame_seq, dkey, vid_nr): - transformations = [ - transforms.ToPILImage() - ] + transformations = [transforms.ToPILImage()] - vid_size = self.register['vid_size_dict'][vid_nr] - interpolation = PIL.Image.LANCZOS if dkey in ('frame', 'sal')\ - else PIL.Image.NEAREST + vid_size = self.register["vid_size_dict"][vid_nr] + interpolation = ( + PIL.Image.LANCZOS if dkey in ("frame", "sal") else PIL.Image.NEAREST + ) out_size_ratio = self.out_size[1] / self.out_size[0] this_size_ratio = vid_size[1] / vid_size[0] if this_size_ratio < out_size_ratio: size = (int(self.out_size[1] / this_size_ratio), self.out_size[1]) else: size = (self.out_size[0], int(self.out_size[0] * this_size_ratio)) - transformations.append( - transforms.Resize(size, interpolation=interpolation)) + transformations.append(transforms.Resize(size, interpolation=interpolation)) transformations += [ transforms.CenterCrop(self.out_size), transforms.ToTensor(), ] - if dkey == 'frame' and 'rgb_mean' in self.preproc_cfg: + if dkey == "frame" and "rgb_mean" in self.preproc_cfg: transformations.append( transforms.Normalize( - self.preproc_cfg['rgb_mean'], self.preproc_cfg['rgb_std'])) - elif dkey == 'sal': + self.preproc_cfg["rgb_mean"], self.preproc_cfg["rgb_std"] + ) + ) + elif dkey == "sal": transformations.append(transforms.Lambda(utils.normalize_tensor)) - elif dkey == 'fix': - transformations.append( - transforms.Lambda(lambda fix: torch.gt(fix, 0.5))) + elif dkey == "fix": + transformations.append(transforms.Lambda(lambda fix: torch.gt(fix, 0.5))) processing = transforms.Compose(transformations) @@ -1056,30 +1152,33 @@ def preprocess_sequence(self, frame_seq, dkey, vid_nr): preprocess_sequence_eval = HollywoodDataset.preprocess_sequence_eval def get_annotation_dir(self, vid_nr): - vid_name = self.register['vid_name_dict'][vid_nr] + vid_name = self.register["vid_name_dict"][vid_nr] return self.dir / vid_name def get_data_file(self, vid_nr, f_nr, dkey): - if dkey == 'frame': - folder = 'images' - elif dkey == 'sal': - folder = 'maps' - elif dkey == 'fix': - folder = 'fixation' + if dkey == "frame": + folder = "images" + elif dkey == "sal": + folder = "maps" + elif dkey == "fix": + folder = "fixation" else: - raise ValueError(f'Unknown data key {dkey}') - vid_name = self.register['vid_name_dict'][vid_nr] - return self.get_annotation_dir(vid_nr) / folder /\ - f"{vid_name[:-4]}_{vid_name[-3:]}_{f_nr:03d}.png" + raise ValueError(f"Unknown data key {dkey}") + vid_name = self.register["vid_name_dict"][vid_nr] + return ( + self.get_annotation_dir(vid_nr) + / folder + / f"{vid_name[:-4]}_{vid_name[-3:]}_{f_nr:03d}.png" + ) get_seq = HollywoodDataset.get_seq @property def dir(self): if self._dir is None: - self._dir = Path(os.environ["UCFSPORTS_DATA_DIR"]) /\ - ('training' if self.phase in ('train', 'valid') - else 'testing') + self._dir = Path(os.environ["UCFSPORTS_DATA_DIR"]) / ( + "training" if self.phase in ("train", "valid") else "testing" + ) return self._dir @@ -1109,13 +1208,14 @@ def __init__(self, images_path, frame_modulo=None, source=None): self.images_path = images_path self.frame_modulo = frame_modulo or 5 self.preproc_cfg = { - 'rgb_mean': (0.485, 0.456, 0.406), - 'rgb_std': (0.229, 0.224, 0.225), + "rgb_mean": (0.485, 0.456, 0.406), + "rgb_std": (0.229, 0.224, 0.225), } frame_files = sorted(list(images_path.glob("*"))) - frame_files = [file for file in frame_files - if file.suffix in ('.png', '.jpg', '.jpeg')] + frame_files = [ + file for file in frame_files if file.suffix in (".png", ".jpg", ".jpeg") + ] self.frame_files = frame_files self.vid_nr_array = [0] self.n_images_dict = {0: len(frame_files)} @@ -1124,13 +1224,13 @@ def __init__(self, images_path, frame_modulo=None, source=None): img_size = tuple(img.shape[:2]) self.target_size_dict = {0: img_size} - if source == 'DHF1K' and img_size == (360, 640): + if source == "DHF1K" and img_size == (360, 640): self.out_size = (224, 384) - elif source == 'Hollywood': + elif source == "Hollywood": self.out_size = (224, 416) - elif source == 'UCFSports': + elif source == "UCFSports": self.out_size = (256, 384) else: @@ -1147,13 +1247,16 @@ def load_frame(self, f_nr): def preprocess_sequence(self, frame_seq): transformations = [] transformations.append(transforms.ToPILImage()) - transformations.append(transforms.Resize( - self.out_size, interpolation=PIL.Image.LANCZOS)) + transformations.append( + transforms.Resize(self.out_size, interpolation=PIL.Image.LANCZOS) + ) transformations.append(transforms.ToTensor()) - if 'rgb_mean' in self.preproc_cfg: + if "rgb_mean" in self.preproc_cfg: transformations.append( transforms.Normalize( - self.preproc_cfg['rgb_mean'], self.preproc_cfg['rgb_std'])) + self.preproc_cfg["rgb_mean"], self.preproc_cfg["rgb_std"] + ) + ) processing = transforms.Compose(transformations) tensor = [processing(img) for img in frame_seq] tensor = torch.stack(tensor) @@ -1180,16 +1283,16 @@ def __init__(self, images_path): self.images_path = images_path self.frame_modulo = 1 self.preproc_cfg = { - 'rgb_mean': (0.485, 0.456, 0.406), - 'rgb_std': (0.229, 0.224, 0.225), + "rgb_mean": (0.485, 0.456, 0.406), + "rgb_std": (0.229, 0.224, 0.225), } image_files = sorted(list(images_path.glob("*"))) - image_files = [file for file in image_files - if file.suffix in ('.png', '.jpg', '.jpeg')] + image_files = [ + file for file in image_files if file.suffix in (".png", ".jpg", ".jpeg") + ] self.image_files = image_files - self.n_images_dict = { - img_idx: 1 for img_idx in range(len(self.image_files))} + self.n_images_dict = {img_idx: 1 for img_idx in range(len(self.image_files))} self.target_size_dict = {} self.out_size_dict = {} @@ -1210,14 +1313,15 @@ def load_image(self, img_idx): def preprocess(self, img, out_size): transformations = [ transforms.ToPILImage(), - transforms.Resize( - out_size, interpolation=PIL.Image.LANCZOS), + transforms.Resize(out_size, interpolation=PIL.Image.LANCZOS), transforms.ToTensor(), ] - if 'rgb_mean' in self.preproc_cfg: + if "rgb_mean" in self.preproc_cfg: transformations.append( transforms.Normalize( - self.preproc_cfg['rgb_mean'], self.preproc_cfg['rgb_std'])) + self.preproc_cfg["rgb_mean"], self.preproc_cfg["rgb_std"] + ) + ) processing = transforms.Compose(transformations) tensor = processing(img) return tensor @@ -1225,7 +1329,7 @@ def preprocess(self, img, out_size): def get_data(self, img_idx): file = self.image_files[img_idx] img = cv2.imread(str(file)) - assert (img is not None) + assert img is not None img = np.ascontiguousarray(img[:, :, ::-1]) out_size = self.out_size_dict[img_idx] img = self.preprocess(img, out_size) @@ -1236,3 +1340,45 @@ def __len__(self): def __getitem__(self, item): return self.get_data(item, 0) + + +def im_preprocess( + img_rgb: np.array, + preproc_cfg: dict = { + "rgb_mean": (0.485, 0.456, 0.406), + "rgb_std": (0.229, 0.224, 0.225), + }, +): + """preprocess image before inference and return a tensor of shap [c,h,w] + Args: + img_rgb: image as numpy array (e.g. cv2.imread(im_path)[:,:,::-1]) + """ + + def preprocess(img, out_size): + transformations = [ + transforms.ToPILImage(), + transforms.Resize(out_size, interpolation=PIL.Image.LANCZOS), + transforms.ToTensor(), + ] + if "rgb_mean" in preproc_cfg: + transformations.append( + transforms.Normalize(preproc_cfg["rgb_mean"], preproc_cfg["rgb_std"]) + ) + processing = transforms.Compose(transformations) + tensor = processing(img) + return tensor + + im = np.ascontiguousarray(img_rgb) + out_size = get_optimal_out_size(img_size=img_rgb.shape[:2]) + return preprocess(im, out_size) + + +def smap_postprocess(smap): + """postprocess an output torch tensor into a numpy array + Args: + smap: a slice of the output torch tensor (e.g. pred_seq[:,0, ...]) + """ + smap = smap.exp() + smap = torch.squeeze(smap) + smap = utils.to_numpy(smap) + return (smap / np.amax(smap) * 255).astype(np.uint8) diff --git a/unisal/demo.py b/unisal/demo.py new file mode 100644 index 0000000..9e42e81 --- /dev/null +++ b/unisal/demo.py @@ -0,0 +1,37 @@ +from pathlib import Path +import os +from . import train +from . import data + +TRAINER_ZOO = {} +PROJECT_DIR = os.path.dirname(os.path.realpath(__file__)) + + +def load_trainer(train_id: str = "pretrained_unisal"): + """Instantiate Trainer class from saved kwargs.""" + train_dir = Path(os.environ["TRAIN_DIR"]) + train_dir = train_dir / train_id + print(f"initalizing trainer from {train_dir}...") + return train.Trainer.init_from_cfg_dir(train_dir) + + +def get_trainer(model_path: str): + """get trainer from memory if already loaded, else load it""" + global TRAINER_ZOO + if model_path not in TRAINER_ZOO.keys(): + trainer = load_trainer() + trainer.model.load_weights_from_path(model_path) + TRAINER_ZOO[model_path] = trainer + return TRAINER_ZOO[model_path] + + +def predict_image( + img_rgb, + model_path: str = os.path.join( + PROJECT_DIR, "../training_runs/pretrained_unisal/weights_best.pth" + ), +): + trainer = get_trainer(model_path) + pred_seq = trainer.inference(img_rgb=img_rgb) + smap = data.smap_postprocess(pred_seq[:, 0, ...]) + return smap diff --git a/unisal/model.py b/unisal/model.py index 3137f59..83f2d00 100644 --- a/unisal/model.py +++ b/unisal/model.py @@ -1,5 +1,5 @@ from collections import OrderedDict -import pprint +import pprint, os from functools import partial from itertools import product @@ -41,6 +41,15 @@ def load_best_weights(self, directory): torch.load(directory / f"weights_best.pth", map_location=DEFAULT_DEVICE) ) + def load_weights_from_path(self, weights_path): + assert os.path.isfile( + weights_path + ), f"weights_path {weights_path} is not a valid file" + assert weights_path.endswith( + ".pth" + ), f"weights file must have .pth extension, {os.path.splitext(weights_path)[-1]}" + self.load_state_dict(torch.load(weights_path, map_location=DEFAULT_DEVICE)) + def load_epoch_checkpoint(self, directory, epoch): """Load state_dict from a Trainer checkpoint at a specific epoch""" chkpnt = torch.load(directory / f"chkpnt_epoch{epoch:04d}.pth") diff --git a/unisal/train.py b/unisal/train.py index 506ec6a..2f2e70c 100644 --- a/unisal/train.py +++ b/unisal/train.py @@ -16,7 +16,6 @@ import torch import torch.nn.functional as F import cv2 -from tensorboardX import SummaryWriter import numpy as np from . import salience_metrics @@ -686,6 +685,141 @@ def other_maps(): if return_predictions: return pred_seq + def inference( + self, + img_rgb: np.array, + source: str = "SALICON", # images: ['SALICON', 'MIT1003', 'MIT300'], videos: ['DHF1K',"UCFSports", "Hollywood"] + # vid_nr, # image index, 0 for video + # dataset=None, + # phase=None, + smooth_method=None, + # metrics=None, + # save_predictions=False, + # return_predictions=False, + seq_len_factor=0.5, + random_seed=27, + # n_aucs_maps=10, + # auc_portion=1.0, + # model_domain=None, + # folder_suffix=None, + ): + """run inference on numpy array and return a mask""" + + # if dataset is None: + # assert phase, "Must provide either dataset or phase" + # dataset = self.get_dataset(phase, source) + + if random_seed is not None: + random.seed(random_seed) + + # Get the original resolution (h,w) + # target_size = dataset.target_size_dict[vid_nr] + target_size = img_rgb.shape[:2] + + # Set the keyword arguments for the forward pass + model_kwargs = {"source": source, "target_size": target_size} + + # Make sure that the model was trained on the selected domain + if model_kwargs["source"] not in self.model.sources: + print( + f"\nWarning! Evaluation bn source {model_kwargs['source']} " + f"doesn't exist in model.\n Using {self.model.sources[0]}." + ) + model_kwargs["source"] = self.model.sources[0] + + # Select static (image) or dynamic (video) forward pass for Bypass-RNN + model_kwargs.update( + {"static": model_kwargs["source"] in ("SALICON", "MIT300", "MIT1003")} + ) + + # Set additional parameters + static_data = source in ("SALICON", "MIT300", "MIT1003") + assert static_data, f"trainer.inference currently only support static_data" + if static_data: + smooth_method = None + # auc_portion = 1.0 + n_images = 1 + frame_modulo = 1 + else: + # video mode + n_images = dataset.n_images_dict[vid_nr] + frame_modulo = dataset.frame_modulo + + # Prepare the model + self.model.to(self.device) + self.model.eval() + torch.cuda.empty_cache() + + # Prepare the prediction and target tensors + results_size = (1, n_images, 1, *model_kwargs["target_size"]) + pred_seq = torch.full(results_size, 0, dtype=torch.float) + sal_seq, fix_seq = None, None + + # Define input sequence length + # seq_len = self.batch_size * self.get_dataset('train').seq_len * \ + # seq_len_factor + seq_len = int(12 * seq_len_factor) + + # Iterate over different offsets to create the interleaved predictions + for offset in range(min(frame_modulo, n_images)): + + # Get the data + if not static_data: + # video mode + sample = dataset.get_data(vid_nr, offset + 1) + sample = sample[:-1] + else: + # sample = dataset.get_data(vid_nr) + sample = [1], data.im_preprocess(img_rgb=img_rgb) + + # Preprocess the data + if len(sample) >= 4: + # if len(sample) == 5: + # sample = sample[:-1] + frame_nrs, frame_seq, this_sal_seq, this_fix_seq = sample + this_sal_seq = this_sal_seq.unsqueeze(0).float() + this_fix_seq = this_fix_seq.unsqueeze(0) + if frame_seq.dim() == 3: + frame_seq = frame_seq.unsqueeze(0) + this_sal_seq = this_sal_seq.unsqueeze(0) + this_fix_seq = this_fix_seq.unsqueeze(0) + else: + frame_nrs, frame_seq = sample + this_sal_seq, this_fix_seq = None, None + if frame_seq.dim() == 3: + frame_seq = frame_seq.unsqueeze(0) + frame_seq = frame_seq.unsqueeze(0).float() + frame_idx_array = [f_nr - 1 for f_nr in frame_nrs] + frame_seq = frame_seq.to(self.device) + + # Run all sequences of the current offset + h0 = [None] + for start in range(0, len(frame_idx_array), seq_len): + + # Select the frames + end = min(len(frame_idx_array), start + seq_len) + this_frame_seq = frame_seq[:, start:end, :, :, :] + this_frame_idx_array = frame_idx_array[start:end] + + # Forward pass + this_pred_seq, h0 = self.model( + this_frame_seq, h0=h0, return_hidden=True, **model_kwargs + ) + + # Insert the predictions into the prediction array + this_pred_seq = this_pred_seq.cpu() + pred_seq[:, this_frame_idx_array, :, :, :] = this_pred_seq + + # Assert non-empty predictions + assert torch.min(pred_seq.exp().sum(-1).sum(-1)) > 0 + + # Optionally smooth the interleaved sequences + if smooth_method is not None: + pred_seq = pred_seq.numpy() + pred_seq = utils.smooth_sequence(pred_seq, smooth_method) + pred_seq = torch.from_numpy(pred_seq).float() + return pred_seq + @staticmethod def eval_sequences( pred_seq, sal_seq, fix_seq, metrics, other_maps=None, auc_portion=1.0 @@ -1353,6 +1487,8 @@ def add_scalar(self, key, value, epoch, this_tboard=True): @property def writer(self): """Return TensorboardX writer""" + from tensorboardX import SummaryWriter + if self.tboard and self._writer is None: if self.data_sources == ("MIT1003",): log_dir = self.mit1003_dir