From bb6573176c5dc3435772f37d388271689fba2478 Mon Sep 17 00:00:00 2001 From: John Ho Date: Tue, 16 Jul 2024 08:00:25 -0400 Subject: [PATCH] enabled CPU inference --- unisal/model.py | 351 +++++++++------- unisal/models/MobileNetV2.py | 100 +++-- unisal/models/cgru.py | 165 +++++--- unisal/train.py | 785 +++++++++++++++++++---------------- 4 files changed, 825 insertions(+), 576 deletions(-) diff --git a/unisal/model.py b/unisal/model.py index bcb0244..3137f59 100644 --- a/unisal/model.py +++ b/unisal/model.py @@ -11,6 +11,11 @@ from .models.cgru import ConvGRU from .models.MobileNetV2 import MobileNetV2, InvertedResidual +DEFAULT_DEVICE = ( + torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +) +print(f"torch device set to: {DEFAULT_DEVICE}") + def get_model(): """Return the model class""" @@ -24,41 +29,50 @@ def forward(self, *input): raise NotImplementedError def save_weights(self, directory, name): - torch.save(self.state_dict(), directory / f'weights_{name}.pth') + torch.save(self.state_dict(), directory / f"weights_{name}.pth") def load_weights(self, directory, name): - self.load_state_dict(torch.load(directory / f'weights_{name}.pth')) + self.load_state_dict( + torch.load(directory / f"weights_{name}.pth", map_location=DEFAULT_DEVICE) + ) def load_best_weights(self, directory): - self.load_state_dict(torch.load(directory / f'weights_best.pth')) + self.load_state_dict( + torch.load(directory / f"weights_best.pth", map_location=DEFAULT_DEVICE) + ) def load_epoch_checkpoint(self, directory, epoch): """Load state_dict from a Trainer checkpoint at a specific epoch""" chkpnt = torch.load(directory / f"chkpnt_epoch{epoch:04d}.pth") - self.load_state_dict(chkpnt['model_state_dict']) + self.load_state_dict(chkpnt["model_state_dict"]) def load_checkpoint(self, file): """Load state_dict from a specific Trainer checkpoint""" """Load """ chkpnt = torch.load(file) - self.load_state_dict(chkpnt['model_state_dict']) + self.load_state_dict(chkpnt["model_state_dict"]) def load_last_chkpnt(self, directory): """Load state_dict from the last Trainer checkpoint""" - last_chkpnt = sorted(list(directory.glob('chkpnt_epoch*.pth')))[-1] + last_chkpnt = sorted(list(directory.glob("chkpnt_epoch*.pth")))[-1] self.load_checkpoint(last_chkpnt) # Set default backbone CNN kwargs default_cnn_cfg = { - 'widen_factor': 1., 'pretrained': True, 'input_channel': 32, - 'last_channel': 1280} + "widen_factor": 1.0, + "pretrained": True, + "input_channel": 32, + "last_channel": 1280, +} # Set default RNN kwargs default_rnn_cfg = { - 'kernel_size': (3, 3), 'gate_ksize': (3, 3), - 'dropout': (False, True, False), 'drop_prob': (0.2, 0.2, 0.2), - 'mobile': True, + "kernel_size": (3, 3), + "gate_ksize": (3, 3), + "dropout": (False, True, False), + "drop_prob": (0.2, 0.2, 0.2), + "mobile": True, } @@ -72,11 +86,11 @@ class DomainBatchNorm2d(nn.Module): def __init__(self, num_features, sources, momenta=None, **kwargs): """ - num_features: Number of channels - sources: List of sources - momenta: List of BatchNorm momenta corresponding to the sources. - Default is 0.1 for each source. - kwargs: Other BatchNorm kwargs + num_features: Number of channels + sources: List of sources + momenta: List of BatchNorm momenta corresponding to the sources. + Default is 0.1 for each source. + kwargs: Other BatchNorm kwargs """ super().__init__() self.sources = sources @@ -85,13 +99,14 @@ def __init__(self, num_features, sources, momenta=None, **kwargs): if momenta is None: momenta = [0.1] * len(sources) self.momenta = momenta - if 'momentum' in kwargs: - del kwargs['momentum'] + if "momentum" in kwargs: + del kwargs["momentum"] # Instantiate the BN modules for src, mnt in zip(sources, self.momenta): - self.__setattr__(f"bn_{src}", nn.BatchNorm2d( - num_features, momentum=mnt, **kwargs)) + self.__setattr__( + f"bn_{src}", nn.BatchNorm2d(num_features, momentum=mnt, **kwargs) + ) # Prepare the self.this_source attribute that will be updated at runtime # by the model @@ -131,35 +146,37 @@ class UNISAL(BaseModel, utils.KwConfigClass): verbose: Verbosity level. """ - def __init__(self, - rnn_input_channels=256, rnn_hidden_channels=256, - cnn_cfg=None, - rnn_cfg=None, - res_rnn=True, - bypass_rnn=True, - drop_probs=(0.0, 0.6, 0.6), - gaussian_init='manual', - n_gaussians=16, - smoothing_ksize=41, - bn_momentum=0.01, - static_bn_momentum=0.1, - sources=('DHF1K', 'Hollywood', 'UCFSports', 'SALICON'), - ds_bn=True, - ds_adaptation=True, - ds_smoothing=True, - ds_gaussians=True, - verbose=1, - ): + def __init__( + self, + rnn_input_channels=256, + rnn_hidden_channels=256, + cnn_cfg=None, + rnn_cfg=None, + res_rnn=True, + bypass_rnn=True, + drop_probs=(0.0, 0.6, 0.6), + gaussian_init="manual", + n_gaussians=16, + smoothing_ksize=41, + bn_momentum=0.01, + static_bn_momentum=0.1, + sources=("DHF1K", "Hollywood", "UCFSports", "SALICON"), + ds_bn=True, + ds_adaptation=True, + ds_smoothing=True, + ds_gaussians=True, + verbose=1, + ): super().__init__() # Check inputs - assert(gaussian_init in ('random', 'manual')) + assert gaussian_init in ("random", "manual") # Bypass-RNN requires residual RNN connection if bypass_rnn: assert res_rnn # Manual Gaussian initialization generates 16 Gaussians - if n_gaussians > 0 and gaussian_init == 'manual': + if n_gaussians > 0 and gaussian_init == "manual": n_gaussians = 16 self.rnn_input_channels = rnn_input_channels @@ -190,64 +207,98 @@ def __init__(self, # Initialize Post-CNN module with optional dropout post_cnn = [ - ('inv_res', InvertedResidual( - self.cnn.out_channels + n_gaussians, - rnn_input_channels, 1, 1, bn_momentum=bn_momentum, - )) + ( + "inv_res", + InvertedResidual( + self.cnn.out_channels + n_gaussians, + rnn_input_channels, + 1, + 1, + bn_momentum=bn_momentum, + ), + ) ] if self.drop_probs[0] > 0: - post_cnn.insert(0, ( - 'dropout', nn.Dropout2d(self.drop_probs[0], inplace=False) - )) + post_cnn.insert( + 0, ("dropout", nn.Dropout2d(self.drop_probs[0], inplace=False)) + ) self.post_cnn = nn.Sequential(OrderedDict(post_cnn)) # Initialize Bypass-RNN if training on dynamic data - if sources != ('SALICON',) or not self.bypass_rnn: + if sources != ("SALICON",) or not self.bypass_rnn: self.rnn = ConvGRU( rnn_input_channels, hidden_channels=[rnn_hidden_channels], batchnorm=self.get_bn_module, - **self.rnn_cfg) - self.post_rnn = self.conv_1x1_bn( - rnn_hidden_channels, rnn_input_channels) + **self.rnn_cfg, + ) + self.post_rnn = self.conv_1x1_bn(rnn_hidden_channels, rnn_input_channels) # Initialize first upsampling module US1 - self.upsampling_1 = nn.Sequential(OrderedDict([ - ('us1', self.upsampling(2)), - ])) + self.upsampling_1 = nn.Sequential( + OrderedDict( + [ + ("us1", self.upsampling(2)), + ] + ) + ) # Number of channels at the 2x scale channels_2x = 128 # Initialize Skip-2x module self.skip_2x = self.make_skip_connection( - self.cnn.feat_2x_channels, channels_2x, 2, self.drop_probs[1]) + self.cnn.feat_2x_channels, channels_2x, 2, self.drop_probs[1] + ) # Initialize second upsampling module US2 - self.upsampling_2 = nn.Sequential(OrderedDict([ - ('inv_res', InvertedResidual( - rnn_input_channels + channels_2x, - channels_2x, 1, 2, batchnorm=self.get_bn_module)), - ('us2', self.upsampling(2)), - ])) + self.upsampling_2 = nn.Sequential( + OrderedDict( + [ + ( + "inv_res", + InvertedResidual( + rnn_input_channels + channels_2x, + channels_2x, + 1, + 2, + batchnorm=self.get_bn_module, + ), + ), + ("us2", self.upsampling(2)), + ] + ) + ) # Number of channels at the 4x scale channels_4x = 64 # Initialize Skip-4x module self.skip_4x = self.make_skip_connection( - self.cnn.feat_4x_channels, channels_4x, 2, self.drop_probs[2]) + self.cnn.feat_4x_channels, channels_4x, 2, self.drop_probs[2] + ) # Initialize Post-US2 module - self.post_upsampling_2= nn.Sequential(OrderedDict([ - ('inv_res', InvertedResidual( - channels_2x + channels_4x, channels_4x, 1, 2, - batchnorm=self.get_bn_module)), - ])) + self.post_upsampling_2 = nn.Sequential( + OrderedDict( + [ + ( + "inv_res", + InvertedResidual( + channels_2x + channels_4x, + channels_4x, + 1, + 2, + batchnorm=self.get_bn_module, + ), + ), + ] + ) + ) # Initialize domain-specific modules for source_str in self.sources: - source_str = f'_{source_str}'.lower() + source_str = f"_{source_str}".lower() # Initialize learned Gaussian priors parameters if n_gaussians > 0: @@ -255,24 +306,23 @@ def __init__(self, # Initialize Adaptation self.__setattr__( - 'adaptation' + (source_str if self.ds_adaptation else ''), - nn.Sequential(*[ - nn.Conv2d(channels_4x, 1, 1, bias=True) - ])) + "adaptation" + (source_str if self.ds_adaptation else ""), + nn.Sequential(*[nn.Conv2d(channels_4x, 1, 1, bias=True)]), + ) # Initialize Smoothing smoothing = nn.Conv2d( - 1, 1, kernel_size=smoothing_ksize, padding=0, bias=False) + 1, 1, kernel_size=smoothing_ksize, padding=0, bias=False + ) with torch.no_grad(): gaussian = self._make_gaussian_maps( - smoothing.weight.data, - torch.Tensor([[[0.5, -2]] * 2]) + smoothing.weight.data, torch.Tensor([[[0.5, -2]] * 2]) ) gaussian /= gaussian.sum() smoothing.weight.data = gaussian self.__setattr__( - 'smoothing' + (source_str if self.ds_smoothing else ''), - smoothing) + "smoothing" + (source_str if self.ds_smoothing else ""), smoothing + ) if self.verbose > 1: pprint.pprint(self.asdict(), width=1) @@ -292,52 +342,58 @@ def this_source(self, source): def get_bn_module(self, num_features, **kwargs): """Return BatchNorm class (domain-specific or domain-invariant).""" - momenta = [self.bn_momentum if src != 'SALICON' - else self.static_bn_momentum for src in self.sources] + momenta = [ + self.bn_momentum if src != "SALICON" else self.static_bn_momentum + for src in self.sources + ] if self.ds_bn: return DomainBatchNorm2d( - num_features, self.sources, momenta=momenta, **kwargs) + num_features, self.sources, momenta=momenta, **kwargs + ) else: return nn.BatchNorm2d(num_features, **kwargs) # @staticmethod def upsampling(self, factor): """Return upsampling module.""" - return nn.Sequential(*[ - nn.Upsample( - scale_factor=factor, mode='bilinear', align_corners=False), - ]) + return nn.Sequential( + *[ + nn.Upsample(scale_factor=factor, mode="bilinear", align_corners=False), + ] + ) - def set_gaussians(self, source_str, prefix='coarse_'): + def set_gaussians(self, source_str, prefix="coarse_"): """Set Gaussian parameters.""" - suffix = source_str if self.ds_gaussians else '' + suffix = source_str if self.ds_gaussians else "" self.__setattr__( - prefix + 'gaussians' + suffix, - self._initialize_gaussians(self.n_gaussians)) + prefix + "gaussians" + suffix, self._initialize_gaussians(self.n_gaussians) + ) def _initialize_gaussians(self, n_gaussians): """ Return initialized Gaussian parameters. Dimensions: [idx, y/x, mu/logstd]. """ - if self.gaussian_init == 'manual': - gaussians = torch.Tensor([ - list(product([0.25, 0.5, 0.75], repeat=2)) + - [(0.5, 0.25), (0.5, 0.5), (0.5, 0.75)] + - [(0.25, 0.5), (0.5, 0.5), (0.75, 0.5)] + - [(0.5, 0.5)], - [(-1.5, -1.5)] * 9 + [(0, -1.5)] * 3 + [(-1.5, 0)] * 3 + - [(0, 0)], - ]).permute(1, 2, 0) - - elif self.gaussian_init == 'random': + if self.gaussian_init == "manual": + gaussians = torch.Tensor( + [ + list(product([0.25, 0.5, 0.75], repeat=2)) + + [(0.5, 0.25), (0.5, 0.5), (0.5, 0.75)] + + [(0.25, 0.5), (0.5, 0.5), (0.75, 0.5)] + + [(0.5, 0.5)], + [(-1.5, -1.5)] * 9 + [(0, -1.5)] * 3 + [(-1.5, 0)] * 3 + [(0, 0)], + ] + ).permute(1, 2, 0) + + elif self.gaussian_init == "random": with torch.no_grad(): - gaussians = torch.stack([ - torch.randn( - n_gaussians, 2, dtype=torch.float) * .1 + 0.5, - torch.randn( - n_gaussians, 2, dtype=torch.float) * .2 - 1,], - dim=2) + gaussians = torch.stack( + [ + torch.randn(n_gaussians, 2, dtype=torch.float) * 0.1 + 0.5, + torch.randn(n_gaussians, 2, dtype=torch.float) * 0.2 - 1, + ], + dim=2, + ) else: raise NotImplementedError @@ -346,7 +402,7 @@ def _initialize_gaussians(self, n_gaussians): return gaussians @staticmethod - def _make_gaussian_maps(x, gaussians, size=None, scaling=6.): + def _make_gaussian_maps(x, gaussians, size=None, scaling=6.0): """Construct prior maps from Gaussian parameters.""" if size is None: size = x.shape[-2:] @@ -360,15 +416,18 @@ def _make_gaussian_maps(x, gaussians, size=None, scaling=6.): gaussian_maps = [] map_template = torch.ones(*size, dtype=dtype, device=device) meshgrids = torch.meshgrid( - [torch.linspace(0, 1, size[0], dtype=dtype, device=device), - torch.linspace(0, 1, size[1], dtype=dtype, device=device),]) + [ + torch.linspace(0, 1, size[0], dtype=dtype, device=device), + torch.linspace(0, 1, size[1], dtype=dtype, device=device), + ] + ) for gaussian_idx, yx_mu_logstd in enumerate(torch.unbind(gaussians)): map = map_template.clone() for mu_logstd, mgrid in zip(yx_mu_logstd, meshgrids): mu = mu_logstd[0] std = torch.exp(mu_logstd[1]) - map *= torch.exp(-((mgrid - mu) / std) ** 2 / 2) + map *= torch.exp(-(((mgrid - mu) / std) ** 2) / 2) map *= scaling gaussian_maps.append(map) @@ -377,27 +436,36 @@ def _make_gaussian_maps(x, gaussians, size=None, scaling=6.): gaussian_maps = gaussian_maps.unsqueeze(0).expand(bs, -1, -1, -1) return gaussian_maps - def _get_gaussian_maps(self, x, source_str, prefix='coarse_', **kwargs): + def _get_gaussian_maps(self, x, source_str, prefix="coarse_", **kwargs): """Return the constructed Gaussian prior maps.""" - suffix = source_str if self.ds_gaussians else '' + suffix = source_str if self.ds_gaussians else "" gaussians = self.__getattr__(prefix + "gaussians" + suffix) gaussian_maps = self._make_gaussian_maps(x, gaussians, **kwargs) return gaussian_maps # @classmethod - def make_skip_connection(self, input_channels, output_channels, expand_ratio, p, - inplace=False): + def make_skip_connection( + self, input_channels, output_channels, expand_ratio, p, inplace=False + ): """Return skip connection module.""" hidden_channels = round(input_channels * expand_ratio) - return nn.Sequential(OrderedDict([ - ('expansion', self.conv_1x1_bn( - input_channels, hidden_channels)), - ('dropout', nn.Dropout2d(p, inplace=inplace)), - ('reduction', nn.Sequential(*[ - nn.Conv2d(hidden_channels, output_channels, 1), - self.get_bn_module(output_channels), - ])), - ])) + return nn.Sequential( + OrderedDict( + [ + ("expansion", self.conv_1x1_bn(input_channels, hidden_channels)), + ("dropout", nn.Dropout2d(p, inplace=inplace)), + ( + "reduction", + nn.Sequential( + *[ + nn.Conv2d(hidden_channels, output_channels, 1), + self.get_bn_module(output_channels), + ] + ), + ), + ] + ) + ) # @staticmethod def conv_1x1_bn(self, inp, oup): @@ -405,11 +473,18 @@ def conv_1x1_bn(self, inp, oup): return nn.Sequential( nn.Conv2d(inp, oup, 1, 1, 0, bias=False), self.get_bn_module(oup), - nn.ReLU6(inplace=True) + nn.ReLU6(inplace=True), ) - def forward(self, x, target_size=None, h0=None, return_hidden=False, - source='DHF1K', static=None): + def forward( + self, + x, + target_size=None, + h0=None, + return_hidden=False, + source="DHF1K", + static=None, + ): """ Forward pass. @@ -429,9 +504,9 @@ def forward(self, x, target_size=None, h0=None, return_hidden=False, self.this_source = source # Prepare other parameters - source_str = f'_{source.lower()}' + source_str = f"_{source.lower()}" if static is None: - static = x.shape[1] == 1 or self.sources == ('SALICON',) + static = x.shape[1] == 1 or self.sources == ("SALICON",) # Compute backbone CNN features and concatenate with Gaussian prior maps feat_seq_1x = [] @@ -461,8 +536,7 @@ def forward(self, x, target_size=None, h0=None, return_hidden=False, # Decoder output_seq = [] - for idx, im_feat in enumerate( - torch.unbind(feat_seq_1x, dim=1)): + for idx, im_feat in enumerate(torch.unbind(feat_seq_1x, dim=1)): if not (static and self.bypass_rnn): rnn_feat = rnn_feat_seq[:, idx, ...] @@ -479,20 +553,19 @@ def forward(self, x, target_size=None, h0=None, return_hidden=False, im_feat = self.post_upsampling_2(im_feat) im_feat = self.__getattr__( - 'adaptation' + (source_str if self.ds_adaptation else ''))( - im_feat) + "adaptation" + (source_str if self.ds_adaptation else "") + )(im_feat) - im_feat = F.interpolate( - im_feat, size=x.shape[-2:], mode='nearest') + im_feat = F.interpolate(im_feat, size=x.shape[-2:], mode="nearest") - im_feat = F.pad(im_feat, [self.smoothing_ksize // 2] * 4, - mode='replicate') + im_feat = F.pad(im_feat, [self.smoothing_ksize // 2] * 4, mode="replicate") im_feat = self.__getattr__( - 'smoothing' + (source_str if self.ds_smoothing else ''))( - im_feat) + "smoothing" + (source_str if self.ds_smoothing else "") + )(im_feat) im_feat = F.interpolate( - im_feat, size=target_size, mode='bilinear', align_corners=False) + im_feat, size=target_size, mode="bilinear", align_corners=False + ) im_feat = utils.log_softmax(im_feat) output_seq.append(im_feat) diff --git a/unisal/models/MobileNetV2.py b/unisal/models/MobileNetV2.py index 4bc424d..8e87c16 100755 --- a/unisal/models/MobileNetV2.py +++ b/unisal/models/MobileNetV2.py @@ -11,7 +11,7 @@ def conv_bn(inp, oup, stride): return nn.Sequential( nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), - nn.ReLU6(inplace=True) + nn.ReLU6(inplace=True), ) @@ -19,23 +19,32 @@ def conv_1x1_bn(inp, oup): return nn.Sequential( nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), - nn.ReLU6(inplace=True) + nn.ReLU6(inplace=True), ) class InvertedResidual(nn.Module): - def __init__(self, inp, oup, stride, expand_ratio, omit_stride=False, - no_res_connect=False, dropout=0., bn_momentum=0.1, - batchnorm=None): + def __init__( + self, + inp, + oup, + stride, + expand_ratio, + omit_stride=False, + no_res_connect=False, + dropout=0.0, + bn_momentum=0.1, + batchnorm=None, + ): super().__init__() self.out_channels = oup self.stride = stride self.omit_stride = omit_stride - self.use_res_connect = not no_res_connect and\ - self.stride == 1 and inp == oup + self.use_res_connect = not no_res_connect and self.stride == 1 and inp == oup self.dropout = dropout actual_stride = self.stride if not self.omit_stride else 1 if batchnorm is None: + def batchnorm(num_features): return nn.BatchNorm2d(num_features, momentum=bn_momentum) @@ -45,8 +54,15 @@ def batchnorm(num_features): if expand_ratio == 1: modules = [ # dw - nn.Conv2d(hidden_dim, hidden_dim, 3, actual_stride, 1, - groups=hidden_dim, bias=False), + nn.Conv2d( + hidden_dim, + hidden_dim, + 3, + actual_stride, + 1, + groups=hidden_dim, + bias=False, + ), batchnorm(hidden_dim), nn.ReLU6(inplace=True), # pw-linear @@ -63,8 +79,15 @@ def batchnorm(num_features): batchnorm(hidden_dim), nn.ReLU6(inplace=True), # dw - nn.Conv2d(hidden_dim, hidden_dim, 3, actual_stride, 1, - groups=hidden_dim, bias=False), + nn.Conv2d( + hidden_dim, + hidden_dim, + 3, + actual_stride, + 1, + groups=hidden_dim, + bias=False, + ), batchnorm(hidden_dim), nn.ReLU6(inplace=True), # pw-linear @@ -86,7 +109,7 @@ def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) + m.weight.data.normal_(0, math.sqrt(2.0 / n)) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): @@ -99,8 +122,9 @@ def _initialize_weights(self): class MobileNetV2(nn.Module): - def __init__(self, widen_factor=1., pretrained=True, - last_channel=None, input_channel=32): + def __init__( + self, widen_factor=1.0, pretrained=True, last_channel=None, input_channel=32 + ): super().__init__() self.widen_factor = widen_factor self.pretrained = pretrained @@ -127,33 +151,45 @@ def __init__(self, widen_factor=1., pretrained=True, output_channel = int(c * widen_factor) for i in range(n): if i == 0: - self.features.append(block( - input_channel, output_channel, s, expand_ratio=t, - omit_stride=True)) + self.features.append( + block( + input_channel, + output_channel, + s, + expand_ratio=t, + omit_stride=True, + ) + ) else: - self.features.append(block( - input_channel, output_channel, 1, expand_ratio=t)) + self.features.append( + block(input_channel, output_channel, 1, expand_ratio=t) + ) input_channel = output_channel # building last several layers if self.last_channel is not None: - output_channel = int(self.last_channel * widen_factor)\ - if widen_factor > 1.0 else self.last_channel + output_channel = ( + int(self.last_channel * widen_factor) + if widen_factor > 1.0 + else self.last_channel + ) self.features.append(conv_1x1_bn(input_channel, output_channel)) # make it nn.Sequential self.features = nn.Sequential(*self.features) self.out_channels = output_channel - self.feat_1x_channels = int( - interverted_residual_setting[-1][1] * widen_factor) - self.feat_2x_channels = int( - interverted_residual_setting[-2][1] * widen_factor) - self.feat_4x_channels = int( - interverted_residual_setting[-4][1] * widen_factor) - self.feat_8x_channels = int( - interverted_residual_setting[-5][1] * widen_factor) + self.feat_1x_channels = int(interverted_residual_setting[-1][1] * widen_factor) + self.feat_2x_channels = int(interverted_residual_setting[-2][1] * widen_factor) + self.feat_4x_channels = int(interverted_residual_setting[-4][1] * widen_factor) + self.feat_8x_channels = int(interverted_residual_setting[-5][1] * widen_factor) if self.pretrained: state_dict = torch.load( - Path(__file__).resolve().parent / 'weights/mobilenet_v2.pth.tar') + Path(__file__).resolve().parent / "weights/mobilenet_v2.pth.tar", + map_location=( + torch.device("cuda:0") + if torch.cuda.is_available() + else torch.device("cpu") + ), + ) self.load_state_dict(state_dict, strict=False) else: self._initialize_weights() @@ -167,7 +203,7 @@ def forward(self, x): feat_4x = x.clone() elif idx == 14: feat_2x = x.clone() - if idx > 0 and hasattr(module, 'stride') and module.stride != 1: + if idx > 0 and hasattr(module, "stride") and module.stride != 1: x = x[..., ::2, ::2] return x, feat_2x, feat_4x @@ -176,7 +212,7 @@ def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - m.weight.data.normal_(0, math.sqrt(2. / n)) + m.weight.data.normal_(0, math.sqrt(2.0 / n)) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): diff --git a/unisal/models/cgru.py b/unisal/models/cgru.py index bb2f976..128dd26 100644 --- a/unisal/models/cgru.py +++ b/unisal/models/cgru.py @@ -39,11 +39,25 @@ class ConvGRUCell(nn.Module): mobile: If True, MobileNet-style convolutions are used. """ - def __init__(self, input_ch, hidden_ch, kernel_size, gate_ksize=(1, 1), - bias=True, norm='', norm_momentum=0.1, affine_norm=True, - batchnorm=None, gain=1, drop_prob=(0., 0., 0.), - do_mode='recurrent', r_bias=0., z_bias=0., mobile=False, - **kwargs): + def __init__( + self, + input_ch, + hidden_ch, + kernel_size, + gate_ksize=(1, 1), + bias=True, + norm="", + norm_momentum=0.1, + affine_norm=True, + batchnorm=None, + gain=1, + drop_prob=(0.0, 0.0, 0.0), + do_mode="recurrent", + r_bias=0.0, + z_bias=0.0, + mobile=False, + **kwargs + ): super().__init__() self.input_ch = input_ch @@ -51,7 +65,7 @@ def __init__(self, input_ch, hidden_ch, kernel_size, gate_ksize=(1, 1), self.kernel_size = kernel_size self.gate_ksize = gate_ksize self.mobile = mobile - self.kwargs = {'init': 'xavier_uniform_'} + self.kwargs = {"init": "xavier_uniform_"} self.kwargs.update(kwargs) # Process normalization arguments @@ -61,11 +75,13 @@ def __init__(self, input_ch, hidden_ch, kernel_size, gate_ksize=(1, 1), self.batchnorm = batchnorm self.norm_kwargs = None if self.batchnorm is not None: - self.norm = 'batch' + self.norm = "batch" elif self.norm: self.norm_kwargs = { - 'affine': self.affine_norm, 'track_running_stats': True, - 'momentum': self.norm_momentum} + "affine": self.affine_norm, + "track_running_stats": True, + "momentum": self.norm_momentum, + } # Prepare normalization modules if self.norm: @@ -79,12 +95,12 @@ def __init__(self, input_ch, hidden_ch, kernel_size, gate_ksize=(1, 1), # Prepare dropout self.drop_prob = drop_prob self.do_mode = do_mode - if self.do_mode == 'recurrent': + if self.do_mode == "recurrent": # Prepare dropout masks if using recurrent dropout for idx, mask in self.yield_drop_masks(): self.register_buffer(self.mask_name(idx), mask) - elif self.do_mode != 'naive': - raise ValueError('Unknown dropout mode ', self.do_mode) + elif self.do_mode != "naive": + raise ValueError("Unknown dropout mode ", self.do_mode) # Instantiate the main weight matrices self.w_r = self._conv2d(self.input_ch, self.gate_ksize, bias=False) @@ -115,9 +131,11 @@ def __init__(self, input_ch, hidden_ch, kernel_size, gate_ksize=(1, 1), def set_weights(self): """Initialize the parameters""" + def gain_from_ksize(ksize): n = ksize[0] * ksize[1] * self.hidden_ch - return math.sqrt(2. / n) + return math.sqrt(2.0 / n) + with torch.no_grad(): if not self.mobile: if self.gain < 0: @@ -125,7 +143,7 @@ def gain_from_ksize(ksize): gain_2 = gain_from_ksize(self.gate_ksize) else: gain_1 = gain_2 = self.gain - init_fn = getattr(init, self.kwargs['init']) + init_fn = getattr(init, self.kwargs["init"]) init_fn(self.w_r.weight, gain=gain_2) init_fn(self.u_r.weight, gain=gain_2) init_fn(self.w_z.weight, gain=gain_2) @@ -197,7 +215,7 @@ def forward(self, x, h_tm1): @staticmethod def mask_name(idx): - return 'drop_mask_{}'.format(idx) + return "drop_mask_{}".format(idx) def set_drop_masks(self): """Set the dropout masks for the current sequence""" @@ -210,27 +228,29 @@ def yield_drop_masks(self): n_channels = (self.input_ch, self.hidden_ch, self.hidden_ch) for idx, p in enumerate(self.drop_prob): if p > 0: - yield (idx, self.generate_do_mask( - p, n_masks[idx], n_channels[idx])) + yield (idx, self.generate_do_mask(p, n_masks[idx], n_channels[idx])) @staticmethod def generate_do_mask(p, n, ch): """Generate a dropout mask for recurrent dropout""" with torch.no_grad(): mask = Bernoulli(torch.full((n, ch), 1 - p)).sample() / (1 - p) - mask = mask.requires_grad_(False).cuda() + mask = ( + mask.requires_grad_(False).cuda() + if torch.cuda.is_available() + else mask.requires_grad_(False).cpu() + ) return mask def apply_dropout(self, x, idx, sub_idx): """Apply recurrent or naive dropout""" if self.training and self.drop_prob[idx] > 0 and idx != 2: - if self.do_mode == 'recurrent': + if self.do_mode == "recurrent": x = x.clone() * torch.reshape( - getattr(self, self.mask_name(idx)) - [sub_idx, :], (1, -1, 1, 1)) - elif self.do_mode == 'naive': - x = f.dropout2d( - x, self.drop_prob[idx], self.training, inplace=False) + getattr(self, self.mask_name(idx))[sub_idx, :], (1, -1, 1, 1) + ) + elif self.do_mode == "naive": + x = f.dropout2d(x, self.drop_prob[idx], self.training, inplace=False) else: x = x.clone() return x @@ -240,9 +260,9 @@ def get_norm_module(self, channels): norm_module = None if self.batchnorm is not None: norm_module = self.batchnorm(channels) - elif self.norm == 'instance': + elif self.norm == "instance": norm_module = nn.InstanceNorm2d(channels, **self.norm_kwargs) - elif self.norm == 'batch': + elif self.norm == "batch": norm_module = nn.BatchNorm2d(channels, **self.norm_kwargs) return norm_module @@ -253,24 +273,38 @@ def _conv2d(self, in_channels, kernel_size, bias=True): """ padding = tuple(k_size // 2 for k_size in kernel_size) if not self.mobile or kernel_size == (1, 1): - return nn.Conv2d(in_channels, self.hidden_ch, kernel_size, - padding=padding, bias=bias) + return nn.Conv2d( + in_channels, self.hidden_ch, kernel_size, padding=padding, bias=bias + ) else: - return nn.Sequential(OrderedDict([ - ('conv_dw', nn.Conv2d( - in_channels, in_channels, kernel_size=kernel_size, - padding=padding, groups=in_channels, bias=False)), - ('sep_bn', self.get_norm_module(in_channels)), - ('sep_relu', nn.ReLU6()), - ('conv_sep', nn.Conv2d( - in_channels, self.hidden_ch, 1, bias=bias)), - ])) + return nn.Sequential( + OrderedDict( + [ + ( + "conv_dw", + nn.Conv2d( + in_channels, + in_channels, + kernel_size=kernel_size, + padding=padding, + groups=in_channels, + bias=False, + ), + ), + ("sep_bn", self.get_norm_module(in_channels)), + ("sep_relu", nn.ReLU6()), + ( + "conv_sep", + nn.Conv2d(in_channels, self.hidden_ch, 1, bias=bias), + ), + ] + ) + ) def _init_hidden(self, input_, cuda=True): """Initialize the hidden state""" batch_size, _, height, width = input_.data.size() - prev_state = torch.zeros( - batch_size, self.hidden_ch, height, width) + prev_state = torch.zeros(batch_size, self.hidden_ch, height, width) if cuda: prev_state = prev_state.cuda() return prev_state @@ -278,10 +312,16 @@ def _init_hidden(self, input_, cuda=True): class ConvGRU(nn.Module): - def __init__(self, input_channels=None, hidden_channels=None, - kernel_size=(3, 3), gate_ksize=(1, 1), - dropout=(False, False, False), drop_prob=(0.5, 0.5, 0.5), - **kwargs): + def __init__( + self, + input_channels=None, + hidden_channels=None, + kernel_size=(3, 3), + gate_ksize=(1, 1), + dropout=(False, False, False), + drop_prob=(0.5, 0.5, 0.5), + **kwargs + ): """ Generates a multi-layer convolutional GRU. Preserves spatial dimensions across cells, only altering depth. @@ -313,8 +353,10 @@ def __init__(self, input_channels=None, hidden_channels=None, self.gate_ksize = self._extend_for_multilayer(gate_ksize) self.dropout = self._extend_for_multilayer(dropout) drop_prob = self._extend_for_multilayer(drop_prob) - self.drop_prob = [tuple(dp_ if do_ else 0. for dp_, do_ in zip(dp, do)) - for dp, do in zip(drop_prob, self.dropout)] + self.drop_prob = [ + tuple(dp_ if do_ else 0.0 for dp_, do_ in zip(dp, do)) + for dp, do in zip(drop_prob, self.dropout) + ] self.kwargs = kwargs cell_list = [] @@ -322,13 +364,19 @@ def __init__(self, input_channels=None, hidden_channels=None, if idx < self.num_layers - 1: # Switch output dropout off for hidden layers. # Otherwise it would confict with input dropout. - this_drop_prob = self.drop_prob[idx][:2] + (0.,) + this_drop_prob = self.drop_prob[idx][:2] + (0.0,) else: this_drop_prob = self.drop_prob[idx] - cell_list.append(ConvGRUCell( - self.input_channels[idx], self.hidden_channels[idx], - self.kernel_size[idx], drop_prob=this_drop_prob, - gate_ksize=self.gate_ksize[idx], **kwargs)) + cell_list.append( + ConvGRUCell( + self.input_channels[idx], + self.hidden_channels[idx], + self.kernel_size[idx], + drop_prob=this_drop_prob, + gate_ksize=self.gate_ksize[idx], + **kwargs + ) + ) self.cell_list = nn.ModuleList(cell_list) def forward(self, input_tensor, hidden=None): @@ -350,8 +398,7 @@ def forward(self, input_tensor, hidden=None): for t, x in enumerate(iterator): for layer_idx in range(self.num_layers): - if self.cell_list[layer_idx].do_mode == 'recurrent'\ - and t == 0: + if self.cell_list[layer_idx].do_mode == "recurrent" and t == 0: self.cell_list[layer_idx].set_drop_masks() (x, h) = self.cell_list[layer_idx](x, hidden[layer_idx]) hidden[layer_idx] = h.clone() @@ -362,14 +409,18 @@ def forward(self, input_tensor, hidden=None): @staticmethod def _check_kernel_size_consistency(kernel_size): - if not (isinstance(kernel_size, tuple) or - (isinstance(kernel_size, list) and - all([isinstance(elem, tuple) for elem in kernel_size]))): - raise ValueError('`kernel_size` must be tuple or list of tuples') + if not ( + isinstance(kernel_size, tuple) + or ( + isinstance(kernel_size, list) + and all([isinstance(elem, tuple) for elem in kernel_size]) + ) + ): + raise ValueError("`kernel_size` must be tuple or list of tuples") def _extend_for_multilayer(self, param): if not isinstance(param, list): param = [param] * self.num_layers else: - assert(len(param) == self.num_layers) + assert len(param) == self.num_layers return param diff --git a/unisal/train.py b/unisal/train.py index 3dc784c..506ec6a 100644 --- a/unisal/train.py +++ b/unisal/train.py @@ -88,47 +88,48 @@ class Trainer(utils.KwConfigClass): """ - phases = ('train', 'valid') - all_data_sources = ('DHF1K', 'Hollywood', 'UCFSports', 'SALICON') - - def __init__(self, - num_epochs=16, - optim_algo='SGD', - momentum=0.9, - lr=0.04, - lr_scheduler='ExponentialLR', - lr_gamma=0.8, - weight_decay=1e-4, - cnn_weight_decay=1e-5, - grad_clip=2., - loss_metrics=('kld', 'nss', 'cc'), - loss_weights=(1, -0.1, -0.1), - data_sources=('DHF1K', 'Hollywood', 'UCFSports', 'SALICON'), - batch_size=4, - salicon_batch_size=32, - hollywood_batch_size=4, - ucfsports_batch_size=4, - salicon_weight=.5, - hollywood_weight=1., - ucfsports_weight=1., - data_cfg=None, - salicon_cfg=None, - hollywood_cfg=None, - ucfsports_cfg=None, - shuffle_datasets=True, - cnn_lr_factor=0.1, - train_cnn_after=2, - cnn_eval=True, - model_cfg=None, - prefix=None, - suffix='unisal', - num_workers=6, - chkpnt_warmup=3, - chkpnt_epochs=2, - tboard=True, - debug=False, - new_instance=True, - ): + phases = ("train", "valid") + all_data_sources = ("DHF1K", "Hollywood", "UCFSports", "SALICON") + + def __init__( + self, + num_epochs=16, + optim_algo="SGD", + momentum=0.9, + lr=0.04, + lr_scheduler="ExponentialLR", + lr_gamma=0.8, + weight_decay=1e-4, + cnn_weight_decay=1e-5, + grad_clip=2.0, + loss_metrics=("kld", "nss", "cc"), + loss_weights=(1, -0.1, -0.1), + data_sources=("DHF1K", "Hollywood", "UCFSports", "SALICON"), + batch_size=4, + salicon_batch_size=32, + hollywood_batch_size=4, + ucfsports_batch_size=4, + salicon_weight=0.5, + hollywood_weight=1.0, + ucfsports_weight=1.0, + data_cfg=None, + salicon_cfg=None, + hollywood_cfg=None, + ucfsports_cfg=None, + shuffle_datasets=True, + cnn_lr_factor=0.1, + train_cnn_after=2, + cnn_eval=True, + model_cfg=None, + prefix=None, + suffix="unisal", + num_workers=6, + chkpnt_warmup=3, + chkpnt_epochs=2, + tboard=True, + debug=False, + new_instance=True, + ): # Save training parameters self.num_epochs = num_epochs self.optim_algo = optim_algo @@ -158,8 +159,8 @@ def __init__(self, self.train_cnn_after = train_cnn_after self.cnn_eval = cnn_eval self.model_cfg = model_cfg or {} - if 'sources' not in self.model_cfg: - self.model_cfg['sources'] = data_sources + if "sources" not in self.model_cfg: + self.model_cfg["sources"] = data_sources # Create training directory. Uses env var TRAIN_DIR self.suffix = suffix @@ -171,7 +172,7 @@ def __init__(self, self.num_workers = num_workers self.chkpnt_warmup = chkpnt_warmup self.chkpnt_epochs = chkpnt_epochs - device = 'cuda:0' if torch.cuda.is_available() else 'cpu' + device = "cuda:0" if torch.cuda.is_available() else "cpu" self.device = torch.device(device) self.tboard = tboard self.debug = debug @@ -181,12 +182,12 @@ def __init__(self, self.num_epochs = 4 self.chkpnt_epochs = 2 self.chkpnt_warmup = 1 - self.suffix += '_debug' - self.data_cfg.update({'subset': 0.02}) - self.salicon_cfg.update({'subset': 0.02}) - self.hollywood_cfg.update({'subset': 0.04}) - self.ucfsports_cfg.update({'subset': 0.1}) - self.eval_subeset = 1. + self.suffix += "_debug" + self.data_cfg.update({"subset": 0.02}) + self.salicon_cfg.update({"subset": 0.02}) + self.hollywood_cfg.update({"subset": 0.04}) + self.ucfsports_cfg.update({"subset": 0.1}) + self.eval_subeset = 1.0 # Initialize properties etc. self.epoch = 0 @@ -218,7 +219,7 @@ def __init__(self, self.model.save_cfg(self.train_dir) self.copy_code() for src in self.data_sources: - self.get_dataset('train', source=src).save_cfg(self.train_dir) + self.get_dataset("train", source=src).save_cfg(self.train_dir) def fit(self): """ @@ -235,9 +236,10 @@ def fit(self): self.fit_epoch() # Save a checkpoint if applicable - if (self.epoch >= self.chkpnt_warmup - and (self.epoch + 1) % self.chkpnt_epochs == 0)\ - or self.epoch == self.num_epochs - 1: + if ( + self.epoch >= self.chkpnt_warmup + and (self.epoch + 1) % self.chkpnt_epochs == 0 + ) or self.epoch == self.num_epochs - 1: self.save_chkpnt() self.epoch += 1 @@ -254,9 +256,9 @@ def fit_epoch(self): # Perform LR decay self.scheduler.step(epoch=self.epoch) - lr = self.optimizer.param_groups[0]['lr'] + lr = self.optimizer.param_groups[0]["lr"] print(f"\nEpoch {self.epoch:3d}, lr {lr:.5f}") - self.add_scalar('conv/lr', lr, self.epoch) + self.add_scalar("conv/lr", lr, self.epoch) # Run the training and validation phase for self.phase in self.phases: @@ -269,29 +271,34 @@ def fit_phase(self): sources = self.data_sources # Prepare book keeping - running_losses = {src: 0. for src in sources} + running_losses = {src: 0.0 for src in sources} running_loss_summands = { - src: [0. for _ in self.loss_weights] for src in sources} + src: [0.0 for _ in self.loss_weights] for src in sources + } n_samples = {src: 0 for src in sources} # Shuffle the dataset batches - dataloaders = {src: self.get_dataloader(self.phase, src) - for src in sources} - all_batches = [src for src in chain.from_iterable(zip_longest( - *[[src for _ in range(len(dataloaders[src]))] for src in sources] - )) if src is not None] + dataloaders = {src: self.get_dataloader(self.phase, src) for src in sources} + all_batches = [ + src + for src in chain.from_iterable( + zip_longest( + *[[src for _ in range(len(dataloaders[src]))] for src in sources] + ) + ) + if src is not None + ] if self.shuffle_datasets: shuffle(all_batches) if self.epoch == 0: print(f"Number of batches: {len(all_batches)}") - print(", ".join(f"{src}: {len(dataloaders[src])}" - for src in sources)) + print(", ".join(f"{src}: {len(dataloaders[src])}" for src in sources)) # Set model train/eval mode - self.model.train(self.phase == 'train') + self.model.train(self.phase == "train") # Switch CNN gradients on/off and set CNN eval mode (for BN modules) - if self.phase == 'train': + if self.phase == "train": cnn_grad = self.epoch >= self.train_cnn_after for param in self._model.cnn.parameters(): param.requires_grad = cnn_grad @@ -304,20 +311,23 @@ def fit_phase(self): # Get the next batch sample = next(data_iters[src]) - target_size = (sample[-1][0][0].item(), - sample[-1][1][0].item()) + target_size = (sample[-1][0][0].item(), sample[-1][1][0].item()) sample = sample[:-1] # Fit/evaluate the batch loss, loss_summands, batch_size = self.fit_sample( - sample, grad_clip=self.grad_clip, target_size=target_size, - source='SALICON' if src == 'MIT1003' else src) + sample, + grad_clip=self.grad_clip, + target_size=target_size, + source="SALICON" if src == "MIT1003" else src, + ) # Book keeping running_losses[src] += loss * batch_size running_loss_summands[src] = [ r + l * batch_size - for r, l in zip(running_loss_summands[src], loss_summands)] + for r, l in zip(running_loss_summands[src], loss_summands) + ] n_samples[src] += batch_size # Book keeping and writing to TensorboardX @@ -325,30 +335,37 @@ def fit_phase(self): for src in sources_eval: phase_loss = running_losses[src] / n_samples[src] phase_loss_summands = [ - loss_ / n_samples[src] for loss_ in running_loss_summands[src]] - - print(f'{src:9s}: Phase: {self.phase}, loss: {phase_loss:.4f}, ' - + ", ".join(f"loss_{idx}: {loss_:.4f}" - for idx, loss_ in enumerate(phase_loss_summands))) + loss_ / n_samples[src] for loss_ in running_loss_summands[src] + ] + + print( + f"{src:9s}: Phase: {self.phase}, loss: {phase_loss:.4f}, " + + ", ".join( + f"loss_{idx}: {loss_:.4f}" + for idx, loss_ in enumerate(phase_loss_summands) + ) + ) - key = 'conv' if src == 'DHF1K' else src.lower() - self.add_scalar(f'{key}/loss/{self.phase}', phase_loss, self.epoch) + key = "conv" if src == "DHF1K" else src.lower() + self.add_scalar(f"{key}/loss/{self.phase}", phase_loss, self.epoch) for idx, loss_ in enumerate(phase_loss_summands): - self.add_scalar(f'{key}/loss_{idx}/{self.phase}', loss_, - self.epoch) - - if src == "DHF1K" and self.phase == 'valid' and\ - self.epoch >= self.chkpnt_warmup: - val_score = - phase_loss + self.add_scalar(f"{key}/loss_{idx}/{self.phase}", loss_, self.epoch) + + if ( + src == "DHF1K" + and self.phase == "valid" + and self.epoch >= self.chkpnt_warmup + ): + val_score = -phase_loss if self.best_val_score is None: self.best_val_score = val_score elif val_score > self.best_val_score: self.best_val_score = val_score self.is_best = True - self.model.save_weights(self.train_dir, 'best') - with open(self.train_dir / 'best_epoch.dat', 'w') as f: + self.model.save_weights(self.train_dir, "best") + with open(self.train_dir / "best_epoch.dat", "w") as f: f.write(str(self.epoch)) - with open(self.train_dir / 'best_val_loss.dat', 'w') as f: + with open(self.train_dir / "best_val_loss.dat", "w") as f: f.write(str(val_score)) else: self.is_best = False @@ -358,7 +375,7 @@ def fit_sample(self, sample, grad_clip=None, **kwargs): Take a sample containing a batch, and fit/evaluate the model """ - with torch.set_grad_enabled(self.phase == 'train'): + with torch.set_grad_enabled(self.phase == "train"): _, x, sal, fix = sample # Add temporal dimension to image data @@ -372,36 +389,38 @@ def fit_sample(self, sample, grad_clip=None, **kwargs): # Switch the gradients of unused modules off to prevent unnecessary # weight decay - if self.phase == 'train': + if self.phase == "train": # Switch the RNN gradients off if this is a image batch rnn_grad = x.shape[1] != 1 or not self.model.bypass_rnn - for param in chain(self._model.rnn.parameters(), - self._model.post_rnn.parameters()): + for param in chain( + self._model.rnn.parameters(), self._model.post_rnn.parameters() + ): param.requires_grad = rnn_grad # Switch the gradients of unused dataset-specific modules off for name, param in self.model.named_parameters(): for source in self.all_data_sources: if source.lower() in name.lower(): - param.requires_grad = source == kwargs['source'] + param.requires_grad = source == kwargs["source"] # Run forward pass pred_seq = self.model(x, **kwargs) # Compute the total loss loss_summands = self.loss_sequences( - pred_seq, sal, fix, metrics=self.loss_metrics) + pred_seq, sal, fix, metrics=self.loss_metrics + ) loss_summands = [l.mean(1).mean(0) for l in loss_summands] - loss = sum(weight * l for weight, l in - zip(self.loss_weights, loss_summands)) + loss = sum( + weight * l for weight, l in zip(self.loss_weights, loss_summands) + ) # Run backward pass and optimization step - if self.phase == 'train': + if self.phase == "train": self.optimizer.zero_grad() loss.backward() if grad_clip is not None: - torch.nn.utils.clip_grad_norm_( - self.model.parameters(), self.grad_clip) + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) self.optimizer.step() return loss.item(), [l.item() for l in loss_summands], x.shape[0] @@ -414,19 +433,31 @@ def loss_sequences(pred_seq, sal_seq, fix_seq, metrics): losses = [] for this_metric in metrics: - if this_metric == 'kld': + if this_metric == "kld": losses.append(utils.kld_loss(pred_seq, sal_seq)) - if this_metric == 'nss': + if this_metric == "nss": losses.append(utils.nss(pred_seq.exp(), fix_seq)) - if this_metric == 'cc': + if this_metric == "cc": losses.append(utils.corr_coeff(pred_seq.exp(), sal_seq)) return losses - def run_inference(self, source, vid_nr, dataset=None, phase=None, - smooth_method=None, metrics=None, save_predictions=False, - return_predictions=False, seq_len_factor=0.5, - random_seed=27, n_aucs_maps=10, auc_portion=1.0, - model_domain=None, folder_suffix=None): + def run_inference( + self, + source, + vid_nr, + dataset=None, + phase=None, + smooth_method=None, + metrics=None, + save_predictions=False, + return_predictions=False, + seq_len_factor=0.5, + random_seed=27, + n_aucs_maps=10, + auc_portion=1.0, + model_domain=None, + folder_suffix=None, + ): if dataset is None: assert phase, "Must provide either dataset or phase" @@ -439,25 +470,26 @@ def run_inference(self, source, vid_nr, dataset=None, phase=None, target_size = dataset.target_size_dict[vid_nr] # Set the keyword arguments for the forward pass - model_kwargs = { - 'source': model_domain or source, - 'target_size': target_size} + model_kwargs = {"source": model_domain or source, "target_size": target_size} # Make sure that the model was trained on the selected domain - if model_kwargs['source'] not in self.model.sources: - print(f"\nWarning! Evaluation bn source {model_kwargs['source']} " - f"doesn't exist in model.\n Using {self.model.sources[0]}.") - model_kwargs['source'] = self.model.sources[0] + if model_kwargs["source"] not in self.model.sources: + print( + f"\nWarning! Evaluation bn source {model_kwargs['source']} " + f"doesn't exist in model.\n Using {self.model.sources[0]}." + ) + model_kwargs["source"] = self.model.sources[0] # Select static or dynamic forward pass for Bypass-RNN model_kwargs.update( - {'static': model_kwargs['source'] in ('SALICON', 'MIT300', 'MIT1003')}) + {"static": model_kwargs["source"] in ("SALICON", "MIT300", "MIT1003")} + ) # Set additional parameters - static_data = source in ('SALICON', 'MIT300', 'MIT1003') + static_data = source in ("SALICON", "MIT300", "MIT1003") if static_data: smooth_method = None - auc_portion = 1. + auc_portion = 1.0 n_images = 1 frame_modulo = 1 else: @@ -470,7 +502,7 @@ def run_inference(self, source, vid_nr, dataset=None, phase=None, torch.cuda.empty_cache() # Prepare the prediction and target tensors - results_size = (1, n_images, 1, *model_kwargs['target_size']) + results_size = (1, n_images, 1, *model_kwargs["target_size"]) pred_seq = torch.full(results_size, 0, dtype=torch.float) if metrics is not None: sal_seq = torch.full(results_size, 0, dtype=torch.float) @@ -508,7 +540,8 @@ def run_inference(self, source, vid_nr, dataset=None, phase=None, if metrics is not None: raise ValueError( "Labels needed for evaluation metrics but not provided" - "by dataset.") + "by dataset." + ) frame_nrs, frame_seq = sample this_sal_seq, this_fix_seq = None, None if frame_seq.dim() == 3: @@ -528,13 +561,12 @@ def run_inference(self, source, vid_nr, dataset=None, phase=None, # Forward pass this_pred_seq, h0 = self.model( - this_frame_seq, h0=h0, return_hidden=True, - **model_kwargs) + this_frame_seq, h0=h0, return_hidden=True, **model_kwargs + ) # Insert the predictions into the prediction array this_pred_seq = this_pred_seq.cpu() - pred_seq[:, this_frame_idx_array, :, :, :] =\ - this_pred_seq + pred_seq[:, this_frame_idx_array, :, :, :] = this_pred_seq # Keep the training targets if scores are to be computed if metrics is not None: @@ -542,7 +574,7 @@ def run_inference(self, source, vid_nr, dataset=None, phase=None, fix_seq[:, frame_idx_array, :, :, :] = this_fix_seq # Assert non-empty predictions - assert(torch.min(pred_seq.exp().sum(-1).sum(-1)) > 0) + assert torch.min(pred_seq.exp().sum(-1).sum(-1)) > 0 # Optionally smooth the interleaved sequences if smooth_method is not None: @@ -561,13 +593,13 @@ def run_inference(self, source, vid_nr, dataset=None, phase=None, this_pred_dir.mkdir(exist_ok=True) # Construct a subfolder if applicable - if source == 'DHF1K': + if source == "DHF1K": this_pred_dir /= f"{vid_nr:04d}" - elif source == 'MIT1003': + elif source == "MIT1003": this_pred_dir /= self.mit1003_dir.stem - elif source == 'MIT300' and 'x_val_step' in self.salicon_cfg: + elif source == "MIT300" and "x_val_step" in self.salicon_cfg: this_pred_dir /= f"MIT300_xVal{self.salicon_cfg['x_val_step']}" - elif source in ('Hollywood', 'UCFSports'): + elif source in ("Hollywood", "UCFSports"): this_pred_dir /= dataset.get_annotation_dir(vid_nr).stem this_pred_dir.mkdir(exist_ok=True) @@ -575,15 +607,16 @@ def run_inference(self, source, vid_nr, dataset=None, phase=None, for frame_idx, smap in enumerate(torch.unbind(pred_seq, dim=1)): # Define the filename - if source == 'SALICON': + if source == "SALICON": filename = f"COCO_test2014_{vid_nr:012d}.png" - elif source == 'MIT300': + elif source == "MIT300": filename = dataset.samples[vid_nr][0] - elif source == 'MIT1003': - filename = dataset.all_image_files[vid_nr]['img'] - elif source in ('Hollywood', 'UCFSports'): + elif source == "MIT1003": + filename = dataset.all_image_files[vid_nr]["img"] + elif source in ("Hollywood", "UCFSports"): filename = dataset.get_data_file( - vid_nr, frame_idx + 1, 'frame').name + vid_nr, frame_idx + 1, "frame" + ).name else: filename = f"{frame_idx + 1:04d}.png" @@ -593,16 +626,13 @@ def run_inference(self, source, vid_nr, dataset=None, phase=None, smap = utils.to_numpy(smap) # Optinally save numpy file in addition to the image - if source == 'MIT300': - np.save( - this_pred_dir / filename.replace(".jpg", ""), - smap.copy()) + if source == "MIT300": + np.save(this_pred_dir / filename.replace(".jpg", ""), smap.copy()) # Save prediction as image smap = (smap / np.amax(smap) * 255).astype(np.uint8) pred_file = this_pred_dir / filename - cv2.imwrite(str(pred_file), smap, - [cv2.IMWRITE_JPEG_QUALITY, 100]) + cv2.imwrite(str(pred_file), smap, [cv2.IMWRITE_JPEG_QUALITY, 100]) # Optionally compute the scores if metrics is not None: @@ -610,7 +640,8 @@ def run_inference(self, source, vid_nr, dataset=None, phase=None, # Compute the KLD, NSS and CC metrics vid_scores = [] loss_sequences = self.loss_sequences( - pred_seq, sal_seq, fix_seq, metrics=metrics) + pred_seq, sal_seq, fix_seq, metrics=metrics + ) loss_sequences = [loss.numpy() for loss in loss_sequences] vid_scores += loss_sequences @@ -618,19 +649,17 @@ def other_maps(): """Sample reference maps for s-AUC""" while True: this_map = np.zeros(results_size[-2:]) - video_nrs = random.sample( - dataset.n_images_dict.keys(), n_aucs_maps) + video_nrs = random.sample(dataset.n_images_dict.keys(), n_aucs_maps) for map_idx, vid_nr in enumerate(video_nrs): - frame_nr = random.randint( - 1, dataset.n_images_dict[vid_nr]) + frame_nr = random.randint(1, dataset.n_images_dict[vid_nr]) if static_data: this_this_map = dataset.get_fixation_map(vid_nr) else: this_this_map = dataset.get_seq( - vid_nr, [frame_nr], 'fix').numpy()[0, 0, ...] + vid_nr, [frame_nr], "fix" + ).numpy()[0, 0, ...] this_this_map = cv2.resize( - this_this_map, tuple(target_size[::-1]), - cv2.INTER_NEAREST + this_this_map, tuple(target_size[::-1]), cv2.INTER_NEAREST ) this_map += this_this_map @@ -639,9 +668,13 @@ def other_maps(): # Compute the SIM, AUC-J, s-AUC metrics vid_scores += self.eval_sequences( - pred_seq, sal_seq, fix_seq, metrics, - other_maps=other_maps() if 'aucs' in metrics else None, - auc_portion=auc_portion) + pred_seq, + sal_seq, + fix_seq, + metrics, + other_maps=other_maps() if "aucs" in metrics else None, + auc_portion=auc_portion, + ) # Average the scores over the frames mean_scores = np.array([np.mean(scores) for scores in vid_scores]) @@ -654,17 +687,17 @@ def other_maps(): return pred_seq @staticmethod - def eval_sequences(pred_seq, sal_seq, fix_seq, metrics, - other_maps=None, auc_portion=1.): + def eval_sequences( + pred_seq, sal_seq, fix_seq, metrics, other_maps=None, auc_portion=1.0 + ): """ Compute SIM, AUC-J and s-AUC scores """ # process inputs - metrics = [metric for metric in metrics - if metric in ('sim', 'aucj', 'aucs')] - if 'aucs' in metrics: - assert(other_maps is not None) + metrics = [metric for metric in metrics if metric in ("sim", "aucj", "aucs")] + if "aucs" in metrics: + assert other_maps is not None # Preprocess sequences shape = pred_seq.shape @@ -677,8 +710,8 @@ def eval_sequences(pred_seq, sal_seq, fix_seq, metrics, # Optionally compute AUC-s for a subset of frames to reduce runtime if auc_portion < 1: auc_indices = set( - random.sample( - range(shape[1]), max(1, int(auc_portion * shape[1])))) + random.sample(range(shape[1]), max(1, int(auc_portion * shape[1]))) + ) else: auc_indices = set(list(range(shape[1]))) @@ -686,18 +719,17 @@ def eval_sequences(pred_seq, sal_seq, fix_seq, metrics, results = {metric: [] for metric in metrics} for idx, (pred, sal, fix) in enumerate(zip(pred_seq, sal_seq, fix_seq)): for this_metric in metrics: - if this_metric == 'sim': - results['sim'].append( - salience_metrics.similarity(pred, sal)) - if this_metric == 'aucj': + if this_metric == "sim": + results["sim"].append(salience_metrics.similarity(pred, sal)) + if this_metric == "aucj": if idx in auc_indices: - results['aucj'].append( - salience_metrics.auc_judd(pred, fix)) - if this_metric == 'aucs': + results["aucj"].append(salience_metrics.auc_judd(pred, fix)) + if this_metric == "aucs": if idx in auc_indices: other_map = next(other_maps) - results['aucs'].append(salience_metrics.auc_shuff_acl( - pred, fix, other_map)) + results["aucs"].append( + salience_metrics.auc_shuff_acl(pred, fix, other_map) + ) return [np.array(results[metric]) for metric in metrics] @property @@ -707,12 +739,21 @@ def pred_dir(self): pred_dir.mkdir(exist_ok=True, parents=True) return pred_dir - def score_model(self, subset=1, source='DHF1K', - metrics=('kld', 'nss', 'cc', 'sim', 'aucj', 'aucs'), - smooth_method=None, seq_len_factor=2, - random_seed=27, n_aucs_maps=10, auc_portion=0.5, - model_domain=None, phase=None, load_weights=True, - vid_nr_array=None): + def score_model( + self, + subset=1, + source="DHF1K", + metrics=("kld", "nss", "cc", "sim", "aucj", "aucs"), + smooth_method=None, + seq_len_factor=2, + random_seed=27, + n_aucs_maps=10, + auc_portion=0.5, + model_domain=None, + phase=None, + load_weights=True, + vid_nr_array=None, + ): """ Compute the evaluation scores of the model. @@ -727,16 +768,15 @@ def score_model(self, subset=1, source='DHF1K', # the last epoch try: self.model.load_best_weights(self.train_dir) - print('Best weights loaded') + print("Best weights loaded") except FileNotFoundError: - print('No best weights found') + print("No best weights found") self.model.load_last_chkpnt(self.train_dir) - print('Last checkpoint loaded') + print("Last checkpoint loaded") # Select the appropriate phase (see docstring) and get the dataset if phase is None: - phase = 'eval' if source in ('DHF1K', 'SALICON', 'MIT1003')\ - else 'test' + phase = "eval" if source in ("DHF1K", "SALICON", "MIT1003") else "test" dataset = self.get_dataset(phase, source) if vid_nr_array is None: @@ -745,22 +785,32 @@ def score_model(self, subset=1, source='DHF1K', # Iterate over the videos/images and compute the scores scores = [] - tmr = utils.Timer(f'Evaluating {len(vid_nr_array)} {source} videos') + tmr = utils.Timer(f"Evaluating {len(vid_nr_array)} {source} videos") random.seed(random_seed) with torch.no_grad(): for vid_idx, vid_nr in enumerate(vid_nr_array): this_scores = self.run_inference( - source, vid_nr, dataset=dataset, - smooth_method=smooth_method, metrics=metrics, - seq_len_factor=seq_len_factor, n_aucs_maps=n_aucs_maps, - auc_portion=auc_portion, model_domain=model_domain) + source, + vid_nr, + dataset=dataset, + smooth_method=smooth_method, + metrics=metrics, + seq_len_factor=seq_len_factor, + n_aucs_maps=n_aucs_maps, + auc_portion=auc_portion, + model_domain=model_domain, + ) scores.append(this_scores) if vid_idx == 0: - print(f' Nr. ( .../{len(vid_nr_array):4d}), ' + - ', '.join(f'{metric:5s}' for metric in metrics)) - print(f'{vid_nr:6d} ' + - f'({vid_idx + 1:4d}/{len(vid_nr_array):4d}), ' + - ', '.join(f'{score:.3f}' for score in this_scores)) + print( + f" Nr. ( .../{len(vid_nr_array):4d}), " + + ", ".join(f"{metric:5s}" for metric in metrics) + ) + print( + f"{vid_nr:6d} " + + f"({vid_idx + 1:4d}/{len(vid_nr_array):4d}), " + + ", ".join(f"{score:.3f}" for score in this_scores) + ) # Compute the average video scores tmr.finish() @@ -771,45 +821,46 @@ def score_model(self, subset=1, source='DHF1K', # which means that each videos contribution to the overall score is # weighted by its number of frames. The equivalent scores are denoted # below as weighted mean - num_frames_array = [ - dataset.n_images_dict[vid_nr] for vid_nr in vid_nr_array] + num_frames_array = [dataset.n_images_dict[vid_nr] for vid_nr in vid_nr_array] weighted_mean_scores = np.average(scores, 0, num_frames_array) # Print and save the scores print() print("Macro average (average of video averages) scores:") - print(', '.join(f'{metric:5s}' for metric in metrics)) - print(', '.join(f'{score:.3f}' for score in mean_scores)) + print(", ".join(f"{metric:5s}" for metric in metrics)) + print(", ".join(f"{score:.3f}" for score in mean_scores)) print() print("Weighted average (per-frame average) scores:") - print(', '.join(f'{metric:5s}' for metric in metrics)) - print(', '.join(f'{score:.3f}' for score in weighted_mean_scores)) + print(", ".join(f"{metric:5s}" for metric in metrics)) + print(", ".join(f"{score:.3f}" for score in weighted_mean_scores)) if subset == 1: - dest_dir = self.mit1003_dir if source == 'MIT1003' \ - else self.train_dir - if source in ('Hollywood', 'UCFSports'): + dest_dir = self.mit1003_dir if source == "MIT1003" else self.train_dir + if source in ("Hollywood", "UCFSports"): source += "_resized" - with open(dest_dir / f'{source}_eval_scores.json', 'w') as f: + with open(dest_dir / f"{source}_eval_scores.json", "w") as f: json.dump(scores, f, cls=utils.NumpyEncoder, indent=2) - with open(dest_dir / f'{source}_eval_mean_scores.json', 'w')\ - as f: + with open(dest_dir / f"{source}_eval_mean_scores.json", "w") as f: json.dump(mean_scores, f, cls=utils.NumpyEncoder, indent=2) - with open(dest_dir / - f'{source}_eval_weighted_mean_scores.json', 'w') as f: - json.dump(weighted_mean_scores, f, cls=utils.NumpyEncoder, - indent=2) + with open(dest_dir / f"{source}_eval_weighted_mean_scores.json", "w") as f: + json.dump(weighted_mean_scores, f, cls=utils.NumpyEncoder, indent=2) try: for score, metric in zip(mean_scores, metrics): - self.add_scalar(f'{source}_eval/{metric}', score, self.epoch) + self.add_scalar(f"{source}_eval/{metric}", score, self.epoch) self.export_scalars(suffix=f"_eval-{source}") except AttributeError: - print('add_scalar failed because tensorboard is closed.') + print("add_scalar failed because tensorboard is closed.") return metrics, mean_scores, scores - def generate_predictions(self, smooth_method=None, source='DHF1K', - phase='eval', load_weights=True, vid_nr_array=None, - **kwargs): + def generate_predictions( + self, + smooth_method=None, + source="DHF1K", + phase="eval", + load_weights=True, + vid_nr_array=None, + **kwargs, + ): """Generate predictions for submission and visualization""" if load_weights: @@ -817,49 +868,53 @@ def generate_predictions(self, smooth_method=None, source='DHF1K', # the last epoch try: self.model.load_best_weights(self.train_dir) - print('Best weights loaded') + print("Best weights loaded") except FileNotFoundError: - print('No best weights found') + print("No best weights found") self.model.load_last_chkpnt(self.train_dir) - print('Last checkpoint loaded') + print("Last checkpoint loaded") # Get the dataset dataset = self.get_dataset(phase, source) if vid_nr_array is None: # Get list of sample numbers - if source == 'MIT300': + if source == "MIT300": vid_nr_array = list(range(300)) else: vid_nr_array = list(dataset.n_images_dict.keys()) - tmr = utils.Timer(f'Predicting {len(vid_nr_array)} {source} videos') + tmr = utils.Timer(f"Predicting {len(vid_nr_array)} {source} videos") with torch.no_grad(): for vid_nr in vid_nr_array: self.run_inference( - source, vid_nr, dataset=dataset, - smooth_method=smooth_method, save_predictions=True, - **kwargs) + source, + vid_nr, + dataset=dataset, + smooth_method=smooth_method, + save_predictions=True, + **kwargs, + ) tmr.finish() def generate_predictions_from_path( - self, folder_path, is_video, source=None, load_weights=True, - **kwargs): + self, folder_path, is_video, source=None, load_weights=True, **kwargs + ): # Process inputs if source is None: - source = 'DHF1K' if is_video else 'SALICON' + source = "DHF1K" if is_video else "SALICON" - if source in ('MIT1003', 'MIT300'): + if source in ("MIT1003", "MIT300"): try: self.model.load_weights(self.train_dir, "ft_mit1003") print("Fine-tuned MIT1003 weights loaded") load_weights = False except: print("No MIT1003 fine-tuned weights found.") - source = 'SALICON' + source = "SALICON" - images_path = folder_path / 'images' + images_path = folder_path / "images" torch.cuda.empty_cache() if load_weights: @@ -867,23 +922,30 @@ def generate_predictions_from_path( # the last epoch try: self.model.load_best_weights(self.train_dir) - print('Best weights loaded') + print("Best weights loaded") except FileNotFoundError: - print('No best weights found') + print("No best weights found") self.model.load_last_chkpnt(self.train_dir) - print('Last checkpoint loaded') + print("Last checkpoint loaded") with torch.no_grad(): if is_video: - frame_modulo = 5 if source == 'DHF1K' else 4 + frame_modulo = 5 if source == "DHF1K" else 4 dataset = data.FolderVideoDataset( - images_path, source=source, frame_modulo=frame_modulo) - pred_dir = folder_path / 'saliency' + images_path, source=source, frame_modulo=frame_modulo + ) + pred_dir = folder_path / "saliency" pred_dir.mkdir(exist_ok=True) pred_seq = self.run_inference( - source, 0, dataset=dataset, phase=None, - return_predictions=True, folder_suffix=None, **kwargs) + source, + 0, + dataset=dataset, + phase=None, + return_predictions=True, + folder_suffix=None, + **kwargs, + ) # Iterate over the prediction frames for frame_idx, smap in enumerate(torch.unbind(pred_seq, dim=1)): @@ -897,18 +959,22 @@ def generate_predictions_from_path( filename = dataset.frame_files[frame_idx].name smap = (smap / np.amax(smap) * 255).astype(np.uint8) pred_file = pred_dir / filename - cv2.imwrite( - str(pred_file), smap, [cv2.IMWRITE_JPEG_QUALITY, 100]) + cv2.imwrite(str(pred_file), smap, [cv2.IMWRITE_JPEG_QUALITY, 100]) else: dataset = data.FolderImageDataset(images_path) - pred_dir = folder_path / 'saliency' + pred_dir = folder_path / "saliency" pred_dir.mkdir(exist_ok=True) for img_idx in range(len(dataset)): pred_seq = self.run_inference( - source, img_idx, dataset=dataset, phase=None, - return_predictions=True, **kwargs) + source, + img_idx, + dataset=dataset, + phase=None, + return_predictions=True, + **kwargs, + ) smap = pred_seq[:, 0, ...] @@ -921,33 +987,32 @@ def generate_predictions_from_path( filename = dataset.image_files[img_idx].name smap = (smap / np.amax(smap) * 255).astype(np.uint8) pred_file = pred_dir / filename - cv2.imwrite( - str(pred_file), smap, [cv2.IMWRITE_JPEG_QUALITY, 100]) + cv2.imwrite(str(pred_file), smap, [cv2.IMWRITE_JPEG_QUALITY, 100]) def fine_tune_mit( - self, lr=0.01, num_epochs=8, lr_gamma=0.8, x_val_step=0, - train_cnn_after=0): + self, lr=0.01, num_epochs=8, lr_gamma=0.8, x_val_step=0, train_cnn_after=0 + ): """Fine tune the model with the MIT1003 dataset for MIT300 submission""" # Set the fine tuning parameters self.num_epochs = num_epochs self.lr = lr self.lr_gamma = lr_gamma - self.lr_scheduler = 'ExponentialLR' - self.optim_algo = 'SGD' + self.lr_scheduler = "ExponentialLR" + self.optim_algo = "SGD" self.momentum = 0.9 self.weight_decay = 1e-4 self.grad_clip = None - self.loss_weights = (1.,) - self.loss_metrics = ('kld',) - self.salicon_weight = 1. + self.loss_weights = (1.0,) + self.loss_metrics = ("kld",) + self.salicon_weight = 1.0 self.salicon_batch_size = 32 - self.data_sources = ('MIT1003',) + self.data_sources = ("MIT1003",) self.shuffle_datasets = True self.cnn_lr_factor = 0.1 self.train_cnn_after = train_cnn_after self.cnn_eval = True - self.salicon_cfg.update({'x_val_step': x_val_step}) + self.salicon_cfg.update({"x_val_step": x_val_step}) self.mit1003_finetuned = True self.num_workers = 4 @@ -956,9 +1021,9 @@ def fine_tune_mit( # the last epoch try: self.model.load_best_weights(self.train_dir) - print('Best weights loaded') + print("Best weights loaded") except FileNotFoundError: - print('No best weights found') + print("No best weights found") self.model.load_last_chkpnt(self.train_dir) # Run the fine tuning @@ -968,18 +1033,18 @@ def fine_tune_mit( self._model.to(self.device) while self.epoch < self.num_epochs: self.scheduler.step(epoch=self.epoch) - lr = self.optimizer.param_groups[0]['lr'] + lr = self.optimizer.param_groups[0]["lr"] print(f"\nEpoch {self.epoch:3d}, lr {lr:.5f}") for self.phase in self.phases: self.fit_phase() - val_loss = self.all_scalars['mit1003']['loss']['valid'][self.epoch] + val_loss = self.all_scalars["mit1003"]["loss"]["valid"][self.epoch] if math.isnan(val_loss): best_epoch = 0 best_val = 1000 break - val_score = - val_loss + val_score = -val_loss if self.best_val_score is None: self.best_val_score = val_score elif val_score > self.best_val_score: @@ -1004,16 +1069,17 @@ def get_dataset(self, phase, source="DHF1K"): else: dataset_cls_name = f"{source}Dataset" dataset_cls = getattr(data, dataset_cls_name) - if source in ('MIT300',): + if source in ("MIT300",): config = {} - elif source in ('MIT1003',): + elif source in ("MIT1003",): config = getattr(self, f"salicon_cfg") else: config = getattr(self, f"{source.lower()}_cfg") - self._datasets[source][phase] = dataset_cls( - phase=phase, **config) - print(f'{source:10s} {phase} dataset loaded with' - f' {len(self._datasets[source][phase])} samples') + self._datasets[source][phase] = dataset_cls(phase=phase, **config) + print( + f"{source:10s} {phase} dataset loaded with" + f" {len(self._datasets[source][phase])} samples" + ) return self._datasets[source][phase] def get_dataloader(self, phase, source="DHF1K"): @@ -1026,22 +1092,24 @@ def get_dataloader(self, phase, source="DHF1K"): if source == "DHF1K": batch_size = self.batch_size elif source in ("Hollywood", "UCFSports"): - batch_size = self.__getattribute__( - f"{source.lower()}_batch_size") - elif phase == 'valid' and source == 'MIT1003': + batch_size = self.__getattribute__(f"{source.lower()}_batch_size") + elif phase == "valid" and source == "MIT1003": batch_size = 8 - elif source in ('SALICON', 'MIT1003'): - batch_size = self.salicon_batch_size or len(dataset) //\ - len(self.get_dataloader(phase)) + elif source in ("SALICON", "MIT1003"): + batch_size = self.salicon_batch_size or len(dataset) // len( + self.get_dataloader(phase) + ) if batch_size > 8: batch_size -= batch_size % 2 batch_size = min(32, batch_size) else: - raise ValueError(f'Unknown dataset source {source}') + raise ValueError(f"Unknown dataset source {source}") print(f"{source}, phase {phase} batch size: {batch_size}") self._dataloaders[source][phase] = data.get_dataloader(source)( - dataset, batch_size=batch_size, - shuffle=phase == 'train', num_workers=self.num_workers, + dataset, + batch_size=batch_size, + shuffle=phase == "train", + num_workers=self.num_workers, drop_last=True, ) return self._dataloaders[source][phase] @@ -1065,7 +1133,7 @@ def measure_runtime(self): # Prepare the experiment self.model.eval() self.num_workers = 0 - dl = self.get_dataloader('test', 'DHF1K') + dl = self.get_dataloader("test", "DHF1K") sample = next(iter(dl)) _, x, _ = sample x = x.float().cuda() @@ -1076,10 +1144,9 @@ def measure_runtime(self): # Measure the average time to process single frames on the GPU times = [] for t_idx in range(1, x.shape[1]): - x_t = x[:1, t_idx:t_idx+1, ...].clone().contiguous() + x_t = x[:1, t_idx : t_idx + 1, ...].clone().contiguous() t0 = time.time() - output, h0 = self.model( - x_t, h0=h0, return_hidden=True, static=False) + output, h0 = self.model(x_t, h0=h0, return_hidden=True, static=False) dt = time.time() - t0 times.append(dt) print() @@ -1097,16 +1164,14 @@ def measure_runtime(self): times = [] torch.cuda.empty_cache() for t_idx in range(1, min(x.shape[1], 16)): - x_t = x[:1, t_idx:t_idx+1, ...].clone().contiguous() + x_t = x[:1, t_idx : t_idx + 1, ...].clone().contiguous() t0 = time.time() - output, h0 = self.model( - x_t, h0=h0, return_hidden=True, static=False) + output, h0 = self.model(x_t, h0=h0, return_hidden=True, static=False) dt = time.time() - t0 times.append(dt) print() dt = sum(times) / len(times) - print("Avg single-frame CPU time: " - f"{dt:.4f} s ({1 / dt:.1f} fps)") + print("Avg single-frame CPU time: " f"{dt:.4f} s ({1 / dt:.1f} fps)") dt = min(times) print(f"Min single-frame CPU time: {dt:.4f} s ({1 / dt:.1f} fps)") dt = max(times) @@ -1120,12 +1185,11 @@ def measure_model_size(self): net = model_cls(verbose=0, **this_model_cfg) dest = self.train_dir - torch.save(net, dest / 'net_full.pth') - file_size = (dest / 'net_full.pth').stat().st_size / 1e6 - print("All data sources net size: " - f"{file_size:.2f} MB") - torch.save(net.state_dict(), dest / 'net_full_state_dict.pth') - file_size = (dest / 'net_full_state_dict.pth').stat().st_size / 1e6 + torch.save(net, dest / "net_full.pth") + file_size = (dest / "net_full.pth").stat().st_size / 1e6 + print("All data sources net size: " f"{file_size:.2f} MB") + torch.save(net.state_dict(), dest / "net_full_state_dict.pth") + file_size = (dest / "net_full_state_dict.pth").stat().st_size / 1e6 print(f"All data sources net state dict size: {file_size:.2f} MB") def get_model_parameter_groups(self): @@ -1133,13 +1197,14 @@ def get_model_parameter_groups(self): Get parameter groups. Output CNN parameters separately with reduced LR and weight decay. """ + def parameters_except_cnn(): parameters = [] adaptation = [] for name, module in self.model.named_children(): - if name == 'cnn': + if name == "cnn": continue - elif 'adaptation' in name: + elif "adaptation" in name: adaptation += list(module.parameters()) else: parameters += list(module.parameters()) @@ -1148,33 +1213,42 @@ def parameters_except_cnn(): parameters, adaptation = parameters_except_cnn() for name, this_parameter in self.model.named_parameters(): - if 'gaussian' in name: + if "gaussian" in name: parameters.append(this_parameter) return [ - {'params': parameters + adaptation}, - {'params': self.model.cnn.parameters(), - 'lr': self.lr * self.cnn_lr_factor, - 'weight_decay': self.cnn_weight_decay, - }, + {"params": parameters + adaptation}, + { + "params": self.model.cnn.parameters(), + "lr": self.lr * self.cnn_lr_factor, + "weight_decay": self.cnn_weight_decay, + }, ] @property def optimizer(self): """Return the optimizer""" if self._optimizer is None: - if self.optim_algo == 'SGD': + if self.optim_algo == "SGD": self._optimizer = torch.optim.SGD( - self.get_model_parameter_groups(), lr=self.lr, - momentum=self.momentum, weight_decay=self.weight_decay) - elif self.optim_algo == 'Adam': + self.get_model_parameter_groups(), + lr=self.lr, + momentum=self.momentum, + weight_decay=self.weight_decay, + ) + elif self.optim_algo == "Adam": self._optimizer = torch.optim.Adam( - self.get_model_parameter_groups(), lr=self.lr, - weight_decay=self.weight_decay) - elif self.optim_algo == 'RMSprop': + self.get_model_parameter_groups(), + lr=self.lr, + weight_decay=self.weight_decay, + ) + elif self.optim_algo == "RMSprop": self._optimizer = torch.optim.RMSprop( - self.get_model_parameter_groups(), lr=self.lr, - weight_decay=self.weight_decay, momentum=self.momentum) + self.get_model_parameter_groups(), + lr=self.lr, + weight_decay=self.weight_decay, + momentum=self.momentum, + ) return self._optimizer @@ -1182,10 +1256,10 @@ def optimizer(self): def scheduler(self): """Return the learning rate scheduler""" if self._scheduler is None: - if self.lr_scheduler == 'ExponentialLR': + if self.lr_scheduler == "ExponentialLR": self._scheduler = torch.optim.lr_scheduler.ExponentialLR( - self.optimizer, gamma=self.lr_gamma, - last_epoch=self.epoch - 1) + self.optimizer, gamma=self.lr_gamma, last_epoch=self.epoch - 1 + ) else: raise ValueError(f"Unknown scheduler {self.lr_scheduler}") return self._scheduler @@ -1199,26 +1273,26 @@ def copy_code(self): """Make a copy of the code to facilitate leading of older models""" source = Path(inspect.getfile(Trainer)).parent.parent - destination = self.train_dir / 'code_copy' + destination = self.train_dir / "code_copy" tracked_files = [ - '.gitignore', - 'unisal/__init__.py', - 'unisal/data.py', - 'unisal/model.py', - 'unisal/models/MobileNetV2.py', - 'unisal/models/__init__.py', - 'unisal/models/cgru.py', - 'unisal/models/weights/mobilenet_v2.pth.tar', - 'unisal/salience_metrics.py', - 'unisal/train.py', - 'unisal/utils.py', - 'run.py', - 'unisal/dhf1k_n_images.dat', - 'unisal/cache/img_size_dict.json', - 'unisal/cache/train_hollywood_register.json', - 'unisal/cache/train_ucfsports_register.json', - 'unisal/cache/test_hollywood_register.json', - 'unisal/cache/test_ucfsports_register.json', + ".gitignore", + "unisal/__init__.py", + "unisal/data.py", + "unisal/model.py", + "unisal/models/MobileNetV2.py", + "unisal/models/__init__.py", + "unisal/models/cgru.py", + "unisal/models/weights/mobilenet_v2.pth.tar", + "unisal/salience_metrics.py", + "unisal/train.py", + "unisal/utils.py", + "run.py", + "unisal/dhf1k_n_images.dat", + "unisal/cache/img_size_dict.json", + "unisal/cache/train_hollywood_register.json", + "unisal/cache/train_ucfsports_register.json", + "unisal/cache/test_hollywood_register.json", + "unisal/cache/test_ucfsports_register.json", ] for file in tracked_files: subdir = Path(file).parent @@ -1229,30 +1303,43 @@ def save_chkpnt(self): """Save model and trainer checkpoint""" print(f"Saving checkpoint at epoch {self.epoch}") chkpnt = { - 'model_state_dict': self.model.state_dict(), - 'optimizer_state_dict': self.optimizer.state_dict(), + "model_state_dict": self.model.state_dict(), + "optimizer_state_dict": self.optimizer.state_dict(), } - chkpnt.update({key: self.__dict__[key] for key in ( - 'epoch', 'best_epoch', 'best_val_score', 'all_scalars',)}) - torch.save(chkpnt, self.train_dir / f'chkpnt_epoch{self.epoch:04d}.pth') + chkpnt.update( + { + key: self.__dict__[key] + for key in ( + "epoch", + "best_epoch", + "best_val_score", + "all_scalars", + ) + } + ) + torch.save(chkpnt, self.train_dir / f"chkpnt_epoch{self.epoch:04d}.pth") def load_checkpoint(self, file): """Load model and trainer checkpoint""" chkpnt = torch.load(file) - self.model.load_state_dict(chkpnt['model_state_dict']) - self.optimizer.load_state_dict(chkpnt['optimizer_state_dict']) - self.__dict__.update({key: chkpnt[key] for key in ( - 'epoch', 'best_epoch', 'best_val_score', 'all_scalars')}) + self.model.load_state_dict(chkpnt["model_state_dict"]) + self.optimizer.load_state_dict(chkpnt["optimizer_state_dict"]) + self.__dict__.update( + { + key: chkpnt[key] + for key in ("epoch", "best_epoch", "best_val_score", "all_scalars") + } + ) self.epoch += 1 def load_last_chkpnt(self): """Load latest model checkpoint""" - last_chkpnt = sorted(list(self.train_dir.glob('chkpnt_epoch*.pth')))[-1] + last_chkpnt = sorted(list(self.train_dir.glob("chkpnt_epoch*.pth")))[-1] self.load_checkpoint(last_chkpnt) def add_scalar(self, key, value, epoch, this_tboard=True): """Add a scalar to self.all_scalars and TensorboardX""" - keys = key.split('/') + keys = key.split("/") this_dict = self.all_scalars for key_ in keys: if key_ not in this_dict: @@ -1267,7 +1354,7 @@ def add_scalar(self, key, value, epoch, this_tboard=True): def writer(self): """Return TensorboardX writer""" if self.tboard and self._writer is None: - if self.data_sources == ('MIT1003',): + if self.data_sources == ("MIT1003",): log_dir = self.mit1003_dir log_dir.mkdir(exist_ok=True) else: @@ -1279,34 +1366,36 @@ def writer(self): def mit1003_dir(self): """Return directory to fine tune on MIT1003""" if self.mit1003_finetuned: - mit1003_dir = self.train_dir / f"MIT1003_lr{self.lr:.4f}_" \ - f"lrGamma{self.lr_gamma:.2f}_nEpochs{self.num_epochs}_" \ - f"TrainCNNAfter{self.train_cnn_after}_" \ + mit1003_dir = ( + self.train_dir / f"MIT1003_lr{self.lr:.4f}_" + f"lrGamma{self.lr_gamma:.2f}_nEpochs{self.num_epochs}_" + f"TrainCNNAfter{self.train_cnn_after}_" f"xVal{self.salicon_cfg['x_val_step']}" + ) else: - mit1003_dir = self.train_dir / f"MIT1003_" \ - f"xVal{self.salicon_cfg['x_val_step']}" + mit1003_dir = ( + self.train_dir / f"MIT1003_" f"xVal{self.salicon_cfg['x_val_step']}" + ) mit1003_dir.mkdir(exist_ok=True) return mit1003_dir - def export_scalars(self, suffix=''): + def export_scalars(self, suffix=""): """Save self.all_scalars""" - if self.data_sources == ('MIT1003',): + if self.data_sources == ("MIT1003",): export_dir = self.mit1003_dir else: export_dir = self.train_dir - with open(export_dir / f'all_scalars{suffix}.json', 'w') as f: - json.dump(self.all_scalars, f, cls=utils.NumpyEncoder, - indent=2) + with open(export_dir / f"all_scalars{suffix}.json", "w") as f: + json.dump(self.all_scalars, f, cls=utils.NumpyEncoder, indent=2) def get_configs(self): """Get configurations of trainer, dataset and model instances""" return { - 'Trainer': self.asdict(), - 'Dataset': self.get_dataset('train', 'DHF1K').asdict(), - 'Model': self.model.asdict() + "Trainer": self.asdict(), + "Dataset": self.get_dataset("train", "DHF1K").asdict(), + "Model": self.model.asdict(), } @property def train_id(self): - return '/'.join(self.train_dir.parts[-2:]) + return "/".join(self.train_dir.parts[-2:])