Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion funque_plus/feature_extractors/funque_feature_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def __init__(self, use_cache: bool = True, sample_rate: Optional[int] = None) ->
self.wavelet_levels = 2
self.csf = 'nadenau_spat'
self.wavelet = 'haar'
self.feat_names = [f'ms_ssim_cov_channel_y_levels_{self.wavelet_levels}', f'dlm_channel_y_scale_{self.wavelet_levels}', f'strred_scalar_channel_y_levels_{self.wavelet_levels}', f'mad_dis_channel_y_scale_{self.wavelet_levels}', f'sai_diff_channel_y_scale_{self.wavelet_levels}']
self.feat_names = [f'ms_ssim_cov_channel_y_levels_{self.wavelet_levels}', f'dlm_channel_y_scale_{self.wavelet_levels}', f'strred_scalar_channel_y_levels_{self.wavelet_levels}', f'strred_ldr_scalar_channel_y_levels_{self.wavelet_levels}', f'mad_dis_channel_y_scale_{self.wavelet_levels}', f'sai_diff_channel_y_scale_{self.wavelet_levels}']

def _run_on_asset(self, asset_dict: Dict[str, Any]) -> Result:
sample_interval = self._get_sample_interval(asset_dict)
Expand Down Expand Up @@ -271,12 +271,14 @@ def _run_on_asset(self, asset_dict: Dict[str, Any]) -> Result:

# STRRED features
(_, _, strred_scales) = pyr_features.strred_hv_pyr(pyr_ref, pyr_dis, prev_pyr_ref, prev_pyr_dis, block_size=1)
(_, _, strred_ldr_scales) = pyr_features.strred_hv_pyr_low_dynamic_range(pyr_ref, pyr_dis, prev_pyr_ref, prev_pyr_dis, block_size=1)
else:
motion_val = 0
strred_scales = [0]*self.wavelet_levels

feats_dict[f'mad_dis_channel_{channel_name}_scale_{self.wavelet_levels}'].append(motion_val)
feats_dict[f'strred_scalar_channel_{channel_name}_levels_{self.wavelet_levels}'].append(strred_scales[-1])
feats_dict[f'strred_ldr_scalar_channel_{channel_name}_levels_{self.wavelet_levels}'].append(strred_ldr_scales[-1])

# TLVQM-like features
# Spatial activity - swap Haar H, V for Sobel H, V
Expand Down
51 changes: 50 additions & 1 deletion funque_plus/features/funque_atoms/pyr_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .gsm_utils import gsm_model, im2col
from .filter_utils import filter_pyr
from .rred_utils import rred_entropies_and_scales

from .rred_utils import rred_entropies_and_scales_low_dynamic_range

from pywt import dwt2

