From bb6573176c5dc3435772f37d388271689fba2478 Mon Sep 17 00:00:00 2001 From: John Ho Date: Tue, 16 Jul 2024 08:00:25 -0400 Subject: [PATCH 1/6] enabled CPU inference --- unisal/model.py | 351 +++++++++------- unisal/models/MobileNetV2.py | 100 +++-- unisal/models/cgru.py | 165 +++++--- unisal/train.py | 785 +++++++++++++++++++---------------- 4 files changed, 825 insertions(+), 576 deletions(-) diff --git a/unisal/model.py b/unisal/model.py index bcb0244..3137f59 100644 --- a/unisal/model.py +++ b/unisal/model.py @@ -11,6 +11,11 @@ from .models.cgru import ConvGRU from .models.MobileNetV2 import MobileNetV2, InvertedResidual +DEFAULT_DEVICE = ( + torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +) +print(f"torch device set to: {DEFAULT_DEVICE}") + def get_model(): """Return the model class""" @@ -24,41 +29,50 @@ def forward(self, *input): raise NotImplementedError def save_weights(self, directory, name): - torch.save(self.state_dict(), directory / f'weights_{name}.pth') + torch.save(self.state_dict(), directory / f"weights_{name}.pth") def load_weights(self, directory, name): - self.load_state_dict(torch.load(directory / f'weights_{name}.pth')) + self.load_state_dict( + torch.load(directory / f"weights_{name}.pth", map_location=DEFAULT_DEVICE) + ) def load_best_weights(self, directory): - self.load_state_dict(torch.load(directory / f'weights_best.pth')) + self.load_state_dict( + torch.load(directory / f"weights_best.pth", map_location=DEFAULT_DEVICE) + ) def load_epoch_checkpoint(self, directory, epoch): """Load state_dict from a Trainer checkpoint at a specific epoch""" chkpnt = torch.load(directory / f"chkpnt_epoch{epoch:04d}.pth") - self.load_state_dict(chkpnt['model_state_dict']) + self.load_state_dict(chkpnt["model_state_dict"]) def load_checkpoint(self, file): """Load state_dict from a specific Trainer checkpoint""" """Load """ chkpnt = torch.load(file) - self.load_state_dict(chkpnt['model_state_dict']) + self.load_state_dict(chkpnt["model_state_dict"]) def load_last_chkpnt(self, directory): """Load state_dict from the last Trainer checkpoint""" - last_chkpnt = sorted(list(directory.glob('chkpnt_epoch*.pth')))[-1] + last_chkpnt = sorted(list(directory.glob("chkpnt_epoch*.pth")))[-1] self.load_checkpoint(last_chkpnt) # Set default backbone CNN kwargs default_cnn_cfg = { - 'widen_factor': 1., 'pretrained': True, 'input_channel': 32, - 'last_channel': 1280} + "widen_factor": 1.0, + "pretrained": True, + "input_channel": 32, + "last_channel": 1280, +} # Set default RNN kwargs default_rnn_cfg = { - 'kernel_size': (3, 3), 'gate_ksize': (3, 3), - 'dropout': (False, True, False), 'drop_prob': (0.2, 0.2, 0.2), - 'mobile': True, + "kernel_size": (3, 3), + "gate_ksize": (3, 3), + "dropout": (False, True, False), + "drop_prob": (0.2, 0.2, 0.2), + "mobile": True, } @@ -72,11 +86,11 @@ class DomainBatchNorm2d(nn.Module): def __init__(self, num_features, sources, momenta=None, **kwargs): """ - num_features: Number of channels - sources: List of sources - momenta: List of BatchNorm momenta corresponding to the sources. - Default is 0.1 for each source. - kwargs: Other BatchNorm kwargs + num_features: Number of channels + sources: List of sources + momenta: List of BatchNorm momenta corresponding to the sources. + Default is 0.1 for each source. + kwargs: Other BatchNorm kwargs """ super().__init__() self.sources = sources @@ -85,13 +99,14 @@ def __init__(self, num_features, sources, momenta=None, **kwargs): if momenta is None: momenta = [0.1] * len(sources) self.momenta = momenta - if 'momentum' in kwargs: - del kwargs['momentum'] + if "momentum" in kwargs: + del kwargs["momentum"] # Instantiate the BN modules for src, mnt in zip(sources, self.momenta): - self.__setattr__(f"bn_{src}", nn.BatchNorm2d( - num_features, momentum=mnt, **kwargs)) + self.__setattr__( + f"bn_{src}", nn.BatchNorm2d(num_features, momentum=mnt, **kwargs) + ) # Prepare the self.this_source attribute that will be updated at runtime # by the model @@ -131,35 +146,37 @@ class UNISAL(BaseModel, utils.KwConfigClass): verbose: Verbosity level. """ - def __init__(self, - rnn_input_channels=256, rnn_hidden_channels=256, - cnn_cfg=None, - rnn_cfg=None, - res_rnn=True, - bypass_rnn=True, - drop_probs=(0.0, 0.6, 0.6), - gaussian_init='manual', - n_gaussians=16, - smoothing_ksize=41, - bn_momentum=0.01, - static_bn_momentum=0.1, - sources=('DHF1K', 'Hollywood', 'UCFSports', 'SALICON'), - ds_bn=True, - ds_adaptation=True, - ds_smoothing=True, - ds_gaussians=True, - verbose=1, - ): + def __init__( + self, + rnn_input_channels=256, + rnn_hidden_channels=256, + cnn_cfg=None, + rnn_cfg=None, + res_rnn=True, + bypass_rnn=True, + drop_probs=(0.0, 0.6, 0.6), + gaussian_init="manual", + n_gaussians=16, + smoothing_ksize=41, + bn_momentum=0.01, + static_bn_momentum=0.1, + sources=("DHF1K", "Hollywood", "UCFSports", "SALICON"), + ds_bn=True, + ds_adaptation=True, + ds_smoothing=True, + ds_gaussians=True, + verbose=1, + ): super().__init__() # Check inputs - assert(gaussian_init in ('random', 'manual')) + assert gaussian_init in ("random", "manual") # Bypass-RNN requires residual RNN connection if bypass_rnn: assert res_rnn # Manual Gaussian initialization generates 16 Gaussians - if n_gaussians > 0 and gaussian_init == 'manual': + if n_gaussians > 0 and gaussian_init == "manual": n_gaussians = 16 self.rnn_input_channels = rnn_input_channels @@ -190,64 +207,98 @@ def __init__(self, # Initialize Post-CNN module with optional dropout post_cnn = [ - ('inv_res', InvertedResidual( - self.cnn.out_channels + n_gaussians, - rnn_input_channels, 1, 1, bn_momentum=bn_momentum, - )) + ( + "inv_res", + InvertedResidual( + self.cnn.out_channels + n_gaussians, + rnn_input_channels, + 1, + 1, + bn_momentum=bn_momentum, + ), + ) ] if self.drop_probs[0] > 0: - post_cnn.insert(0, ( - 'dropout', nn.Dropout2d(self.drop_probs[0], inplace=False) - )) + post_cnn.insert( + 0, ("dropout", nn.Dropout2d(self.drop_probs[0], inplace=False)) + ) self.post_cnn = nn.Sequential(OrderedDict(post_cnn)) # Initialize Bypass-RNN if training on dynamic data - if sources != ('SALICON',) or not self.bypass_rnn: + if sources != ("SALICON",) or not self.bypass_rnn: self.rnn = ConvGRU( rnn_input_channels, hidden_channels=[rnn_hidden_channels], batchnorm=self.get_bn_module, - **self.rnn_cfg) - self.post_rnn = self.conv_1x1_bn( - rnn_hidden_channels, rnn_input_channels) + **self.rnn_cfg, + ) + self.post_rnn = self.conv_1x1_bn(rnn_hidden_channels, rnn_input_channels) # Initialize first upsampling module US1 - self.upsampling_1 = nn.Sequential(OrderedDict([ - ('us1', self.upsampling(2)), - ])) + self.upsampling_1 = nn.Sequential( + OrderedDict( + [ + ("us1", self.upsampling(2)), + ] + ) + ) # Number of channels at the 2x scale channels_2x = 128 # Initialize Skip-2x module self.skip_2x = self.make_skip_connection( - self.cnn.feat_2x_channels, channels_2x, 2, self.drop_probs[1]) + self.cnn.feat_2x_channels, channels_2x, 2, self.drop_probs[1] + ) # Initialize second upsampling module US2 - self.upsampling_2 = nn.Sequential(OrderedDict([ - ('inv_res', InvertedResidual( - rnn_input_channels + channels_2x, - channels_2x, 1, 2, batchnorm=self.get_bn_module)), - ('us2', self.upsampling(2)), - ])) + self.upsampling_2 = nn.Sequential( + OrderedDict( + [ + ( + "inv_res", + InvertedResidual( + rnn_input_channels + channels_2x, + channels_2x, + 1, + 2, + batchnorm=self.get_bn_module, + ), + ), + ("us2", self.upsampling(2)), + ] + ) + ) # Number of channels at the 4x scale channels_4x = 64 # Initialize Skip-4x module self.skip_4x = self.make_skip_connection( - self.cnn.feat_4x_channels, channels_4x, 2, self.drop_probs[2]) + self.cnn.feat_4x_channels, channels_4x, 2, self.drop_probs[2] + ) # Initialize Post-US2 module - self.post_upsampling_2= nn.Sequential(OrderedDict([ - ('inv_res', InvertedResidual( - channels_2x + channels_4x, channels_4x, 1, 2, - batchnorm=self.get_bn_module)), - ])) + self.post_upsampling_2 = nn.Sequential( + OrderedDict( + [ + ( + "inv_res", + InvertedResidual( + channels_2x + channels_4x, + channels_4x, + 1, + 2, + batchnorm=self.get_bn_module, + ), + ), + ] + ) + ) # Initialize domain-specific modules for source_str in self.sources: - source_str = f'_{source_str}'.lower() + source_str = f"_{source_str}".lower() # Initialize learned Gaussian priors parameters if n_gaussians > 0: @@ -255,24 +306,23 @@ def __init__(self, # Initialize Adaptation self.__setattr__( - 'adaptation' + (source_str if self.ds_adaptation else ''), - nn.Sequential(*[ - nn.Conv2d(channels_4x, 1, 1, bias=True) - ])) + "adaptation" + (source_str if self.ds_adaptation else ""), + nn.Sequential(*[nn.Conv2d(channels_4x, 1, 1, bias=True)]), + ) # Initialize Smoothing smoothing = nn.Conv2d( - 1, 1, kernel_size=smoothing_ksize, padding=0, bias=False) + 1, 1, kernel_size=smoothing_ksize, padding=0, bias=False + ) with torch.no_grad(): gaussian = self._make_gaussian_maps( - smoothing.weight.data, - torch.Tensor([[[0.5, -2]] * 2]) + smoothing.weight.data, torch.Tensor([[[0.5, -2]] * 2]) ) gaussian /= gaussian.sum() smoothing.weight.data = gaussian self.__setattr__( - 'smoothing' + (source_str if self.ds_smoothing else ''), - smoothing) + "smoothing" + (source_str if self.ds_smoothing else ""), smoothing + ) if self.verbose > 1: pprint.pprint(self.asdict(), width=1) @@ -292,52 +342,58 @@ def this_source(self, source): def get_bn_module(self, num_features, **kwargs): """Return BatchNorm class (domain-specific or domain-invariant).""" - momenta = [self.bn_momentum if src != 'SALICON' - else self.static_bn_momentum for src in self.sources] + momenta = [ + self.bn_momentum if src != "SALICON" else self.static_bn_momentum + for src in self.sources + ] if self.ds_bn: return DomainBatchNorm2d( - num_features, self.sources, momenta=momenta, **kwargs) + num_features, self.sources, momenta=momenta, **kwargs + ) else: return nn.BatchNorm2d(num_features, **kwargs) # @staticmethod def upsampling(self, factor): """Return upsampling module.""" - return nn.Sequential(*[ - nn.Upsample( - scale_factor=factor, mode='bilinear', align_corners=False), - ]) + return nn.Sequential( + *[ + nn.Upsample(scale_factor=factor, mode="bilinear", align_corners=False), + ] + ) - def set_gaussians(self, source_str, prefix='coarse_'): + def set_gaussians(self, source_str, prefix="coarse_"): """Set Gaussian parameters.""" - suffix = source_str if self.ds_gaussians else '' + suffix = source_str if self.ds_gaussians else "" self.__setattr__( - prefix + 'gaussians' + suffix, - self._initialize_gaussians(self.n_gaussians)) + prefix + "gaussians" + suffix, self._initialize_gaussians(self.n_gaussians) + ) def _initialize_gaussians(self, n_gaussians): """ Return initialized Gaussian parameters. Dimensions: [idx, y/x, mu/logstd]. """ - if self.gaussian_init == 'manual': - gaussians = torch.Tensor([ - list(product([0.25, 0.5, 0.75], repeat=2)) + - [(0.5, 0.25), (0.5, 0.5), (0.5, 0.75)] + - [(0.25, 0.5), (0.5, 0.5), (0.75, 0.5)] + - [(0.5, 0.5)], - [(-1.5, -1.5)] * 9 + [(0, -1.5)] * 3 + [(-1.5, 0)] * 3 + - [(0, 0)], - ]).permute(1, 2, 0) - - elif self.gaussian_init == 'random': + if self.gaussian_init == "manual": + gaussians = torch.Tensor( + [ + list(product([0.25, 0.5, 0.75], repeat=2)) + + [(0.5, 0.25), (0.5, 0.5), (0.5, 0.75)] + + [(0.25, 0.5), (0.5, 0.5), (0.75, 0.5)] + + [(0.5, 0.5)], + [(-1.5, -1.5)] * 9 + [(0, -1.5)] * 3 + [(-1.5, 0)] * 3 + [(0, 0)], + ] + ).permute(1, 2, 0) + + elif self.gaussian_init == "random": with torch.no_grad(): - gaussians = torch.stack([ - torch.randn( - n_gaussians, 2, dtype=torch.float) * .1 + 0.5, - torch.randn( - n_gaussians, 2, dtype=torch.float) * .2 - 1,], - dim=2) + gaussians = torch.stack( + [ + torch.randn(n_gaussians, 2, dtype=torch.float) * 0.1 + 0.5, + torch.randn(n_gaussians, 2, dtype=torch.float) * 0.2 - 1, + ], + dim=2, + ) else: raise NotImplementedError @@ -346,7 +402,7 @@ def _initialize_gaussians(self, n_gaussians): return gaussians @staticmethod - def _make_gaussian_maps(x, gaussians, size=None, scaling=6.): + def _make_gaussian_maps(x, gaussians, size=None, scaling=6.0): """Construct prior maps from Gaussian parameters.""" if size is None: size = x.shape[-2:] @@ -360,15 +416,18 @@ def _make_gaussian_maps(x, gaussians, size=None, scaling=6.): gaussian_maps = [] map_template = torch.ones(*size, dtype=dtype, device=device) meshgrids = torch.meshgrid( - [torch.linspace(0, 1, size[0], dtype=dtype, device=device), - torch.linspace(0, 1, size[1], dtype=dtype, device=device),]) + [ + torch.linspace(0, 1, size[0], dtype=dtype, device=device), + torch.linspace(0, 1, size[1], dtype=dtype, device=device), + ] + ) for gaussian_idx, yx_mu_logstd in enumerate(torch.unbind(gaussians)): map = map_template.clone() for mu_logstd, mgrid in zip(yx_mu_logstd, meshgrids): mu = mu_logstd[0] std = torch.exp(mu_logstd[1]) - map *= torch.exp(-((mgrid - mu) / std) ** 2 / 2) + map *= torch.exp(-(((mgrid - mu) / std) ** 2) / 2) map *= scaling gaussian_maps.append(map) @@ -377,27 +436,36 @@ def _make_gaussian_maps(x, gaussians, size=None, scaling=6.): gaussian_maps = gaussian_maps.unsqueeze(0).expand(bs, -1, -1, -1) return gaussian_maps - def _get_gaussian_maps(self, x, source_str, prefix='coarse_', **kwargs): + def _get_gaussian_maps(self, x, source_str, prefix="coarse_", **kwargs): """Return the constructed Gaussian prior maps.""" - suffix = source_str if self.ds_gaussians else '' + suffix = source_str if self.ds_gaussians else "" gaussians = self.__getattr__(prefix + "gaussians" + suffix) gaussian_maps = self._make_gaussian_maps(x, gaussians, **kwargs) return gaussian_maps # @classmethod - def make_skip_connection(self, input_channels, output_channels, expand_ratio, p, - inplace=False): + def make_skip_connection( + self, input_channels, output_channels, expand_ratio, p, inplace=False + ): """Return skip connection module.""" hidden_channels = round(input_channels * expand_ratio) - return nn.Sequential(OrderedDict([ - ('expansion', self.conv_1x1_bn( - input_channels, hidden_channels)), - ('dropout', nn.Dropout2d(p, inplace=inplace)), - ('reduction', nn.Sequential(*[ - nn.Conv2d(hidden_channels, output_channels, 1), - self.get_bn_module(output_channels), - ])), - ])) + return nn.Sequential( + OrderedDict( + [ + ("expansion", self.conv_1x1_bn(input_channels, hidden_channels)), + ("dropout", nn.Dropout2d(p, inplace=inplace)), + ( + "reduction", + nn.Sequential( + *[ + nn.Conv2d(hidden_channels, output_channels, 1), + self.get_bn_module(output_channels), + ] + ), + ), + ] + ) + ) # @staticmethod def conv_1x1_bn(self, inp, oup): @@ -405,11 +473,18 @@ def conv_1x1_bn(self, inp, oup): return nn.Sequential( nn.Conv2d(inp, oup, 1, 1, 0, bias=False), self.get_bn_module(oup), - nn.ReLU6(inplace=True) + nn.ReLU6(inplace=True), ) - def forward(self, x, target_size=None, h0=None, return_hidden=False, - source='DHF1K', static=None): + def forward( + self, + x, + target_size=None, + h0=None, + return_hidden=False, + source="DHF1K", + static=None, + ): """ Forward pass. @@ -429,9 +504,9 @@ def forward(self, x, target_size=None, h0=None, return_hidden=False, self.this_source = source # Prepare other parameters - source_str = f'_{source.lower()}' + source_str = f"_{source.lower()}" if static is None: - static = x.shape[1] == 1 or self.sources == ('SALICON',) + static = x.shape[1] == 1 or self.sources == ("SALICON",) # Compute backbone CNN features and concatenate with Gaussian prior maps feat_seq_1x = [] @@ -461,8 +536,7 @@ def forward(self, x, target_size=None, h0=None, return_hidden=False, # Decoder output_seq = [] - for idx, im_feat in enumerate( - torch.unbind(feat_seq_1x, dim=1)): + for idx, im_feat in enumerate(torch.unbind(feat_seq_1x, dim=1)): if not (static and self.bypass_rnn): rnn_feat = rnn_feat_seq[:, idx, ...] @@ -479,20 +553,19 @@ def forward(self, x, target_size=None, h0=None, return_hidden=False, im_feat = self.post_upsampling_2(im_feat) im_feat = self.__getattr__( - 'adaptation' + (source_str if self.ds_adaptation else ''))( - im_feat) + "adaptation" + (source_str if self.ds_adaptation else "") + )(im_feat) - im_feat = F.interpolate( - im_feat, size=x.shape[-2:], mode='nearest') + im_feat = F.interpolate(im_feat, size=x.shape[-2:], mode="nearest") - im_feat = F.pad(im_feat, [self.smoothing_ksize // 2] * 4, - mode='replicate') + im_feat = F.pad(im_feat, [self.smoothing_ksize // 2] * 4, mode="replicate") im_feat = self.__getattr__( - 'smoothing' + (source_str if self.ds_smoothing else ''))( - im_feat) + "smoothing" + (source_str if self.ds_smoothing else "") + )(im_feat) im_feat = F.interpolate( - im_feat, size=target_size, mode='bilinear', align_corners=False) + im_feat, size=target_size, mode="bilinear", align_corners=False + ) im_feat = utils.log_softmax(im_feat) output_seq.append(im_feat) diff --git a/unisal/models/MobileNetV2.py b/unisal/models/MobileNetV2.py index 4bc424d..8e87c16 100755 --- a/unisal/models/MobileNetV2.py +++ b/unisal/models/MobileNetV2.py @@ -11,7 +11,7 @@ def conv_bn(inp, oup, stride): return nn.Sequential( nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), - nn.ReLU6(inplace=True) + nn.ReLU6(inplace=True), ) @@ -19,23 +19,32 @@ def conv_1x1_bn(inp, oup): return nn.Sequential( nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), - nn.ReLU6(inplace=True) + nn.ReLU6(inplace=True), ) class InvertedResidual(nn.Module): - def __init__(self, inp, oup, stride, expand_ratio, omit_stride=False, - no_res_connect=False, dropout=0., bn_momentum=0.1, - batchnorm=None): + def __init__( + self, + inp, + oup, + stride, + expand_ratio, + omit_stride=False, + no_res_connect=False, + dropout=0.0, + bn_momentum=0.1, + batchnorm=None, + ): super().__init__() self.out_channels = oup self.stride = stride self.omit_stride = omit_stride - self.use_res_connect = not no_res_connect and\ - self.stride == 1 and inp == oup + self.use_res_connect = not no_res_connect and self.stride == 1 and inp == oup self.dropout = dropout actual_stride = self.stride if not self.omit_stride else 1 if batchnorm is None: + def batchnorm(num_features): return nn.BatchNorm2d(num_features, momentum=bn_momentum) @@ -45,8 +54,15 @@ def batchnorm(num_features): if expand_ratio == 1: modules = [ # dw - nn.Conv2d(hidden_dim, hidden_dim, 3, actual_stride, 1, - groups=hidden_dim, bias=False), + nn.Conv2d( + hidden_dim, + hidden_dim, + 3, + actual_stride, + 1, + groups=hidden_dim, + bias=False, + ), batchnorm(hidden_dim), nn.ReLU6(inplace=True), # pw-linear @@ -63,8 +79,15 @@ def batchnorm(num_features): batchnorm(hidden_dim), nn.ReLU6(inplace=True), # dw - nn.Conv2d(hidden_dim, hidden_dim, 3, actual_stride, 1, - groups=hidden_dim, bias=False), + nn.Conv2d( + hidden_dim, + hidden_dim, + 3, + actual_stride, + 1, + groups=hidden_dim, + bias=False, + ), batchnorm(hidden_dim), nn.ReLU6(inplace=True), # pw-linear @@ -86,7 +109,7 @@ def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) + m.weight.data.normal_(0, math.sqrt(2.0 / n)) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): @@ -99,8 +122,9 @@ def _initialize_weights(self): class MobileNetV2(nn.Module): - def __init__(self, widen_factor=1., pretrained=True, - last_channel=None, input_channel=32): + def __init__( + self, widen_factor=1.0, pretrained=True, last_channel=None, input_channel=32 + ): super().__init__() self.widen_factor = widen_factor self.pretrained = pretrained @@ -127,33 +151,45 @@ def __init__(self, widen_factor=1., pretrained=True, output_channel = int(c * widen_factor) for i in range(n): if i == 0: - self.features.append(block( - input_channel, output_channel, s, expand_ratio=t, - omit_stride=True)) + self.features.append( + block( + input_channel, + output_channel, + s, + expand_ratio=t, + omit_stride=True, + ) + ) else: - self.features.append(block( - input_channel, output_channel, 1, expand_ratio=t)) + self.features.append( + block(input_channel, output_channel, 1, expand_ratio=t) + ) input_channel = output_channel # building last several layers if self.last_channel is not None: - output_channel = int(self.last_channel * widen_factor)\ - if widen_factor > 1.0 else self.last_channel + output_channel = ( + int(self.last_channel * widen_factor) + if widen_factor > 1.0 + else self.last_channel + ) self.features.append(conv_1x1_bn(input_channel, output_channel)) # make it nn.Sequential self.features = nn.Sequential(*self.features) self.out_channels = output_channel - self.feat_1x_channels = int( - interverted_residual_setting[-1][1] * widen_factor) - self.feat_2x_channels = int( - interverted_residual_setting[-2][1] * widen_factor) - self.feat_4x_channels = int( - interverted_residual_setting[-4][1] * widen_factor) - self.feat_8x_channels = int( - interverted_residual_setting[-5][1] * widen_factor) + self.feat_1x_channels = int(interverted_residual_setting[-1][1] * widen_factor) + self.feat_2x_channels = int(interverted_residual_setting[-2][1] * widen_factor) + self.feat_4x_channels = int(interverted_residual_setting[-4][1] * widen_factor) + self.feat_8x_channels = int(interverted_residual_setting[-5][1] * widen_factor) if self.pretrained: state_dict = torch.load( - Path(__file__).resolve().parent / 'weights/mobilenet_v2.pth.tar') + Path(__file__).resolve().parent / "weights/mobilenet_v2.pth.tar", + map_location=( + torch.device("cuda:0") + if torch.cuda.is_available() + else torch.device("cpu") + ), + ) self.load_state_dict(state_dict, strict=False) else: self._initialize_weights() @@ -167,7 +203,7 @@ def forward(self, x): feat_4x = x.clone() elif idx == 14: feat_2x = x.clone() - if idx > 0 and hasattr(module, 'stride') and module.stride != 1: + if idx > 0 and hasattr(module, "stride") and module.stride != 1: x = x[..., ::2, ::2] return x, feat_2x, feat_4x @@ -176,7 +212,7 @@ def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) + m.weight.data.normal_(0, math.sqrt(2.0 / n)) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): diff --git a/unisal/models/cgru.py b/unisal/models/cgru.py index bb2f976..128dd26 100644 --- a/unisal/models/cgru.py +++ b/unisal/models/cgru.py @@ -39,11 +39,25 @@ class ConvGRUCell(nn.Module): mobile: If True, MobileNet-style convolutions are used. """ - def __init__(self, input_ch, hidden_ch, kernel_size, gate_ksize=(1, 1), - bias=True, norm='', norm_momentum=0.1, affine_norm=True, - batchnorm=None, gain=1, drop_prob=(0., 0., 0.), - do_mode='recurrent', r_bias=0., z_bias=0., mobile=False, - **kwargs): + def __init__( + self, + input_ch, + hidden_ch, + kernel_size, + gate_ksize=(1, 1), + bias=True, + norm="", + norm_momentum=0.1, + affine_norm=True, + batchnorm=None, + gain=1, + drop_prob=(0.0, 0.0, 0.0), + do_mode="recurrent", + r_bias=0.0, + z_bias=0.0, + mobile=False, + **kwargs + ): super().__init__() self.input_ch = input_ch @@ -51,7 +65,7 @@ def __init__(self, input_ch, hidden_ch, kernel_size, gate_ksize=(1, 1), self.kernel_size = kernel_size self.gate_ksize = gate_ksize self.mobile = mobile - self.kwargs = {'init': 'xavier_uniform_'} + self.kwargs = {"init": "xavier_uniform_"} self.kwargs.update(kwargs) # Process normalization arguments @@ -61,11 +75,13 @@ def __init__(self, input_ch, hidden_ch, kernel_size, gate_ksize=(1, 1), self.batchnorm = batchnorm self.norm_kwargs = None if self.batchnorm is not None: - self.norm = 'batch' + self.norm = "batch" elif self.norm: self.norm_kwargs = { - 'affine': self.affine_norm, 'track_running_stats': True, - 'momentum': self.norm_momentum} + "affine": self.affine_norm, + "track_running_stats": True, + "momentum": self.norm_momentum, + } # Prepare normalization modules if self.norm: @@ -79,12 +95,12 @@ def __init__(self, input_ch, hidden_ch, kernel_size, gate_ksize=(1, 1), # Prepare dropout self.drop_prob = drop_prob self.do_mode = do_mode - if self.do_mode == 'recurrent': + if self.do_mode == "recurrent": # Prepare dropout masks if using recurrent dropout for idx, mask in self.yield_drop_masks(): self.register_buffer(self.mask_name(idx), mask) - elif self.do_mode != 'naive': - raise ValueError('Unknown dropout mode ', self.do_mode) + elif self.do_mode != "naive": + raise ValueError("Unknown dropout mode ", self.do_mode) # Instantiate the main weight matrices self.w_r = self._conv2d(self.input_ch, self.gate_ksize, bias=False) @@ -115,9 +131,11 @@ def __init__(self, input_ch, hidden_ch, kernel_size, gate_ksize=(1, 1), def set_weights(self): """Initialize the parameters""" + def gain_from_ksize(ksize): n = ksize[0] * ksize[1] * self.hidden_ch - return math.sqrt(2. / n) + return math.sqrt(2.0 / n) + with torch.no_grad(): if not self.mobile: if self.gain < 0: @@ -125,7 +143,7 @@ def gain_from_ksize(ksize): gain_2 = gain_from_ksize(self.gate_ksize) else: gain_1 = gain_2 = self.gain - init_fn = getattr(init, self.kwargs['init']) + init_fn = getattr(init, self.kwargs["init"]) init_fn(self.w_r.weight, gain=gain_2) init_fn(self.u_r.weight, gain=gain_2) init_fn(self.w_z.weight, gain=gain_2) @@ -197,7 +215,7 @@ def forward(self, x, h_tm1): @staticmethod def mask_name(idx): - return 'drop_mask_{}'.format(idx) + return "drop_mask_{}".format(idx) def set_drop_masks(self): """Set the dropout masks for the current sequence""" @@ -210,27 +228,29 @@ def yield_drop_masks(self): n_channels = (self.input_ch, self.hidden_ch, self.hidden_ch) for idx, p in enumerate(self.drop_prob): if p > 0: - yield (idx, self.generate_do_mask( - p, n_masks[idx], n_channels[idx])) + yield (idx, self.generate_do_mask(p, n_masks[idx], n_channels[idx])) @staticmethod def generate_do_mask(p, n, ch): """Generate a dropout mask for recurrent dropout""" with torch.no_grad(): mask = Bernoulli(torch.full((n, ch), 1 - p)).sample() / (1 - p) - mask = mask.requires_grad_(False).cuda() + mask = ( + mask.requires_grad_(False).cuda() + if torch.cuda.is_available() + else mask.requires_grad_(False).cpu() + ) return mask def apply_dropout(self, x, idx, sub_idx): """Apply recurrent or naive dropout""" if self.training and self.drop_prob[idx] > 0 and idx != 2: - if self.do_mode == 'recurrent': + if self.do_mode == "recurrent": x = x.clone() * torch.reshape( - getattr(self, self.mask_name(idx)) - [sub_idx, :], (1, -1, 1, 1)) - elif self.do_mode == 'naive': - x = f.dropout2d( - x, self.drop_prob[idx], self.training, inplace=False) + getattr(self, self.mask_name(idx))[sub_idx, :], (1, -1, 1, 1) + ) + elif self.do_mode == "naive": + x = f.dropout2d(x, self.drop_prob[idx], self.training, inplace=False) else: x = x.clone() return x @@ -240,9 +260,9 @@ def get_norm_module(self, channels): norm_module = None if self.batchnorm is not None: norm_module = self.batchnorm(channels) - elif self.norm == 'instance': + elif self.norm == "instance": norm_module = nn.InstanceNorm2d(channels, **self.norm_kwargs) - elif self.norm == 'batch': + elif self.norm == "batch": norm_module = nn.BatchNorm2d(channels, **self.norm_kwargs) return norm_module @@ -253,24 +273,38 @@ def _conv2d(self, in_channels, kernel_size, bias=True): """ padding = tuple(k_size // 2 for k_size in kernel_size) if not self.mobile or kernel_size == (1, 1): - return nn.Conv2d(in_channels, self.hidden_ch, kernel_size, - padding=padding, bias=bias) + return nn.Conv2d( + in_channels, self.hidden_ch, kernel_size, padding=padding, bias=bias + ) else: - return nn.Sequential(OrderedDict([ - ('conv_dw', nn.Conv2d( - in_channels, in_channels, kernel_size=kernel_size, - padding=padding, groups=in_channels, bias=False)), - ('sep_bn', self.get_norm_module(in_channels)), - ('sep_relu', nn.ReLU6()), - ('conv_sep', nn.Conv2d( - in_channels, self.hidden_ch, 1, bias=bias)), - ])) + return nn.Sequential( + OrderedDict( + [ + ( + "conv_dw", + nn.Conv2d( + in_channels, + in_channels, + kernel_size=kernel_size, + padding=padding, + groups=in_channels, + bias=False, + ), + ), + ("sep_bn", self.get_norm_module(in_channels)), + ("sep_relu", nn.ReLU6()), + ( + "conv_sep", + nn.Conv2d(in_channels, self.hidden_ch, 1, bias=bias), + ), + ] + ) + ) def _init_hidden(self, input_, cuda=True): """Initialize the hidden state""" batch_size, _, height, width = input_.data.size() - prev_state = torch.zeros( - batch_size, self.hidden_ch, height, width) + prev_state = torch.zeros(batch_size, self.hidden_ch, height, width) if cuda: prev_state = prev_state.cuda() return prev_state @@ -278,10 +312,16 @@ def _init_hidden(self, input_, cuda=True): class ConvGRU(nn.Module): - def __init__(self, input_channels=None, hidden_channels=None, - kernel_size=(3, 3), gate_ksize=(1, 1), - dropout=(False, False, False), drop_prob=(0.5, 0.5, 0.5), - **kwargs): + def __init__( + self, + input_channels=None, + hidden_channels=None, + kernel_size=(3, 3), + gate_ksize=(1, 1), + dropout=(False, False, False), + drop_prob=(0.5, 0.5, 0.5), + **kwargs + ): """ Generates a multi-layer convolutional GRU. Preserves spatial dimensions across cells, only altering depth. @@ -313,8 +353,10 @@ def __init__(self, input_channels=None, hidden_channels=None, self.gate_ksize = self._extend_for_multilayer(gate_ksize) self.dropout = self._extend_for_multilayer(dropout) drop_prob = self._extend_for_multilayer(drop_prob) - self.drop_prob = [tuple(dp_ if do_ else 0. for dp_, do_ in zip(dp, do)) - for dp, do in zip(drop_prob, self.dropout)] + self.drop_prob = [ + tuple(dp_ if do_ else 0.0 for dp_, do_ in zip(dp, do)) + for dp, do in zip(drop_prob, self.dropout) + ] self.kwargs = kwargs cell_list = [] @@ -322,13 +364,19 @@ def __init__(self, input_channels=None, hidden_channels=None, if idx < self.num_layers - 1: # Switch output dropout off for hidden layers. # Otherwise it would confict with input dropout. - this_drop_prob = self.drop_prob[idx][:2] + (0.,) + this_drop_prob = self.drop_prob[idx][:2] + (0.0,) else: this_drop_prob = self.drop_prob[idx] - cell_list.append(ConvGRUCell( - self.input_channels[idx], self.hidden_channels[idx], - self.kernel_size[idx], drop_prob=this_drop_prob, - gate_ksize=self.gate_ksize[idx], **kwargs)) + cell_list.append( + ConvGRUCell( + self.input_channels[idx], + self.hidden_channels[idx], + self.kernel_size[idx], + drop_prob=this_drop_prob, + gate_ksize=self.gate_ksize[idx], + **kwargs + ) + ) self.cell_list = nn.ModuleList(cell_list) def forward(self, input_tensor, hidden=None): @@ -350,8 +398,7 @@ def forward(self, input_tensor, hidden=None): for t, x in enumerate(iterator): for layer_idx in range(self.num_layers): - if self.cell_list[layer_idx].do_mode == 'recurrent'\ - and t == 0: + if self.cell_list[layer_idx].do_mode == "recurrent" and t == 0: self.cell_list[layer_idx].set_drop_masks() (x, h) = self.cell_list[layer_idx](x, hidden[layer_idx]) hidden[layer_idx] = h.clone() @@ -362,14 +409,18 @@ def forward(self, input_tensor, hidden=None): @staticmethod def _check_kernel_size_consistency(kernel_size): - if not (isinstance(kernel_size, tuple) or - (isinstance(kernel_size, list) and - all([isinstance(elem, tuple) for elem in kernel_size]))): - raise ValueError('`kernel_size` must be tuple or list of tuples') + if not ( + isinstance(kernel_size, tuple) + or ( + isinstance(kernel_size, list) + and all([isinstance(elem, tuple) for elem in kernel_size]) + ) + ): + raise ValueError("`kernel_size` must be tuple or list of tuples") def _extend_for_multilayer(self, param): if not isinstance(param, list): param = [param] * self.num_layers else: - assert(len(param) == self.num_layers) + assert len(param) == self.num_layers return param diff --git a/unisal/train.py b/unisal/train.py index 3dc784c..506ec6a 100644 --- a/unisal/train.py +++ b/unisal/train.py @@ -88,47 +88,48 @@ class Trainer(utils.KwConfigClass): """ - phases = ('train', 'valid') - all_data_sources = ('DHF1K', 'Hollywood', 'UCFSports', 'SALICON') - - def __init__(self, - num_epochs=16, - optim_algo='SGD', - momentum=0.9, - lr=0.04, - lr_scheduler='ExponentialLR', - lr_gamma=0.8, - weight_decay=1e-4, - cnn_weight_decay=1e-5, - grad_clip=2., - loss_metrics=('kld', 'nss', 'cc'), - loss_weights=(1, -0.1, -0.1), - data_sources=('DHF1K', 'Hollywood', 'UCFSports', 'SALICON'), - batch_size=4, - salicon_batch_size=32, - hollywood_batch_size=4, - ucfsports_batch_size=4, - salicon_weight=.5, - hollywood_weight=1., - ucfsports_weight=1., - data_cfg=None, - salicon_cfg=None, - hollywood_cfg=None, - ucfsports_cfg=None, - shuffle_datasets=True, - cnn_lr_factor=0.1, - train_cnn_after=2, - cnn_eval=True, - model_cfg=None, - prefix=None, - suffix='unisal', - num_workers=6, - chkpnt_warmup=3, - chkpnt_epochs=2, - tboard=True, - debug=False, - new_instance=True, - ): + phases = ("train", "valid") + all_data_sources = ("DHF1K", "Hollywood", "UCFSports", "SALICON") + + def __init__( + self, + num_epochs=16, + optim_algo="SGD", + momentum=0.9, + lr=0.04, + lr_scheduler="ExponentialLR", + lr_gamma=0.8, + weight_decay=1e-4, + cnn_weight_decay=1e-5, + grad_clip=2.0, + loss_metrics=("kld", "nss", "cc"), + loss_weights=(1, -0.1, -0.1), + data_sources=("DHF1K", "Hollywood", "UCFSports", "SALICON"), + batch_size=4, + salicon_batch_size=32, + hollywood_batch_size=4, + ucfsports_batch_size=4, + salicon_weight=0.5, + hollywood_weight=1.0, + ucfsports_weight=1.0, + data_cfg=None, + salicon_cfg=None, + hollywood_cfg=None, + ucfsports_cfg=None, + shuffle_datasets=True, + cnn_lr_factor=0.1, + train_cnn_after=2, + cnn_eval=True, + model_cfg=None, + prefix=None, + suffix="unisal", + num_workers=6, + chkpnt_warmup=3, + chkpnt_epochs=2, + tboard=True, + debug=False, + new_instance=True, + ): # Save training parameters self.num_epochs = num_epochs self.optim_algo = optim_algo @@ -158,8 +159,8 @@ def __init__(self, self.train_cnn_after = train_cnn_after self.cnn_eval = cnn_eval self.model_cfg = model_cfg or {} - if 'sources' not in self.model_cfg: - self.model_cfg['sources'] = data_sources + if "sources" not in self.model_cfg: + self.model_cfg["sources"] = data_sources # Create training directory. Uses env var TRAIN_DIR self.suffix = suffix @@ -171,7 +172,7 @@ def __init__(self, self.num_workers = num_workers self.chkpnt_warmup = chkpnt_warmup self.chkpnt_epochs = chkpnt_epochs - device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + device = "cuda:0" if torch.cuda.is_available() else "cpu" self.device = torch.device(device) self.tboard = tboard self.debug = debug @@ -181,12 +182,12 @@ def __init__(self, self.num_epochs = 4 self.chkpnt_epochs = 2 self.chkpnt_warmup = 1 - self.suffix += '_debug' - self.data_cfg.update({'subset': 0.02}) - self.salicon_cfg.update({'subset': 0.02}) - self.hollywood_cfg.update({'subset': 0.04}) - self.ucfsports_cfg.update({'subset': 0.1}) - self.eval_subeset = 1. + self.suffix += "_debug" + self.data_cfg.update({"subset": 0.02}) + self.salicon_cfg.update({"subset": 0.02}) + self.hollywood_cfg.update({"subset": 0.04}) + self.ucfsports_cfg.update({"subset": 0.1}) + self.eval_subeset = 1.0 # Initialize properties etc. self.epoch = 0 @@ -218,7 +219,7 @@ def __init__(self, self.model.save_cfg(self.train_dir) self.copy_code() for src in self.data_sources: - self.get_dataset('train', source=src).save_cfg(self.train_dir) + self.get_dataset("train", source=src).save_cfg(self.train_dir) def fit(self): """ @@ -235,9 +236,10 @@ def fit(self): self.fit_epoch() # Save a checkpoint if applicable - if (self.epoch >= self.chkpnt_warmup - and (self.epoch + 1) % self.chkpnt_epochs == 0)\ - or self.epoch == self.num_epochs - 1: + if ( + self.epoch >= self.chkpnt_warmup + and (self.epoch + 1) % self.chkpnt_epochs == 0 + ) or self.epoch == self.num_epochs - 1: self.save_chkpnt() self.epoch += 1 @@ -254,9 +256,9 @@ def fit_epoch(self): # Perform LR decay self.scheduler.step(epoch=self.epoch) - lr = self.optimizer.param_groups[0]['lr'] + lr = self.optimizer.param_groups[0]["lr"] print(f"\nEpoch {self.epoch:3d}, lr {lr:.5f}") - self.add_scalar('conv/lr', lr, self.epoch) + self.add_scalar("conv/lr", lr, self.epoch) # Run the training and validation phase for self.phase in self.phases: @@ -269,29 +271,34 @@ def fit_phase(self): sources = self.data_sources # Prepare book keeping - running_losses = {src: 0. for src in sources} + running_losses = {src: 0.0 for src in sources} running_loss_summands = { - src: [0. for _ in self.loss_weights] for src in sources} + src: [0.0 for _ in self.loss_weights] for src in sources + } n_samples = {src: 0 for src in sources} # Shuffle the dataset batches - dataloaders = {src: self.get_dataloader(self.phase, src) - for src in sources} - all_batches = [src for src in chain.from_iterable(zip_longest( - *[[src for _ in range(len(dataloaders[src]))] for src in sources] - )) if src is not None] + dataloaders = {src: self.get_dataloader(self.phase, src) for src in sources} + all_batches = [ + src + for src in chain.from_iterable( + zip_longest( + *[[src for _ in range(len(dataloaders[src]))] for src in sources] + ) + ) + if src is not None + ] if self.shuffle_datasets: shuffle(all_batches) if self.epoch == 0: print(f"Number of batches: {len(all_batches)}") - print(", ".join(f"{src}: {len(dataloaders[src])}" - for src in sources)) + print(", ".join(f"{src}: {len(dataloaders[src])}" for src in sources)) # Set model train/eval mode - self.model.train(self.phase == 'train') + self.model.train(self.phase == "train") # Switch CNN gradients on/off and set CNN eval mode (for BN modules) - if self.phase == 'train': + if self.phase == "train": cnn_grad = self.epoch >= self.train_cnn_after for param in self._model.cnn.parameters(): param.requires_grad = cnn_grad @@ -304,20 +311,23 @@ def fit_phase(self): # Get the next batch sample = next(data_iters[src]) - target_size = (sample[-1][0][0].item(), - sample[-1][1][0].item()) + target_size = (sample[-1][0][0].item(), sample[-1][1][0].item()) sample = sample[:-1] # Fit/evaluate the batch loss, loss_summands, batch_size = self.fit_sample( - sample, grad_clip=self.grad_clip, target_size=target_size, - source='SALICON' if src == 'MIT1003' else src) + sample, + grad_clip=self.grad_clip, + target_size=target_size, + source="SALICON" if src == "MIT1003" else src, + ) # Book keeping running_losses[src] += loss * batch_size running_loss_summands[src] = [ r + l * batch_size - for r, l in zip(running_loss_summands[src], loss_summands)] + for r, l in zip(running_loss_summands[src], loss_summands) + ] n_samples[src] += batch_size # Book keeping and writing to TensorboardX @@ -325,30 +335,37 @@ def fit_phase(self): for src in sources_eval: phase_loss = running_losses[src] / n_samples[src] phase_loss_summands = [ - loss_ / n_samples[src] for loss_ in running_loss_summands[src]] - - print(f'{src:9s}: Phase: {self.phase}, loss: {phase_loss:.4f}, ' - + ", ".join(f"loss_{idx}: {loss_:.4f}" - for idx, loss_ in enumerate(phase_loss_summands))) + loss_ / n_samples[src] for loss_ in running_loss_summands[src] + ] + + print( + f"{src:9s}: Phase: {self.phase}, loss: {phase_loss:.4f}, " + + ", ".join( + f"loss_{idx}: {loss_:.4f}" + for idx, loss_ in enumerate(phase_loss_summands) + ) + ) - key = 'conv' if src == 'DHF1K' else src.lower() - self.add_scalar(f'{key}/loss/{self.phase}', phase_loss, self.epoch) + key = "conv" if src == "DHF1K" else src.lower() + self.add_scalar(f"{key}/loss/{self.phase}", phase_loss, self.epoch) for idx, loss_ in enumerate(phase_loss_summands): - self.add_scalar(f'{key}/loss_{idx}/{self.phase}', loss_, - self.epoch) - - if src == "DHF1K" and self.phase == 'valid' and\ - self.epoch >= self.chkpnt_warmup: - val_score = - phase_loss + self.add_scalar(f"{key}/loss_{idx}/{self.phase}", loss_, self.epoch) + + if ( + src == "DHF1K" + and self.phase == "valid" + and self.epoch >= self.chkpnt_warmup + ): + val_score = -phase_loss if self.best_val_score is None: self.best_val_score = val_score elif val_score > self.best_val_score: self.best_val_score = val_score self.is_best = True - self.model.save_weights(self.train_dir, 'best') - with open(self.train_dir / 'best_epoch.dat', 'w') as f: + self.model.save_weights(self.train_dir, "best") + with open(self.train_dir / "best_epoch.dat", "w") as f: f.write(str(self.epoch)) - with open(self.train_dir / 'best_val_loss.dat', 'w') as f: + with open(self.train_dir / "best_val_loss.dat", "w") as f: f.write(str(val_score)) else: self.is_best = False @@ -358,7 +375,7 @@ def fit_sample(self, sample, grad_clip=None, **kwargs): Take a sample containing a batch, and fit/evaluate the model """ - with torch.set_grad_enabled(self.phase == 'train'): + with torch.set_grad_enabled(self.phase == "train"): _, x, sal, fix = sample # Add temporal dimension to image data @@ -372,36 +389,38 @@ def fit_sample(self, sample, grad_clip=None, **kwargs): # Switch the gradients of unused modules off to prevent unnecessary # weight decay - if self.phase == 'train': + if self.phase == "train": # Switch the RNN gradients off if this is a image batch rnn_grad = x.shape[1] != 1 or not self.model.bypass_rnn - for param in chain(self._model.rnn.parameters(), - self._model.post_rnn.parameters()): + for param in chain( + self._model.rnn.parameters(), self._model.post_rnn.parameters() + ): param.requires_grad = rnn_grad # Switch the gradients of unused dataset-specific modules off for name, param in self.model.named_parameters(): for source in self.all_data_sources: if source.lower() in name.lower(): - param.requires_grad = source == kwargs['source'] + param.requires_grad = source == kwargs["source"] # Run forward pass pred_seq = self.model(x, **kwargs) # Compute the total loss loss_summands = self.loss_sequences( - pred_seq, sal, fix, metrics=self.loss_metrics) + pred_seq, sal, fix, metrics=self.loss_metrics + ) loss_summands = [l.mean(1).mean(0) for l in loss_summands] - loss = sum(weight * l for weight, l in - zip(self.loss_weights, loss_summands)) + loss = sum( + weight * l for weight, l in zip(self.loss_weights, loss_summands) + ) # Run backward pass and optimization step - if self.phase == 'train': + if self.phase == "train": self.optimizer.zero_grad() loss.backward() if grad_clip is not None: - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), self.grad_clip) + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) self.optimizer.step() return loss.item(), [l.item() for l in loss_summands], x.shape[0] @@ -414,19 +433,31 @@ def loss_sequences(pred_seq, sal_seq, fix_seq, metrics): losses = [] for this_metric in metrics: - if this_metric == 'kld': + if this_metric == "kld": losses.append(utils.kld_loss(pred_seq, sal_seq)) - if this_metric == 'nss': + if this_metric == "nss": losses.append(utils.nss(pred_seq.exp(), fix_seq)) - if this_metric == 'cc': + if this_metric == "cc": losses.append(utils.corr_coeff(pred_seq.exp(), sal_seq)) return losses - def run_inference(self, source, vid_nr, dataset=None, phase=None, - smooth_method=None, metrics=None, save_predictions=False, - return_predictions=False, seq_len_factor=0.5, - random_seed=27, n_aucs_maps=10, auc_portion=1.0, - model_domain=None, folder_suffix=None): + def run_inference( + self, + source, + vid_nr, + dataset=None, + phase=None, + smooth_method=None, + metrics=None, + save_predictions=False, + return_predictions=False, + seq_len_factor=0.5, + random_seed=27, + n_aucs_maps=10, + auc_portion=1.0, + model_domain=None, + folder_suffix=None, + ): if dataset is None: assert phase, "Must provide either dataset or phase" @@ -439,25 +470,26 @@ def run_inference(self, source, vid_nr, dataset=None, phase=None, target_size = dataset.target_size_dict[vid_nr] # Set the keyword arguments for the forward pass - model_kwargs = { - 'source': model_domain or source, - 'target_size': target_size} + model_kwargs = {"source": model_domain or source, "target_size": target_size} # Make sure that the model was trained on the selected domain - if model_kwargs['source'] not in self.model.sources: - print(f"\nWarning! Evaluation bn source {model_kwargs['source']} " - f"doesn't exist in model.\n Using {self.model.sources[0]}.") - model_kwargs['source'] = self.model.sources[0] + if model_kwargs["source"] not in self.model.sources: + print( + f"\nWarning! Evaluation bn source {model_kwargs['source']} " + f"doesn't exist in model.\n Using {self.model.sources[0]}." + ) + model_kwargs["source"] = self.model.sources[0] # Select static or dynamic forward pass for Bypass-RNN model_kwargs.update( - {'static': model_kwargs['source'] in ('SALICON', 'MIT300', 'MIT1003')}) + {"static": model_kwargs["source"] in ("SALICON", "MIT300", "MIT1003")} + ) # Set additional parameters - static_data = source in ('SALICON', 'MIT300', 'MIT1003') + static_data = source in ("SALICON", "MIT300", "MIT1003") if static_data: smooth_method = None - auc_portion = 1. + auc_portion = 1.0 n_images = 1 frame_modulo = 1 else: @@ -470,7 +502,7 @@ def run_inference(self, source, vid_nr, dataset=None, phase=None, torch.cuda.empty_cache() # Prepare the prediction and target tensors - results_size = (1, n_images, 1, *model_kwargs['target_size']) + results_size = (1, n_images, 1, *model_kwargs["target_size"]) pred_seq = torch.full(results_size, 0, dtype=torch.float) if metrics is not None: sal_seq = torch.full(results_size, 0, dtype=torch.float) @@ -508,7 +540,8 @@ def run_inference(self, source, vid_nr, dataset=None, phase=None, if metrics is not None: raise ValueError( "Labels needed for evaluation metrics but not provided" - "by dataset.") + "by dataset." + ) frame_nrs, frame_seq = sample this_sal_seq, this_fix_seq = None, None if frame_seq.dim() == 3: @@ -528,13 +561,12 @@ def run_inference(self, source, vid_nr, dataset=None, phase=None, # Forward pass this_pred_seq, h0 = self.model( - this_frame_seq, h0=h0, return_hidden=True, - **model_kwargs) + this_frame_seq, h0=h0, return_hidden=True, **model_kwargs + ) # Insert the predictions into the prediction array this_pred_seq = this_pred_seq.cpu() - pred_seq[:, this_frame_idx_array, :, :, :] =\ - this_pred_seq + pred_seq[:, this_frame_idx_array, :, :, :] = this_pred_seq # Keep the training targets if scores are to be computed if metrics is not None: @@ -542,7 +574,7 @@ def run_inference(self, source, vid_nr, dataset=None, phase=None, fix_seq[:, frame_idx_array, :, :, :] = this_fix_seq # Assert non-empty predictions - assert(torch.min(pred_seq.exp().sum(-1).sum(-1)) > 0) + assert torch.min(pred_seq.exp().sum(-1).sum(-1)) > 0 # Optionally smooth the interleaved sequences if smooth_method is not None: @@ -561,13 +593,13 @@ def run_inference(self, source, vid_nr, dataset=None, phase=None, this_pred_dir.mkdir(exist_ok=True) # Construct a subfolder if applicable - if source == 'DHF1K': + if source == "DHF1K": this_pred_dir /= f"{vid_nr:04d}" - elif source == 'MIT1003': + elif source == "MIT1003": this_pred_dir /= self.mit1003_dir.stem - elif source == 'MIT300' and 'x_val_step' in self.salicon_cfg: + elif source == "MIT300" and "x_val_step" in self.salicon_cfg: this_pred_dir /= f"MIT300_xVal{self.salicon_cfg['x_val_step']}" - elif source in ('Hollywood', 'UCFSports'): + elif source in ("Hollywood", "UCFSports"): this_pred_dir /= dataset.get_annotation_dir(vid_nr).stem this_pred_dir.mkdir(exist_ok=True) @@ -575,15 +607,16 @@ def run_inference(self, source, vid_nr, dataset=None, phase=None, for frame_idx, smap in enumerate(torch.unbind(pred_seq, dim=1)): # Define the filename - if source == 'SALICON': + if source == "SALICON": filename = f"COCO_test2014_{vid_nr:012d}.png" - elif source == 'MIT300': + elif source == "MIT300": filename = dataset.samples[vid_nr][0] - elif source == 'MIT1003': - filename = dataset.all_image_files[vid_nr]['img'] - elif source in ('Hollywood', 'UCFSports'): + elif source == "MIT1003": + filename = dataset.all_image_files[vid_nr]["img"] + elif source in ("Hollywood", "UCFSports"): filename = dataset.get_data_file( - vid_nr, frame_idx + 1, 'frame').name + vid_nr, frame_idx + 1, "frame" + ).name else: filename = f"{frame_idx + 1:04d}.png" @@ -593,16 +626,13 @@ def run_inference(self, source, vid_nr, dataset=None, phase=None, smap = utils.to_numpy(smap) # Optinally save numpy file in addition to the image - if source == 'MIT300': - np.save( - this_pred_dir / filename.replace(".jpg", ""), - smap.copy()) + if source == "MIT300": + np.save(this_pred_dir / filename.replace(".jpg", ""), smap.copy()) # Save prediction as image smap = (smap / np.amax(smap) * 255).astype(np.uint8) pred_file = this_pred_dir / filename - cv2.imwrite(str(pred_file), smap, - [cv2.IMWRITE_JPEG_QUALITY, 100]) + cv2.imwrite(str(pred_file), smap, [cv2.IMWRITE_JPEG_QUALITY, 100]) # Optionally compute the scores if metrics is not None: @@ -610,7 +640,8 @@ def run_inference(self, source, vid_nr, dataset=None, phase=None, # Compute the KLD, NSS and CC metrics vid_scores = [] loss_sequences = self.loss_sequences( - pred_seq, sal_seq, fix_seq, metrics=metrics) + pred_seq, sal_seq, fix_seq, metrics=metrics + ) loss_sequences = [loss.numpy() for loss in loss_sequences] vid_scores += loss_sequences @@ -618,19 +649,17 @@ def other_maps(): """Sample reference maps for s-AUC""" while True: this_map = np.zeros(results_size[-2:]) - video_nrs = random.sample( - dataset.n_images_dict.keys(), n_aucs_maps) + video_nrs = random.sample(dataset.n_images_dict.keys(), n_aucs_maps) for map_idx, vid_nr in enumerate(video_nrs): - frame_nr = random.randint( - 1, dataset.n_images_dict[vid_nr]) + frame_nr = random.randint(1, dataset.n_images_dict[vid_nr]) if static_data: this_this_map = dataset.get_fixation_map(vid_nr) else: this_this_map = dataset.get_seq( - vid_nr, [frame_nr], 'fix').numpy()[0, 0, ...] + vid_nr, [frame_nr], "fix" + ).numpy()[0, 0, ...] this_this_map = cv2.resize( - this_this_map, tuple(target_size[::-1]), - cv2.INTER_NEAREST + this_this_map, tuple(target_size[::-1]), cv2.INTER_NEAREST ) this_map += this_this_map @@ -639,9 +668,13 @@ def other_maps(): # Compute the SIM, AUC-J, s-AUC metrics vid_scores += self.eval_sequences( - pred_seq, sal_seq, fix_seq, metrics, - other_maps=other_maps() if 'aucs' in metrics else None, - auc_portion=auc_portion) + pred_seq, + sal_seq, + fix_seq, + metrics, + other_maps=other_maps() if "aucs" in metrics else None, + auc_portion=auc_portion, + ) # Average the scores over the frames mean_scores = np.array([np.mean(scores) for scores in vid_scores]) @@ -654,17 +687,17 @@ def other_maps(): return pred_seq @staticmethod - def eval_sequences(pred_seq, sal_seq, fix_seq, metrics, - other_maps=None, auc_portion=1.): + def eval_sequences( + pred_seq, sal_seq, fix_seq, metrics, other_maps=None, auc_portion=1.0 + ): """ Compute SIM, AUC-J and s-AUC scores """ # process inputs - metrics = [metric for metric in metrics - if metric in ('sim', 'aucj', 'aucs')] - if 'aucs' in metrics: - assert(other_maps is not None) + metrics = [metric for metric in metrics if metric in ("sim", "aucj", "aucs")] + if "aucs" in metrics: + assert other_maps is not None # Preprocess sequences shape = pred_seq.shape @@ -677,8 +710,8 @@ def eval_sequences(pred_seq, sal_seq, fix_seq, metrics, # Optionally compute AUC-s for a subset of frames to reduce runtime if auc_portion < 1: auc_indices = set( - random.sample( - range(shape[1]), max(1, int(auc_portion * shape[1])))) + random.sample(range(shape[1]), max(1, int(auc_portion * shape[1]))) + ) else: auc_indices = set(list(range(shape[1]))) @@ -686,18 +719,17 @@ def eval_sequences(pred_seq, sal_seq, fix_seq, metrics, results = {metric: [] for metric in metrics} for idx, (pred, sal, fix) in enumerate(zip(pred_seq, sal_seq, fix_seq)): for this_metric in metrics: - if this_metric == 'sim': - results['sim'].append( - salience_metrics.similarity(pred, sal)) - if this_metric == 'aucj': + if this_metric == "sim": + results["sim"].append(salience_metrics.similarity(pred, sal)) + if this_metric == "aucj": if idx in auc_indices: - results['aucj'].append( - salience_metrics.auc_judd(pred, fix)) - if this_metric == 'aucs': + results["aucj"].append(salience_metrics.auc_judd(pred, fix)) + if this_metric == "aucs": if idx in auc_indices: other_map = next(other_maps) - results['aucs'].append(salience_metrics.auc_shuff_acl( - pred, fix, other_map)) + results["aucs"].append( + salience_metrics.auc_shuff_acl(pred, fix, other_map) + ) return [np.array(results[metric]) for metric in metrics] @property @@ -707,12 +739,21 @@ def pred_dir(self): pred_dir.mkdir(exist_ok=True, parents=True) return pred_dir - def score_model(self, subset=1, source='DHF1K', - metrics=('kld', 'nss', 'cc', 'sim', 'aucj', 'aucs'), - smooth_method=None, seq_len_factor=2, - random_seed=27, n_aucs_maps=10, auc_portion=0.5, - model_domain=None, phase=None, load_weights=True, - vid_nr_array=None): + def score_model( + self, + subset=1, + source="DHF1K", + metrics=("kld", "nss", "cc", "sim", "aucj", "aucs"), + smooth_method=None, + seq_len_factor=2, + random_seed=27, + n_aucs_maps=10, + auc_portion=0.5, + model_domain=None, + phase=None, + load_weights=True, + vid_nr_array=None, + ): """ Compute the evaluation scores of the model. @@ -727,16 +768,15 @@ def score_model(self, subset=1, source='DHF1K', # the last epoch try: self.model.load_best_weights(self.train_dir) - print('Best weights loaded') + print("Best weights loaded") except FileNotFoundError: - print('No best weights found') + print("No best weights found") self.model.load_last_chkpnt(self.train_dir) - print('Last checkpoint loaded') + print("Last checkpoint loaded") # Select the appropriate phase (see docstring) and get the dataset if phase is None: - phase = 'eval' if source in ('DHF1K', 'SALICON', 'MIT1003')\ - else 'test' + phase = "eval" if source in ("DHF1K", "SALICON", "MIT1003") else "test" dataset = self.get_dataset(phase, source) if vid_nr_array is None: @@ -745,22 +785,32 @@ def score_model(self, subset=1, source='DHF1K', # Iterate over the videos/images and compute the scores scores = [] - tmr = utils.Timer(f'Evaluating {len(vid_nr_array)} {source} videos') + tmr = utils.Timer(f"Evaluating {len(vid_nr_array)} {source} videos") random.seed(random_seed) with torch.no_grad(): for vid_idx, vid_nr in enumerate(vid_nr_array): this_scores = self.run_inference( - source, vid_nr, dataset=dataset, - smooth_method=smooth_method, metrics=metrics, - seq_len_factor=seq_len_factor, n_aucs_maps=n_aucs_maps, - auc_portion=auc_portion, model_domain=model_domain) + source, + vid_nr, + dataset=dataset, + smooth_method=smooth_method, + metrics=metrics, + seq_len_factor=seq_len_factor, + n_aucs_maps=n_aucs_maps, + auc_portion=auc_portion, + model_domain=model_domain, + ) scores.append(this_scores) if vid_idx == 0: - print(f' Nr. ( .../{len(vid_nr_array):4d}), ' + - ', '.join(f'{metric:5s}' for metric in metrics)) - print(f'{vid_nr:6d} ' + - f'({vid_idx + 1:4d}/{len(vid_nr_array):4d}), ' + - ', '.join(f'{score:.3f}' for score in this_scores)) + print( + f" Nr. ( .../{len(vid_nr_array):4d}), " + + ", ".join(f"{metric:5s}" for metric in metrics) + ) + print( + f"{vid_nr:6d} " + + f"({vid_idx + 1:4d}/{len(vid_nr_array):4d}), " + + ", ".join(f"{score:.3f}" for score in this_scores) + ) # Compute the average video scores tmr.finish() @@ -771,45 +821,46 @@ def score_model(self, subset=1, source='DHF1K', # which means that each videos contribution to the overall score is # weighted by its number of frames. The equivalent scores are denoted # below as weighted mean - num_frames_array = [ - dataset.n_images_dict[vid_nr] for vid_nr in vid_nr_array] + num_frames_array = [dataset.n_images_dict[vid_nr] for vid_nr in vid_nr_array] weighted_mean_scores = np.average(scores, 0, num_frames_array) # Print and save the scores print() print("Macro average (average of video averages) scores:") - print(', '.join(f'{metric:5s}' for metric in metrics)) - print(', '.join(f'{score:.3f}' for score in mean_scores)) + print(", ".join(f"{metric:5s}" for metric in metrics)) + print(", ".join(f"{score:.3f}" for score in mean_scores)) print() print("Weighted average (per-frame average) scores:") - print(', '.join(f'{metric:5s}' for metric in metrics)) - print(', '.join(f'{score:.3f}' for score in weighted_mean_scores)) + print(", ".join(f"{metric:5s}" for metric in metrics)) + print(", ".join(f"{score:.3f}" for score in weighted_mean_scores)) if subset == 1: - dest_dir = self.mit1003_dir if source == 'MIT1003' \ - else self.train_dir - if source in ('Hollywood', 'UCFSports'): + dest_dir = self.mit1003_dir if source == "MIT1003" else self.train_dir + if source in ("Hollywood", "UCFSports"): source += "_resized" - with open(dest_dir / f'{source}_eval_scores.json', 'w') as f: + with open(dest_dir / f"{source}_eval_scores.json", "w") as f: json.dump(scores, f, cls=utils.NumpyEncoder, indent=2) - with open(dest_dir / f'{source}_eval_mean_scores.json', 'w')\ - as f: + with open(dest_dir / f"{source}_eval_mean_scores.json", "w") as f: json.dump(mean_scores, f, cls=utils.NumpyEncoder, indent=2) - with open(dest_dir / - f'{source}_eval_weighted_mean_scores.json', 'w') as f: - json.dump(weighted_mean_scores, f, cls=utils.NumpyEncoder, - indent=2) + with open(dest_dir / f"{source}_eval_weighted_mean_scores.json", "w") as f: + json.dump(weighted_mean_scores, f, cls=utils.NumpyEncoder, indent=2) try: for score, metric in zip(mean_scores, metrics): - self.add_scalar(f'{source}_eval/{metric}', score, self.epoch) + self.add_scalar(f"{source}_eval/{metric}", score, self.epoch) self.export_scalars(suffix=f"_eval-{source}") except AttributeError: - print('add_scalar failed because tensorboard is closed.') + print("add_scalar failed because tensorboard is closed.") return metrics, mean_scores, scores - def generate_predictions(self, smooth_method=None, source='DHF1K', - phase='eval', load_weights=True, vid_nr_array=None, - **kwargs): + def generate_predictions( + self, + smooth_method=None, + source="DHF1K", + phase="eval", + load_weights=True, + vid_nr_array=None, + **kwargs, + ): """Generate predictions for submission and visualization""" if load_weights: @@ -817,49 +868,53 @@ def generate_predictions(self, smooth_method=None, source='DHF1K', # the last epoch try: self.model.load_best_weights(self.train_dir) - print('Best weights loaded') + print("Best weights loaded") except FileNotFoundError: - print('No best weights found') + print("No best weights found") self.model.load_last_chkpnt(self.train_dir) - print('Last checkpoint loaded') + print("Last checkpoint loaded") # Get the dataset dataset = self.get_dataset(phase, source) if vid_nr_array is None: # Get list of sample numbers - if source == 'MIT300': + if source == "MIT300": vid_nr_array = list(range(300)) else: vid_nr_array = list(dataset.n_images_dict.keys()) - tmr = utils.Timer(f'Predicting {len(vid_nr_array)} {source} videos') + tmr = utils.Timer(f"Predicting {len(vid_nr_array)} {source} videos") with torch.no_grad(): for vid_nr in vid_nr_array: self.run_inference( - source, vid_nr, dataset=dataset, - smooth_method=smooth_method, save_predictions=True, - **kwargs) + source, + vid_nr, + dataset=dataset, + smooth_method=smooth_method, + save_predictions=True, + **kwargs, + ) tmr.finish() def generate_predictions_from_path( - self, folder_path, is_video, source=None, load_weights=True, - **kwargs): + self, folder_path, is_video, source=None, load_weights=True, **kwargs + ): # Process inputs if source is None: - source = 'DHF1K' if is_video else 'SALICON' + source = "DHF1K" if is_video else "SALICON" - if source in ('MIT1003', 'MIT300'): + if source in ("MIT1003", "MIT300"): try: self.model.load_weights(self.train_dir, "ft_mit1003") print("Fine-tuned MIT1003 weights loaded") load_weights = False except: print("No MIT1003 fine-tuned weights found.") - source = 'SALICON' + source = "SALICON" - images_path = folder_path / 'images' + images_path = folder_path / "images" torch.cuda.empty_cache() if load_weights: @@ -867,23 +922,30 @@ def generate_predictions_from_path( # the last epoch try: self.model.load_best_weights(self.train_dir) - print('Best weights loaded') + print("Best weights loaded") except FileNotFoundError: - print('No best weights found') + print("No best weights found") self.model.load_last_chkpnt(self.train_dir) - print('Last checkpoint loaded') + print("Last checkpoint loaded") with torch.no_grad(): if is_video: - frame_modulo = 5 if source == 'DHF1K' else 4 + frame_modulo = 5 if source == "DHF1K" else 4 dataset = data.FolderVideoDataset( - images_path, source=source, frame_modulo=frame_modulo) - pred_dir = folder_path / 'saliency' + images_path, source=source, frame_modulo=frame_modulo + ) + pred_dir = folder_path / "saliency" pred_dir.mkdir(exist_ok=True) pred_seq = self.run_inference( - source, 0, dataset=dataset, phase=None, - return_predictions=True, folder_suffix=None, **kwargs) + source, + 0, + dataset=dataset, + phase=None, + return_predictions=True, + folder_suffix=None, + **kwargs, + ) # Iterate over the prediction frames for frame_idx, smap in enumerate(torch.unbind(pred_seq, dim=1)): @@ -897,18 +959,22 @@ def generate_predictions_from_path( filename = dataset.frame_files[frame_idx].name smap = (smap / np.amax(smap) * 255).astype(np.uint8) pred_file = pred_dir / filename - cv2.imwrite( - str(pred_file), smap, [cv2.IMWRITE_JPEG_QUALITY, 100]) + cv2.imwrite(str(pred_file), smap, [cv2.IMWRITE_JPEG_QUALITY, 100]) else: dataset = data.FolderImageDataset(images_path) - pred_dir = folder_path / 'saliency' + pred_dir = folder_path / "saliency" pred_dir.mkdir(exist_ok=True) for img_idx in range(len(dataset)): pred_seq = self.run_inference( - source, img_idx, dataset=dataset, phase=None, - return_predictions=True, **kwargs) + source, + img_idx, + dataset=dataset, + phase=None, + return_predictions=True, + **kwargs, + ) smap = pred_seq[:, 0, ...] @@ -921,33 +987,32 @@ def generate_predictions_from_path( filename = dataset.image_files[img_idx].name smap = (smap / np.amax(smap) * 255).astype(np.uint8) pred_file = pred_dir / filename - cv2.imwrite( - str(pred_file), smap, [cv2.IMWRITE_JPEG_QUALITY, 100]) + cv2.imwrite(str(pred_file), smap, [cv2.IMWRITE_JPEG_QUALITY, 100]) def fine_tune_mit( - self, lr=0.01, num_epochs=8, lr_gamma=0.8, x_val_step=0, - train_cnn_after=0): + self, lr=0.01, num_epochs=8, lr_gamma=0.8, x_val_step=0, train_cnn_after=0 + ): """Fine tune the model with the MIT1003 dataset for MIT300 submission""" # Set the fine tuning parameters self.num_epochs = num_epochs self.lr = lr self.lr_gamma = lr_gamma - self.lr_scheduler = 'ExponentialLR' - self.optim_algo = 'SGD' + self.lr_scheduler = "ExponentialLR" + self.optim_algo = "SGD" self.momentum = 0.9 self.weight_decay = 1e-4 self.grad_clip = None - self.loss_weights = (1.,) - self.loss_metrics = ('kld',) - self.salicon_weight = 1. + self.loss_weights = (1.0,) + self.loss_metrics = ("kld",) + self.salicon_weight = 1.0 self.salicon_batch_size = 32 - self.data_sources = ('MIT1003',) + self.data_sources = ("MIT1003",) self.shuffle_datasets = True self.cnn_lr_factor = 0.1 self.train_cnn_after = train_cnn_after self.cnn_eval = True - self.salicon_cfg.update({'x_val_step': x_val_step}) + self.salicon_cfg.update({"x_val_step": x_val_step}) self.mit1003_finetuned = True self.num_workers = 4 @@ -956,9 +1021,9 @@ def fine_tune_mit( # the last epoch try: self.model.load_best_weights(self.train_dir) - print('Best weights loaded') + print("Best weights loaded") except FileNotFoundError: - print('No best weights found') + print("No best weights found") self.model.load_last_chkpnt(self.train_dir) # Run the fine tuning @@ -968,18 +1033,18 @@ def fine_tune_mit( self._model.to(self.device) while self.epoch < self.num_epochs: self.scheduler.step(epoch=self.epoch) - lr = self.optimizer.param_groups[0]['lr'] + lr = self.optimizer.param_groups[0]["lr"] print(f"\nEpoch {self.epoch:3d}, lr {lr:.5f}") for self.phase in self.phases: self.fit_phase() - val_loss = self.all_scalars['mit1003']['loss']['valid'][self.epoch] + val_loss = self.all_scalars["mit1003"]["loss"]["valid"][self.epoch] if math.isnan(val_loss): best_epoch = 0 best_val = 1000 break - val_score = - val_loss + val_score = -val_loss if self.best_val_score is None: self.best_val_score = val_score elif val_score > self.best_val_score: @@ -1004,16 +1069,17 @@ def get_dataset(self, phase, source="DHF1K"): else: dataset_cls_name = f"{source}Dataset" dataset_cls = getattr(data, dataset_cls_name) - if source in ('MIT300',): + if source in ("MIT300",): config = {} - elif source in ('MIT1003',): + elif source in ("MIT1003",): config = getattr(self, f"salicon_cfg") else: config = getattr(self, f"{source.lower()}_cfg") - self._datasets[source][phase] = dataset_cls( - phase=phase, **config) - print(f'{source:10s} {phase} dataset loaded with' - f' {len(self._datasets[source][phase])} samples') + self._datasets[source][phase] = dataset_cls(phase=phase, **config) + print( + f"{source:10s} {phase} dataset loaded with" + f" {len(self._datasets[source][phase])} samples" + ) return self._datasets[source][phase] def get_dataloader(self, phase, source="DHF1K"): @@ -1026,22 +1092,24 @@ def get_dataloader(self, phase, source="DHF1K"): if source == "DHF1K": batch_size = self.batch_size elif source in ("Hollywood", "UCFSports"): - batch_size = self.__getattribute__( - f"{source.lower()}_batch_size") - elif phase == 'valid' and source == 'MIT1003': + batch_size = self.__getattribute__(f"{source.lower()}_batch_size") + elif phase == "valid" and source == "MIT1003": batch_size = 8 - elif source in ('SALICON', 'MIT1003'): - batch_size = self.salicon_batch_size or len(dataset) //\ - len(self.get_dataloader(phase)) + elif source in ("SALICON", "MIT1003"): + batch_size = self.salicon_batch_size or len(dataset) // len( + self.get_dataloader(phase) + ) if batch_size > 8: batch_size -= batch_size % 2 batch_size = min(32, batch_size) else: - raise ValueError(f'Unknown dataset source {source}') + raise ValueError(f"Unknown dataset source {source}") print(f"{source}, phase {phase} batch size: {batch_size}") self._dataloaders[source][phase] = data.get_dataloader(source)( - dataset, batch_size=batch_size, - shuffle=phase == 'train', num_workers=self.num_workers, + dataset, + batch_size=batch_size, + shuffle=phase == "train", + num_workers=self.num_workers, drop_last=True, ) return self._dataloaders[source][phase] @@ -1065,7 +1133,7 @@ def measure_runtime(self): # Prepare the experiment self.model.eval() self.num_workers = 0 - dl = self.get_dataloader('test', 'DHF1K') + dl = self.get_dataloader("test", "DHF1K") sample = next(iter(dl)) _, x, _ = sample x = x.float().cuda() @@ -1076,10 +1144,9 @@ def measure_runtime(self): # Measure the average time to process single frames on the GPU times = [] for t_idx in range(1, x.shape[1]): - x_t = x[:1, t_idx:t_idx+1, ...].clone().contiguous() + x_t = x[:1, t_idx : t_idx + 1, ...].clone().contiguous() t0 = time.time() - output, h0 = self.model( - x_t, h0=h0, return_hidden=True, static=False) + output, h0 = self.model(x_t, h0=h0, return_hidden=True, static=False) dt = time.time() - t0 times.append(dt) print() @@ -1097,16 +1164,14 @@ def measure_runtime(self): times = [] torch.cuda.empty_cache() for t_idx in range(1, min(x.shape[1], 16)): - x_t = x[:1, t_idx:t_idx+1, ...].clone().contiguous() + x_t = x[:1, t_idx : t_idx + 1, ...].clone().contiguous() t0 = time.time() - output, h0 = self.model( - x_t, h0=h0, return_hidden=True, static=False) + output, h0 = self.model(x_t, h0=h0, return_hidden=True, static=False) dt = time.time() - t0 times.append(dt) print() dt = sum(times) / len(times) - print("Avg single-frame CPU time: " - f"{dt:.4f} s ({1 / dt:.1f} fps)") + print("Avg single-frame CPU time: " f"{dt:.4f} s ({1 / dt:.1f} fps)") dt = min(times) print(f"Min single-frame CPU time: {dt:.4f} s ({1 / dt:.1f} fps)") dt = max(times) @@ -1120,12 +1185,11 @@ def measure_model_size(self): net = model_cls(verbose=0, **this_model_cfg) dest = self.train_dir - torch.save(net, dest / 'net_full.pth') - file_size = (dest / 'net_full.pth').stat().st_size / 1e6 - print("All data sources net size: " - f"{file_size:.2f} MB") - torch.save(net.state_dict(), dest / 'net_full_state_dict.pth') - file_size = (dest / 'net_full_state_dict.pth').stat().st_size / 1e6 + torch.save(net, dest / "net_full.pth") + file_size = (dest / "net_full.pth").stat().st_size / 1e6 + print("All data sources net size: " f"{file_size:.2f} MB") + torch.save(net.state_dict(), dest / "net_full_state_dict.pth") + file_size = (dest / "net_full_state_dict.pth").stat().st_size / 1e6 print(f"All data sources net state dict size: {file_size:.2f} MB") def get_model_parameter_groups(self): @@ -1133,13 +1197,14 @@ def get_model_parameter_groups(self): Get parameter groups. Output CNN parameters separately with reduced LR and weight decay. """ + def parameters_except_cnn(): parameters = [] adaptation = [] for name, module in self.model.named_children(): - if name == 'cnn': + if name == "cnn": continue - elif 'adaptation' in name: + elif "adaptation" in name: adaptation += list(module.parameters()) else: parameters += list(module.parameters()) @@ -1148,33 +1213,42 @@ def parameters_except_cnn(): parameters, adaptation = parameters_except_cnn() for name, this_parameter in self.model.named_parameters(): - if 'gaussian' in name: + if "gaussian" in name: parameters.append(this_parameter) return [ - {'params': parameters + adaptation}, - {'params': self.model.cnn.parameters(), - 'lr': self.lr * self.cnn_lr_factor, - 'weight_decay': self.cnn_weight_decay, - }, + {"params": parameters + adaptation}, + { + "params": self.model.cnn.parameters(), + "lr": self.lr * self.cnn_lr_factor, + "weight_decay": self.cnn_weight_decay, + }, ] @property def optimizer(self): """Return the optimizer""" if self._optimizer is None: - if self.optim_algo == 'SGD': + if self.optim_algo == "SGD": self._optimizer = torch.optim.SGD( - self.get_model_parameter_groups(), lr=self.lr, - momentum=self.momentum, weight_decay=self.weight_decay) - elif self.optim_algo == 'Adam': + self.get_model_parameter_groups(), + lr=self.lr, + momentum=self.momentum, + weight_decay=self.weight_decay, + ) + elif self.optim_algo == "Adam": self._optimizer = torch.optim.Adam( - self.get_model_parameter_groups(), lr=self.lr, - weight_decay=self.weight_decay) - elif self.optim_algo == 'RMSprop': + self.get_model_parameter_groups(), + lr=self.lr, + weight_decay=self.weight_decay, + ) + elif self.optim_algo == "RMSprop": self._optimizer = torch.optim.RMSprop( - self.get_model_parameter_groups(), lr=self.lr, - weight_decay=self.weight_decay, momentum=self.momentum) + self.get_model_parameter_groups(), + lr=self.lr, + weight_decay=self.weight_decay, + momentum=self.momentum, + ) return self._optimizer @@ -1182,10 +1256,10 @@ def optimizer(self): def scheduler(self): """Return the learning rate scheduler""" if self._scheduler is None: - if self.lr_scheduler == 'ExponentialLR': + if self.lr_scheduler == "ExponentialLR": self._scheduler = torch.optim.lr_scheduler.ExponentialLR( - self.optimizer, gamma=self.lr_gamma, - last_epoch=self.epoch - 1) + self.optimizer, gamma=self.lr_gamma, last_epoch=self.epoch - 1 + ) else: raise ValueError(f"Unknown scheduler {self.lr_scheduler}") return self._scheduler @@ -1199,26 +1273,26 @@ def copy_code(self): """Make a copy of the code to facilitate leading of older models""" source = Path(inspect.getfile(Trainer)).parent.parent - destination = self.train_dir / 'code_copy' + destination = self.train_dir / "code_copy" tracked_files = [ - '.gitignore', - 'unisal/__init__.py', - 'unisal/data.py', - 'unisal/model.py', - 'unisal/models/MobileNetV2.py', - 'unisal/models/__init__.py', - 'unisal/models/cgru.py', - 'unisal/models/weights/mobilenet_v2.pth.tar', - 'unisal/salience_metrics.py', - 'unisal/train.py', - 'unisal/utils.py', - 'run.py', - 'unisal/dhf1k_n_images.dat', - 'unisal/cache/img_size_dict.json', - 'unisal/cache/train_hollywood_register.json', - 'unisal/cache/train_ucfsports_register.json', - 'unisal/cache/test_hollywood_register.json', - 'unisal/cache/test_ucfsports_register.json', + ".gitignore", + "unisal/__init__.py", + "unisal/data.py", + "unisal/model.py", + "unisal/models/MobileNetV2.py", + "unisal/models/__init__.py", + "unisal/models/cgru.py", + "unisal/models/weights/mobilenet_v2.pth.tar", + "unisal/salience_metrics.py", + "unisal/train.py", + "unisal/utils.py", + "run.py", + "unisal/dhf1k_n_images.dat", + "unisal/cache/img_size_dict.json", + "unisal/cache/train_hollywood_register.json", + "unisal/cache/train_ucfsports_register.json", + "unisal/cache/test_hollywood_register.json", + "unisal/cache/test_ucfsports_register.json", ] for file in tracked_files: subdir = Path(file).parent @@ -1229,30 +1303,43 @@ def save_chkpnt(self): """Save model and trainer checkpoint""" print(f"Saving checkpoint at epoch {self.epoch}") chkpnt = { - 'model_state_dict': self.model.state_dict(), - 'optimizer_state_dict': self.optimizer.state_dict(), + "model_state_dict": self.model.state_dict(), + "optimizer_state_dict": self.optimizer.state_dict(), } - chkpnt.update({key: self.__dict__[key] for key in ( - 'epoch', 'best_epoch', 'best_val_score', 'all_scalars',)}) - torch.save(chkpnt, self.train_dir / f'chkpnt_epoch{self.epoch:04d}.pth') + chkpnt.update( + { + key: self.__dict__[key] + for key in ( + "epoch", + "best_epoch", + "best_val_score", + "all_scalars", + ) + } + ) + torch.save(chkpnt, self.train_dir / f"chkpnt_epoch{self.epoch:04d}.pth") def load_checkpoint(self, file): """Load model and trainer checkpoint""" chkpnt = torch.load(file) - self.model.load_state_dict(chkpnt['model_state_dict']) - self.optimizer.load_state_dict(chkpnt['optimizer_state_dict']) - self.__dict__.update({key: chkpnt[key] for key in ( - 'epoch', 'best_epoch', 'best_val_score', 'all_scalars')}) + self.model.load_state_dict(chkpnt["model_state_dict"]) + self.optimizer.load_state_dict(chkpnt["optimizer_state_dict"]) + self.__dict__.update( + { + key: chkpnt[key] + for key in ("epoch", "best_epoch", "best_val_score", "all_scalars") + } + ) self.epoch += 1 def load_last_chkpnt(self): """Load latest model checkpoint""" - last_chkpnt = sorted(list(self.train_dir.glob('chkpnt_epoch*.pth')))[-1] + last_chkpnt = sorted(list(self.train_dir.glob("chkpnt_epoch*.pth")))[-1] self.load_checkpoint(last_chkpnt) def add_scalar(self, key, value, epoch, this_tboard=True): """Add a scalar to self.all_scalars and TensorboardX""" - keys = key.split('/') + keys = key.split("/") this_dict = self.all_scalars for key_ in keys: if key_ not in this_dict: @@ -1267,7 +1354,7 @@ def add_scalar(self, key, value, epoch, this_tboard=True): def writer(self): """Return TensorboardX writer""" if self.tboard and self._writer is None: - if self.data_sources == ('MIT1003',): + if self.data_sources == ("MIT1003",): log_dir = self.mit1003_dir log_dir.mkdir(exist_ok=True) else: @@ -1279,34 +1366,36 @@ def writer(self): def mit1003_dir(self): """Return directory to fine tune on MIT1003""" if self.mit1003_finetuned: - mit1003_dir = self.train_dir / f"MIT1003_lr{self.lr:.4f}_" \ - f"lrGamma{self.lr_gamma:.2f}_nEpochs{self.num_epochs}_" \ - f"TrainCNNAfter{self.train_cnn_after}_" \ + mit1003_dir = ( + self.train_dir / f"MIT1003_lr{self.lr:.4f}_" + f"lrGamma{self.lr_gamma:.2f}_nEpochs{self.num_epochs}_" + f"TrainCNNAfter{self.train_cnn_after}_" f"xVal{self.salicon_cfg['x_val_step']}" + ) else: - mit1003_dir = self.train_dir / f"MIT1003_" \ - f"xVal{self.salicon_cfg['x_val_step']}" + mit1003_dir = ( + self.train_dir / f"MIT1003_" f"xVal{self.salicon_cfg['x_val_step']}" + ) mit1003_dir.mkdir(exist_ok=True) return mit1003_dir - def export_scalars(self, suffix=''): + def export_scalars(self, suffix=""): """Save self.all_scalars""" - if self.data_sources == ('MIT1003',): + if self.data_sources == ("MIT1003",): export_dir = self.mit1003_dir else: export_dir = self.train_dir - with open(export_dir / f'all_scalars{suffix}.json', 'w') as f: - json.dump(self.all_scalars, f, cls=utils.NumpyEncoder, - indent=2) + with open(export_dir / f"all_scalars{suffix}.json", "w") as f: + json.dump(self.all_scalars, f, cls=utils.NumpyEncoder, indent=2) def get_configs(self): """Get configurations of trainer, dataset and model instances""" return { - 'Trainer': self.asdict(), - 'Dataset': self.get_dataset('train', 'DHF1K').asdict(), - 'Model': self.model.asdict() + "Trainer": self.asdict(), + "Dataset": self.get_dataset("train", "DHF1K").asdict(), + "Model": self.model.asdict(), } @property def train_id(self): - return '/'.join(self.train_dir.parts[-2:]) + return "/".join(self.train_dir.parts[-2:]) From 57849f2fb2f8d75a6bcdfd482d4fceeca6242ae5 Mon Sep 17 00:00:00 2001 From: ohjho Date: Tue, 16 Jul 2024 15:20:02 -0400 Subject: [PATCH 2/6] added image preprocessing function and proprocessing function for saliency map --- unisal/data.py | 876 ++++++++++++++++++++++++++++--------------------- 1 file changed, 511 insertions(+), 365 deletions(-) diff --git a/unisal/data.py b/unisal/data.py index d4366f3..f91af36 100644 --- a/unisal/data.py +++ b/unisal/data.py @@ -1,4 +1,3 @@ - from pathlib import Path import os import random @@ -7,8 +6,13 @@ import copy import torch -from torch.utils.data import Dataset, DataLoader, BatchSampler, RandomSampler, \ - SequentialSampler +from torch.utils.data import ( + Dataset, + DataLoader, + BatchSampler, + RandomSampler, + SequentialSampler, +) from torchvision import transforms import numpy as np import cv2 @@ -25,8 +29,7 @@ if "SALICON_DATA_DIR" not in os.environ: os.environ["SALICON_DATA_DIR"] = str(default_data_dir / "SALICON") if "HOLLYWOOD_DATA_DIR" not in os.environ: - os.environ["HOLLYWOOD_DATA_DIR"] = str( - default_data_dir / "Hollywood2_actions") + os.environ["HOLLYWOOD_DATA_DIR"] = str(default_data_dir / "Hollywood2_actions") if "UCFSPORTS_DATA_DIR" not in os.environ: os.environ["UCFSPORTS_DATA_DIR"] = str(default_data_dir / "ucf-002") if "MIT300_DATA_DIR" not in os.environ: @@ -40,77 +43,96 @@ def get_dataset(): return DHF1KDataset -def get_dataloader(src='DHF1K'): - if src in ('MIT1003',): +def get_dataloader(src="DHF1K"): + if src in ("MIT1003",): return ImgSizeDataLoader return DataLoader class SALICONDataset(Dataset, utils.KwConfigClass): - source = 'SALICON' + source = "SALICON" dynamic = False - def __init__(self, phase='train', subset=None, verbose=1, - out_size=(288, 384), target_size=(480, 640), - preproc_cfg=None): + def __init__( + self, + phase="train", + subset=None, + verbose=1, + out_size=(288, 384), + target_size=(480, 640), + preproc_cfg=None, + ): self.phase = phase - self.train = phase == 'train' + self.train = phase == "train" self.subset = subset self.verbose = verbose self.out_size = out_size self.target_size = target_size self.preproc_cfg = { - 'rgb_mean': (0.485, 0.456, 0.406), - 'rgb_std': (0.229, 0.224, 0.225), + "rgb_mean": (0.485, 0.456, 0.406), + "rgb_std": (0.229, 0.224, 0.225), } if preproc_cfg is not None: self.preproc_cfg.update(preproc_cfg) - self.phase_str = 'val' if phase in ('valid', 'eval') else phase + self.phase_str = "val" if phase in ("valid", "eval") else phase self.file_stem = f"COCO_{self.phase_str}2014_" self.file_nr = "{:012d}" self.samples = self.prepare_samples() if self.subset is not None: - self.samples = self.samples[:int(len(self.samples) * subset)] + self.samples = self.samples[: int(len(self.samples) * subset)] # For compatibility with video datasets self.n_images_dict = {img_nr: 1 for img_nr in self.samples} - self.target_size_dict = { - img_nr: self.target_size for img_nr in self.samples} + self.target_size_dict = {img_nr: self.target_size for img_nr in self.samples} self.n_samples = len(self.samples) self.frame_modulo = 1 def get_map(self, img_nr): - map_file = self.dir / 'maps' / self.phase_str / ( - self.file_stem + self.file_nr.format(img_nr) + '.png') + map_file = ( + self.dir + / "maps" + / self.phase_str + / (self.file_stem + self.file_nr.format(img_nr) + ".png") + ) map = cv2.imread(str(map_file), cv2.IMREAD_GRAYSCALE) - assert(map is not None) + assert map is not None return map def get_img(self, img_nr): - img_file = self.dir / 'images' / ( - self.file_stem + self.file_nr.format(img_nr) + '.jpg') + img_file = ( + self.dir + / "images" + / (self.file_stem + self.file_nr.format(img_nr) + ".jpg") + ) img = cv2.imread(str(img_file)) - assert(img is not None) + assert img is not None return np.ascontiguousarray(img[:, :, ::-1]) def get_raw_fixations(self, img_nr): - raw_fix_file = self.dir / 'fixations' / self.phase_str / ( - self.file_stem + self.file_nr.format(img_nr) + '.mat') + raw_fix_file = ( + self.dir + / "fixations" + / self.phase_str + / (self.file_stem + self.file_nr.format(img_nr) + ".mat") + ) fix_data = scipy.io.loadmat(raw_fix_file) - fixations_array = [gaze[2] for gaze in fix_data['gaze'][:, 0]] - return fixations_array, fix_data['resolution'].tolist()[0] + fixations_array = [gaze[2] for gaze in fix_data["gaze"][:, 0]] + return fixations_array, fix_data["resolution"].tolist()[0] def process_raw_fixations(self, fixations_array, res): fix_map = np.zeros(res, dtype=np.uint8) for subject_fixations in fixations_array: - fix_map[subject_fixations[:, 1] - 1, subject_fixations[:, 0] - 1]\ - = 255 + fix_map[subject_fixations[:, 1] - 1, subject_fixations[:, 0] - 1] = 255 return fix_map def get_fixation_map(self, img_nr): - fix_map_file = self.dir / 'fixations' / self.phase_str / ( - self.file_stem + self.file_nr.format(img_nr) + '.png') + fix_map_file = ( + self.dir + / "fixations" + / self.phase_str + / (self.file_stem + self.file_nr.format(img_nr) + ".png") + ) if fix_map_file.exists(): fix_map = cv2.imread(str(fix_map_file), cv2.IMREAD_GRAYSCALE) else: @@ -125,30 +147,32 @@ def dir(self): def prepare_samples(self): samples = [] - for file in (self.dir / 'images').glob(self.file_stem + '*.jpg'): + for file in (self.dir / "images").glob(self.file_stem + "*.jpg"): samples.append(int(file.stem[-12:])) return sorted(samples) def __len__(self): return len(self.samples) - def preprocess(self, img, data='img'): + def preprocess(self, img, data="img"): transformations = [ transforms.ToPILImage(), ] - if data == 'img': - transformations.append(transforms.Resize( - self.out_size, interpolation=PIL.Image.LANCZOS)) + if data == "img": + transformations.append( + transforms.Resize(self.out_size, interpolation=PIL.Image.LANCZOS) + ) transformations.append(transforms.ToTensor()) - if data == 'img' and 'rgb_mean' in self.preproc_cfg: + if data == "img" and "rgb_mean" in self.preproc_cfg: transformations.append( transforms.Normalize( - self.preproc_cfg['rgb_mean'], self.preproc_cfg['rgb_std'])) - elif data == 'sal': + self.preproc_cfg["rgb_mean"], self.preproc_cfg["rgb_std"] + ) + ) + elif data == "sal": transformations.append(transforms.Lambda(utils.normalize_tensor)) - elif data == 'fix': - transformations.append( - transforms.Lambda(lambda fix: torch.gt(fix, 0.5))) + elif data == "fix": + transformations.append(transforms.Lambda(lambda fix: torch.gt(fix, 0.5))) processing = transforms.Compose(transformations) tensor = processing(img) @@ -156,14 +180,14 @@ def preprocess(self, img, data='img'): def get_data(self, img_nr): img = self.get_img(img_nr) - img = self.preprocess(img, data='img') - if self.phase == 'test': + img = self.preprocess(img, data="img") + if self.phase == "test": return [1], img, self.target_size sal = self.get_map(img_nr) - sal = self.preprocess(sal, data='sal') + sal = self.preprocess(sal, data="sal") fix = self.get_fixation_map(img_nr) - fix = self.preprocess(fix, data='fix') + fix = self.preprocess(fix, data="fix") return [1], img, sal, fix, self.target_size @@ -175,20 +199,20 @@ def __getitem__(self, item): class ImgSizeBatchSampler: def __init__(self, dataset, batch_size=1, shuffle=False, drop_last=False): - assert(isinstance(dataset, MIT1003Dataset)) + assert isinstance(dataset, MIT1003Dataset) self.batch_size = batch_size self.shuffle = shuffle self.drop_last = drop_last out_size_array = [ - dataset.size_dict[img_idx]['out_size'] - for img_idx in dataset.samples] + dataset.size_dict[img_idx]["out_size"] for img_idx in dataset.samples + ] self.out_size_set = sorted(list(set(out_size_array))) - self.sample_idx_dict = { - out_size: [] for out_size in self.out_size_set} + self.sample_idx_dict = {out_size: [] for out_size in self.out_size_set} for sample_idx, img_idx in enumerate(dataset.samples): - self.sample_idx_dict[dataset.size_dict[img_idx]['out_size']].append( - sample_idx) + self.sample_idx_dict[dataset.size_dict[img_idx]["out_size"]].append( + sample_idx + ) self.len = 0 self.n_batches_dict = {} @@ -198,9 +222,12 @@ def __init__(self, dataset, batch_size=1, shuffle=False, drop_last=False): self.n_batches_dict[out_size] = this_n_batches def __iter__(self): - batch_array = list(itertools.chain.from_iterable( - [out_size for _ in range(n_batches)] - for out_size, n_batches in self.n_batches_dict.items())) + batch_array = list( + itertools.chain.from_iterable( + [out_size for _ in range(n_batches)] + for out_size, n_batches in self.n_batches_dict.items() + ) + ) if not self.shuffle: random.seed(27) random.shuffle(batch_array) @@ -209,8 +236,8 @@ def __iter__(self): for sample_idx_array in this_sample_idx_dict.values(): random.shuffle(sample_idx_array) for out_size in batch_array: - this_indices = this_sample_idx_dict[out_size][:self.batch_size] - del this_sample_idx_dict[out_size][:self.batch_size] + this_indices = this_sample_idx_dict[out_size][: self.batch_size] + del this_sample_idx_dict[out_size][: self.batch_size] yield this_indices def __len__(self): @@ -219,8 +246,7 @@ def __len__(self): class ImgSizeDataLoader(DataLoader): - def __init__(self, dataset, batch_size=1, shuffle=False, drop_last=False, - **kwargs): + def __init__(self, dataset, batch_size=1, shuffle=False, drop_last=False, **kwargs): if batch_size == 1: if shuffle: sampler = RandomSampler(dataset) @@ -229,32 +255,34 @@ def __init__(self, dataset, batch_size=1, shuffle=False, drop_last=False, batch_sampler = BatchSampler(sampler, batch_size, drop_last) else: batch_sampler = ImgSizeBatchSampler( - dataset, batch_size=batch_size, shuffle=shuffle, - drop_last=drop_last) + dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last + ) super().__init__(dataset, batch_sampler=batch_sampler, **kwargs) class MIT300Dataset(Dataset, utils.KwConfigClass): - source = 'MIT300' + source = "MIT300" dynamic = False - def __init__(self, phase='test'): - assert(phase == 'test') + def __init__(self, phase="test"): + assert phase == "test" self.phase = phase self.train = False self.target_size = None self.preproc_cfg = { - 'rgb_mean': (0.485, 0.456, 0.406), - 'rgb_std': (0.229, 0.224, 0.225), + "rgb_mean": (0.485, 0.456, 0.406), + "rgb_std": (0.229, 0.224, 0.225), } self.samples, self.target_size_dict = self.load_data() def load_data(self): samples = [] target_size_dict = {} - file_list = list(self.dir.glob('*.jpg')) - file_list = sorted(file_list, key=lambda x: int(x.stem[1:min(4, len(x.stem))])) + file_list = list(self.dir.glob("*.jpg")) + file_list = sorted( + file_list, key=lambda x: int(x.stem[1 : min(4, len(x.stem))]) + ) for img_idx, file in enumerate(file_list): img = cv2.imread(str(file)) @@ -283,24 +311,25 @@ def load_data(self): @property def dir(self): - return Path(os.environ["MIT300_DATA_DIR"]) / 'BenchmarkIMAGES' + return Path(os.environ["MIT300_DATA_DIR"]) / "BenchmarkIMAGES" def __len__(self): return len(self.samples) - def preprocess(self, img, out_size, data='img'): - assert(data == 'img') + def preprocess(self, img, out_size, data="img"): + assert data == "img" transformations = [ transforms.ToPILImage(), - transforms.Resize( - out_size, interpolation=PIL.Image.LANCZOS), + transforms.Resize(out_size, interpolation=PIL.Image.LANCZOS), transforms.ToTensor(), ] - if 'rgb_mean' in self.preproc_cfg: + if "rgb_mean" in self.preproc_cfg: transformations.append( transforms.Normalize( - self.preproc_cfg['rgb_mean'], self.preproc_cfg['rgb_std'])) + self.preproc_cfg["rgb_mean"], self.preproc_cfg["rgb_std"] + ) + ) processing = transforms.Compose(transformations) tensor = processing(img) @@ -310,9 +339,9 @@ def get_data(self, item): img_name, out_size = self.samples[item] img_file = self.dir / img_name img = cv2.imread(str(img_file)) - assert(img is not None) + assert img is not None img = np.ascontiguousarray(img[:, :, ::-1]) - img = self.preprocess(img, out_size, data='img') + img = self.preprocess(img, out_size, data="img") return [1], img, self.target_size_dict[item] def __getitem__(self, item): @@ -321,19 +350,27 @@ def __getitem__(self, item): class MIT1003Dataset(Dataset, utils.KwConfigClass): - source = 'MIT1003' + source = "MIT1003" n_train_val_images = 1003 dynamic = False - def __init__(self, phase='train', subset=None, verbose=1, - preproc_cfg=None, n_x_val=10, x_val_step=0, x_val_seed=27): + def __init__( + self, + phase="train", + subset=None, + verbose=1, + preproc_cfg=None, + n_x_val=10, + x_val_step=0, + x_val_seed=27, + ): self.phase = phase - self.train = phase == 'train' + self.train = phase == "train" self.subset = subset self.verbose = verbose self.preproc_cfg = { - 'rgb_mean': (0.485, 0.456, 0.406), - 'rgb_std': (0.229, 0.224, 0.225), + "rgb_mean": (0.485, 0.456, 0.406), + "rgb_std": (0.229, 0.224, 0.225), } if preproc_cfg is not None: self.preproc_cfg.update(preproc_cfg) @@ -348,7 +385,7 @@ def __init__(self, phase='train', subset=None, verbose=1, self.samples = np.arange(0, n_images) else: print(f"X-Val step: {x_val_step}") - assert(self.x_val_step < self.n_x_val) + assert self.x_val_step < self.n_x_val samples = np.arange(0, n_images) if self.x_val_seed > 0: np.random.seed(self.x_val_seed) @@ -364,31 +401,31 @@ def __init__(self, phase='train', subset=None, verbose=1, self.all_image_files, self.size_dict = self.load_data() if self.subset is not None: - self.samples = self.samples[:int(len(self.samples) * subset)] + self.samples = self.samples[: int(len(self.samples) * subset)] # For compatibility with video datasets self.n_images_dict = {sample: 1 for sample in self.samples} self.target_size_dict = { - img_idx: self.size_dict[img_idx]['target_size'] - for img_idx in self.samples} + img_idx: self.size_dict[img_idx]["target_size"] for img_idx in self.samples + } self.n_samples = len(self.samples) self.frame_modulo = 1 def get_map(self, img_idx): - map_file = self.fix_dir / self.all_image_files[img_idx]['map'] + map_file = self.fix_dir / self.all_image_files[img_idx]["map"] map = cv2.imread(str(map_file), cv2.IMREAD_GRAYSCALE) - assert(map is not None) + assert map is not None return map def get_img(self, img_idx): - img_file = self.img_dir / self.all_image_files[img_idx]['img'] + img_file = self.img_dir / self.all_image_files[img_idx]["img"] img = cv2.imread(str(img_file)) - assert(img is not None) + assert img is not None return np.ascontiguousarray(img[:, :, ::-1]) def get_fixation_map(self, img_idx): - fix_map_file = self.fix_dir / self.all_image_files[img_idx]['pts'] + fix_map_file = self.fix_dir / self.all_image_files[img_idx]["pts"] fix_map = cv2.imread(str(fix_map_file), cv2.IMREAD_GRAYSCALE) - assert(fix_map is not None) + assert fix_map is not None return fix_map @property @@ -397,11 +434,11 @@ def dir(self): @property def fix_dir(self): - return self.dir / 'ALLFIXATIONMAPS' / 'ALLFIXATIONMAPS' + return self.dir / "ALLFIXATIONMAPS" / "ALLFIXATIONMAPS" @property def img_dir(self): - return self.dir / 'ALLSTIMULI' / 'ALLSTIMULI' + return self.dir / "ALLSTIMULI" / "ALLSTIMULI" def get_out_size_eval(self, img_size): ar = img_size[0] / img_size[1] @@ -445,69 +482,73 @@ def load_data(self): all_image_files = [] for img_file in sorted(self.img_dir.glob("*.jpeg")): - all_image_files.append({ - 'img': img_file.name, - 'map': img_file.stem + "_fixMap.jpg", - 'pts': img_file.stem + "_fixPts.jpg", - }) - assert((self.fix_dir / all_image_files[-1]['map']).exists()) - assert((self.fix_dir / all_image_files[-1]['pts']).exists()) + all_image_files.append( + { + "img": img_file.name, + "map": img_file.stem + "_fixMap.jpg", + "pts": img_file.stem + "_fixPts.jpg", + } + ) + assert (self.fix_dir / all_image_files[-1]["map"]).exists() + assert (self.fix_dir / all_image_files[-1]["pts"]).exists() size_dict_file = config_path / "img_size_dict.json" if size_dict_file.exists(): - with open(size_dict_file, 'r') as f: + with open(size_dict_file, "r") as f: size_dict = json.load(f) - size_dict = {int(img_idx): val for - img_idx, val in size_dict.items()} + size_dict = {int(img_idx): val for img_idx, val in size_dict.items()} else: size_dict = {} for img_idx in range(self.n_train_val_images): - img = cv2.imread( - str(self.img_dir / all_image_files[img_idx]['img'])) - size_dict[img_idx] = {'img_size': img.shape[:2]} - with open(size_dict_file, 'w') as f: + img = cv2.imread(str(self.img_dir / all_image_files[img_idx]["img"])) + size_dict[img_idx] = {"img_size": img.shape[:2]} + with open(size_dict_file, "w") as f: json.dump(size_dict, f) for img_idx in self.samples: - img_size = size_dict[img_idx]['img_size'] - if self.phase in ('train', 'valid'): + img_size = size_dict[img_idx]["img_size"] + if self.phase in ("train", "valid"): out_size = self.get_out_size_train(img_size) else: out_size = self.get_out_size_eval(img_size) - if self.phase in ('train', 'valid'): + if self.phase in ("train", "valid"): target_size = tuple(sz * 2 for sz in out_size) else: target_size = img_size - size_dict[img_idx].update({ - 'out_size': out_size, 'target_size': target_size}) + size_dict[img_idx].update( + {"out_size": out_size, "target_size": target_size} + ) return all_image_files, size_dict def __len__(self): return len(self.samples) - def preprocess(self, img, out_size=None, data='img'): + def preprocess(self, img, out_size=None, data="img"): transformations = [ transforms.ToPILImage(), ] - if data in ('img', 'sal'): - transformations.append(transforms.Resize( - out_size, interpolation=PIL.Image.LANCZOS)) + if data in ("img", "sal"): + transformations.append( + transforms.Resize(out_size, interpolation=PIL.Image.LANCZOS) + ) else: - transformations.append(transforms.Resize( - out_size, interpolation=PIL.Image.NEAREST)) + transformations.append( + transforms.Resize(out_size, interpolation=PIL.Image.NEAREST) + ) transformations.append(transforms.ToTensor()) - if data == 'img' and 'rgb_mean' in self.preproc_cfg: + if data == "img" and "rgb_mean" in self.preproc_cfg: transformations.append( transforms.Normalize( - self.preproc_cfg['rgb_mean'], self.preproc_cfg['rgb_std'])) - elif data == 'sal': + self.preproc_cfg["rgb_mean"], self.preproc_cfg["rgb_std"] + ) + ) + elif data == "sal": transformations.append(transforms.Lambda(utils.normalize_tensor)) - elif data == 'fix': - transformations.append( - transforms.Lambda(lambda fix: torch.gt(fix, 0.5))) + elif data == "fix": + transformations.append(transforms.Lambda(lambda fix: torch.gt(fix, 0.5))) processing = transforms.Compose(transformations) tensor = processing(img) @@ -515,16 +556,16 @@ def preprocess(self, img, out_size=None, data='img'): def get_data(self, img_idx): img = self.get_img(img_idx) - out_size = self.size_dict[img_idx]['out_size'] + out_size = self.size_dict[img_idx]["out_size"] target_size = self.target_size_dict[img_idx] - img = self.preprocess(img, out_size=out_size, data='img') - if self.phase == 'test': + img = self.preprocess(img, out_size=out_size, data="img") + if self.phase == "test": return [1], img, target_size sal = self.get_map(img_idx) - sal = self.preprocess(sal, target_size, data='sal') + sal = self.preprocess(sal, target_size, data="sal") fix = self.get_fixation_map(img_idx) - fix = self.preprocess(fix, target_size, data='fix') + fix = self.preprocess(fix, target_size, data="fix") return [1], img, sal, fix, target_size @@ -539,29 +580,42 @@ class DHF1KDataset(Dataset, utils.KwConfigClass): n_train_val_videos = 700 test_vid_nrs = (701, 1000) frame_rate = 30 - source = 'DHF1K' + source = "DHF1K" dynamic = True - def __init__(self, - seq_len=12, - frame_modulo=5, - max_seq_len=1e6, - preproc_cfg=None, - out_size=(224, 384), phase='train', target_size=(360, 640), - debug=False, val_size=100, n_x_val=3, x_val_step=2, - x_val_seed=0, seq_per_vid=1, subset=None, verbose=1, - n_images_file='dhf1k_n_images.dat', seq_per_vid_val=2, - sal_offset=None): + def __init__( + self, + seq_len=12, + frame_modulo=5, + max_seq_len=1e6, + preproc_cfg=None, + out_size=(224, 384), + phase="train", + target_size=(360, 640), + debug=False, + val_size=100, + n_x_val=3, + x_val_step=2, + x_val_seed=0, + seq_per_vid=1, + subset=None, + verbose=1, + n_images_file="dhf1k_n_images.dat", + seq_per_vid_val=2, + sal_offset=None, + ): self.phase = phase - self.train = phase == 'train' + self.train = phase == "train" if not self.train: preproc_cfg = {} elif preproc_cfg is None: preproc_cfg = {} - preproc_cfg.update({ - 'rgb_mean': (0.485, 0.456, 0.406), - 'rgb_std': (0.229, 0.224, 0.225), - }) + preproc_cfg.update( + { + "rgb_mean": (0.485, 0.456, 0.406), + "rgb_std": (0.229, 0.224, 0.225), + } + ) self.preproc_cfg = preproc_cfg self.out_size = out_size self.debug = debug @@ -585,46 +639,50 @@ def __init__(self, self.vid_nr_array = None # Evaluation - if phase in ('eval', 'test'): + if phase in ("eval", "test"): self.seq_len = int(1e6) - if self.phase in ('test',): - self.vid_nr_array = list(range( - self.test_vid_nrs[0], self.test_vid_nrs[1] + 1)) + if self.phase in ("test",): + self.vid_nr_array = list( + range(self.test_vid_nrs[0], self.test_vid_nrs[1] + 1) + ) self.samples, self.target_size_dict = self.prepare_samples() return # Cross-validation split n_videos = self.n_train_val_videos - assert(self.val_size <= n_videos // self.n_x_val) - assert(self.x_val_step < self.n_x_val) + assert self.val_size <= n_videos // self.n_x_val + assert self.x_val_step < self.n_x_val vid_nr_array = np.arange(1, n_videos + 1) if self.x_val_seed > 0: np.random.seed(self.x_val_seed) np.random.shuffle(vid_nr_array) - val_start = (len(vid_nr_array) - self.val_size) //\ - (self.n_x_val - 1) * self.x_val_step + val_start = ( + (len(vid_nr_array) - self.val_size) // (self.n_x_val - 1) * self.x_val_step + ) vid_nr_array = vid_nr_array.tolist() if not self.train: - self.vid_nr_array =\ - vid_nr_array[val_start:val_start + self.val_size] + self.vid_nr_array = vid_nr_array[val_start : val_start + self.val_size] else: - del vid_nr_array[val_start:val_start + self.val_size] + del vid_nr_array[val_start : val_start + self.val_size] self.vid_nr_array = vid_nr_array if self.subset is not None: - self.vid_nr_array =\ - self.vid_nr_array[:int(len(self.vid_nr_array) * self.subset)] + self.vid_nr_array = self.vid_nr_array[ + : int(len(self.vid_nr_array) * self.subset) + ] self.samples, self.target_size_dict = self.prepare_samples() @property def n_images_dict(self): if self._n_images_dict is None: - with open(config_path.parent / self.n_images_file, 'r') as f: + with open(config_path.parent / self.n_images_file, "r") as f: self._n_images_dict = { - idx + 1: int(line) for idx, line in enumerate(f) - if idx + 1 in self.vid_nr_array} + idx + 1: int(line) + for idx, line in enumerate(f) + if idx + 1 in self.vid_nr_array + } return self._n_images_dict @property @@ -645,9 +703,8 @@ def prepare_samples(self): too_short = 0 too_long = 0 for vid_nr, n_images in self.n_images_dict.items(): - if self.phase in ('eval', 'test'): - samples += [ - (vid_nr, offset + 1) for offset in range(self.frame_modulo)] + if self.phase in ("eval", "test"): + samples += [(vid_nr, offset + 1) for offset in range(self.frame_modulo)] continue if n_images < self.clip_len: too_short += 1 @@ -655,51 +712,59 @@ def prepare_samples(self): if n_images // self.frame_modulo > self.max_seq_len: too_long += 1 continue - if self.phase == 'train': + if self.phase == "train": samples += [(vid_nr, None)] * self.seq_per_vid continue - elif self.phase == 'valid': + elif self.phase == "valid": x = n_images // (self.seq_per_vid_val * 2) - self.clip_len // 2 start = max(1, x) end = min(n_images - self.clip_len, n_images - x) samples += [ - (vid_nr, int(start)) for start in - np.linspace(start, end, self.seq_per_vid_val)] + (vid_nr, int(start)) + for start in np.linspace(start, end, self.seq_per_vid_val) + ] continue - if self.phase not in ('eval', 'test') and self.n_images_dict: + if self.phase not in ("eval", "test") and self.n_images_dict: n_loaded = len(self.n_images_dict) - too_short - too_long - print(f"{n_loaded} videos loaded " - f"({n_loaded / len(self.n_images_dict) * 100:.1f}%)") - print(f"{too_short} videos are too short " - f"({too_short / len(self.n_images_dict) * 100:.1f}%)") - print(f"{too_long} videos are too long " - f"({too_long / len(self.n_images_dict) * 100:.1f}%)") + print( + f"{n_loaded} videos loaded " + f"({n_loaded / len(self.n_images_dict) * 100:.1f}%)" + ) + print( + f"{too_short} videos are too short " + f"({too_short / len(self.n_images_dict) * 100:.1f}%)" + ) + print( + f"{too_long} videos are too long " + f"({too_long / len(self.n_images_dict) * 100:.1f}%)" + ) target_size_dict = { - vid_nr: self.target_size for vid_nr in self.n_images_dict.keys()} + vid_nr: self.target_size for vid_nr in self.n_images_dict.keys() + } return samples, target_size_dict def get_frame_nrs(self, vid_nr, start): n_images = self.n_images_dict[vid_nr] - if self.phase in ('eval', 'test'): + if self.phase in ("eval", "test"): return list(range(start, n_images + 1, self.frame_modulo)) return list(range(start, start + self.clip_len, self.frame_modulo)) def get_annotation_dir(self, vid_nr): - return self.dir / 'annotation' / f'{vid_nr:04d}' + return self.dir / "annotation" / f"{vid_nr:04d}" def get_data_file(self, vid_nr, f_nr, dkey): - if dkey == 'frame': - folder = 'images' - elif dkey == 'sal': - folder = 'maps' - elif dkey == 'fix': - folder = 'fixation' + if dkey == "frame": + folder = "images" + elif dkey == "sal": + folder = "maps" + elif dkey == "fix": + folder = "fixation" else: - raise ValueError(f'Unknown data key {dkey}') - return self.get_annotation_dir(vid_nr) / folder / f'{f_nr:04d}.png' + raise ValueError(f"Unknown data key {dkey}") + return self.get_annotation_dir(vid_nr) / folder / f"{f_nr:04d}.png" def load_data(self, vid_nr, f_nr, dkey): - read_flag = None if dkey == 'frame' else cv2.IMREAD_GRAYSCALE + read_flag = None if dkey == "frame" else cv2.IMREAD_GRAYSCALE data_file = self.get_data_file(vid_nr, f_nr, dkey) if read_flag is not None: data = cv2.imread(str(data_file), read_flag) @@ -707,10 +772,10 @@ def load_data(self, vid_nr, f_nr, dkey): data = cv2.imread(str(data_file)) if data is None: raise FileNotFoundError(data_file) - if dkey == 'frame': + if dkey == "frame": data = np.ascontiguousarray(data[:, :, ::-1]) - if dkey == 'sal' and self.train and self.sal_offset is not None: + if dkey == "sal" and self.train and self.sal_offset is not None: data += self.sal_offset data[0, 0] = 0 @@ -718,20 +783,22 @@ def load_data(self, vid_nr, f_nr, dkey): def preprocess_sequence(self, frame_seq, dkey, vid_nr): transformations = [] - if dkey == 'frame': + if dkey == "frame": transformations.append(transforms.ToPILImage()) - transformations.append(transforms.Resize( - self.out_size, interpolation=PIL.Image.LANCZOS)) + transformations.append( + transforms.Resize(self.out_size, interpolation=PIL.Image.LANCZOS) + ) transformations.append(transforms.ToTensor()) - if dkey == 'frame' and 'rgb_mean' in self.preproc_cfg: + if dkey == "frame" and "rgb_mean" in self.preproc_cfg: transformations.append( transforms.Normalize( - self.preproc_cfg['rgb_mean'], self.preproc_cfg['rgb_std'])) - elif dkey == 'sal': + self.preproc_cfg["rgb_mean"], self.preproc_cfg["rgb_std"] + ) + ) + elif dkey == "sal": transformations.append(transforms.Lambda(utils.normalize_tensor)) - elif dkey == 'fix': - transformations.append( - transforms.Lambda(lambda fix: torch.gt(fix, 0.5))) + elif dkey == "fix": + transformations.append(transforms.Lambda(lambda fix: torch.gt(fix, 0.5))) processing = transforms.Compose(transformations) @@ -751,12 +818,12 @@ def get_data(self, vid_nr, start): else: start = np.random.randint(1, max_start) frame_nrs = self.get_frame_nrs(vid_nr, start) - frame_seq = self.get_seq(vid_nr, frame_nrs, 'frame') + frame_seq = self.get_seq(vid_nr, frame_nrs, "frame") target_size = self.target_size_dict[vid_nr] - if self.phase == 'test' and self.source in ('DHF1K',): + if self.phase == "test" and self.source in ("DHF1K",): return frame_nrs, frame_seq, target_size - sal_seq = self.get_seq(vid_nr, frame_nrs, 'sal') - fix_seq = self.get_seq(vid_nr, frame_nrs, 'fix') + sal_seq = self.get_seq(vid_nr, frame_nrs, "sal") + fix_seq = self.get_seq(vid_nr, frame_nrs, "fix") return frame_nrs, frame_seq, sal_seq, fix_seq, target_size def __getitem__(self, item): @@ -767,115 +834,124 @@ def __getitem__(self, item): class HollywoodDataset(DHF1KDataset): - source = 'Hollywood' + source = "Hollywood" dynamic = True img_channels = 1 - n_videos = { - 'train': 747, - 'test': 884 - } + n_videos = {"train": 747, "test": 884} test_vid_nrs = (1, 884) frame_rate = 24 - def __init__(self, out_size=(224, 416), val_size=75, n_images_file=None, - seq_per_vid_val=1, register_file='hollywood_register.json', - phase='train', - frame_modulo=4, - seq_len=12, - **kwargs): + def __init__( + self, + out_size=(224, 416), + val_size=75, + n_images_file=None, + seq_per_vid_val=1, + register_file="hollywood_register.json", + phase="train", + frame_modulo=4, + seq_len=12, + **kwargs, + ): self.register = None - self.phase_str = 'test' if phase in ('eval', 'test') else 'train' + self.phase_str = "test" if phase in ("eval", "test") else "train" self.register_file = self.phase_str + "_" + register_file - super().__init__(out_size=out_size, val_size=val_size, - n_images_file=n_images_file, - seq_per_vid_val=seq_per_vid_val, - x_val_seed=42, phase=phase, target_size=out_size, - frame_modulo=frame_modulo, - seq_len=seq_len, - **kwargs) - if phase in ('eval', 'test'): - self.target_size_dict = self.get_register()['vid_size_dict'] + super().__init__( + out_size=out_size, + val_size=val_size, + n_images_file=n_images_file, + seq_per_vid_val=seq_per_vid_val, + x_val_seed=42, + phase=phase, + target_size=out_size, + frame_modulo=frame_modulo, + seq_len=seq_len, + **kwargs, + ) + if phase in ("eval", "test"): + self.target_size_dict = self.get_register()["vid_size_dict"] @property def n_images_dict(self): if self._n_images_dict is None: - self._n_images_dict = self.get_register()['n_images_dict'] - self._n_images_dict = {vid_nr: ni for vid_nr, ni - in self._n_images_dict.items() - if vid_nr // 100 in self.vid_nr_array} + self._n_images_dict = self.get_register()["n_images_dict"] + self._n_images_dict = { + vid_nr: ni + for vid_nr, ni in self._n_images_dict.items() + if vid_nr // 100 in self.vid_nr_array + } return self._n_images_dict def get_register(self): if self.register is None: register_file = config_path / self.register_file if register_file.exists(): - with open(config_path / register_file, 'r') as f: + with open(config_path / register_file, "r") as f: self.register = json.load(f) - for reg_key in ('n_images_dict', 'start_image_dict', - 'vid_size_dict'): + for reg_key in ("n_images_dict", "start_image_dict", "vid_size_dict"): self.register[reg_key] = { - int(key): val for key, val in - self.register[reg_key].items()} + int(key): val for key, val in self.register[reg_key].items() + } else: self.register = self.generate_register() - with open(config_path / register_file, 'w') as f: + with open(config_path / register_file, "w") as f: json.dump(self.register, f, indent=2) return self.register def generate_register(self): - n_shots = { - vid_nr: 0 for vid_nr in range(1, self.n_videos[self.phase_str] + 1)} + n_shots = {vid_nr: 0 for vid_nr in range(1, self.n_videos[self.phase_str] + 1)} n_images_dict = {} start_image_dict = {} vid_size_dict = {} - for folder in sorted(self.dir.glob('actionclip*')): + for folder in sorted(self.dir.glob("actionclip*")): name = folder.stem vid_nr_start = 10 + len(self.phase_str) - vid_nr = int(name[vid_nr_start:vid_nr_start + 5]) + vid_nr = int(name[vid_nr_start : vid_nr_start + 5]) shot_nr = int(name[-2:].replace("_", "")) n_shots[vid_nr] += 1 vid_nr_shot_nr = 100 * vid_nr + shot_nr - image_files = sorted((folder / 'images').glob('actionclip*.png')) + image_files = sorted((folder / "images").glob("actionclip*.png")) n_images_dict[vid_nr_shot_nr] = len(image_files) start_image_dict[vid_nr_shot_nr] = int(image_files[0].stem[-5:]) img = cv2.imread(str(image_files[0])) vid_size_dict[vid_nr_shot_nr] = tuple(img.shape[:2]) return dict( - n_shots=n_shots, n_images_dict=n_images_dict, - start_image_dict=start_image_dict, vid_size_dict=vid_size_dict) + n_shots=n_shots, + n_images_dict=n_images_dict, + start_image_dict=start_image_dict, + vid_size_dict=vid_size_dict, + ) def preprocess_sequence(self, frame_seq, dkey, vid_nr): - transformations = [ - transforms.ToPILImage() - ] + transformations = [transforms.ToPILImage()] - vid_size = self.register['vid_size_dict'][vid_nr] + vid_size = self.register["vid_size_dict"][vid_nr] if vid_size[0] != self.out_size[0]: - interpolation = PIL.Image.LANCZOS if dkey in ('frame', 'sal')\ - else PIL.Image.NEAREST - size = (self.out_size[0], - int(vid_size[1] * self.out_size[0] / vid_size[0])) - transformations.append( - transforms.Resize(size, interpolation=interpolation)) + interpolation = ( + PIL.Image.LANCZOS if dkey in ("frame", "sal") else PIL.Image.NEAREST + ) + size = (self.out_size[0], int(vid_size[1] * self.out_size[0] / vid_size[0])) + transformations.append(transforms.Resize(size, interpolation=interpolation)) transformations += [ transforms.CenterCrop(self.out_size), transforms.ToTensor(), ] - if dkey == 'frame' and 'rgb_mean' in self.preproc_cfg: + if dkey == "frame" and "rgb_mean" in self.preproc_cfg: transformations.append( transforms.Normalize( - self.preproc_cfg['rgb_mean'], self.preproc_cfg['rgb_std'])) - elif dkey == 'sal': + self.preproc_cfg["rgb_mean"], self.preproc_cfg["rgb_std"] + ) + ) + elif dkey == "sal": transformations.append(transforms.Lambda(utils.normalize_tensor)) - elif dkey == 'fix': - transformations.append( - transforms.Lambda(lambda fix: torch.gt(fix, 0.5))) + elif dkey == "fix": + transformations.append(transforms.Lambda(lambda fix: torch.gt(fix, 0.5))) processing = transforms.Compose(transformations) @@ -886,22 +962,23 @@ def preprocess_sequence(self, frame_seq, dkey, vid_nr): def preprocess_sequence_eval(self, frame_seq, dkey, vid_nr): transformations = [] - if dkey == 'frame': + if dkey == "frame": transformations.append(transforms.ToPILImage()) transformations.append( - transforms.Resize( - self.out_size, interpolation=PIL.Image.LANCZOS)) + transforms.Resize(self.out_size, interpolation=PIL.Image.LANCZOS) + ) transformations.append(transforms.ToTensor()) - if dkey == 'frame' and 'rgb_mean' in self.preproc_cfg: + if dkey == "frame" and "rgb_mean" in self.preproc_cfg: transformations.append( transforms.Normalize( - self.preproc_cfg['rgb_mean'], self.preproc_cfg['rgb_std'])) - elif dkey == 'sal': + self.preproc_cfg["rgb_mean"], self.preproc_cfg["rgb_std"] + ) + ) + elif dkey == "sal": transformations.append(transforms.Lambda(utils.normalize_tensor)) - elif dkey == 'fix': - transformations.append( - transforms.Lambda(lambda fix: torch.gt(fix, 0.5))) + elif dkey == "fix": + transformations.append(transforms.Lambda(lambda fix: torch.gt(fix, 0.5))) processing = transforms.Compose(transformations) @@ -915,37 +992,43 @@ def get_annotation_dir(self, vid_nr_shot_nr): return self.dir / f"actionclip{self.phase_str}{vid_nr:05d}_{shot_nr:1d}" def get_data_file(self, vid_nr_shot_nr, f_nr, dkey): - if dkey == 'frame': - folder = 'images' - elif dkey == 'sal': - folder = 'maps' - elif dkey == 'fix': - folder = 'fixation' + if dkey == "frame": + folder = "images" + elif dkey == "sal": + folder = "maps" + elif dkey == "fix": + folder = "fixation" else: - raise ValueError(f'Unknown data key {dkey}') + raise ValueError(f"Unknown data key {dkey}") vid_nr = vid_nr_shot_nr // 100 - f_nr += self.register['start_image_dict'][vid_nr_shot_nr] - 1 - return self.get_annotation_dir(vid_nr_shot_nr) / folder /\ - f'actionclip{self.phase_str}{vid_nr:05d}_{f_nr:05d}.png' + f_nr += self.register["start_image_dict"][vid_nr_shot_nr] - 1 + return ( + self.get_annotation_dir(vid_nr_shot_nr) + / folder + / f"actionclip{self.phase_str}{vid_nr:05d}_{f_nr:05d}.png" + ) def get_seq(self, vid_nr, frame_nrs, dkey): data_seq = [self.load_data(vid_nr, f_nr, dkey) for f_nr in frame_nrs] - preproc_fun = self.preprocess_sequence if self.phase \ - in ('train', 'valid') else self.preprocess_sequence_eval + preproc_fun = ( + self.preprocess_sequence + if self.phase in ("train", "valid") + else self.preprocess_sequence_eval + ) return preproc_fun(data_seq, dkey, vid_nr) @property def dir(self): if self._dir is None: - self._dir = Path(os.environ["HOLLYWOOD_DATA_DIR"]) /\ - ('training' if self.phase in ('train', 'valid') - else 'testing') + self._dir = Path(os.environ["HOLLYWOOD_DATA_DIR"]) / ( + "training" if self.phase in ("train", "valid") else "testing" + ) return self._dir class UCFSportsDataset(DHF1KDataset): - source = 'UCFSports' + source = "UCFSports" dynamic = True img_channels = 1 @@ -953,48 +1036,60 @@ class UCFSportsDataset(DHF1KDataset): test_vid_nrs = (1, 47) frame_rate = 24 - def __init__(self, out_size=(256, 384), val_size=10, n_images_file=None, - seq_per_vid_val=1, register_file='ucfsports_register.json', - phase='train', - frame_modulo=4, - seq_len=12, - **kwargs): - self.phase_str = 'test' if phase in ('eval', 'test') else 'train' + def __init__( + self, + out_size=(256, 384), + val_size=10, + n_images_file=None, + seq_per_vid_val=1, + register_file="ucfsports_register.json", + phase="train", + frame_modulo=4, + seq_len=12, + **kwargs, + ): + self.phase_str = "test" if phase in ("eval", "test") else "train" self.register_file = self.phase_str + "_" + register_file self.register = None - super().__init__(out_size=out_size, val_size=val_size, - n_images_file=n_images_file, - seq_per_vid_val=seq_per_vid_val, - x_val_seed=27, target_size=out_size, - frame_modulo=frame_modulo, phase=phase, - seq_len=seq_len, - **kwargs) - if phase in ('eval', 'test'): - self.target_size_dict = self.get_register()['vid_size_dict'] + super().__init__( + out_size=out_size, + val_size=val_size, + n_images_file=n_images_file, + seq_per_vid_val=seq_per_vid_val, + x_val_seed=27, + target_size=out_size, + frame_modulo=frame_modulo, + phase=phase, + seq_len=seq_len, + **kwargs, + ) + if phase in ("eval", "test"): + self.target_size_dict = self.get_register()["vid_size_dict"] @property def n_images_dict(self): if self._n_images_dict is None: - self._n_images_dict = self.get_register()['n_images_dict'] - self._n_images_dict = {vid_nr: ni for vid_nr, ni - in self._n_images_dict.items() - if vid_nr in self.vid_nr_array} + self._n_images_dict = self.get_register()["n_images_dict"] + self._n_images_dict = { + vid_nr: ni + for vid_nr, ni in self._n_images_dict.items() + if vid_nr in self.vid_nr_array + } return self._n_images_dict def get_register(self): if self.register is None: register_file = config_path / self.register_file if register_file.exists(): - with open(config_path / register_file, 'r') as f: + with open(config_path / register_file, "r") as f: self.register = json.load(f) - for reg_key in ('n_images_dict', 'vid_name_dict', - 'vid_size_dict'): + for reg_key in ("n_images_dict", "vid_name_dict", "vid_size_dict"): self.register[reg_key] = { - int(key): val for key, val in - self.register[reg_key].items()} + int(key): val for key, val in self.register[reg_key].items() + } else: self.register = self.generate_register() - with open(config_path / register_file, 'w') as f: + with open(config_path / register_file, "w") as f: json.dump(self.register, f, indent=2) return self.register @@ -1003,49 +1098,50 @@ def generate_register(self): vid_name_dict = {} vid_size_dict = {} - for vid_idx, folder in enumerate(sorted(self.dir.glob('*-*'))): + for vid_idx, folder in enumerate(sorted(self.dir.glob("*-*"))): vid_nr = vid_idx + 1 vid_name_dict[vid_nr] = folder.stem - image_files = list((folder / 'images').glob('*.png')) + image_files = list((folder / "images").glob("*.png")) n_images_dict[vid_nr] = len(image_files) img = cv2.imread(str(image_files[0])) vid_size_dict[vid_nr] = tuple(img.shape[:2]) return dict( - vid_name_dict=vid_name_dict, n_images_dict=n_images_dict, - vid_size_dict=vid_size_dict) + vid_name_dict=vid_name_dict, + n_images_dict=n_images_dict, + vid_size_dict=vid_size_dict, + ) def preprocess_sequence(self, frame_seq, dkey, vid_nr): - transformations = [ - transforms.ToPILImage() - ] + transformations = [transforms.ToPILImage()] - vid_size = self.register['vid_size_dict'][vid_nr] - interpolation = PIL.Image.LANCZOS if dkey in ('frame', 'sal')\ - else PIL.Image.NEAREST + vid_size = self.register["vid_size_dict"][vid_nr] + interpolation = ( + PIL.Image.LANCZOS if dkey in ("frame", "sal") else PIL.Image.NEAREST + ) out_size_ratio = self.out_size[1] / self.out_size[0] this_size_ratio = vid_size[1] / vid_size[0] if this_size_ratio < out_size_ratio: size = (int(self.out_size[1] / this_size_ratio), self.out_size[1]) else: size = (self.out_size[0], int(self.out_size[0] * this_size_ratio)) - transformations.append( - transforms.Resize(size, interpolation=interpolation)) + transformations.append(transforms.Resize(size, interpolation=interpolation)) transformations += [ transforms.CenterCrop(self.out_size), transforms.ToTensor(), ] - if dkey == 'frame' and 'rgb_mean' in self.preproc_cfg: + if dkey == "frame" and "rgb_mean" in self.preproc_cfg: transformations.append( transforms.Normalize( - self.preproc_cfg['rgb_mean'], self.preproc_cfg['rgb_std'])) - elif dkey == 'sal': + self.preproc_cfg["rgb_mean"], self.preproc_cfg["rgb_std"] + ) + ) + elif dkey == "sal": transformations.append(transforms.Lambda(utils.normalize_tensor)) - elif dkey == 'fix': - transformations.append( - transforms.Lambda(lambda fix: torch.gt(fix, 0.5))) + elif dkey == "fix": + transformations.append(transforms.Lambda(lambda fix: torch.gt(fix, 0.5))) processing = transforms.Compose(transformations) @@ -1056,30 +1152,33 @@ def preprocess_sequence(self, frame_seq, dkey, vid_nr): preprocess_sequence_eval = HollywoodDataset.preprocess_sequence_eval def get_annotation_dir(self, vid_nr): - vid_name = self.register['vid_name_dict'][vid_nr] + vid_name = self.register["vid_name_dict"][vid_nr] return self.dir / vid_name def get_data_file(self, vid_nr, f_nr, dkey): - if dkey == 'frame': - folder = 'images' - elif dkey == 'sal': - folder = 'maps' - elif dkey == 'fix': - folder = 'fixation' + if dkey == "frame": + folder = "images" + elif dkey == "sal": + folder = "maps" + elif dkey == "fix": + folder = "fixation" else: - raise ValueError(f'Unknown data key {dkey}') - vid_name = self.register['vid_name_dict'][vid_nr] - return self.get_annotation_dir(vid_nr) / folder /\ - f"{vid_name[:-4]}_{vid_name[-3:]}_{f_nr:03d}.png" + raise ValueError(f"Unknown data key {dkey}") + vid_name = self.register["vid_name_dict"][vid_nr] + return ( + self.get_annotation_dir(vid_nr) + / folder + / f"{vid_name[:-4]}_{vid_name[-3:]}_{f_nr:03d}.png" + ) get_seq = HollywoodDataset.get_seq @property def dir(self): if self._dir is None: - self._dir = Path(os.environ["UCFSPORTS_DATA_DIR"]) /\ - ('training' if self.phase in ('train', 'valid') - else 'testing') + self._dir = Path(os.environ["UCFSPORTS_DATA_DIR"]) / ( + "training" if self.phase in ("train", "valid") else "testing" + ) return self._dir @@ -1109,13 +1208,14 @@ def __init__(self, images_path, frame_modulo=None, source=None): self.images_path = images_path self.frame_modulo = frame_modulo or 5 self.preproc_cfg = { - 'rgb_mean': (0.485, 0.456, 0.406), - 'rgb_std': (0.229, 0.224, 0.225), + "rgb_mean": (0.485, 0.456, 0.406), + "rgb_std": (0.229, 0.224, 0.225), } frame_files = sorted(list(images_path.glob("*"))) - frame_files = [file for file in frame_files - if file.suffix in ('.png', '.jpg', '.jpeg')] + frame_files = [ + file for file in frame_files if file.suffix in (".png", ".jpg", ".jpeg") + ] self.frame_files = frame_files self.vid_nr_array = [0] self.n_images_dict = {0: len(frame_files)} @@ -1124,13 +1224,13 @@ def __init__(self, images_path, frame_modulo=None, source=None): img_size = tuple(img.shape[:2]) self.target_size_dict = {0: img_size} - if source == 'DHF1K' and img_size == (360, 640): + if source == "DHF1K" and img_size == (360, 640): self.out_size = (224, 384) - elif source == 'Hollywood': + elif source == "Hollywood": self.out_size = (224, 416) - elif source == 'UCFSports': + elif source == "UCFSports": self.out_size = (256, 384) else: @@ -1147,13 +1247,16 @@ def load_frame(self, f_nr): def preprocess_sequence(self, frame_seq): transformations = [] transformations.append(transforms.ToPILImage()) - transformations.append(transforms.Resize( - self.out_size, interpolation=PIL.Image.LANCZOS)) + transformations.append( + transforms.Resize(self.out_size, interpolation=PIL.Image.LANCZOS) + ) transformations.append(transforms.ToTensor()) - if 'rgb_mean' in self.preproc_cfg: + if "rgb_mean" in self.preproc_cfg: transformations.append( transforms.Normalize( - self.preproc_cfg['rgb_mean'], self.preproc_cfg['rgb_std'])) + self.preproc_cfg["rgb_mean"], self.preproc_cfg["rgb_std"] + ) + ) processing = transforms.Compose(transformations) tensor = [processing(img) for img in frame_seq] tensor = torch.stack(tensor) @@ -1180,16 +1283,16 @@ def __init__(self, images_path): self.images_path = images_path self.frame_modulo = 1 self.preproc_cfg = { - 'rgb_mean': (0.485, 0.456, 0.406), - 'rgb_std': (0.229, 0.224, 0.225), + "rgb_mean": (0.485, 0.456, 0.406), + "rgb_std": (0.229, 0.224, 0.225), } image_files = sorted(list(images_path.glob("*"))) - image_files = [file for file in image_files - if file.suffix in ('.png', '.jpg', '.jpeg')] + image_files = [ + file for file in image_files if file.suffix in (".png", ".jpg", ".jpeg") + ] self.image_files = image_files - self.n_images_dict = { - img_idx: 1 for img_idx in range(len(self.image_files))} + self.n_images_dict = {img_idx: 1 for img_idx in range(len(self.image_files))} self.target_size_dict = {} self.out_size_dict = {} @@ -1210,14 +1313,15 @@ def load_image(self, img_idx): def preprocess(self, img, out_size): transformations = [ transforms.ToPILImage(), - transforms.Resize( - out_size, interpolation=PIL.Image.LANCZOS), + transforms.Resize(out_size, interpolation=PIL.Image.LANCZOS), transforms.ToTensor(), ] - if 'rgb_mean' in self.preproc_cfg: + if "rgb_mean" in self.preproc_cfg: transformations.append( transforms.Normalize( - self.preproc_cfg['rgb_mean'], self.preproc_cfg['rgb_std'])) + self.preproc_cfg["rgb_mean"], self.preproc_cfg["rgb_std"] + ) + ) processing = transforms.Compose(transformations) tensor = processing(img) return tensor @@ -1225,7 +1329,7 @@ def preprocess(self, img, out_size): def get_data(self, img_idx): file = self.image_files[img_idx] img = cv2.imread(str(file)) - assert (img is not None) + assert img is not None img = np.ascontiguousarray(img[:, :, ::-1]) out_size = self.out_size_dict[img_idx] img = self.preprocess(img, out_size) @@ -1236,3 +1340,45 @@ def __len__(self): def __getitem__(self, item): return self.get_data(item, 0) + + +def im_preprocess( + img_rgb: np.array, + preproc_cfg: dict = { + "rgb_mean": (0.485, 0.456, 0.406), + "rgb_std": (0.229, 0.224, 0.225), + }, +): + """preprocess image before inference and return a tensor of shap [c,h,w] + Args: + img_rgb: image as numpy array (e.g. cv2.imread(im_path)[:,:,::-1]) + """ + + def preprocess(img, out_size): + transformations = [ + transforms.ToPILImage(), + transforms.Resize(out_size, interpolation=PIL.Image.LANCZOS), + transforms.ToTensor(), + ] + if "rgb_mean" in preproc_cfg: + transformations.append( + transforms.Normalize(preproc_cfg["rgb_mean"], preproc_cfg["rgb_std"]) + ) + processing = transforms.Compose(transformations) + tensor = processing(img) + return tensor + + im = np.ascontiguousarray(img_rgb) + out_size = get_optimal_out_size(img_size=img_rgb.shape[:2]) + return preprocess(im, out_size) + + +def smap_postprocess(smap): + """postprocess an output torch tensor into a numpy array + Args: + smap: a slice of the output torch tensor (e.g. pred_seq[:,0, ...]) + """ + smap = smap.exp() + smap = torch.squeeze(smap) + smap = utils.to_numpy(smap) + return (smap / np.amax(smap) * 255).astype(np.uint8) From a6258ae49564e1863456637231c687487f0895a4 Mon Sep 17 00:00:00 2001 From: ohjho Date: Tue, 16 Jul 2024 15:22:02 -0400 Subject: [PATCH 3/6] added load_weights_from_path to simplify loading of model weights --- unisal/model.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/unisal/model.py b/unisal/model.py index 3137f59..83f2d00 100644 --- a/unisal/model.py +++ b/unisal/model.py @@ -1,5 +1,5 @@ from collections import OrderedDict -import pprint +import pprint, os from functools import partial from itertools import product @@ -41,6 +41,15 @@ def load_best_weights(self, directory): torch.load(directory / f"weights_best.pth", map_location=DEFAULT_DEVICE) ) + def load_weights_from_path(self, weights_path): + assert os.path.isfile( + weights_path + ), f"weights_path {weights_path} is not a valid file" + assert weights_path.endswith( + ".pth" + ), f"weights file must have .pth extension, {os.path.splitext(weights_path)[-1]}" + self.load_state_dict(torch.load(weights_path, map_location=DEFAULT_DEVICE)) + def load_epoch_checkpoint(self, directory, epoch): """Load state_dict from a Trainer checkpoint at a specific epoch""" chkpnt = torch.load(directory / f"chkpnt_epoch{epoch:04d}.pth") From 0f6a50e26f6040d947811ab8f53ee4d05a1699f3 Mon Sep 17 00:00:00 2001 From: ohjho Date: Fri, 19 Jul 2024 18:09:08 -0400 Subject: [PATCH 4/6] added new function predict_image() --- run.py | 75 ++++++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 50 insertions(+), 25 deletions(-) diff --git a/run.py b/run.py index 731f804..5ef8cf0 100644 --- a/run.py +++ b/run.py @@ -1,13 +1,9 @@ from pathlib import Path -import os - -import fire - +import os, fire import unisal -def train(eval_sources=('DHF1K', 'SALICON', 'UCFSports', 'Hollywood'), - **kwargs): +def train(eval_sources=("DHF1K", "SALICON", "UCFSports", "Hollywood"), **kwargs): """Run training and evaluation.""" trainer = unisal.train.Trainer(**kwargs) trainer.fit() @@ -20,17 +16,17 @@ def train(eval_sources=('DHF1K', 'SALICON', 'UCFSports', 'Hollywood'), def load_trainer(train_id=None): """Instantiate Trainer class from saved kwargs.""" if train_id is None: - train_id = 'pretrained_unisal' + train_id = "pretrained_unisal" print(f"Train ID: {train_id}") train_dir = Path(os.environ["TRAIN_DIR"]) train_dir = train_dir / train_id + print(f"initalizing trainer from {train_dir}...") return unisal.train.Trainer.init_from_cfg_dir(train_dir) def score_model( - train_id=None, - sources=('DHF1K', 'SALICON', 'UCFSports', 'Hollywood'), - **kwargs): + train_id=None, sources=("DHF1K", "SALICON", "UCFSports", "Hollywood"), **kwargs +): """Compute the scores for a trained model.""" trainer = load_trainer(train_id) @@ -39,26 +35,27 @@ def score_model( def generate_predictions( - train_id=None, - sources=('DHF1K', 'SALICON', 'UCFSports', 'Hollywood', - 'MIT1003', 'MIT300'), - **kwargs): + train_id=None, + sources=("DHF1K", "SALICON", "UCFSports", "Hollywood", "MIT1003", "MIT300"), + **kwargs, +): """Generate predictions with a trained model.""" trainer = load_trainer(train_id) for source in sources: # Load fine-tuned weights for MIT datasets - if source in ('MIT1003', 'MIT300'): + if source in ("MIT1003", "MIT300"): trainer.model.load_weights(trainer.train_dir, "ft_mit1003") - trainer.salicon_cfg['x_val_step'] = 0 - kwargs.update({'model_domain': 'SALICON', 'load_weights': False}) + trainer.salicon_cfg["x_val_step"] = 0 + kwargs.update({"model_domain": "SALICON", "load_weights": False}) trainer.generate_predictions(source=source, **kwargs) def predictions_from_folder( - folder_path, is_video, source=None, train_id=None, model_domain=None): + folder_path, is_video, source=None, train_id=None, model_domain=None +): """Generate predictions of files in a folder with a trained model.""" # Allows us to call this function directly from command-line @@ -67,7 +64,8 @@ def predictions_from_folder( trainer = load_trainer(train_id) trainer.generate_predictions_from_path( - folder_path, is_video, source=source, model_domain=model_domain) + folder_path, is_video, source=source, model_domain=model_domain + ) def predict_examples(train_id=None): @@ -76,21 +74,48 @@ def predict_examples(train_id=None): continue source = example_folder.name - is_video = source not in ('SALICON', 'MIT1003') + is_video = source not in ("SALICON", "MIT1003") - print(f"\nGenerating predictions for {'video' if is_video else 'image'} " - f"folder\n{str(source)}") + print( + f"\nGenerating predictions for {'video' if is_video else 'image'} " + f"folder\n{str(source)}" + ) if is_video: if not example_folder.is_dir(): continue - for video_folder in example_folder.glob('[!.]*'): # ignore hidden files + for video_folder in example_folder.glob("[!.]*"): # ignore hidden files predictions_from_folder( - video_folder, is_video, train_id=train_id, source=source) + video_folder, is_video, train_id=train_id, source=source + ) else: predictions_from_folder( - example_folder, is_video, train_id=train_id, source=source) + example_folder, is_video, train_id=train_id, source=source + ) + + +def predict_image( + image_path: str, + out_image_path: str = None, +): + """a minimal working example of running inference on a single image""" + from PIL import Image + import numpy as np + + assert os.path.isfile( + image_path + ), f"provide image_path ({image_path}) is not a valid file" + im = Image.open(image_path) + smap = unisal.demo.predict_image(img_rgb=np.array(im)) + + if out_image_path: + output_dir = os.path.dirname(out_image_path) + assert os.path.isdir( + output_dir + ), f"output directory {output_dir} does not exist" + Image.fromarray(smap).save(out_image_path) + return smap if __name__ == "__main__": From 0ca366219a28b06ba7a2ec0143a2cba0e96b8aea Mon Sep 17 00:00:00 2001 From: ohjho Date: Fri, 19 Jul 2024 18:11:46 -0400 Subject: [PATCH 5/6] added new method inference() for trainer --- unisal/train.py | 138 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 137 insertions(+), 1 deletion(-) diff --git a/unisal/train.py b/unisal/train.py index 506ec6a..2f2e70c 100644 --- a/unisal/train.py +++ b/unisal/train.py @@ -16,7 +16,6 @@ import torch import torch.nn.functional as F import cv2 -from tensorboardX import SummaryWriter import numpy as np from . import salience_metrics @@ -686,6 +685,141 @@ def other_maps(): if return_predictions: return pred_seq + def inference( + self, + img_rgb: np.array, + source: str = "SALICON", # images: ['SALICON', 'MIT1003', 'MIT300'], videos: ['DHF1K',"UCFSports", "Hollywood"] + # vid_nr, # image index, 0 for video + # dataset=None, + # phase=None, + smooth_method=None, + # metrics=None, + # save_predictions=False, + # return_predictions=False, + seq_len_factor=0.5, + random_seed=27, + # n_aucs_maps=10, + # auc_portion=1.0, + # model_domain=None, + # folder_suffix=None, + ): + """run inference on numpy array and return a mask""" + + # if dataset is None: + # assert phase, "Must provide either dataset or phase" + # dataset = self.get_dataset(phase, source) + + if random_seed is not None: + random.seed(random_seed) + + # Get the original resolution (h,w) + # target_size = dataset.target_size_dict[vid_nr] + target_size = img_rgb.shape[:2] + + # Set the keyword arguments for the forward pass + model_kwargs = {"source": source, "target_size": target_size} + + # Make sure that the model was trained on the selected domain + if model_kwargs["source"] not in self.model.sources: + print( + f"\nWarning! Evaluation bn source {model_kwargs['source']} " + f"doesn't exist in model.\n Using {self.model.sources[0]}." + ) + model_kwargs["source"] = self.model.sources[0] + + # Select static (image) or dynamic (video) forward pass for Bypass-RNN + model_kwargs.update( + {"static": model_kwargs["source"] in ("SALICON", "MIT300", "MIT1003")} + ) + + # Set additional parameters + static_data = source in ("SALICON", "MIT300", "MIT1003") + assert static_data, f"trainer.inference currently only support static_data" + if static_data: + smooth_method = None + # auc_portion = 1.0 + n_images = 1 + frame_modulo = 1 + else: + # video mode + n_images = dataset.n_images_dict[vid_nr] + frame_modulo = dataset.frame_modulo + + # Prepare the model + self.model.to(self.device) + self.model.eval() + torch.cuda.empty_cache() + + # Prepare the prediction and target tensors + results_size = (1, n_images, 1, *model_kwargs["target_size"]) + pred_seq = torch.full(results_size, 0, dtype=torch.float) + sal_seq, fix_seq = None, None + + # Define input sequence length + # seq_len = self.batch_size * self.get_dataset('train').seq_len * \ + # seq_len_factor + seq_len = int(12 * seq_len_factor) + + # Iterate over different offsets to create the interleaved predictions + for offset in range(min(frame_modulo, n_images)): + + # Get the data + if not static_data: + # video mode + sample = dataset.get_data(vid_nr, offset + 1) + sample = sample[:-1] + else: + # sample = dataset.get_data(vid_nr) + sample = [1], data.im_preprocess(img_rgb=img_rgb) + + # Preprocess the data + if len(sample) >= 4: + # if len(sample) == 5: + # sample = sample[:-1] + frame_nrs, frame_seq, this_sal_seq, this_fix_seq = sample + this_sal_seq = this_sal_seq.unsqueeze(0).float() + this_fix_seq = this_fix_seq.unsqueeze(0) + if frame_seq.dim() == 3: + frame_seq = frame_seq.unsqueeze(0) + this_sal_seq = this_sal_seq.unsqueeze(0) + this_fix_seq = this_fix_seq.unsqueeze(0) + else: + frame_nrs, frame_seq = sample + this_sal_seq, this_fix_seq = None, None + if frame_seq.dim() == 3: + frame_seq = frame_seq.unsqueeze(0) + frame_seq = frame_seq.unsqueeze(0).float() + frame_idx_array = [f_nr - 1 for f_nr in frame_nrs] + frame_seq = frame_seq.to(self.device) + + # Run all sequences of the current offset + h0 = [None] + for start in range(0, len(frame_idx_array), seq_len): + + # Select the frames + end = min(len(frame_idx_array), start + seq_len) + this_frame_seq = frame_seq[:, start:end, :, :, :] + this_frame_idx_array = frame_idx_array[start:end] + + # Forward pass + this_pred_seq, h0 = self.model( + this_frame_seq, h0=h0, return_hidden=True, **model_kwargs + ) + + # Insert the predictions into the prediction array + this_pred_seq = this_pred_seq.cpu() + pred_seq[:, this_frame_idx_array, :, :, :] = this_pred_seq + + # Assert non-empty predictions + assert torch.min(pred_seq.exp().sum(-1).sum(-1)) > 0 + + # Optionally smooth the interleaved sequences + if smooth_method is not None: + pred_seq = pred_seq.numpy() + pred_seq = utils.smooth_sequence(pred_seq, smooth_method) + pred_seq = torch.from_numpy(pred_seq).float() + return pred_seq + @staticmethod def eval_sequences( pred_seq, sal_seq, fix_seq, metrics, other_maps=None, auc_portion=1.0 @@ -1353,6 +1487,8 @@ def add_scalar(self, key, value, epoch, this_tboard=True): @property def writer(self): """Return TensorboardX writer""" + from tensorboardX import SummaryWriter + if self.tboard and self._writer is None: if self.data_sources == ("MIT1003",): log_dir = self.mit1003_dir From 7a74dd8d13c78dd7550a5232c7094b3be683e0ba Mon Sep 17 00:00:00 2001 From: ohjho Date: Fri, 19 Jul 2024 18:12:33 -0400 Subject: [PATCH 6/6] added new module to demo inference --- unisal/__init__.py | 2 +- unisal/demo.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) create mode 100644 unisal/demo.py diff --git a/unisal/__init__.py b/unisal/__init__.py index 45d1627..0bdf348 100644 --- a/unisal/__init__.py +++ b/unisal/__init__.py @@ -1 +1 @@ -from . import train, data, model, models, utils +from . import train, data, model, models, utils, demo diff --git a/unisal/demo.py b/unisal/demo.py new file mode 100644 index 0000000..9e42e81 --- /dev/null +++ b/unisal/demo.py @@ -0,0 +1,37 @@ +from pathlib import Path +import os +from . import train +from . import data + +TRAINER_ZOO = {} +PROJECT_DIR = os.path.dirname(os.path.realpath(__file__)) + + +def load_trainer(train_id: str = "pretrained_unisal"): + """Instantiate Trainer class from saved kwargs.""" + train_dir = Path(os.environ["TRAIN_DIR"]) + train_dir = train_dir / train_id + print(f"initalizing trainer from {train_dir}...") + return train.Trainer.init_from_cfg_dir(train_dir) + + +def get_trainer(model_path: str): + """get trainer from memory if already loaded, else load it""" + global TRAINER_ZOO + if model_path not in TRAINER_ZOO.keys(): + trainer = load_trainer() + trainer.model.load_weights_from_path(model_path) + TRAINER_ZOO[model_path] = trainer + return TRAINER_ZOO[model_path] + + +def predict_image( + img_rgb, + model_path: str = os.path.join( + PROJECT_DIR, "../training_runs/pretrained_unisal/weights_best.pth" + ), +): + trainer = get_trainer(model_path) + pred_seq = trainer.inference(img_rgb=img_rgb) + smap = data.smap_postprocess(pred_seq[:, 0, ...]) + return smap