diff --git a/glow/information_bottleneck/estimator.py b/glow/information_bottleneck/estimator.py index eef4228..3e97630 100644 --- a/glow/information_bottleneck/estimator.py +++ b/glow/information_bottleneck/estimator.py @@ -3,6 +3,7 @@ import math import torch import glow.utils.hsic_utils as kernel_module +import numpy as np class Estimator: @@ -112,12 +113,13 @@ class EDGE(Estimator): """ - def __init__(self, hash_function, gpu=True, **kwargs): + def __init__(self, hash_function, U=10, gpu=True, **kwargs): super().__init__(gpu, **kwargs) self.hash_function = hash_function + self.U = U def g(self, x): - return x * torch.log(x) * (1 / math.log(10)) + return x * math.log(x) * (1 / math.log(2)) def criterion(self, x, y): """ @@ -127,38 +129,36 @@ def criterion(self, x, y): """ h = hash_module.get(self.hash_function, self.params_dict) num_samples = x.shape[0] - if "F" in self.params_dict.keys(): - F = self.params_dict["F"] * num_samples - else: - raise Exception( - "Cannot find argument for number of nodes of dependency graph in EDGE estimator" - ) - N = torch.zeros(F, 1) - M = torch.zeros(F, 1) - L = torch.zeros(F, F) - """ + edge_list = [] + N = {} + M = {} + L = {} + for k, x_k in enumerate(x): y_k = y[k] i = h(x_k) j = h(y_k) - N[i] = N[i] + 1 - M[j] = M[j] + 1 - L[i][j] = L[i][j] + 1 - """ - N = torch.nn.functional.one_hot(N.long().view(-1, 1), F) - N = torch.sum(N, dim=0) - - M = torch.nn.functional.one_hot(M.long().view(-1, 1), F) - M = torch.sum(M, dim=0) - - n = (1 / num_samples) * N - m = (1 / num_samples) * M - temp_matrix = torch.mm(N, torch.transpose(M, 0, 1)) - zero_matrix = torch.zeros(F, F) - w = torch.addcdiv(zero_matrix, num_samples, L, temp_matrix) - temp_matrix = torch.mm(n, torch.transpose(m, 0, 1)) - mut_info = torch.sum(temp_matrix * g_hat(w)) + if list(x_k.size()) == []: + i = i.item() + j = j.item() + else: + i = tuple(i.tolist()) + j = tuple(j.tolist()) + + N[i] = (N[i] + 1.0) if i in N else 1.0 + M[j] = (M[j] + 1.0) if j in M else 1.0 + L[i,j] = (L[i,j] + 1.0) if (i,j) in L else 1.0 + edge_list.append((i, j)) + + mut_info = 0.0 + + for i, j in edge_list: + wi = 1.0 * N[i] / num_samples + wj = 1.0 * M[j] / num_samples + wij = min(self.U, 1.0 * L[i,j] * num_samples / (N[i]*M[j])) + mut_info += wi * wj * self.g(wij) + return mut_info