diff --git a/funque_plus/feature_extractors/funque_feature_extractors.py b/funque_plus/feature_extractors/funque_feature_extractors.py index 78c6e0f..10f6ef6 100755 --- a/funque_plus/feature_extractors/funque_feature_extractors.py +++ b/funque_plus/feature_extractors/funque_feature_extractors.py @@ -205,7 +205,7 @@ def __init__(self, use_cache: bool = True, sample_rate: Optional[int] = None) -> self.wavelet_levels = 2 self.csf = 'nadenau_spat' self.wavelet = 'haar' - self.feat_names = [f'ms_ssim_cov_channel_y_levels_{self.wavelet_levels}', f'dlm_channel_y_scale_{self.wavelet_levels}', f'strred_scalar_channel_y_levels_{self.wavelet_levels}', f'mad_dis_channel_y_scale_{self.wavelet_levels}', f'sai_diff_channel_y_scale_{self.wavelet_levels}'] + self.feat_names = [f'ms_ssim_cov_channel_y_levels_{self.wavelet_levels}', f'dlm_channel_y_scale_{self.wavelet_levels}', f'strred_scalar_channel_y_levels_{self.wavelet_levels}', f'strred_ldr_scalar_channel_y_levels_{self.wavelet_levels}', f'mad_dis_channel_y_scale_{self.wavelet_levels}', f'sai_diff_channel_y_scale_{self.wavelet_levels}'] def _run_on_asset(self, asset_dict: Dict[str, Any]) -> Result: sample_interval = self._get_sample_interval(asset_dict) @@ -271,12 +271,14 @@ def _run_on_asset(self, asset_dict: Dict[str, Any]) -> Result: # STRRED features (_, _, strred_scales) = pyr_features.strred_hv_pyr(pyr_ref, pyr_dis, prev_pyr_ref, prev_pyr_dis, block_size=1) + (_, _, strred_ldr_scales) = pyr_features.strred_hv_pyr_low_dynamic_range(pyr_ref, pyr_dis, prev_pyr_ref, prev_pyr_dis, block_size=1) else: motion_val = 0 strred_scales = [0]*self.wavelet_levels feats_dict[f'mad_dis_channel_{channel_name}_scale_{self.wavelet_levels}'].append(motion_val) feats_dict[f'strred_scalar_channel_{channel_name}_levels_{self.wavelet_levels}'].append(strred_scales[-1]) + feats_dict[f'strred_ldr_scalar_channel_{channel_name}_levels_{self.wavelet_levels}'].append(strred_ldr_scales[-1]) # TLVQM-like features # Spatial activity - swap Haar H, V for Sobel H, V diff --git a/funque_plus/features/funque_atoms/pyr_features.py b/funque_plus/features/funque_atoms/pyr_features.py index 9137f12..2262a36 100644 --- a/funque_plus/features/funque_atoms/pyr_features.py +++ b/funque_plus/features/funque_atoms/pyr_features.py @@ -5,7 +5,7 @@ from .gsm_utils import gsm_model, im2col from .filter_utils import filter_pyr from .rred_utils import rred_entropies_and_scales - +from .rred_utils import rred_entropies_and_scales_low_dynamic_range from pywt import dwt2 @@ -434,6 +434,55 @@ def strred_hv_pyr(pyr_ref, pyr_dist, prev_pyr_ref, prev_pyr_dist, block_size=3, return (srred_vals, trred_vals, strred_vals) else: return (srred_vals, trred_vals, strred_vals), (spat_vals, temp_vals, spat_temp_vals) + +def strred_hv_pyr_low_dynamic_range(pyr_ref, pyr_dist, prev_pyr_ref, prev_pyr_dist, block_size=3, single=False, full=False): + # Pyramids are assumed to have the structure + # ([A1, ..., An], [(H1, V1, D1), ..., (Hn, Vn, Dn)]) + approxs_ref, details_ref = pyr_ref + approxs_dist, details_dist = pyr_dist + assert len(details_ref) == len(details_dist), 'Both wavelet pyramids must be of the same height' + n_levels = len(details_ref) + spat_gsm_ref_details = [tuple([rred_entropies_and_scales_low_dynamic_range(subband, block_size) for subband in level]) for level in details_ref] + spat_gsm_dist_details = [tuple([rred_entropies_and_scales_low_dynamic_range(subband, block_size) for subband in level]) for level in details_dist] + compute_temporal = (prev_pyr_ref is not None and prev_pyr_dist is not None) + if compute_temporal: + prev_approxs_ref, prev_details_ref = prev_pyr_ref + prev_approxs_dist, prev_details_dist = prev_pyr_dist + temp_gsm_ref_details = [tuple([rred_entropies_and_scales_low_dynamic_range(subband - prev_subband, block_size) for subband, prev_subband in zip(level, prev_level)]) + for level, prev_level in zip(details_ref, prev_details_ref)] + temp_gsm_dist_details = [tuple([rred_entropies_and_scales_low_dynamic_range(subband - prev_subband, block_size) for subband, prev_subband in zip(level, prev_level)]) + for level, prev_level in zip(details_dist, prev_details_dist)] + + agg = lambda x: np.abs(np.mean(x)) if single else np.mean(np.abs(x)) + + spat_vals = np.array([ + np.mean([agg(scale_ref * entropy_ref - scale_dist * entropy_dist) for (entropy_ref, scale_ref), (entropy_dist, scale_dist) in zip(level_ref, level_dist)]) + for level_ref, level_dist in zip(spat_gsm_ref_details, spat_gsm_dist_details) + ]) + + if compute_temporal: + temp_vals = np.array([ + np.mean([ + agg(spat_scale_ref * temp_scale_ref * entropy_ref - spat_scale_dist * temp_scale_dist * entropy_dist) + for (_, spat_scale_ref), (_, spat_scale_dist), (entropy_ref, temp_scale_ref), (entropy_dist, temp_scale_dist) + in zip(spat_level_ref, spat_level_dist, temp_level_ref, temp_level_dist) + ]) + for spat_level_ref, spat_level_dist, temp_level_ref, temp_level_dist in zip(spat_gsm_ref_details, spat_gsm_dist_details, temp_gsm_ref_details, temp_gsm_dist_details) + ]) + spat_temp_vals = spat_vals * temp_vals + else: + temp_vals = np.zeros_like(spat_vals) + spat_temp_vals = np.zeros_like(spat_vals) + + norm_factors = np.arange(1, n_levels+1) + srred_vals = np.cumsum(spat_vals) / norm_factors + trred_vals = np.cumsum(temp_vals) / norm_factors + strred_vals = np.cumsum(spat_temp_vals) / norm_factors + + if not full: + return (srred_vals, trred_vals, strred_vals) + else: + return (srred_vals, trred_vals, strred_vals), (spat_vals, temp_vals, spat_temp_vals) def blur_edge_pyr(pyr_ref, pyr_dis, mode='both'): if mode not in ['blur', 'edge', 'both']: diff --git a/funque_plus/features/funque_atoms/rred_utils.py b/funque_plus/features/funque_atoms/rred_utils.py index f906d27..a533a3b 100644 --- a/funque_plus/features/funque_atoms/rred_utils.py +++ b/funque_plus/features/funque_atoms/rred_utils.py @@ -39,4 +39,42 @@ def rred_entropies_and_scales(subband, block_size=3): entropies = entropies + np.log(s*lamda[j]+sigma_nsq) + np.log(2*np.pi*np.exp(1)) scales = np.log(1 + s) + return entropies, scales + +def rred_entropies_and_scales_low_dynamic_range(subband, block_size=3): + sigma_nsq = 0.01 + tol = 1e-10 + + if block_size == 1: + entr_const = np.log(2*np.pi*np.exp(1)) + sigma_nsq = 0.01 + k = 9 + k_norm = k**2 + x_pad = np.pad(subband, int((k - 1)/2), mode='reflect') + int_1_x = integral_image(x_pad) + int_2_x = integral_image(x_pad*x_pad) + mu_x = (int_1_x[:-k, :-k] - int_1_x[:-k, k:] - int_1_x[k:, :-k] + int_1_x[k:, k:])/k_norm + var_x = (int_2_x[:-k, :-k] - int_2_x[:-k, k:] - int_2_x[k:, :-k] + int_2_x[k:, k:])/k_norm - mu_x**2 + var_x = np.clip(var_x, 0, None) + entropies = np.log(var_x + sigma_nsq) + entr_const + scales = np.log(0.1 + var_x) + else: + if np.iscomplexobj(subband): + s, cov, rel = complex_gsm_model(subband, block_size) + cov_x = 0.5*(np.real(cov) + np.real(rel)) + cov_y = 0.5*(np.real(cov) - np.real(rel)) + cov_xy = 0.5*(np.imag(rel) - np.imag(cov)) + cov_real = np.block([[cov_x, cov_xy], [cov_xy.T, cov_y]]) + lamda, _ = np.linalg.eigh(cov_real) + lamda[lamda < tol] = tol + else: + s, lamda, cov = gsm_model(subband, block_size) + + n_eigs = (2 if np.iscomplexobj(subband) else 1)*block_size*block_size + + entropies = np.zeros_like(s) + for j in range(n_eigs): + entropies = entropies + np.log(s*lamda[j]+sigma_nsq) + np.log(2*np.pi*np.exp(1)) + scales = np.log(1 + s) + return entropies, scales \ No newline at end of file