Expand Down Expand Up @@ -434,6 +434,55 @@ def strred_hv_pyr(pyr_ref, pyr_dist, prev_pyr_ref, prev_pyr_dist, block_size=3,
return (srred_vals, trred_vals, strred_vals)
else:
return (srred_vals, trred_vals, strred_vals), (spat_vals, temp_vals, spat_temp_vals)

def strred_hv_pyr_low_dynamic_range(pyr_ref, pyr_dist, prev_pyr_ref, prev_pyr_dist, block_size=3, single=False, full=False):
# Pyramids are assumed to have the structure
# ([A1, ..., An], [(H1, V1, D1), ..., (Hn, Vn, Dn)])
approxs_ref, details_ref = pyr_ref
approxs_dist, details_dist = pyr_dist
assert len(details_ref) == len(details_dist), 'Both wavelet pyramids must be of the same height'
n_levels = len(details_ref)
spat_gsm_ref_details = [tuple([rred_entropies_and_scales_low_dynamic_range(subband, block_size) for subband in level]) for level in details_ref]
spat_gsm_dist_details = [tuple([rred_entropies_and_scales_low_dynamic_range(subband, block_size) for subband in level]) for level in details_dist]
compute_temporal = (prev_pyr_ref is not None and prev_pyr_dist is not None)
if compute_temporal:
prev_approxs_ref, prev_details_ref = prev_pyr_ref
prev_approxs_dist, prev_details_dist = prev_pyr_dist
temp_gsm_ref_details = [tuple([rred_entropies_and_scales_low_dynamic_range(subband - prev_subband, block_size) for subband, prev_subband in zip(level, prev_level)])
for level, prev_level in zip(details_ref, prev_details_ref)]
temp_gsm_dist_details = [tuple([rred_entropies_and_scales_low_dynamic_range(subband - prev_subband, block_size) for subband, prev_subband in zip(level, prev_level)])
for level, prev_level in zip(details_dist, prev_details_dist)]

agg = lambda x: np.abs(np.mean(x)) if single else np.mean(np.abs(x))

spat_vals = np.array([
np.mean([agg(scale_ref * entropy_ref - scale_dist * entropy_dist) for (entropy_ref, scale_ref), (entropy_dist, scale_dist) in zip(level_ref, level_dist)])
for level_ref, level_dist in zip(spat_gsm_ref_details, spat_gsm_dist_details)
])

if compute_temporal:
temp_vals = np.array([
np.mean([
agg(spat_scale_ref * temp_scale_ref * entropy_ref - spat_scale_dist * temp_scale_dist * entropy_dist)
for (_, spat_scale_ref), (_, spat_scale_dist), (entropy_ref, temp_scale_ref), (entropy_dist, temp_scale_dist)
in zip(spat_level_ref, spat_level_dist, temp_level_ref, temp_level_dist)
])
for spat_level_ref, spat_level_dist, temp_level_ref, temp_level_dist in zip(spat_gsm_ref_details, spat_gsm_dist_details, temp_gsm_ref_details, temp_gsm_dist_details)
])
spat_temp_vals = spat_vals * temp_vals
else:
temp_vals = np.zeros_like(spat_vals)
spat_temp_vals = np.zeros_like(spat_vals)

norm_factors = np.arange(1, n_levels+1)
srred_vals = np.cumsum(spat_vals) / norm_factors
trred_vals = np.cumsum(temp_vals) / norm_factors
strred_vals = np.cumsum(spat_temp_vals) / norm_factors

if not full:
return (srred_vals, trred_vals, strred_vals)
else:
return (srred_vals, trred_vals, strred_vals), (spat_vals, temp_vals, spat_temp_vals)

def blur_edge_pyr(pyr_ref, pyr_dis, mode='both'):
if mode not in ['blur', 'edge', 'both']:
Expand Down
38 changes: 38 additions & 0 deletions funque_plus/features/funque_atoms/rred_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,42 @@ def rred_entropies_and_scales(subband, block_size=3):
entropies = entropies + np.log(s*lamda[j]+sigma_nsq) + np.log(2*np.pi*np.exp(1))
scales = np.log(1 + s)

return entropies, scales

def rred_entropies_and_scales_low_dynamic_range(subband, block_size=3):
sigma_nsq = 0.01
tol = 1e-10

if block_size == 1:
entr_const = np.log(2*np.pi*np.exp(1))
sigma_nsq = 0.01
k = 9
k_norm = k**2
x_pad = np.pad(subband, int((k - 1)/2), mode='reflect')
int_1_x = integral_image(x_pad)
int_2_x = integral_image(x_pad*x_pad)
mu_x = (int_1_x[:-k, :-k] - int_1_x[:-k, k:] - int_1_x[k:, :-k] + int_1_x[k:, k:])/k_norm
var_x = (int_2_x[:-k, :-k] - int_2_x[:-k, k:] - int_2_x[k:, :-k] + int_2_x[k:, k:])/k_norm - mu_x**2
var_x = np.clip(var_x, 0, None)
entropies = np.log(var_x + sigma_nsq) + entr_const
scales = np.log(0.1 + var_x)
else:
if np.iscomplexobj(subband):
s, cov, rel = complex_gsm_model(subband, block_size)
cov_x = 0.5*(np.real(cov) + np.real(rel))
cov_y = 0.5*(np.real(cov) - np.real(rel))
cov_xy = 0.5*(np.imag(rel) - np.imag(cov))
cov_real = np.block([[cov_x, cov_xy], [cov_xy.T, cov_y]])
lamda, _ = np.linalg.eigh(cov_real)
lamda[lamda < tol] = tol
else:
s, lamda, cov = gsm_model(subband, block_size)

n_eigs = (2 if np.iscomplexobj(subband) else 1)*block_size*block_size

entropies = np.zeros_like(s)
for j in range(n_eigs):
entropies = entropies + np.log(s*lamda[j]+sigma_nsq) + np.log(2*np.pi*np.exp(1))
scales = np.log(1 + s)

return entropies, scales