From 9caba91136e0445c08e8f86328c3df18968768df Mon Sep 17 00:00:00 2001 From: Manvi07 Date: Fri, 23 Oct 2020 12:58:06 +0530 Subject: [PATCH 1/2] Improved EDGE Estimator criterion --- glow/information_bottleneck/estimator.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/glow/information_bottleneck/estimator.py b/glow/information_bottleneck/estimator.py index eef4228..042cfea 100644 --- a/glow/information_bottleneck/estimator.py +++ b/glow/information_bottleneck/estimator.py @@ -137,15 +137,15 @@ def criterion(self, x, y): N = torch.zeros(F, 1) M = torch.zeros(F, 1) L = torch.zeros(F, F) - """ + for k, x_k in enumerate(x): y_k = y[k] - i = h(x_k) - j = h(y_k) - N[i] = N[i] + 1 - M[j] = M[j] + 1 - L[i][j] = L[i][j] + 1 - """ + i = int(h(x_k)) + j = int(h(y_k)) + N[i] = N[i] + 1.0 + M[j] = M[j] + 1.0 + L[i][j] = L[i][j] + 1.0 + N = torch.nn.functional.one_hot(N.long().view(-1, 1), F) N = torch.sum(N, dim=0) @@ -154,11 +154,12 @@ def criterion(self, x, y): n = (1 / num_samples) * N m = (1 / num_samples) * M - temp_matrix = torch.mm(N, torch.transpose(M, 0, 1)) + + temp_matrix = torch.mm(N, torch.transpose(M, 0, 1)).type(torch.FloatTensor) zero_matrix = torch.zeros(F, F) - w = torch.addcdiv(zero_matrix, num_samples, L, temp_matrix) + w = torch.addcdiv(zero_matrix, L, temp_matrix, value=num_samples) temp_matrix = torch.mm(n, torch.transpose(m, 0, 1)) - mut_info = torch.sum(temp_matrix * g_hat(w)) + mut_info = torch.sum(temp_matrix * self.g(w)) return mut_info From cbefebdbb7141607c7aa9918db7a4afbd0534db1 Mon Sep 17 00:00:00 2001 From: Manvi07 Date: Mon, 2 Nov 2020 23:12:42 +0530 Subject: [PATCH 2/2] Added EDGE criterion --- glow/information_bottleneck/estimator.py | 59 ++++++++++++------------ 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/glow/information_bottleneck/estimator.py b/glow/information_bottleneck/estimator.py index 042cfea..3e97630 100644 --- a/glow/information_bottleneck/estimator.py +++ b/glow/information_bottleneck/estimator.py @@ -3,6 +3,7 @@ import math import torch import glow.utils.hsic_utils as kernel_module +import numpy as np class Estimator: @@ -112,12 +113,13 @@ class EDGE(Estimator): """ - def __init__(self, hash_function, gpu=True, **kwargs): + def __init__(self, hash_function, U=10, gpu=True, **kwargs): super().__init__(gpu, **kwargs) self.hash_function = hash_function + self.U = U def g(self, x): - return x * torch.log(x) * (1 / math.log(10)) + return x * math.log(x) * (1 / math.log(2)) def criterion(self, x, y): """ @@ -127,39 +129,36 @@ def criterion(self, x, y): """ h = hash_module.get(self.hash_function, self.params_dict) num_samples = x.shape[0] - if "F" in self.params_dict.keys(): - F = self.params_dict["F"] * num_samples - else: - raise Exception( - "Cannot find argument for number of nodes of dependency graph in EDGE estimator" - ) - N = torch.zeros(F, 1) - M = torch.zeros(F, 1) - L = torch.zeros(F, F) + edge_list = [] + N = {} + M = {} + L = {} for k, x_k in enumerate(x): y_k = y[k] - i = int(h(x_k)) - j = int(h(y_k)) - N[i] = N[i] + 1.0 - M[j] = M[j] + 1.0 - L[i][j] = L[i][j] + 1.0 + i = h(x_k) + j = h(y_k) + if list(x_k.size()) == []: + i = i.item() + j = j.item() + else: + i = tuple(i.tolist()) + j = tuple(j.tolist()) + + N[i] = (N[i] + 1.0) if i in N else 1.0 + M[j] = (M[j] + 1.0) if j in M else 1.0 + L[i,j] = (L[i,j] + 1.0) if (i,j) in L else 1.0 + edge_list.append((i, j)) + + mut_info = 0.0 + + for i, j in edge_list: + wi = 1.0 * N[i] / num_samples + wj = 1.0 * M[j] / num_samples + wij = min(self.U, 1.0 * L[i,j] * num_samples / (N[i]*M[j])) + mut_info += wi * wj * self.g(wij) - N = torch.nn.functional.one_hot(N.long().view(-1, 1), F) - N = torch.sum(N, dim=0) - - M = torch.nn.functional.one_hot(M.long().view(-1, 1), F) - M = torch.sum(M, dim=0) - - n = (1 / num_samples) * N - m = (1 / num_samples) * M - - temp_matrix = torch.mm(N, torch.transpose(M, 0, 1)).type(torch.FloatTensor) - zero_matrix = torch.zeros(F, F) - w = torch.addcdiv(zero_matrix, L, temp_matrix, value=num_samples) - temp_matrix = torch.mm(n, torch.transpose(m, 0, 1)) - mut_info = torch.sum(temp_matrix * self.g(w)) return mut_info