diff --git a/MNIST.py b/MNIST.py new file mode 100644 index 0000000..a1c5212 --- /dev/null +++ b/MNIST.py @@ -0,0 +1,30 @@ +import gzip +import pickle + +import h5py +import numpy as np + + +def preprocess_MNIST( + filename="mnist.pkl.gz", + binary_threshold=0.3, + out_dir="dataset", +): + with gzip.open(filename, "rb") as f: + training_data, validation_data, test_data = pickle.load(f, +encoding="latin1") + + names = ["MNIST_train.h5", "MNIST_val.h5", "MNIST_test.h5"] + datasets = [training_data, validation_data, test_data] + for dataset, name in zip(datasets, names): + curr_data = np.array(dataset[0]) + curr_data = (curr_data > binary_threshold).astype("float") + curr_labels = np.array(dataset[1]) + + with h5py.File(name, "w") as f: + f["samples"] = curr_data + f["labels"] = curr_labels + + +if __name__ == "__main__": + preprocess_MNIST() \ No newline at end of file diff --git a/clusterbm b/clusterbm new file mode 160000 index 0000000..c4375aa --- /dev/null +++ b/clusterbm @@ -0,0 +1 @@ +Subproject commit c4375aac469329e605026a12f7f1082adebd5a53 diff --git a/mnist.pkl.gz b/mnist.pkl.gz new file mode 100644 index 0000000..6a73954 Binary files /dev/null and b/mnist.pkl.gz differ diff --git a/rbms/bernoulli_bernoulli/classes.py b/rbms/bernoulli_bernoulli/classes.py index dd75ac8..af5da35 100644 --- a/rbms/bernoulli_bernoulli/classes.py +++ b/rbms/bernoulli_bernoulli/classes.py @@ -25,6 +25,8 @@ def __init__( weight_matrix: Tensor, vbias: Tensor, hbias: Tensor, + K1: Tensor, + K2: Tensor, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): @@ -46,8 +48,13 @@ def __init__( self.device = device self.dtype = dtype self.weight_matrix = weight_matrix.to(device=self.device, dtype=self.dtype) + self.w_norm_0 = torch.norm(weight_matrix) self.vbias = vbias.to(device=self.device, dtype=self.dtype) + self.v_norm_0 = torch.norm(vbias) self.hbias = hbias.to(device=self.device, dtype=self.dtype) + self.K1 = K1.to(device=self.device, dtype=self.dtype) + self.K2 = K2.to(device=self.device, dtype=self.dtype) + self.K2_norm_0 = torch.norm(K2) self.name = "BBRBM" def __add__(self, other): @@ -104,7 +111,7 @@ def compute_energy_visibles(self, v: Tensor) -> Tensor: weight_matrix=self.weight_matrix, ) - def compute_gradient(self, data, chains, centered=True): + def compute_gradient(self, data, chains, use_fields, centered=True): _compute_gradient( v_data=data["visible"], mh_data=data["hidden_mag"], @@ -145,7 +152,7 @@ def init_chains(self, num_samples, weights=None, start_v=None): ) @staticmethod - def init_parameters(num_hiddens, dataset, device, dtype, var_init=0.0001): + def init_parameters(num_hiddens, dataset, device, dtype, beta, use_fields, var_init=0.0001): data = dataset.data # Convert to torch Tensor if necessary if isinstance(data, np.ndarray): @@ -156,14 +163,21 @@ def init_parameters(num_hiddens, dataset, device, dtype, var_init=0.0001): device=device, dtype=dtype, var_init=var_init, + beta=beta, ) - return BBRBM(weight_matrix=weight_matrix, vbias=vbias, hbias=hbias) + num_visible = len(data[0,:]) + K1 = torch.randn_like(weight_matrix, device=device, dtype=dtype)/np.sqrt(float(num_hiddens)) + K2 = torch.randn_like(weight_matrix, device=device, dtype=dtype)/np.sqrt(float(num_hiddens)) + + return BBRBM(weight_matrix=weight_matrix, vbias=vbias, hbias=hbias, K1=K1, K2=K2) def named_parameters(self): return { "weight_matrix": self.weight_matrix, "vbias": self.vbias, "hbias": self.hbias, + "K1": self.K1, + "K2": self.K2, } def num_hiddens(self): @@ -210,6 +224,8 @@ def set_named_parameters(named_params: dict[str, Tensor]) -> Self: weight_matrix=named_params.pop("weight_matrix"), vbias=named_params.pop("vbias"), hbias=named_params.pop("hbias"), + K1=named_params.pop("K1"), + K2=named_params.pop("K2") ) if len(named_params.keys()) > 0: raise ValueError( @@ -227,4 +243,62 @@ def to( self.weight_matrix = self.weight_matrix.to(device=self.device, dtype=self.dtype) self.vbias = self.vbias.to(device=self.device, dtype=self.dtype) self.hbias = self.hbias.to(device=self.device, dtype=self.dtype) + self.K1 = self.K1.to(device=self.device, dtype=self.dtype) + self.K2 = self.K2.to(device=self.device, dtype=self.dtype) return self + + + # ───────────────────────── 1st-order PL (single visible site) ───────────────────────── + def compute_loss_PL1(self, data, l, use_fields=True, use_hfield=True): + x = data # [M,N] entries ∈{0,1} + + # mean-field hidden expectation ⟨h_a⟩ ≃ tanh(λ⋅pre) + h_pre = torch.einsum("ia,mi->ma", self.K1, x) # Wᵀx + if use_hfield: + h_pre = h_pre + self.hbias + h = torch.tanh(l * h_pre) # ±1 hidden → tanh + + F = torch.einsum("ja,ma->mj", self.K1, h) # local field on each visible + if use_fields: + F = F + self.vbias + + logZ = F.softplus(l * F) if hasattr(F, "softplus") else F # F.softplus → log(1+e^{λF}) + e_i = -x * F + (1. / l) * logZ # −log P(x_i|rest)/λ + return e_i.mean() + + + # ───────────────────────── 2nd-order PL (visible–hidden pair) ───────────────────────── + def compute_loss_PL2(self, data, l, use_fields=True, use_hfield=True): + x = data # [M,N] + # hidden mean-field + h_pre = torch.einsum("ia,mi->ma", self.K2, x) + if use_fields and use_hfield: + h_pre = h_pre + self.hbias + h = torch.tanh(l * h_pre) # [M,A] + + # leave-one-out effective fields + b = torch.einsum("ja,ma->mj", self.K2, h) # vis-field from h + c = torch.einsum("ja,mj->ma", self.K2, x) # hid-field from x + if use_fields: + b = b + self.vbias + if use_hfield: + c = c + self.hbias + + j_term = torch.einsum("ja,ma->mja", self.K2, h) # W_ja h_a + a_term = torch.einsum("ja,mj->mja", self.K2, x) # W_ja x_j + b_i_eff = b.unsqueeze(2) - j_term # b̂_j|¬a [M,N,1] + c_a_eff = c.unsqueeze(1) - a_term # ĉ_a|¬j [M,1,A] + + # observed energy E_obs = −(W_ja x_j h_a + b̂_j x_j + ĉ_a h_a) + w_ai = torch.einsum("ma,ja,mj->mja", h, self.K2, x) # W_ja x_j h_a + h_ai = b_i_eff * x.unsqueeze(2) + c_a_eff * h.unsqueeze(1) + + # partition Z_{ja} over x_j∈{0,1}, h_a∈{±1} + z0 = torch.exp(-l * c_a_eff) # (x=0,h=-1) + z1 = torch.exp( l * c_a_eff) # (x=0,h=+1) + z2 = torch.exp( l * (b_i_eff - self.K2 - c_a_eff)) # (x=1,h=-1) + z3 = torch.exp( l * (b_i_eff + self.K2 + c_a_eff)) # (x=1,h=+1) + Z_ai = z0 + z1 + z2 + z3 + + e_ij = -w_ai - h_ai + (1. / l) * torch.log(Z_ai + 1e-9) # −log P(x_j,h_a|rest)/λ + return e_ij.mean() diff --git a/rbms/bernoulli_bernoulli/implement.py b/rbms/bernoulli_bernoulli/implement.py index 7221f3a..9ea6fde 100644 --- a/rbms/bernoulli_bernoulli/implement.py +++ b/rbms/bernoulli_bernoulli/implement.py @@ -159,6 +159,7 @@ def _init_parameters( device: torch.device, dtype: torch.dtype, var_init: float = 1e-4, + beta: float=1., ) -> Tuple[Tensor, Tensor, Tensor]: _, num_visibles = data.shape eps = 1e-4 @@ -168,7 +169,7 @@ def _init_parameters( ) frequencies = data.mean(0) frequencies = torch.clamp(frequencies, min=eps, max=(1.0 - eps)) - vbias = (torch.log(frequencies) - torch.log(1.0 - frequencies)).to( + vbias = 1/beta*(torch.log(frequencies) - torch.log(1.0 - frequencies)).to( device=device, dtype=dtype ) hbias = torch.zeros(num_hiddens, device=device, dtype=dtype) diff --git a/rbms/bernoulli_bernoulli_BM/__init__.py b/rbms/bernoulli_bernoulli_BM/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/rbms/bernoulli_bernoulli_BM/bm.py b/rbms/bernoulli_bernoulli_BM/bm.py new file mode 100644 index 0000000..427d511 --- /dev/null +++ b/rbms/bernoulli_bernoulli_BM/bm.py @@ -0,0 +1,330 @@ +import rbms +from rbms.classes import RBM + +from rbms.bernoulli_bernoulli_BM.implement import ( + _compute_energy, + _compute_energy_hiddens, + _compute_energy_visibles, + _compute_gradient, + _init_chains, + _init_parameters, + _sample_hiddens, + _sample_visibles, +) + +from typing import List, Optional, Self +import torch.nn.functional as F + + +import numpy as np +import torch +from torch import Tensor + +class BBBM(RBM): + """Parameters of the Bernoulli-Bernoulli RBM""" + + def __init__( + self, + weight_matrix: Tensor, + vbias: Tensor, + hbias: Tensor, + K1: Tensor, + K2: Tensor, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + """Initialize the parameters of the Bernoulli-Bernoulli RBM. + + Args: + weight_matrix (Tensor): The weight matrix of the RBM. + vbias (Tensor): The visible bias of the RBM. + hbias (Tensor): The hidden bias of the RBM. + device (Optional[torch.device], optional): The device for the parameters. + Defaults to the device of `weight_matrix`. + dtype (Optional[torch.dtype], optional): The data type for the parameters. + Defaults to the data type of `weight_matrix`. + """ + if device is None: + device = weight_matrix.device + if dtype is None: + dtype = weight_matrix.dtype + self.device = device + self.dtype = dtype + self.weight_matrix = weight_matrix.to(device=self.device, dtype=self.dtype) + self.w_norm_0 = torch.norm(weight_matrix) + self.vbias = vbias.to(device=self.device, dtype=self.dtype) + self.v_norm_0 = torch.norm(vbias) + self.hbias = hbias.to(device=self.device, dtype=self.dtype) + self.name = "BBBM" + self.N = len(weight_matrix[0]) + self.K1 = K1.to(device=self.device, dtype=self.dtype) + self.K2 = K2.to(device=self.device, dtype=self.dtype) + self.K2_norm_0 = torch.norm(K2) + self.mask = torch.ones_like(self.weight_matrix, device=self.device) # Shape [N, N] + self.mask.fill_diagonal_(0) # Set diagonal to 0 + + def __add__(self, other): + return BBBM( + weight_matrix=self.weight_matrix + other.weight_matrix, + vbias=self.vbias + other.vbias, + hbias=self.hbias + other.hbias, + ) + + def __mul__(self, other): + return BBBM( + weight_matrix=self.weight_matrix * other, + vbias=self.vbias * other, + hbias=self.hbias * other, + ) + + def clone( + self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None + ): + if device is None: + device = self.device + if dtype is None: + dtype = self.dtype + return BBBM( + weight_matrix=self.weight_matrix.clone(), + vbias=self.vbias.clone(), + hbias=self.hbias.clone(), + device=device, + dtype=dtype, + ) + + def compute_energy(self, v: Tensor, h: Tensor) -> Tensor: + return _compute_energy( + v=v, + vbias=self.vbias, + weight_matrix=self.weight_matrix, + ) + + def compute_energy_hiddens(self, h: Tensor) -> Tensor: + return _compute_energy_hiddens( + h=h, + vbias=self.vbias, + hbias=self.hbias, + weight_matrix=self.weight_matrix, + ) + + def compute_energy_visibles(self, v: Tensor) -> Tensor: + return _compute_energy_visibles( + v=v, + vbias=self.vbias, + hbias=self.hbias, + weight_matrix=self.weight_matrix, + ) + + def compute_gradient(self, data, chains, use_fields, centered=True): + _compute_gradient( + v_data=data["visible"], + v_chain=chains["visible"], + weight_matrix=self.weight_matrix, + vbias=self.vbias, + use_fields=use_fields, + centered=centered, + ) + + def independent_model(self): + return BBBM( + weight_matrix=torch.zeros_like(self.weight_matrix), + vbias=self.vbias, + hbias=torch.zeros_like(self.hbias), + ) + + def init_chains(self, num_samples, weights=None, start_v=None): + visible = _init_chains( + num_samples=num_samples, + weight_matrix=self.weight_matrix, + hbias=self.hbias, + start_v=start_v, + ) + if weights is None: + weights = torch.ones( + visible.shape[0], device=visible.device, dtype=visible.dtype + ) + return dict( + visible=visible, + hidden=visible, + visible_mag=visible, + hidden_mag=visible, + weights=weights, + ) + + @staticmethod + def init_parameters(num_hiddens, dataset, device, dtype, beta, use_fields, var_init=0.1): + data = dataset.data + # Convert to torch Tensor if necessary + if isinstance(data, np.ndarray): + data = torch.from_numpy(dataset.data).to(device=device, dtype=dtype) + weight_matrix, vbias = _init_parameters( + data=data, + device=device, + dtype=dtype, + var_init=var_init, + beta=beta + ) + num_visible = len(data[0,:]) + if use_fields==False: + vbias = torch.zeros_like(weight_matrix[0], device=device, dtype=dtype) + + hbias = torch.zeros_like(weight_matrix[0], device=device, dtype=dtype) + K1 = torch.randn_like(weight_matrix, device=device, dtype=dtype)/np.sqrt(float(num_visible)) + K2 = torch.randn_like(weight_matrix, device=device, dtype=dtype)/np.sqrt(float(num_visible)) + K1 = (K1.T+K1)/2 + K2 = (K2.T+K2)/2 + K1.fill_diagonal_(0.0) + K2.fill_diagonal_(0.0) + + return BBBM(weight_matrix=weight_matrix, vbias=vbias, hbias=hbias, K1=K1, K2=K2) + + def named_parameters(self): + return { + "weight_matrix": self.weight_matrix, + "vbias": self.vbias, + "hbias": self.hbias, + "K1": self.K1, + "K2": self.K2, + } + + def num_hiddens(self): + return self.hbias.shape[0] + + def num_visibles(self): + return self.weight_matrix.shape[0] + + def parameters(self) -> List[Tensor]: + return [self.weight_matrix, self.vbias, self.hbias, self.K1, self.K2] + + def ref_log_z(self): + return ( + self.num_visibles() * np.log(2) + ).item() + + def sample_hiddens(self, chains: dict[str, Tensor], beta=1) -> dict[str, Tensor]: + chains["hidden"], chains["hidden_mag"] = _sample_hiddens( + v=chains["visible"], + weight_matrix=self.weight_matrix, + hbias=self.hbias, + beta=beta, + ) + return chains + + def sample_visibles(self, chains: dict[str, Tensor], beta=1) -> dict[str, Tensor]: + chains["visible"], chains["visible_mag"] = _sample_visibles( + v=chains["visible"], + h=chains["hidden"], + weight_matrix=self.weight_matrix, + vbias=self.vbias, + beta=beta, + ) + return chains + + @staticmethod + def set_named_parameters(named_params: dict[str, Tensor]) -> Self: + names = ["vbias", "hbias", "weight_matrix", "K1", "K2"] + for k in names: + if k not in named_params.keys(): + raise ValueError( + f"""Dictionary params missing key '{k}'\n Provided keys : {named_params.keys()}\n Expected keys: {names}""" + ) + params = BBBM( + weight_matrix=named_params.pop("weight_matrix"), + vbias=named_params.pop("vbias"), + hbias=named_params.pop("hbias"), + K1=named_params.pop("K1"), + K2=named_params.pop("K2") + ) + if len(named_params.keys()) > 0: + raise ValueError( + f"Too many keys in params dictionary. Remaining keys: {named_params.keys()}" + ) + return params + + def to( + self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None + ): + if device is not None: + self.device = device + if dtype is not None: + self.dtype = dtype + self.weight_matrix = self.weight_matrix.to(device=self.device, dtype=self.dtype) + self.vbias = self.vbias.to(device=self.device, dtype=self.dtype) + self.hbias = self.hbias.to(device=self.device, dtype=self.dtype) + return self + + def Z_i_mu_func(self, h, l): + Z_i_mu = 2*torch.cosh(l*h) + return Z_i_mu + + def compute_pseudolikelihood_matrix(self, matrix, data, l): + h = torch.einsum('ij,mj->mi', matrix*self.mask, data) + x_J_x = torch.einsum('mi,mi->mi', data, h) + energy_i_mu = -x_J_x + (1 / l) * torch.log(self.Z_i_mu_func(h,l)+1e-9) + PL = energy_i_mu.mean() + return PL + + def compute_loss_PL1(self, data, l, use_fields, use_hfield=False): + x = data # [M,N] each entry ∈{0,1} + h = torch.einsum('ij,mj->mi', self.K1*self.mask.to(self.K1.device), x) # local field Σ_j K_ij x_j + if use_fields: # optional visible bias + h = h + self.vbias.unsqueeze(0) + + logZ = F.softplus(l*h) # log(1+e^{λ h}) – numerically stable + e_i = -x*h + (1./l)*logZ # −log P(x_i|x_{¬i}) / λ + return e_i.mean() # scalar loss + + + # ---------- second–order pseudolikelihood (pairwise) ---------- + def compute_loss_PL2(self, data, l, use_fields, use_hfield=False): + x = data # [M,N] + K = self.K2*self.mask # interaction matrix with zero diagonal + + # global fields for every unit + h = torch.einsum('ik,mk->mi', K, x) # [M,N] + diff_term = torch.einsum('ik,mk->mik', K, x) # [M,N,N] + + if use_fields: + h = h + self.vbias.unsqueeze(0) # add biases only once + + # “leave-one-out’’ effective fields + h_i_eff = h.unsqueeze(2) - diff_term # h_i − K_ij x_j [M,N,N] + h_j_eff = h.unsqueeze(1) - diff_term # h_j − K_ij x_i [M,N,N] + + x_i = x.unsqueeze(2) # x_i [M,N,1] + x_j = x.unsqueeze(1) # x_j [M,1,N] + + # energy part actually observed: −(K_ij x_i x_j + h_i_eff x_i + h_j_eff x_j) + E_pair = K * x_i * x_j + E_field = h_i_eff * x_i + h_j_eff * x_j + + # partition function for the {0,1}×{0,1} pair + Z_xx = ( + 1.0 # (0,0) + + torch.exp(l * h_i_eff) # (1,0) + + torch.exp(l * h_j_eff) # (0,1) + + torch.exp(l * (K + h_i_eff + h_j_eff)) # (1,1) + ) + + e_ij = -E_pair - E_field + (1./l)*torch.log(Z_xx + 1e-9) # −log P(x_i,x_j|rest)/λ + return e_ij.mean() + + + def normalize_w(self): + with torch.no_grad(): + norm = torch.norm(self.weight_matrix.data) + self.weight_matrix.data = self.weight_matrix.data * self.w_norm_0 / (norm+1e-9) + + def normalize_K2(self): + with torch.no_grad(): + norm = torch.norm(self.K2.data) + self.K2.data = self.K2.data * self.K2_norm_0 / (norm+1e-9) + + def normalize_v(self): + with torch.no_grad(): + norm = torch.norm(self.vbias.data) + self.vbias.data = self.vbias.data * self.v_norm_0 / (norm+1e-9) + + + + \ No newline at end of file diff --git a/rbms/bernoulli_bernoulli_BM/implement.py b/rbms/bernoulli_bernoulli_BM/implement.py new file mode 100644 index 0000000..b3b026a --- /dev/null +++ b/rbms/bernoulli_bernoulli_BM/implement.py @@ -0,0 +1,153 @@ +from typing import Optional, Tuple + +import torch +from torch import Tensor +from torch.nn.functional import softmax + + +@torch.jit.script +def _sample_hiddens( + v: Tensor, weight_matrix: Tensor, hbias: Tensor, beta: float = 1.0 +) -> Tuple[Tensor, Tensor]: + dtype = torch.float32 + return v, v #nothing happens to visible after dummy hidden passage + + +@torch.jit.script +def _sample_visibles( + v: Tensor, h: Tensor, weight_matrix: Tensor, vbias: Tensor, beta: float = 1.0 +) -> Tuple[Tensor, Tensor]: + x = v.clone() + num_chains, num_units = x.shape + for i in range(num_units): #it has to be sequential + h_i = torch.einsum("j,bj->b",weight_matrix[i],x) #diagonal is already zero + probs = torch.sigmoid(beta*(h_i+vbias[i])) + x[:,i] = torch.bernoulli(probs) + return x, x + + +@torch.jit.script +def _compute_energy( + v: Tensor, + vbias: Tensor, + weight_matrix: Tensor, +) -> Tensor: + x = v.clone() + fields = (x * vbias).sum(dim=1) + interaction = 0.5 * torch.einsum("bi,ij,bj->b", x, weight_matrix, x) + + return -fields - interaction + + +@torch.jit.script +def _compute_energy_visibles( + v: Tensor, vbias: Tensor, hbias: Tensor, weight_matrix: Tensor +) -> Tensor: + x = v.clone() + fields = (x * vbias).sum(dim=1) + interaction = 0.5 * torch.einsum("bi,ij,bj->b", x, weight_matrix, x) + + return -fields - interaction + + +@torch.jit.script +def _compute_energy_hiddens( + h: Tensor, vbias: Tensor, hbias: Tensor, weight_matrix: Tensor +) -> Tensor: + dtype = torch.float32 + return torch.tensor([0], device=weight_matrix.device, dtype=dtype) + + +@torch.jit.script +def _compute_gradient( + v_data: Tensor, + v_chain: Tensor, + weight_matrix: Tensor, + vbias: Tensor, + use_fields: bool, + centered: bool = False, +) -> None: + dtype = torch.float32 + + x_data = v_data + x_model = v_chain + + if centered: + pass + else: + pass + + # Empirical averages + data_mean = x_data.mean(dim=0) # shape (num_units,) + data_corr = (x_data.unsqueeze(2) * x_data.unsqueeze(1)).mean(dim=0) # shape (num_units, num_units) + # Model (chain) averages + model_mean = x_model.mean(dim=0) + model_corr = (x_model.unsqueeze(2) * x_model.unsqueeze(1)).mean(dim=0) + # Bias gradient + grad_weight_matrix = data_corr - model_corr + grad_weight_matrix.fill_diagonal_(0.0) + if use_fields==True: + grad_vbias = data_mean - model_mean + else: + grad_vbias = torch.tensor([0], device=weight_matrix.device, dtype=dtype) + #grad_hbias = torch.tensor([0], device=weight_matrix.device, dtype=weight_matrix.type) + + # Attach to the parameters + weight_matrix.grad.set_(grad_weight_matrix) + vbias.grad.set_(grad_vbias) + #hbias.grad.set_(grad_hbias) + +@torch.jit.script +def _init_chains( + num_samples: int, + weight_matrix: Tensor, + hbias: Tensor, + start_v: Optional[Tensor] = None, +): + num_visibles = weight_matrix.shape[0] + device = weight_matrix.device + dtype = torch.float32 + # Handle negative number of samples + if num_samples <= 0: + if start_v is not None: + num_samples = start_v.shape[0] + else: + raise ValueError(f"Got negative num_samples arg: {num_samples}") + + if start_v is None: + # Dummy mean visible + v = torch.randint(low=0, high=2, size=(num_samples, num_visibles), device=device, dtype=torch.float32) + else: + v = start_v.to(device=device, dtype=dtype) + return v + + +def _init_parameters( + data: Tensor, + device: torch.device, + dtype: torch.dtype, + var_init: float = 1e-2, + beta: float=1. +) -> Tuple[Tensor, Tensor, Tensor]: + _, num_visibles = data.shape + eps = 1e-4 + weight_matrix = ( + torch.randn(size=(num_visibles, num_visibles), device=device, dtype=dtype) + * var_init + ) + weight_matrix.fill_diagonal_(0.0) + weight_matrix = 0.5*(weight_matrix+weight_matrix.T) + ''' + frequencies = data.mean(0) + frequencies = torch.clamp(frequencies, min=eps, max=(1.0 - eps)) + vbias = (torch.log(frequencies) - torch.log(1.0 - frequencies)).to( + device=device, dtype=dtype + ) + hbias = torch.zeros(num_hiddens, device=device, dtype=dtype) + ''' + frequencies = data.mean(0) + frequencies = torch.clamp(frequencies, min=eps, max=(1.0 - eps)) + vbias = 1/beta*(torch.log(frequencies) - torch.log(1.0 - frequencies)).to( + device=device, dtype=dtype + ) + return weight_matrix, vbias diff --git a/rbms/classes.py b/rbms/classes.py index 50dc3d6..a1851f0 100644 --- a/rbms/classes.py +++ b/rbms/classes.py @@ -6,7 +6,6 @@ from rbms.dataset.dataset_class import RBMDataset - class RBM(ABC): """An abstract class representing the parameters of a RBM.""" diff --git a/rbms/ising/__init__.py b/rbms/ising/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/rbms/ising/implement.py b/rbms/ising/implement.py new file mode 100644 index 0000000..cf54b7f --- /dev/null +++ b/rbms/ising/implement.py @@ -0,0 +1,197 @@ +from typing import Optional, Tuple + +import torch +from torch import Tensor +from torch.nn.functional import softmax +from rbms.custom_fn import log2cosh + + +@torch.jit.script +def _sample_hiddens( + v: Tensor, weight_matrix: Tensor, hbias: Tensor, beta: float = 1.0 +) -> Tuple[Tensor, Tensor]: + tmp = beta * (hbias + (v @ weight_matrix)) + # mh = torch.exp(tmp) / (2 * torch.cosh(tmp)) + mh = torch.tanh(tmp) # Because of Ising + h = torch.bernoulli(0.5 * (mh+ 1)) * 2 - 1 + return h, mh + + +# @torch.jit.script +def _sample_visibles( + h: Tensor, weight_matrix: Tensor, vbias: Tensor, beta: float = 1.0 +) -> Tuple[Tensor, Tensor]: + tmp = beta * (vbias + (h @ weight_matrix.T)) + mv = torch.tanh(tmp) # Because of Ising + v = torch.bernoulli(0.5 * ( mv+ 1)) * 2 - 1 + return v, mv + + +@torch.jit.script +def _compute_energy( + v: Tensor, + h: Tensor, + vbias: Tensor, + hbias: Tensor, + weight_matrix: Tensor, +) -> Tensor: + fields = torch.tensordot(vbias, v, dims=[[0], [1]]) + torch.tensordot( + hbias, h, dims=[[0], [1]] + ) + interaction = torch.multiply( + v, torch.tensordot(h, weight_matrix, dims=[[1], [1]]) + ).sum(1) + + return -fields - interaction + + +@torch.jit.script +def _compute_energy_visibles( + v: Tensor, vbias: Tensor, hbias: Tensor, weight_matrix: Tensor +) -> Tensor: + field = v @ vbias + exponent = hbias + (v @ weight_matrix) + + log_term = log2cosh(exponent) + # log_term = torch.where( + # exponent < 10, torch.log(1.0 + torch.exp(exponent)), exponent + # ) + return -field - log_term.sum(1) + + +@torch.jit.script +def _compute_energy_hiddens( + h: Tensor, vbias: Tensor, hbias: Tensor, weight_matrix: Tensor +) -> Tensor: + field = h @ hbias + exponent = vbias + (h @ weight_matrix.T) + # log_term = torch.where( + # exponent < 10, torch.log(1.0 + torch.exp(exponent)), exponent + # ) + log_term = log2cosh(exponent) + + return -field - log_term.sum(1) + + +@torch.jit.script +def _compute_gradient( + v_data: Tensor, + mh_data: Tensor, + w_data: Tensor, + v_chain: Tensor, + h_chain: Tensor, + w_chain: Tensor, + vbias: Tensor, + hbias: Tensor, + weight_matrix: Tensor, + centered: bool = True, +) -> None: + w_data = w_data.view(-1, 1) + w_chain = w_chain.view(-1, 1) + # Turn the weights of the chains into normalized weights + chain_weights = softmax(-w_chain, dim=0) + w_data_norm = w_data.sum() + + # Averages over data and generated samples + v_data_mean = (v_data * w_data).sum(0) / w_data_norm + torch.clamp_(v_data_mean, min=-1+1e-7, max=(1.0 - 1e-7)) + h_data_mean = (mh_data * w_data).sum(0) / w_data_norm + v_gen_mean = (v_chain * chain_weights).sum(0) + torch.clamp_(v_gen_mean, min=-1+1e-7, max=(1.0 - 1e-7)) + h_gen_mean = (h_chain * chain_weights).sum(0) + + if centered: + # Centered variables + v_data_centered = v_data - v_data_mean + h_data_centered = mh_data - h_data_mean + v_gen_centered = v_chain - v_data_mean + h_gen_centered = h_chain - h_data_mean + + # Gradient + grad_weight_matrix = ( + (v_data_centered * w_data).T @ h_data_centered + ) / w_data_norm - ((v_gen_centered * chain_weights).T @ h_gen_centered) + grad_vbias = v_data_mean - v_gen_mean - (grad_weight_matrix @ h_data_mean) + grad_hbias = h_data_mean - h_gen_mean - (v_data_mean @ grad_weight_matrix) + else: + v_data_centered = v_data + h_data_centered = mh_data + v_gen_centered = v_chain + h_gen_centered = h_chain + + # Gradient + grad_weight_matrix = ((v_data * w_data).T @ mh_data) / w_data_norm - ( + (v_chain * chain_weights).T @ h_chain + ) + # grad_weight_matrix = (v_data.T @ mh_data) / v_data.shape[0] - (v_chain.T @ h_chain ) / v_chain.shape[0] + grad_vbias = v_data_mean - v_gen_mean + grad_hbias = h_data_mean - h_gen_mean + + # Attach to the parameters + + weight_matrix.grad.set_(grad_weight_matrix) + vbias.grad.set_(grad_vbias) + hbias.grad.set_(grad_hbias) + + +@torch.jit.script +def _init_chains( + num_samples: int, + weight_matrix: Tensor, + hbias: Tensor, + start_v: Optional[Tensor] = None, +): + num_visibles, _ = weight_matrix.shape + device = weight_matrix.device + dtype = weight_matrix.dtype + # Handle negative number of samples + if num_samples <= 0: + if start_v is not None: + num_samples = start_v.shape[0] + else: + raise ValueError(f"Got negative num_samples arg: {num_samples}") + + if start_v is None: + # Dummy mean visible + mv = ( + torch.ones(size=(num_samples, num_visibles), device=device, dtype=dtype) / 2 + ) + v = torch.bernoulli(mv) * 2 - 1 + else: + # Dummy mean visible + mv = torch.ones_like(start_v, device=device, dtype=dtype) / 2 + v = start_v.to(device=device, dtype=dtype) + + # Initialize chains + + h, mh = _sample_hiddens(v=v, weight_matrix=weight_matrix, hbias=hbias) + return v, h, mv, mh + + +def _init_parameters( + num_hiddens: int, + data: Tensor, + device: torch.device, + dtype: torch.dtype, + var_init: float = 1e-4, + beta: float=1., +) -> Tuple[Tensor, Tensor, Tensor]: + _, num_visibles = data.shape + eps = 1e-4 + weight_matrix = ( + torch.randn(size=(num_visibles, num_hiddens), device=device, dtype=dtype) + * var_init + ) + spin_means = data.mean(0) + spin_means = torch.clamp(spin_means,-0.95,0.95) + ''' + frequencies = torch.clamp(frequencies, min=eps, max=(1.0 - eps)) + + vbias = (torch.log(frequencies) - torch.log(1.0 - frequencies)).to( + device=device, dtype=dtype + ) + ''' + vbias = 1/beta*torch.atanh(spin_means) + + hbias = torch.zeros(num_hiddens, device=device, dtype=dtype) + return vbias, hbias, weight_matrix diff --git a/rbms/ising/ising_rbm.py b/rbms/ising/ising_rbm.py new file mode 100644 index 0000000..f544b26 --- /dev/null +++ b/rbms/ising/ising_rbm.py @@ -0,0 +1,310 @@ +from typing import List, Optional, Self + +import numpy as np +import torch +from torch import Tensor + +from rbms.ising.implement import ( + _compute_energy, + _compute_energy_hiddens, + _compute_energy_visibles, + _compute_gradient, + _init_chains, + _init_parameters, + _sample_hiddens, + _sample_visibles, +) +from rbms.classes import RBM + + +class IsingRBM(RBM): + """Parameters of the Bernoulli-Bernoulli RBM""" + + def __init__( + self, + weight_matrix: Tensor, + vbias: Tensor, + hbias: Tensor, + K1: Tensor, + K2: Tensor, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + """Initialize the parameters of the Bernoulli-Bernoulli RBM. + + Args: + weight_matrix (Tensor): The weight matrix of the RBM. + vbias (Tensor): The visible bias of the RBM. + hbias (Tensor): The hidden bias of the RBM. + device (Optional[torch.device], optional): The device for the parameters. + Defaults to the device of `weight_matrix`. + dtype (Optional[torch.dtype], optional): The data type for the parameters. + Defaults to the data type of `weight_matrix`. + """ + if device is None: + device = weight_matrix.device + if dtype is None: + dtype = weight_matrix.dtype + self.device = device + self.dtype = dtype + self.weight_matrix = weight_matrix.to(device=self.device, dtype=self.dtype) + self.w_norm_0 = torch.norm(weight_matrix) + self.vbias = vbias.to(device=self.device, dtype=self.dtype) + self.v_norm_0 = torch.norm(vbias) + self.hbias = hbias.to(device=self.device, dtype=self.dtype) + self.K1 = K1.to(device=self.device, dtype=self.dtype) + self.K2 = K2.to(device=self.device, dtype=self.dtype) + self.K2_norm_0 = torch.norm(K2) + self.name = "IsingRBM" + + def __add__(self, other): + return IsingRBM( + weight_matrix=self.weight_matrix + other.weight_matrix, + vbias=self.vbias + other.vbias, + hbias=self.hbias + other.hbias, + ) + + def __mul__(self, other): + return IsingRBM( + weight_matrix=self.weight_matrix * other, + vbias=self.vbias * other, + hbias=self.hbias * other, + ) + + def clone( + self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None + ): + if device is None: + device = self.device + if dtype is None: + dtype = self.dtype + return IsingRBM( + weight_matrix=self.weight_matrix.clone(), + vbias=self.vbias.clone(), + hbias=self.hbias.clone(), + device=device, + dtype=dtype, + ) + + def compute_energy(self, v: Tensor, h: Tensor) -> Tensor: + return _compute_energy( + v=v, + h=h, + vbias=self.vbias, + hbias=self.hbias, + weight_matrix=self.weight_matrix, + ) + + def compute_energy_hiddens(self, h: Tensor) -> Tensor: + return _compute_energy_hiddens( + h=h, + vbias=self.vbias, + hbias=self.hbias, + weight_matrix=self.weight_matrix, + ) + + def compute_energy_visibles(self, v: Tensor) -> Tensor: + return _compute_energy_visibles( + v=v, + vbias=self.vbias, + hbias=self.hbias, + weight_matrix=self.weight_matrix, + ) + + def compute_gradient(self, data, chains, use_fields, centered=True): + _compute_gradient( + v_data=data["visible"], + mh_data=data["hidden_mag"], + w_data=data["weights"], + v_chain=chains["visible"], + h_chain=chains["hidden"], + w_chain=chains["weights"], + vbias=self.vbias, + hbias=self.hbias, + weight_matrix=self.weight_matrix, + centered=False + ) + + def independent_model(self): + return IsingRBM( + weight_matrix=torch.zeros_like(self.weight_matrix), + vbias=self.vbias, + hbias=torch.zeros_like(self.hbias), + ) + + def init_chains(self, num_samples, weights=None, start_v=None): + visible, hidden, mean_visible, mean_hidden = _init_chains( + num_samples=num_samples, + weight_matrix=self.weight_matrix, + hbias=self.hbias, + start_v=start_v, + ) + if weights is None: + weights = torch.ones( + visible.shape[0], device=visible.device, dtype=visible.dtype + ) + return dict( + visible=visible, + hidden=hidden, + visible_mag=mean_visible, + hidden_mag=mean_hidden, + weights=weights, + ) + + @staticmethod + def init_parameters(num_hiddens, dataset, device, dtype, beta, use_fields, var_init=0.0001): + data = dataset.data + # Convert to torch Tensor if necessary + if isinstance(data, np.ndarray): + data = torch.from_numpy(dataset.data).to(device=device, dtype=dtype) + vbias, hbias, weight_matrix = _init_parameters( + num_hiddens=num_hiddens, + data=data, + device=device, + dtype=dtype, + var_init=var_init, + beta=beta + ) + num_visible = len(data[0,:]) + K1 = torch.randn_like(weight_matrix, device=device, dtype=dtype)/np.sqrt(float(num_hiddens)) + K2 = torch.randn_like(weight_matrix, device=device, dtype=dtype)/np.sqrt(float(num_hiddens)) + + return IsingRBM(weight_matrix=weight_matrix, vbias=vbias, hbias=hbias, K1=K1, K2=K2) + + def named_parameters(self): + return { + "weight_matrix": self.weight_matrix, + "vbias": self.vbias, + "hbias": self.hbias, + "K1": self.K1, + "K2": self.K2, + } + + def num_hiddens(self): + return self.hbias.shape[0] + + def num_visibles(self): + return self.vbias.shape[0] + + def parameters(self) -> List[Tensor]: + return [self.weight_matrix, self.vbias, self.hbias, self.K1, self.K2] + + def ref_log_z(self): + return ( + torch.log1p(torch.exp(self.vbias)).sum() + self.num_hiddens() * np.log(2) + ).item() + + def sample_hiddens(self, chains: dict[str, Tensor], beta=1) -> dict[str, Tensor]: + chains["hidden"], chains["hidden_mag"] = _sample_hiddens( + v=chains["visible"], + weight_matrix=self.weight_matrix, + hbias=self.hbias, + beta=beta, + ) + return chains + + def sample_visibles(self, chains: dict[str, Tensor], beta=1) -> dict[str, Tensor]: + chains["visible"], chains["visible_mag"] = _sample_visibles( + h=chains["hidden"], + weight_matrix=self.weight_matrix, + vbias=self.vbias, + beta=beta, + ) + return chains + + @staticmethod + def set_named_parameters(named_params: dict[str, Tensor]) -> Self: + names = ["vbias", "hbias", "weight_matrix", "K1", "K2"] + for k in names: + if k not in named_params.keys(): + raise ValueError( + f"""Dictionary params missing key '{k}'\n Provided keys : {named_params.keys()}\n Expected keys: {names}""" + ) + params = IsingRBM( + weight_matrix=named_params.pop("weight_matrix"), + vbias=named_params.pop("vbias"), + hbias=named_params.pop("hbias"), + K1=named_params.pop("K1"), + K2=named_params.pop("K2") + ) + if len(named_params.keys()) > 0: + raise ValueError( + f"Too many keys in params dictionary. Remaining keys: {named_params.keys()}" + ) + return params + + def to( + self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None + ): + if device is not None: + self.device = device + if dtype is not None: + self.dtype = dtype + self.weight_matrix = self.weight_matrix.to(device=self.device, dtype=self.dtype) + self.vbias = self.vbias.to(device=self.device, dtype=self.dtype) + self.hbias = self.hbias.to(device=self.device, dtype=self.dtype) + self.K1 = self.K1.to(device=self.device, dtype=self.dtype) + self.K2 = self.K2.to(device=self.device, dtype=self.dtype) + return self + + def compute_loss_PL1(self, data, l, use_fields=True, use_hfield=True): + x = data # [M, N] + + with torch.no_grad(): + h_pre = torch.einsum("ia,mi->ma", self.K1, x) + if use_hfield: + h_pre = h_pre + self.hbias + h = torch.tanh(l * h_pre) + + F = torch.einsum("ja,ma->mj", self.K1, h) + if use_fields: + F = F + self.vbias + + xF = torch.einsum("mi,mi->mi", x, F) + Z_i = 2*torch.cosh(l * F) + e_i = -xF + (1.0 / l) * torch.log(Z_i + 1e-9) + + return e_i.mean() + + def compute_loss_PL2(self, data, l, use_fields, use_hfield): + x = data + with torch.no_grad(): + if use_fields == True and use_hfield ==True: + h = torch.tanh(l*(self.hbias+torch.einsum("ia,mi->ma", self.K2, x))) + else: + h = torch.tanh(l*(torch.einsum("ia,mi->ma", self.K2, x))) + b = torch.einsum("ja,ma->mj",self.K2, h) + c = torch.einsum("ja,mj->ma",self.K2, x) + + if use_fields == True: + b = b+self.vbias + if use_hfield==True: + c = c+self.hbias + + j_term = torch.einsum("ja,ma->mja", self.K2, h) + a_term = torch.einsum("ja,mj->mja", self.K2, x) + + b_i_eff = b.unsqueeze(2)-j_term #[M,N,1] + c_a_eff = c.unsqueeze(1)-a_term #[M,1,N] + + w_ai = torch.einsum("ma,ja,mj->mja", h, self.K2, x) + h_ai = b_i_eff*x.unsqueeze(2)+c_a_eff*h.unsqueeze(1) + Z_ai = 2.*(torch.exp(l*self.K2)*torch.cosh(l*b_i_eff+l*c_a_eff)+torch.exp(-l*self.K2)*torch.cosh(l*b_i_eff-l*c_a_eff)) + + e_ij = -w_ai-h_ai+1./l*torch.log(Z_ai+1e-9) + return e_ij.mean() + + def normalize_w(self): + with torch.no_grad(): + norm = torch.norm(self.weight_matrix.data) + self.weight_matrix.data = self.weight_matrix.data * self.w_norm_0 / (norm+1e-9) + + def normalize_K2(self): + with torch.no_grad(): + norm = torch.norm(self.K2.data) + self.K2.data = self.K2.data * self.K2_norm_0 / (norm+1e-9) + + def normalize_v(self): + with torch.no_grad(): + norm = torch.norm(self.vbias.data) + self.vbias.data = self.vbias.data * self.v_norm_0 / (norm+1e-9) \ No newline at end of file diff --git a/rbms/isingBM/bm.py b/rbms/isingBM/bm.py new file mode 100644 index 0000000..9509a72 --- /dev/null +++ b/rbms/isingBM/bm.py @@ -0,0 +1,366 @@ +import rbms +from rbms.classes import RBM + +from rbms.isingBM.implement import ( + _compute_energy, + _compute_energy_hiddens, + _compute_energy_visibles, + _compute_gradient, + _init_chains, + _init_parameters, + _sample_hiddens, + _sample_visibles, +) + +from typing import List, Optional, Self + +import numpy as np +import torch +from torch import Tensor + +class IsingBM(RBM): + """Parameters of the Bernoulli-Bernoulli RBM""" + + def __init__( + self, + weight_matrix: Tensor, + vbias: Tensor, + hbias: Tensor, + K1: Tensor, + K2: Tensor, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + """Initialize the parameters of the Bernoulli-Bernoulli RBM. + + Args: + weight_matrix (Tensor): The weight matrix of the RBM. + vbias (Tensor): The visible bias of the RBM. + hbias (Tensor): The hidden bias of the RBM. + device (Optional[torch.device], optional): The device for the parameters. + Defaults to the device of `weight_matrix`. + dtype (Optional[torch.dtype], optional): The data type for the parameters. + Defaults to the data type of `weight_matrix`. + """ + if device is None: + device = weight_matrix.device + if dtype is None: + dtype = weight_matrix.dtype + self.device = device + self.dtype = dtype + self.weight_matrix = weight_matrix.to(device=self.device, dtype=self.dtype) + self.w_norm_0 = torch.norm(weight_matrix) + self.vbias = vbias.to(device=self.device, dtype=self.dtype) + self.v_norm_0 = torch.norm(vbias) + self.hbias = hbias.to(device=self.device, dtype=self.dtype) + self.name = "IsingBM" + self.N = len(weight_matrix[0]) + self.K1 = K1.to(device=self.device, dtype=self.dtype) + self.K2 = K2.to(device=self.device, dtype=self.dtype) + self.K2_norm_0 = torch.norm(K2) + self.mask = torch.ones_like(self.weight_matrix, device=self.device) # Shape [N, N] + self.mask.fill_diagonal_(0) # Set diagonal to 0 + + def __add__(self, other): + return IsingBM( + weight_matrix=self.weight_matrix + other.weight_matrix, + vbias=self.vbias + other.vbias, + hbias=self.hbias + other.hbias, + ) + + def __mul__(self, other): + return IsingBM( + weight_matrix=self.weight_matrix * other, + vbias=self.vbias * other, + hbias=self.hbias * other, + ) + + def clone( + self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None + ): + if device is None: + device = self.device + if dtype is None: + dtype = self.dtype + return IsingBM( + weight_matrix=self.weight_matrix.clone(), + vbias=self.vbias.clone(), + hbias=self.hbias.clone(), + device=device, + dtype=dtype, + ) + + def compute_energy(self, v: Tensor, h: Tensor) -> Tensor: + return _compute_energy( + v=v, + vbias=self.vbias, + weight_matrix=self.weight_matrix, + ) + + def compute_energy_hiddens(self, h: Tensor) -> Tensor: + return _compute_energy_hiddens( + h=h, + vbias=self.vbias, + hbias=self.hbias, + weight_matrix=self.weight_matrix, + ) + + def compute_energy_visibles(self, v: Tensor) -> Tensor: + return _compute_energy_visibles( + v=v, + vbias=self.vbias, + hbias=self.hbias, + weight_matrix=self.weight_matrix, + ) + + def compute_gradient(self, data, chains, use_fields, centered=True): + _compute_gradient( + v_data=data["visible"], + v_chain=chains["visible"], + weight_matrix=self.weight_matrix, + vbias=self.vbias, + use_fields=use_fields, + centered=centered, + ) + + def independent_model(self): + return IsingBM( + weight_matrix=torch.zeros_like(self.weight_matrix), + vbias=self.vbias, + hbias=torch.zeros_like(self.hbias), + ) + + def init_chains(self, num_samples, weights=None, start_v=None): + visible = _init_chains( + num_samples=num_samples, + weight_matrix=self.weight_matrix, + hbias=self.hbias, + start_v=start_v, + ) + if weights is None: + weights = torch.ones( + visible.shape[0], device=visible.device, dtype=visible.dtype + ) + return dict( + visible=visible, + hidden=visible, + visible_mag=visible, + hidden_mag=visible, + weights=weights, + ) + + @staticmethod + def init_parameters(num_hiddens, dataset, device, dtype, beta, use_fields, var_init=0.1): + data = dataset.data + # Convert to torch Tensor if necessary + if isinstance(data, np.ndarray): + data = torch.from_numpy(dataset.data).to(device=device, dtype=dtype) + weight_matrix, vbias = _init_parameters( + data=data, + device=device, + dtype=dtype, + var_init=var_init, + beta=beta + ) + num_visible = len(data[0,:]) + if use_fields==False: + vbias = torch.zeros_like(weight_matrix[0], device=device, dtype=dtype) + + hbias = torch.zeros_like(weight_matrix[0], device=device, dtype=dtype) + K1 = torch.randn_like(weight_matrix, device=device, dtype=dtype)/np.sqrt(float(num_visible)) + K2 = torch.randn_like(weight_matrix, device=device, dtype=dtype)/np.sqrt(float(num_visible)) + K1 = (K1.T+K1)/2 + K2 = (K2.T+K2)/2 + K1.fill_diagonal_(0.0) + K2.fill_diagonal_(0.0) + + return IsingBM(weight_matrix=weight_matrix, vbias=vbias, hbias=hbias, K1=K1, K2=K2) + + def named_parameters(self): + return { + "weight_matrix": self.weight_matrix, + "vbias": self.vbias, + "hbias": self.hbias, + "K1": self.K1, + "K2": self.K2, + } + + def num_hiddens(self): + return self.hbias.shape[0] + + def num_visibles(self): + return self.weight_matrix.shape[0] + + def parameters(self) -> List[Tensor]: + return [self.weight_matrix, self.vbias, self.hbias, self.K1, self.K2] + + def ref_log_z(self): + return ( + self.num_visibles() * np.log(2) + ).item() + + def sample_hiddens(self, chains: dict[str, Tensor], beta=1) -> dict[str, Tensor]: + chains["hidden"], chains["hidden_mag"] = _sample_hiddens( + v=chains["visible"], + weight_matrix=self.weight_matrix, + hbias=self.hbias, + beta=beta, + ) + return chains + + def sample_visibles(self, chains: dict[str, Tensor], beta=1) -> dict[str, Tensor]: + chains["visible"], chains["visible_mag"] = _sample_visibles( + v=chains["visible"], + h=chains["hidden"], + weight_matrix=self.weight_matrix, + vbias=self.vbias, + beta=beta, + ) + return chains + + @staticmethod + def set_named_parameters(named_params: dict[str, Tensor]) -> Self: + names = ["vbias", "hbias", "weight_matrix", "K1", "K2"] + for k in names: + if k not in named_params.keys(): + raise ValueError( + f"""Dictionary params missing key '{k}'\n Provided keys : {named_params.keys()}\n Expected keys: {names}""" + ) + params = IsingBM( + weight_matrix=named_params.pop("weight_matrix"), + vbias=named_params.pop("vbias"), + hbias=named_params.pop("hbias"), + K1=named_params.pop("K1"), + K2=named_params.pop("K2") + ) + if len(named_params.keys()) > 0: + raise ValueError( + f"Too many keys in params dictionary. Remaining keys: {named_params.keys()}" + ) + return params + + def to( + self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None + ): + if device is not None: + self.device = device + if dtype is not None: + self.dtype = dtype + self.weight_matrix = self.weight_matrix.to(device=self.device, dtype=self.dtype) + self.vbias = self.vbias.to(device=self.device, dtype=self.dtype) + self.hbias = self.hbias.to(device=self.device, dtype=self.dtype) + return self + + def Z_i_mu_func(self, h, l): + Z_i_mu = 2*torch.cosh(l*h) + return Z_i_mu + + def compute_pseudolikelihood_matrix(self, matrix, data, l): + h = torch.einsum('ij,mj->mi', matrix*self.mask, data) + x_J_x = torch.einsum('mi,mi->mi', data, h) + energy_i_mu = -x_J_x + (1 / l) * torch.log(self.Z_i_mu_func(h,l)+1e-9) + PL = energy_i_mu.mean() + return PL + + ''' + def compute_pseudolikelihood_K1(self, data, l): + h = torch.einsum('ij,mj->mi', self.K_matrix*self.mask, data) + x_J_x = torch.einsum('mi,mi->mi', data, h) + energy_i_mu = -x_J_x + (1 / l) * torch.log(self.Z_i_mu_func(h,l)) + PL = energy_i_mu.mean() + return PL + ''' + def compute_gradient_PL1(self, data): + pass + + ''' + def compute_gradient_PL2(self, data, l): + x = data["visible"] + h = torch.einsum("ij,mj->mi",self.K2*self.mask, x) + print("h", h.mean()) + + #h_a = h.unsqueeze(2)+h.unsqueeze(1) + #h_b = h.unsqueeze(2)-h.unsqueeze(1) + #i_term = torch.einsum("ij,mi->mij", self.K2, x) + #j_term = torch.einsum("ij,mj->mij", self.K2, x) + # + #h_a = h_a - i_term - j_term + #h_b = h_b - i_term + j_term + # + #grad_term_1 = torch.exp(self.K2)*torch.cosh(h_a) + #grad_term_2 = torch.exp(-self.K2)*torch.cosh(h_b) + # + #print("grad1", grad_term_1.mean()) + #print("grad2", grad_term_2.mean()) + + #grad_K2 = ((grad_term_1-grad_term_2)/(grad_term_1+grad_term_2+1e-9)).mean(0) + + + i_term = torch.einsum("ij,mi->mij", self.K2, x) + j_term = torch.einsum("ij,mj->mij", self.K2, x) + h_i_eff = h.unsqueeze(2)-i_term + h_j_eff = h.unsqueeze(1)-j_term + + data_corr = (x.unsqueeze(2) * x.unsqueeze(1)).mean(dim=0) + + grad_K2 = data_corr - torch.tanh(l*self.K2+0.5*torch.log(torch.cosh(h_i_eff+h_j_eff)+1e-9)-torch.log(torch.cosh(h_i_eff-h_j_eff)+1e-9)).mean(0) + + grad_K2 = (grad_K2+grad_K2.T)/2 + + print(grad_K2.mean()) + + grad_K2.fill_diagonal_(0.0) + + + self.K2.grad.set_(grad_K2) + ''' + + def compute_loss_PL1(self, data, l, use_fields, use_hfield=False): + # [M, N] + x=data + J_x = torch.einsum('ij,mj->mi', self.K1 * self.mask.to(self.K1.device), x) # [M, d] + y_i_mu = torch.absolute(J_x) # Taking the norm over the last dimension -> [M,N] + x_J_x = torch.einsum('mi,mi->mi', x, J_x) # [M, N] + Z_i_mu = 2*torch.cosh(l*y_i_mu) + # Compute the energy term for each mu: - dot_product + lam^-1 * log(Z_i_mu) + # Compute the energy term for each mu: - dot_product + lam^-1 * log(Z_i_mu) + e_i = -x_J_x + (1 / l) * torch.log(Z_i_mu+1e-9) # [M,N] + + return e_i.mean() + + def compute_loss_PL2(self, data, l, use_fields, use_hfield=False): + x = data#["visible"] + h = torch.einsum("ik,mk->mi",self.K2*self.mask, x) + diff_term = torch.einsum("ik,mk->mik", self.K2*self.mask, x) + #j_term = torch.einsum("ik,mk->mi", self.K2*self.mask, x) + if use_fields == True: + #fields_x = torch.einsum("i,mi->mi", self.vbias, x) + h = h+self.vbias.unsqueeze(0) + h_i_eff = h.unsqueeze(2)-diff_term #[M,N,1] + h_j_eff = h.unsqueeze(1)-diff_term #[M,1,N] + + J_xx = torch.einsum("mi,ij,mj->mij", x, self.K2*self.mask, x) + h_xx = h_i_eff*x.unsqueeze(2)+h_j_eff*x.unsqueeze(1) + Z_xx = 2.*(torch.exp(l*self.K2*self.mask)*torch.cosh(l*h_i_eff+l*h_j_eff)+torch.exp(-l*self.K2*self.mask)*torch.cosh(l*h_i_eff-l*h_j_eff)) + + e_ij = -J_xx-h_xx+1./l*torch.log(Z_xx+(1-self.mask)+1e-9) + return e_ij.mean() + + def normalize_w(self): + with torch.no_grad(): + norm = torch.norm(self.weight_matrix.data) + self.weight_matrix.data = self.weight_matrix.data * self.w_norm_0 / (norm+1e-9) + + def normalize_K2(self): + with torch.no_grad(): + norm = torch.norm(self.K2.data) + self.K2.data = self.K2.data * self.K2_norm_0 / (norm+1e-9) + + def normalize_v(self): + with torch.no_grad(): + norm = torch.norm(self.vbias.data) + self.vbias.data = self.vbias.data * self.v_norm_0 / (norm+1e-9) + + + + \ No newline at end of file diff --git a/rbms/isingBM/implement.py b/rbms/isingBM/implement.py new file mode 100644 index 0000000..3c15afd --- /dev/null +++ b/rbms/isingBM/implement.py @@ -0,0 +1,151 @@ +from typing import Optional, Tuple + +import torch +from torch import Tensor +from torch.nn.functional import softmax + + +@torch.jit.script +def _sample_hiddens( + v: Tensor, weight_matrix: Tensor, hbias: Tensor, beta: float = 1.0 +) -> Tuple[Tensor, Tensor]: + dtype = torch.float32 + return v, v #nothing happens to visible after dummy hidden passage + + +@torch.jit.script +def _sample_visibles( + v: Tensor, h: Tensor, weight_matrix: Tensor, vbias: Tensor, beta: float = 1.0 +) -> Tuple[Tensor, Tensor]: + x = v.clone() + num_chains, num_units = x.shape + for i in range(num_units): #it has to be sequential + h_i = torch.einsum("j,bj->b",weight_matrix[i],x) #diagonal is already zero + probs = torch.sigmoid(beta*(h_i+vbias[i])) + x[:,i] = torch.bernoulli(probs)*2-1 + return x, x + + +@torch.jit.script +def _compute_energy( + v: Tensor, + vbias: Tensor, + weight_matrix: Tensor, +) -> Tensor: + x = v.clone() + fields = (x * vbias).sum(dim=1) + interaction = 0.5 * torch.einsum("bi,ij,bj->b", x, weight_matrix, x) + + return -fields - interaction + + +@torch.jit.script +def _compute_energy_visibles( + v: Tensor, vbias: Tensor, hbias: Tensor, weight_matrix: Tensor +) -> Tensor: + x = v.clone() + fields = (x * vbias).sum(dim=1) + interaction = 0.5 * torch.einsum("bi,ij,bj->b", x, weight_matrix, x) + + return -fields - interaction + + +@torch.jit.script +def _compute_energy_hiddens( + h: Tensor, vbias: Tensor, hbias: Tensor, weight_matrix: Tensor +) -> Tensor: + dtype = torch.float32 + return torch.tensor([0], device=weight_matrix.device, dtype=dtype) + + +@torch.jit.script +def _compute_gradient( + v_data: Tensor, + v_chain: Tensor, + weight_matrix: Tensor, + vbias: Tensor, + use_fields: bool, + centered: bool = False, +) -> None: + dtype = torch.float32 + + x_data = v_data + x_model = v_chain + + if centered: + pass + else: + pass + + # Empirical averages + data_mean = x_data.mean(dim=0) # shape (num_units,) + data_corr = (x_data.unsqueeze(2) * x_data.unsqueeze(1)).mean(dim=0) # shape (num_units, num_units) + # Model (chain) averages + model_mean = x_model.mean(dim=0) + model_corr = (x_model.unsqueeze(2) * x_model.unsqueeze(1)).mean(dim=0) + # Bias gradient + grad_weight_matrix = data_corr - model_corr + grad_weight_matrix.fill_diagonal_(0.0) + if use_fields==True: + grad_vbias = data_mean - model_mean + else: + grad_vbias = torch.tensor([0], device=weight_matrix.device, dtype=dtype) + #grad_hbias = torch.tensor([0], device=weight_matrix.device, dtype=weight_matrix.type) + + # Attach to the parameters + weight_matrix.grad.set_(grad_weight_matrix) + vbias.grad.set_(grad_vbias) + #hbias.grad.set_(grad_hbias) + +@torch.jit.script +def _init_chains( + num_samples: int, + weight_matrix: Tensor, + hbias: Tensor, + start_v: Optional[Tensor] = None, +): + num_visibles = weight_matrix.shape[0] + device = weight_matrix.device + dtype = torch.float32 + # Handle negative number of samples + if num_samples <= 0: + if start_v is not None: + num_samples = start_v.shape[0] + else: + raise ValueError(f"Got negative num_samples arg: {num_samples}") + + if start_v is None: + # Dummy mean visible + v = torch.randint(low=0, high=2, size=(num_samples, num_visibles), device=device, dtype=torch.float32)*2-1 + else: + v = start_v.to(device=device, dtype=dtype) + return v + + +def _init_parameters( + data: Tensor, + device: torch.device, + dtype: torch.dtype, + var_init: float = 1e-2, + beta: float=1. +) -> Tuple[Tensor, Tensor, Tensor]: + _, num_visibles = data.shape + eps = 1e-4 + weight_matrix = ( + torch.randn(size=(num_visibles, num_visibles), device=device, dtype=dtype) + * var_init + ) + weight_matrix.fill_diagonal_(0.0) + weight_matrix = 0.5*(weight_matrix+weight_matrix.T) + ''' + frequencies = data.mean(0) + frequencies = torch.clamp(frequencies, min=eps, max=(1.0 - eps)) + vbias = (torch.log(frequencies) - torch.log(1.0 - frequencies)).to( + device=device, dtype=dtype + ) + hbias = torch.zeros(num_hiddens, device=device, dtype=dtype) + ''' + spin_means = data.mean(0) + spin_means = torch.clamp(spin_means,-0.95,0.95) + vbias = 1/beta*torch.atanh(spin_means) + return weight_matrix, vbias diff --git a/rbms/map_model.py b/rbms/map_model.py index b556d85..4ef3db7 100644 --- a/rbms/map_model.py +++ b/rbms/map_model.py @@ -1,5 +1,6 @@ from rbms.bernoulli_bernoulli.classes import BBRBM +from rbms.ising.ising_rbm import IsingRBM from rbms.classes import RBM from rbms.potts_bernoulli.classes import PBRBM -map_model: dict[str, RBM] = {"BBRBM": BBRBM, "PBRBM": PBRBM} +map_model: dict[str, RBM] = {"BBRBM": BBRBM, "PBRBM": PBRBM, "IsingRBM":IsingRBM} diff --git a/rbms/metrics/aats.py b/rbms/metrics/aats.py index 5e5f7d1..5573688 100644 --- a/rbms/metrics/aats.py +++ b/rbms/metrics/aats.py @@ -50,8 +50,10 @@ def compute_aats( closest = distance_matrix.argmin(axis=1) n = int(closest.shape[0] / 2) - # for a true sample, proba that the closest is in the set of true samples - aa_truth = (closest[:n] >= n).sum() / n + # for a true sample, proba t + + + .sum() / n # for a fake sample, proba that the closest is in the set of fake samples aa_syn = (closest[n:] >= n).sum() / n diff --git a/rbms/pseudolikelihood/__init__.py b/rbms/pseudolikelihood/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/rbms/pseudolikelihood/train_PL.py b/rbms/pseudolikelihood/train_PL.py new file mode 100644 index 0000000..35448eb --- /dev/null +++ b/rbms/pseudolikelihood/train_PL.py @@ -0,0 +1,338 @@ +import time +from typing import Tuple + +import numpy as np +import torch +from torch import Tensor +from torch.optim import SGD, Adam, AdamW +from torch.utils.data import Subset + +from rbms.classes import RBM +from rbms.bernoulli_bernoulli_BM.bm import BBBM +from rbms.ising.ising_rbm import IsingRBM +from rbms.dataset.dataset_class import RBMDataset +from rbms.io import save_model +from rbms.map_model import map_model +from rbms.potts_bernoulli.classes import PBRBM +from rbms.potts_bernoulli.utils import ensure_zero_sum_gauge +from rbms.sampling.gibbs import sample_state +from rbms.training.utils import create_machine, setup_training +from rbms.utils import check_file_existence, log_to_csv + + +def step_PL2( + batch: Tuple[Tensor, Tensor], + params: RBM, + l: float, +)-> dict: + v_data, w_data = batch + curr_batch = params.init_chains( + num_samples=v_data.shape[0], + weights=w_data, + start_v=v_data, + ) + params.compute_gradient_PL2(data=curr_batch, l=l) + logs = {} + return logs + + +def train_PL1( + dataset: RBMDataset, + test_dataset: RBMDataset, + model_type: str, + args: dict, + dtype: torch.dtype, + checkpoints: np.ndarray, + map_model: dict[str, RBM] = map_model, +) -> None: + """Train the Bernoulli-Bernoulli RBM model. + + Args: + dataset (RBMDataset): The training dataset. + test_dataset (RBMDataset): The test dataset (not used). + model_type (str): Type of RBM used (BBRBM or PBRBM) + args (dict): A dictionary of training arguments. + dtype (torch.dtype): The data type for the parameters. + checkpoints (np.ndarray): An array of checkpoints for saving model states. + """ + filename = args["filename"] + if not (args["overwrite"]): + check_file_existence(filename) + + if args["gibbs_steps_init"]: #MODIFICATION DONE DUE TO BM SLOW DYNAMICS + gibbs_steps_init = args["gibbs_steps_init"] + else: + gibbs_steps_init = 1000 + + num_visibles = dataset.get_num_visibles() + + # Create a first archive with the initialized model + if not (args["restore"]): + params = map_model[model_type].init_parameters( + num_hiddens=args["num_hiddens"], + dataset=dataset, + device=args["device"], + dtype=dtype, + beta=args["beta"], + use_fields=args["use_fields"] + ) + create_machine( + filename=filename, + params=params, + num_visibles=num_visibles, + num_hiddens=args["num_hiddens"], + num_chains=args["num_chains"], + batch_size=args["batch_size"], + gibbs_steps=args["gibbs_steps"], + learning_rate=args["learning_rate"], + log=args["log"], + flags=["checkpoint"], + gibbs_steps_init=gibbs_steps_init + ) + + ( + params, + parallel_chains, + args, + learning_rate, + num_updates, + start, + elapsed_time, + log_filename, + pbar, + ) = setup_training(args, map_model=map_model) + + use_fields = args["use_fields"] + use_hfield = args["use_hfield"] + normalize_fields = args["normalize_fields"] + normalize_K1 = args["normalize_K1"] + if use_fields==False: + params.vbias = torch.zeros_like(params.vbias) + + params.hbias = torch.zeros_like(params.hbias) + + params.K1 = torch.nn.Parameter(params.K1) ############### + params.vbias = torch.nn.Parameter(params.vbias) + params.hbias = torch.nn.Parameter(params.hbias) + + + # for p in params.parameters(): + # p.grad = torch.zeros_like(p) + + optimizer = AdamW(params.parameters(), lr=learning_rate)#, weight_decay=0.001)#, maximize=True) + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args["lr_arr"], gamma=args["lr_factor"]) + + for k, v in args.items(): + print(f"{k} : {v}") + + logs = {} + + # Continue the training + #with torch.no_grad(): + + max_norm=50 + + for idx in range(num_updates + 1, args["num_updates"] + 1): + rand_idx = torch.randperm(len(dataset))[: args["batch_size"]] + batch = (dataset.data[rand_idx], dataset.weights[rand_idx]) + #optimizer.zero_grad(set_to_none=False) + #logs = step_PL2(batch, params, l=args["lambda"]) + + + loss = params.compute_loss_PL1(batch[0], l=args["lambda"], use_fields=use_fields, use_hfield=use_hfield) + + + if (args["verbose"]==True) and (idx%10 == 1): + print("Update: ", idx, "Loss:", loss.item(), "lr:", lr_scheduler.get_last_lr(), "J_norm:", torch.norm(params.K1).item(), "v_norm:", torch.norm(params.vbias).item(), "h_norm:", torch.norm(params.hbias).item()) + + if (torch.isnan(loss).any() == False) and (torch.isinf(loss).any() == False): + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(params.parameters(), max_norm) + optimizer.step() + ''' + else: + for param_group in optimizer.param_groups: + param_group['lr'] *= 0.1 + lr_scheduler.base_lrs = [group['lr'] for group in optimizer.param_groups] + ''' + lr_scheduler.step() + + if normalize_K1==True: + params.normalize_K1() + if normalize_fields==True: + params.normalize_v() + + # if isinstance(params, BM): + # with torch.no_grad(): + # params.K2.data.fill_diagonal_(0.) + if isinstance(params, PBRBM): + ensure_zero_sum_gauge(params) + + + # Save current model if necessary + if idx in checkpoints: + curr_time = time.time() - start + save_model( + filename=args["filename"], + params=params, + chains=parallel_chains, + num_updates=idx, + time=curr_time + elapsed_time, + flags=["checkpoint"], + ) + if args["log"]: + log_to_csv(logs, log_file=log_filename) + # Update progress bar + pbar.update(1) + + +def train_PL2( + dataset: RBMDataset, + test_dataset: RBMDataset, + model_type: str, + args: dict, + dtype: torch.dtype, + checkpoints: np.ndarray, + map_model: dict[str, RBM] = map_model, +) -> None: + """Train the Bernoulli-Bernoulli RBM model. + + Args: + dataset (RBMDataset): The training dataset. + test_dataset (RBMDataset): The test dataset (not used). + model_type (str): Type of RBM used (BBRBM or PBRBM) + args (dict): A dictionary of training arguments. + dtype (torch.dtype): The data type for the parameters. + checkpoints (np.ndarray): An array of checkpoints for saving model states. + """ + filename = args["filename"] + if not (args["overwrite"]): + check_file_existence(filename) + + if args["gibbs_steps_init"]: #MODIFICATION DONE DUE TO BM SLOW DYNAMICS + gibbs_steps_init = args["gibbs_steps_init"] + else: + gibbs_steps_init = 1000 + + num_visibles = dataset.get_num_visibles() + + # Create a first archive with the initialized model + if not (args["restore"]): + params = map_model[model_type].init_parameters( + num_hiddens=args["num_hiddens"], + dataset=dataset, + device=args["device"], + dtype=dtype, + beta=args["beta"], + use_fields=args["use_fields"] + ) + create_machine( + filename=filename, + params=params, + num_visibles=num_visibles, + num_hiddens=args["num_hiddens"], + num_chains=args["num_chains"], + batch_size=args["batch_size"], + gibbs_steps=args["gibbs_steps"], + learning_rate=args["learning_rate"], + log=args["log"], + flags=["checkpoint"], + gibbs_steps_init=gibbs_steps_init + ) + + ( + params, + parallel_chains, + args, + learning_rate, + num_updates, + start, + elapsed_time, + log_filename, + pbar, + ) = setup_training(args, map_model=map_model) + + use_fields = args["use_fields"] + use_hfield = args["use_hfield"] + normalize_fields = args["normalize_fields"] + normalize_K2 = args["normalize_K2"] + if use_fields==False: + params.vbias = torch.zeros_like(params.vbias) + + params.hbias = torch.zeros_like(params.hbias) + + params.K2 = torch.nn.Parameter(params.K2) ############### + params.vbias = torch.nn.Parameter(params.vbias) + params.hbias = torch.nn.Parameter(params.hbias) + + + # for p in params.parameters(): + # p.grad = torch.zeros_like(p) + + optimizer = Adam(params.parameters(), lr=learning_rate)#, weight_decay=0.001)#, maximize=True) + lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args["lr_arr"], gamma=args["lr_factor"]) + + for k, v in args.items(): + print(f"{k} : {v}") + + logs = {} + + # Continue the training + #with torch.no_grad(): + + max_norm=50 + + for idx in range(num_updates + 1, args["num_updates"] + 1): + rand_idx = torch.randperm(len(dataset))[: args["batch_size"]] + batch = (dataset.data[rand_idx], dataset.weights[rand_idx]) + #optimizer.zero_grad(set_to_none=False) + #logs = step_PL2(batch, params, l=args["lambda"]) + + + loss = params.compute_loss_PL2(batch[0], l=args["lambda"], use_fields=use_fields, use_hfield=use_hfield) + + + if (args["verbose"]==True) and (idx%10 == 1): + print("Update: ", idx, "Loss:", loss.item(), "lr:", lr_scheduler.get_last_lr(), "J_norm:", torch.norm(params.K2).item(), "v_norm:", torch.norm(params.vbias).item(), "h_norm:", torch.norm(params.hbias).item()) + + if (torch.isnan(loss).any() == False) and (torch.isinf(loss).any() == False): + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(params.parameters(), max_norm) + optimizer.step() + ''' + else: + for param_group in optimizer.param_groups: + param_group['lr'] *= 0.1 + lr_scheduler.base_lrs = [group['lr'] for group in optimizer.param_groups] + ''' + lr_scheduler.step() + + if normalize_K2==True: + params.normalize_K2() + if normalize_fields==True: + params.normalize_v() + + # if isinstance(params, BM): + # with torch.no_grad(): + # params.K2.data.fill_diagonal_(0.) + if isinstance(params, PBRBM): + ensure_zero_sum_gauge(params) + + + # Save current model if necessary + if idx in checkpoints: + curr_time = time.time() - start + save_model( + filename=args["filename"], + params=params, + chains=parallel_chains, + num_updates=idx, + time=curr_time + elapsed_time, + flags=["checkpoint"], + ) + if args["log"]: + log_to_csv(logs, log_file=log_filename) + # Update progress bar + pbar.update(1) diff --git a/rbms/training/pcd.py b/rbms/training/pcd.py index bac2ba7..a4011da 100644 --- a/rbms/training/pcd.py +++ b/rbms/training/pcd.py @@ -5,6 +5,8 @@ import torch from torch import Tensor from torch.optim import SGD +from torch.utils.data import Subset +import copy from rbms.classes import RBM from rbms.dataset.dataset_class import RBMDataset @@ -23,6 +25,7 @@ def fit_batch_pcd( params: RBM, gibbs_steps: int, beta: float, + use_fields: bool=False, centered: bool = True, ) -> Tuple[dict[str, Tensor], dict]: """Sample the RBM and compute the gradient. @@ -51,7 +54,7 @@ def fit_batch_pcd( params=params, beta=beta, ) - params.compute_gradient(data=curr_batch, chains=parallel_chains, centered=centered) + params.compute_gradient(data=curr_batch, chains=parallel_chains, centered=centered, use_fields=use_fields) logs = {} return parallel_chains, logs @@ -78,6 +81,11 @@ def train( filename = args["filename"] if not (args["overwrite"]): check_file_existence(filename) + + if args["gibbs_steps_init"]: #MODIFICATION DONE DUE TO BM SLOW DYNAMICS + gibbs_steps_init = args["gibbs_steps_init"] + else: + gibbs_steps_init = 1000 num_visibles = dataset.get_num_visibles() @@ -88,6 +96,8 @@ def train( dataset=dataset, device=args["device"], dtype=dtype, + beta=args["beta"], + use_fields=args["use_fields"] ) create_machine( filename=filename, @@ -100,8 +110,8 @@ def train( learning_rate=args["learning_rate"], log=args["log"], flags=["checkpoint"], + gibbs_steps_init=gibbs_steps_init ) - ( params, parallel_chains, @@ -113,7 +123,8 @@ def train( log_filename, pbar, ) = setup_training(args, map_model=map_model) - + + optimizer = SGD(params.parameters(), lr=learning_rate, maximize=True) for k, v in args.items(): @@ -125,6 +136,10 @@ def train( rand_idx = torch.randperm(len(dataset))[: args["batch_size"]] batch = (dataset.data[rand_idx], dataset.weights[rand_idx]) + if (args["verbose"]==True) and (idx%10 == 1): + print("Update: ", idx, "lr:", args["learning_rate"], "J_norm:", torch.norm(params.weight_matrix).item(), "v_norm:", torch.norm(params.vbias).item(), "h_norm:", torch.norm(params.hbias).item()) + + optimizer.zero_grad(set_to_none=False) parallel_chains, logs = fit_batch_pcd( batch=batch, @@ -132,6 +147,7 @@ def train( params=params, gibbs_steps=args["gibbs_steps"], beta=args["beta"], + use_fields=args["use_fields"] ) optimizer.step() if isinstance(params, PBRBM): diff --git a/rbms/training/utils.py b/rbms/training/utils.py index 2681526..f492f64 100644 --- a/rbms/training/utils.py +++ b/rbms/training/utils.py @@ -7,6 +7,7 @@ import torch from torch import Tensor from tqdm import tqdm +import copy from rbms.classes import RBM from rbms.const import LOG_FILE_HEADER @@ -59,6 +60,15 @@ def setup_training( ascii="-#", ) pbar.set_description("Training RBM") + + + with torch.no_grad(): + if args["start_as"] == "K2": + params = copy.deepcopy(args["init_model"]) + params.weight_matrix.data = params.K2.data.clone() + elif args["start_as"] == "K1": + params = copy.deepcopy(args["init_model"]) + params.weight_matrix.data = params.K1.data.clone() # Initialize gradients for the parameters for p in params.parameters(): @@ -91,6 +101,7 @@ def create_machine( learning_rate: float, log: bool, flags: List[str], + gibbs_steps_init = 1000 ) -> None: """Create a RBM and save it to a new file. @@ -108,7 +119,7 @@ def create_machine( # Permanent chains parallel_chains = params.init_chains(num_samples=num_chains) parallel_chains = sample_state( - gibbs_steps=1000, chains=parallel_chains, params=params + gibbs_steps=gibbs_steps_init, chains=parallel_chains, params=params ) with h5py.File(filename, "w") as file_model: hyperparameters = file_model.create_group("hyperparameters") diff --git a/rbms/utils.py b/rbms/utils.py index 0bdd6ef..9373376 100644 --- a/rbms/utils.py +++ b/rbms/utils.py @@ -42,6 +42,66 @@ def get_eigenvalues_history(filename: str): return gradient_updates, eigenvalues +def get_eigenvalues_history_PL1(filename: str): + """ + Extracts the history of eigenvalues of the RBM's weight matrix. + + Args: + filename (str): Path to the HDF5 training archive. + + Returns: + tuple: A tuple containing two elements: + - gradient_updates (np.ndarray): Array of gradient update steps. + - eigenvalues (np.ndarray): Eigenvalues along training. + """ + with h5py.File(filename, "r") as f: + gradient_updates = [] + eigenvalues = [] + for key in f.keys(): + if "update" in key: + K1 = f[key]["params"]["K1"][()] + K1 = K1.reshape(-1, K1.shape[-1]) + eig = np.linalg.svd(K1, compute_uv=False) + eigenvalues.append(eig.reshape(*eig.shape, 1)) + gradient_updates.append(int(key.split("_")[1])) + + # Sort the results + sorting = np.argsort(gradient_updates) + gradient_updates = np.array(gradient_updates)[sorting] + eigenvalues = np.array(np.hstack(eigenvalues).T)[sorting] + + return gradient_updates, eigenvalues + +def get_eigenvalues_history_PL2(filename: str): + """ + Extracts the history of eigenvalues of the RBM's weight matrix. + + Args: + filename (str): Path to the HDF5 training archive. + + Returns: + tuple: A tuple containing two elements: + - gradient_updates (np.ndarray): Array of gradient update steps. + - eigenvalues (np.ndarray): Eigenvalues along training. + """ + with h5py.File(filename, "r") as f: + gradient_updates = [] + eigenvalues = [] + for key in f.keys(): + if "update" in key: + K2 = f[key]["params"]["K2"][()] + K2 = K2.reshape(-1, K2.shape[-1]) + eig = np.linalg.svd(K2, compute_uv=False) + eigenvalues.append(eig.reshape(*eig.shape, 1)) + gradient_updates.append(int(key.split("_")[1])) + + # Sort the results + sorting = np.argsort(gradient_updates) + gradient_updates = np.array(gradient_updates)[sorting] + eigenvalues = np.array(np.hstack(eigenvalues).T)[sorting] + + return gradient_updates, eigenvalues + def get_saved_updates(filename: str) -> np.ndarray: """