diff --git a/rbms/ising_gaussian/__init__.py b/rbms/ising_gaussian/__init__.py new file mode 100644 index 0000000..b20d9a5 --- /dev/null +++ b/rbms/ising_gaussian/__init__.py @@ -0,0 +1,3 @@ +# ruff: noqa +from rbms.ising_gaussian.classes import IGRBM +from rbms.ising_gaussian.functional import * \ No newline at end of file diff --git a/rbms/ising_gaussian/classes.py b/rbms/ising_gaussian/classes.py new file mode 100644 index 0000000..3ba5042 --- /dev/null +++ b/rbms/ising_gaussian/classes.py @@ -0,0 +1,239 @@ +from typing import List, Optional + +import numpy as np +import torch +from torch import Tensor + +from rbms.ising_gaussian.implement import ( + _compute_energy, + _compute_energy_hiddens, + _compute_energy_visibles, + _compute_gradient, + _init_chains, + _init_parameters, + _sample_hiddens, + _sample_visibles, +) +from rbms.classes import RBM + + +class IGRBM(RBM): + """Ising-Gaussian RBM with fixed hidden variance = 1/Nv, \pm 1 visibles, without any bias""" + + def __init__( + self, + weight_matrix: Tensor, + vbias: Tensor, + hbias: Tensor, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + if device is None: + device = weight_matrix.device + if dtype is None: + dtype = weight_matrix.dtype + self.device, self.dtype = device, dtype + + self.weight_matrix = weight_matrix.to(device=self.device, dtype=self.dtype) + self.vbias = vbias.to(device=self.device, dtype=self.dtype) + self.hbias = hbias.to(device=self.device, dtype=self.dtype) + + log_two_pi = torch.log(torch.tensor(2.0 * torch.pi, dtype=dtype, device=device)) + const = ( + 0.5 + * float(self.weight_matrix[1]) + * ( + torch.log( + torch.tensor(float(self.weight_matrix[0]), dtype=dtype, device=device) + ) + - log_two_pi + ) + ) + self.const = const + self.name = "IGRBM" + + def __add__(self, other): + out = IGRBM( + weight_matrix=self.weight_matrix + other.weight_matrix, + vbias=self.vbias + other.vbias, + hbias=self.hbias + other.hbias, + device=self.device, + dtype=self.dtype, + ) + return out + + def __mul__(self, other): + out = IGRBM( + weight_matrix=self.weight_matrix * other, + vbias=self.vbias * other, + hbias=self.hbias * other, + device=self.device, + dtype=self.dtype, + ) + return out + + def clone( + self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None + ): + if device is None: + device = self.device + if dtype is None: + dtype = self.dtype + return IGRBM( + weight_matrix=self.weight_matrix.clone(), + vbias=self.vbias.clone(), + hbias=self.hbias.clone(), + device=device, + dtype=dtype, + ) + + def compute_energy(self, v: Tensor, h: Tensor) -> Tensor: + return _compute_energy( + v=v, h=h, vbias=self.vbias, hbias=self.hbias, weight_matrix=self.weight_matrix + ) + + def compute_energy_hiddens(self, h: Tensor) -> Tensor: + return _compute_energy_hiddens( + h=h, vbias=self.vbias, hbias=self.hbias, weight_matrix=self.weight_matrix + ) + + def compute_energy_visibles(self, v: Tensor) -> Tensor: + return _compute_energy_visibles( + v=v, vbias=self.vbias, hbias=self.hbias, weight_matrix=self.weight_matrix + ) + + def compute_gradient(self, data, chains, centered=True, lambda_l1=0.0, lambda_l2=0.0): + _compute_gradient( + v_data=data["visible"], + h_data=data["hidden_mag"], + w_data=data["weights"], + v_chain=chains["visible"], + h_chain=chains["hidden_mag"], + w_chain=chains["weights"], + vbias=self.vbias, + hbias=self.hbias, + weight_matrix=self.weight_matrix, + centered=centered, + lambda_l1=lambda_l1, + lambda_l2=lambda_l2, + ) + + def independent_model(self): + return IGRBM( + weight_matrix=torch.zeros_like(self.weight_matrix), + vbias=self.vbias, + hbias=torch.zeros_like(self.hbias), + device=self.device, + dtype=self.dtype, + ) + + def init_chains(self, num_samples, weights=None, start_v=None): + visible, hidden, mean_visible, mean_hidden = _init_chains( + num_samples=num_samples, + weight_matrix=self.weight_matrix, + hbias=self.hbias, + start_v=start_v, + ) + if weights is None: + weights = torch.ones( + visible.shape[0], device=visible.device, dtype=visible.dtype + ) + return dict( + visible=visible, + hidden=hidden, + visible_mag=mean_visible, + hidden_mag=mean_hidden, + weights=weights, + ) + + @staticmethod + def init_parameters(num_hiddens, dataset, device, dtype, var_init=0.0001): + data = dataset.data + if isinstance(data, np.ndarray): + data = torch.from_numpy(dataset.data).to(device=device, dtype=dtype) + vbias, hbias, weight_matrix = _init_parameters( + num_hiddens=num_hiddens, + data=data, + device=device, + dtype=dtype, + var_init=var_init, + ) + return IGRBM( + weight_matrix=weight_matrix, + vbias=vbias, + hbias=hbias, + device=device, + dtype=dtype, + ) + + def named_parameters(self): + return { + "weight_matrix": self.weight_matrix, + "vbias": self.vbias, + "hbias": self.hbias, + } + + def num_hiddens(self): + return self.hbias.shape[0] + + def num_visibles(self): + return self.vbias.shape[0] + + def parameters(self) -> List[Tensor]: + return [self.weight_matrix, self.vbias, self.hbias] + + def ref_log_z(self): + K = self.num_hiddens() + logZ_v = torch.log1p(torch.exp(self.vbias)).sum() + quad = 0.5 * torch.dot(self.hbias, self.hbias) / float(self.num_visibles()) + log_norm = 0.5 * K * np.log(2.0 * np.pi) - 0.5 * K * np.log(float(self.num_visibles())) + return (logZ_v + quad + log_norm).item() + + def sample_hiddens(self, chains: dict[str, Tensor], beta=1) -> dict[str, Tensor]: + chains["hidden"], chains["hidden_mag"] = _sample_hiddens( + v=chains["visible"], + weight_matrix=self.weight_matrix, + hbias=self.hbias, + beta=beta, + ) + return chains + + def sample_visibles(self, chains: dict[str, Tensor], beta=1) -> dict[str, Tensor]: + chains["visible"], chains["visible_mag"] = _sample_visibles( + h=chains["hidden"], + weight_matrix=self.weight_matrix, + vbias=self.vbias, + beta=beta, + ) + return chains + + @staticmethod + def set_named_parameters(named_params: dict[str, Tensor]) -> "IGRBM": + names = ["vbias", "hbias", "weight_matrix"] + for k in names: + if k not in named_params: + raise ValueError( + f"""Dictionary params missing key '{k}'\n Provided keys : {named_params.keys()}\n Expected keys: {names}""" + ) + params = IGRBM( + weight_matrix=named_params.pop("weight_matrix"), + vbias=named_params.pop("vbias"), + hbias=named_params.pop("hbias"), + ) + if len(named_params) > 0: + raise ValueError( + f"Too many keys in params dictionary. Remaining keys: {named_params.keys()}" + ) + return params + + def to( + self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None + ) -> "IGRBM": + if device is not None: + self.device = device + if dtype is not None: + self.dtype = dtype + self.weight_matrix = self.weight_matrix.to(device=self.device, dtype=self.dtype) + self.vbias = self.vbias.to(device=self.device, dtype=self.dtype) + self.hbias = self.hbias.to(device=self.device, dtype=self.dtype) + return self diff --git a/rbms/ising_gaussian/functional.py b/rbms/ising_gaussian/functional.py new file mode 100644 index 0000000..5d837ff --- /dev/null +++ b/rbms/ising_gaussian/functional.py @@ -0,0 +1,135 @@ +from typing import Optional + +import numpy as np +import torch +from torch import Tensor + +from rbms.ising_gaussian.classes import IGRBM +from rbms.ising_gaussian.implement import ( + _compute_energy, + _compute_energy_hiddens, + _compute_energy_visibles, + _compute_gradient, + _init_chains, + _init_parameters, + _sample_hiddens, + _sample_visibles, +) +from rbms.dataset.dataset_class import RBMDataset + + +def sample_hiddens( + chains: dict[str, Tensor], params: IGRBM, beta: float = 1.0 +) -> dict[str, Tensor]: + chains["hidden"], chains["hidden_mag"] = _sample_hiddens( + v=chains["visible"], + weight_matrix=params.weight_matrix, + hbias=params.hbias, + beta=beta, + ) + return chains + + +def sample_visibles( + chains: dict[str, Tensor], params: IGRBM, beta: float = 1.0 +) -> dict[str, Tensor]: + chains["visible"], chains["visible_mag"] = _sample_visibles( + h=chains["hidden"], + weight_matrix=params.weight_matrix, + vbias=params.vbias, + beta=beta, + ) + return chains + + +def compute_energy(v: Tensor, h: Tensor, params: IGRBM) -> Tensor: + return _compute_energy( + v=v, + h=h, + vbias=params.vbias, + hbias=params.hbias, + weight_matrix=params.weight_matrix, + ) + + +def compute_energy_visibles(v: Tensor, params: IGRBM) -> Tensor: + return _compute_energy_visibles( + v=v, + vbias=params.vbias, + hbias=params.hbias, + weight_matrix=params.weight_matrix, + ) + + +def compute_energy_hiddens(h: Tensor, params: IGRBM) -> Tensor: + return _compute_energy_hiddens( + h=h, + vbias=params.vbias, + hbias=params.hbias, + weight_matrix=params.weight_matrix, + ) + + +def compute_gradient( + data: dict[str, Tensor], + chains: dict[str, Tensor], + params: IGRBM, + centered: bool, + lambda_l1: float = 0.0, + lambda_l2: float = 0.0, +) -> None: + _compute_gradient( + v_data=data["visible"], + mh_data=data["hidden_mag"], + w_data=data["weights"], + v_chain=chains["visible"], + h_chain=chains["hidden"], + w_chain=chains["weights"], + vbias=params.vbias, + hbias=params.hbias, + weight_matrix=params.weight_matrix, + centered=centered, + lambda_l1=lambda_l1, + lambda_l2=lambda_l2, + ) + + +def init_chains( + num_samples: int, + params: IGRBM, + weights: Optional[Tensor] = None, + start_v: Optional[Tensor] = None, +) -> dict[str, Tensor]: + visible, hidden, mean_visible, mean_hidden = _init_chains( + num_samples=num_samples, + weight_matrix=params.weight_matrix, + hbias=params.hbias, + start_v=start_v, + ) + if weights is None: + weights = torch.ones(visible.shape[0], device=visible.device, dtype=visible.dtype) + return dict( + visible=visible, + hidden=hidden, + visible_mag=mean_visible, + hidden_mag=mean_hidden, + weights=weights, + ) + + +def init_parameters( + num_hiddens: int, + dataset: RBMDataset, + device: torch.device, + dtype: torch.dtype, + var_init: float = 1e-4, +) -> IGRBM: + data = dataset.data + if isinstance(data, np.ndarray): + data = torch.from_numpy(dataset.data).to(device=device, dtype=dtype) + vbias, hbias, weight_matrix = _init_parameters( + num_hiddens=num_hiddens, data=data, device=device, dtype=dtype, var_init=var_init + ) + return IGRBM( + weight_matrix=weight_matrix, vbias=vbias, hbias=hbias, device=device, dtype=dtype + ) diff --git a/rbms/ising_gaussian/implement.py b/rbms/ising_gaussian/implement.py new file mode 100644 index 0000000..0d440b1 --- /dev/null +++ b/rbms/ising_gaussian/implement.py @@ -0,0 +1,179 @@ +from typing import Optional, Tuple + +import torch +from torch import Tensor +from torch.nn.functional import softmax + + +@torch.jit.script +def _sample_hiddens( + v: Tensor, weight_matrix: Tensor, hbias: Tensor, beta: float = 1.0 +) -> Tuple[Tensor, Tensor]: + mh = hbias + (v @ weight_matrix) + h = torch.randn_like(mh) + mh + return h, mh + + +@torch.jit.script +def _sample_visibles( + h: Tensor, weight_matrix: Tensor, vbias: Tensor, beta: float = 1.0 +) -> Tuple[Tensor, Tensor]: + effective_field = beta * (vbias + (h @ weight_matrix.T)) + mv = torch.tanh(effective_field) + v = 2 * torch.bernoulli(torch.sigmoid(2 * effective_field)) - 1 + return v, mv + + +@torch.jit.script +def _compute_energy( + v: Tensor, + h: Tensor, + vbias: Tensor, + hbias: Tensor, + weight_matrix: Tensor, +) -> Tensor: + fields = torch.tensordot(vbias, v, dims=[[0], [1]]) + torch.tensordot( + hbias, h, dims=[[0], [1]] + ) + interaction = torch.multiply( + v, torch.tensordot(h, weight_matrix, dims=[[1], [1]]) + ).sum(1) + quad = 0.5 * float(weight_matrix.shape[0]) * (h * h).sum(1) + return -fields - interaction + quad + + +@torch.jit.script +def _compute_energy_visibles( + v: Tensor, vbias: Tensor, hbias: Tensor, weight_matrix: Tensor, const: Tensor +) -> Tensor: + field = v @ vbias + t = hbias + (v @ weight_matrix) + quad_term = 0.5 * (t * t).sum(1) / float(weight_matrix.shape[0]) + return -field - quad_term + const + + +@torch.jit.script +def _compute_energy_hiddens( + h: Tensor, vbias: Tensor, hbias: Tensor, weight_matrix: Tensor +) -> Tensor: + field = h @ hbias + exponent = vbias + (h @ weight_matrix.T) + log_term = torch.where(exponent < 10, torch.log1p(torch.exp(exponent)), exponent) + quad = 0.5 * float(weight_matrix.shape[0]) * (h * h).sum(1) + return -field - log_term.sum(1) + quad + + +@torch.jit.script +def _compute_gradient( + v_data: Tensor, + h_data: Tensor, + w_data: Tensor, + v_chain: Tensor, + h_chain: Tensor, + w_chain: Tensor, + vbias: Tensor, + hbias: Tensor, + weight_matrix: Tensor, + centered: bool, + lambda_l1: float = 0.0, + lambda_l2: float = 0.0, +) -> None: + w_data = w_data.view(-1, 1) + w_chain = w_chain.view(-1, 1) + chain_weights = softmax(-w_chain, dim=0) + w_data_norm = w_data.sum() + + v_data_mean = (v_data * w_data).sum(0) / w_data_norm + torch.clamp_(v_data_mean, min=1e-4, max=(1.0 - 1e-4)) + h_data_mean = (h_data * w_data).sum(0) / w_data_norm + v_gen_mean = v_chain.mean(0) + torch.clamp_(v_gen_mean, min=1e-4, max=(1.0 - 1e-4)) + + if centered: + v_data_centered = v_data - v_data_mean + h_data_centered = h_data - h_data_mean + v_gen_centered = v_chain - v_data_mean + h_gen_centered = h_chain - h_data_mean + + grad_weight_matrix = ( + (v_data_centered * w_data).T @ h_data_centered + ) / w_data_norm - ((v_gen_centered * chain_weights).T @ h_gen_centered) + grad_vbias = torch.zeros( + vbias.shape[0], device=vbias.device, dtype=vbias.dtype + ) # No training on biases + grad_hbias = torch.zeros( + hbias.shape[0], device=hbias.device, dtype=hbias.dtype + ) # No training on biases + else: + v_data_centered = v_data + h_data_centered = h_data + v_gen_centered = v_chain + h_gen_centered = h_chain + + grad_weight_matrix = ((v_data * w_data).T @ h_data) / w_data_norm - ( + (v_chain * chain_weights).T @ h_chain + ) + + grad_vbias = torch.zeros( + vbias.shape[0], device=vbias.device, dtype=vbias.dtype + ) # No training on biases + grad_hbias = torch.zeros( + hbias.shape[0], device=hbias.device, dtype=hbias.dtype + ) # No training on biases + + if lambda_l1 > 0: + grad_weight_matrix -= lambda_l1 * torch.sign(weight_matrix) + grad_vbias -= lambda_l1 * torch.sign(vbias) + grad_hbias -= lambda_l1 * torch.sign(hbias) + + if lambda_l2 > 0: + grad_weight_matrix -= 2 * lambda_l2 * weight_matrix + grad_vbias -= 2 * lambda_l2 * vbias + grad_hbias -= 2 * lambda_l2 * hbias + + weight_matrix.grad.set_(grad_weight_matrix) + vbias.grad.set_(grad_vbias) + hbias.grad.set_(grad_hbias) + + +@torch.jit.script +def _init_chains( + num_samples: int, + weight_matrix: Tensor, + hbias: Tensor, + start_v: Optional[Tensor] = None, +): + device = weight_matrix.device + dtype = weight_matrix.dtype + if num_samples <= 0: + if start_v is not None: + num_samples = start_v.shape[0] + else: + raise ValueError(f"Got negative num_samples arg: {num_samples}") + + if start_v is None: + mv = torch.ones(size=(num_samples, weight_matrix.shape[0]), device=device, dtype=dtype) / 2 + v = torch.bernoulli(mv) * 2 - 1 + else: + mv = torch.zeros_like(start_v, device=device, dtype=dtype) + v = start_v.to(device=device, dtype=dtype) + + h, mh = _sample_hiddens(v=v, weight_matrix=weight_matrix, hbias=hbias) + return v, h, mv, mh + + +def _init_parameters( + num_hiddens: int, + data: Tensor, + device: torch.device, + dtype: torch.dtype, + var_init: float = 1e-6, +): + _, num_visibles = data.shape + weight_matrix = ( + torch.randn(size=(num_visibles, num_hiddens), device=device, dtype=dtype) + * var_init + ) + vbias = torch.zeros(num_visibles, device=device, dtype=dtype) + hbias = torch.zeros(num_hiddens, device=device, dtype=dtype) + return vbias, hbias, weight_matrix diff --git a/rbms/map_model.py b/rbms/map_model.py index f9d5d2f..c2a236b 100644 --- a/rbms/map_model.py +++ b/rbms/map_model.py @@ -1,5 +1,7 @@ from rbms.bernoulli_bernoulli.classes import BBRBM from rbms.classes import EBM from rbms.potts_bernoulli.classes import PBRBM +from rbms.ising_gaussian.classes import IGRBM +from rbms.bernoulli_gaussian.classes import BGRBM -map_model: dict[str, EBM] = {"BBRBM": BBRBM, "PBRBM": PBRBM} +map_model: dict[str, EBM] = {"BBRBM": BBRBM, "PBRBM": PBRBM, "BGRBM": BGRBM, "IGRBM": IGRBM}