From b6c5ddf2b0c95f41e60cba33d86088780642bb4d Mon Sep 17 00:00:00 2001 From: Francill66 <129966239+Francill66@users.noreply.github.com> Date: Fri, 12 Dec 2025 12:03:40 +0100 Subject: [PATCH 1/7] Binary (0-1) Bernoulli-Gaussian RBM --- rbms/bernoulli_gaussian/classes.py | 214 ++++++++++++++++++++++++++ rbms/bernoulli_gaussian/functional.py | 127 +++++++++++++++ rbms/bernoulli_gaussian/implement.py | 195 +++++++++++++++++++++++ 3 files changed, 536 insertions(+) create mode 100644 rbms/bernoulli_gaussian/classes.py create mode 100644 rbms/bernoulli_gaussian/functional.py create mode 100644 rbms/bernoulli_gaussian/implement.py diff --git a/rbms/bernoulli_gaussian/classes.py b/rbms/bernoulli_gaussian/classes.py new file mode 100644 index 0000000..cf688e0 --- /dev/null +++ b/rbms/bernoulli_gaussian/classes.py @@ -0,0 +1,214 @@ +from typing import List, Optional + +import numpy as np +import torch +from torch import Tensor + +# use the Bernoulli-Gaussian backend +from rbms.bernoulli_gaussian.implement import ( + _compute_energy, + _compute_energy_hiddens, + _compute_energy_visibles, + _compute_gradient, + _init_chains, + _init_parameters, + _sample_hiddens, + _sample_visibles, +) +from rbms.classes import RBM + + +class BGRBM(RBM): + """Bernoulli-Gaussian RBM with fixed hidden variance = 1/Nv (precision γ = Nv).""" + + def __init__( + self, + weight_matrix: Tensor, + vbias: Tensor, + hbias: Tensor, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + if device is None: device = weight_matrix.device + if dtype is None: dtype = weight_matrix.dtype + self.device, self.dtype = device, dtype + + self.weight_matrix = weight_matrix.to(device=self.device, dtype=self.dtype) + self.vbias = vbias.to(device=self.device, dtype=self.dtype) + self.hbias = hbias.to(device=self.device, dtype=self.dtype) + + Nv = int(self.vbias.numel()) + Nh = int(self.hbias.numel()) + + self.name = "BGRBM" + + def __add__(self, other): + # keep fixed variance policy; recompute eta from resulting vbias size + out = BGRBM( + weight_matrix=self.weight_matrix + other.weight_matrix, + vbias=self.vbias + other.vbias, + hbias=self.hbias + other.hbias, + device=self.device, + dtype=self.dtype, + ) + return out + + def __mul__(self, other): + # scalar multiply trained params only; variance stays fixed to Nv + out = BGRBM( + weight_matrix=self.weight_matrix * other, + vbias=self.vbias * other, + hbias=self.hbias * other, + device=self.device, + dtype=self.dtype, + ) + return out + + def clone( + self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None + ): + if device is None: device = self.device + if dtype is None: dtype = self.dtype + return BGRBM( + weight_matrix=self.weight_matrix.clone(), + vbias=self.vbias.clone(), + hbias=self.hbias.clone(), + device=device, + dtype=dtype, + ) + + def compute_energy(self, v: Tensor, h: Tensor) -> Tensor: + return _compute_energy( + v=v, h=h, vbias=self.vbias, hbias=self.hbias, weight_matrix=self.weight_matrix + ) + + def compute_energy_hiddens(self, h: Tensor) -> Tensor: + return _compute_energy_hiddens( + h=h, vbias=self.vbias, hbias=self.hbias, weight_matrix=self.weight_matrix + ) + + def compute_energy_visibles(self, v: Tensor) -> Tensor: + return _compute_energy_visibles( + v=v, vbias=self.vbias, hbias=self.hbias, weight_matrix=self.weight_matrix + ) + + def compute_gradient( + self, data, chains, centered=True, lambda_l1=0.0, lambda_l2=0.0 + ): + # backend should ignore grads on eta or treat it as const; we pass it for conditionals + _compute_gradient( + v_data=data["visible"], + h_data=data["hidden"], + w_data=data["weights"], + v_chain=chains["visible"], + h_chain=chains["hidden"], + w_chain=chains["weights"], + vbias=self.vbias, + hbias=self.hbias, + weight_matrix=self.weight_matrix, + centered=centered, + lambda_l1=lambda_l1, + lambda_l2=lambda_l2, + ) + + def independent_model(self): + return BGRBM( + weight_matrix=torch.zeros_like(self.weight_matrix), + vbias=self.vbias, + hbias=torch.zeros_like(self.hbias), + device=self.device, + dtype=self.dtype, + ) + + def init_chains(self, num_samples, weights=None, start_v=None): + visible, hidden, mean_visible, mean_hidden = _init_chains( + num_samples=num_samples, + weight_matrix=self.weight_matrix, + hbias=self.hbias, + start_v=start_v, + ) + if weights is None: + weights = torch.ones( + visible.shape[0], device=visible.device, dtype=visible.dtype + ) + return dict( + visible=visible, + hidden=hidden, + visible_mag=mean_visible, + hidden_mag=mean_hidden, + weights=weights, + ) + + @staticmethod + def init_parameters(num_hiddens, dataset, device, dtype, var_init=0.0001): + data = dataset.data + if isinstance(data, np.ndarray): + data = torch.from_numpy(dataset.data).to(device=device, dtype=dtype) + vbias, hbias, weight_matrix = _init_parameters( + num_hiddens=num_hiddens, data=data, device=device, dtype=dtype, var_init=var_init + ) + return BGRBM(weight_matrix=weight_matrix, vbias=vbias, hbias=hbias, device=device, dtype=dtype) + + def named_parameters(self): + return {"weight_matrix": self.weight_matrix, "vbias": self.vbias, "hbias": self.hbias} + + def num_hiddens(self): + return self.hbias.shape[0] + + def num_visibles(self): + return self.vbias.shape[0] + + def parameters(self) -> List[Tensor]: + # keep trainables only + return [self.weight_matrix, self.vbias, self.hbias] + + def ref_log_z(self): + K = self.num_hiddens() + Nv = self.num_visibles() + logZ_v = torch.log1p(torch.exp(self.vbias)).sum() + inv_gamma = 1.0 / float(Nv) + quad = 0.5 * inv_gamma * torch.dot(self.hbias, self.hbias) + log_norm = 0.5 * K * np.log(2.0 * np.pi) - 0.5 * K * np.log(float(Nv)) + return (logZ_v + quad + log_norm).item() + + def sample_hiddens(self, chains: dict[str, Tensor], beta=1) -> dict[str, Tensor]: + chains["hidden"], chains["hidden_mag"] = _sample_hiddens( + v=chains["visible"], + weight_matrix=self.weight_matrix, + hbias=self.hbias, + beta=beta, + ) + return chains + + def sample_visibles(self, chains: dict[str, Tensor], beta=1) -> dict[str, Tensor]: + chains["visible"], chains["visible_mag"] = _sample_visibles( + h=chains["hidden"], weight_matrix=self.weight_matrix, vbias=self.vbias, beta=beta + ) + return chains + + @staticmethod + def set_named_parameters(named_params: dict[str, Tensor]) -> "BGRBM": + names = ["vbias", "hbias", "weight_matrix"] + for k in names: + if k not in named_params: + raise ValueError( + f"""Dictionary params missing key '{k}'\n Provided keys : {named_params.keys()}\n Expected keys: {names}""" + ) + params = BGRBM( + weight_matrix=named_params.pop("weight_matrix"), + vbias=named_params.pop("vbias"), + hbias=named_params.pop("hbias"), + ) + if len(named_params) > 0: + raise ValueError(f"Too many keys in params dictionary. Remaining keys: {named_params.keys()}") + return params + + def to( + self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None + ) -> "BGRBM": + if device is not None: self.device = device + if dtype is not None: self.dtype = dtype + self.weight_matrix = self.weight_matrix.to(device=self.device, dtype=self.dtype) + self.vbias = self.vbias.to(device=self.device, dtype=self.dtype) + self.hbias = self.hbias.to(device=self.device, dtype=self.dtype) + return self diff --git a/rbms/bernoulli_gaussian/functional.py b/rbms/bernoulli_gaussian/functional.py new file mode 100644 index 0000000..0310277 --- /dev/null +++ b/rbms/bernoulli_gaussian/functional.py @@ -0,0 +1,127 @@ +from typing import Optional + +import numpy as np +import torch +from torch import Tensor + +# --- switched to Bernoulli–Gaussian (Gaussian hidden) --- +from rbms.bernoulli_gaussian.classes import BGRBM +from rbms.bernoulli_gaussian.implement import ( + _compute_energy, + _compute_energy_hiddens, + _compute_energy_visibles, + _compute_gradient, + _init_chains, + _init_parameters, + _sample_hiddens, + _sample_visibles, +) +from rbms.dataset.dataset_class import RBMDataset + + +def sample_hiddens(chains: dict[str, Tensor], params: BGRBM, beta: float = 1.0) -> dict[str, Tensor]: + """Sample h|v(Gaussian hidden with fixed var = 1/Nv)""" + chains["hidden"], chains["hidden_mag"] = _sample_hiddens( + v=chains["visible"], + weight_matrix=params.weight_matrix, + hbias=params.hbias, + beta=beta, + ) + return chains + + +def sample_visibles(chains: dict[str, Tensor], params: BGRBM, beta: float = 1.0) -> dict[str, Tensor]: + """Sample v|h Bernoulli""" + chains["visible"], chains["visible_mag"] = _sample_visibles( + h=chains["hidden"], + weight_matrix=params.weight_matrix, + vbias=params.vbias, + beta=beta, + ) + return chains + + +def compute_energy(v: Tensor, h: Tensor, params: BGRBM) -> Tensor: + return _compute_energy( + v=v, + h=h, + vbias=params.vbias, + hbias=params.hbias, + weight_matrix=params.weight_matrix, + ) + + +def compute_energy_visibles(v: Tensor, params: BGRBM) -> Tensor: + """Marginalized energy over h""" + return _compute_energy_visibles( + v=v, + vbias=params.vbias, + hbias=params.hbias, + weight_matrix=params.weight_matrix, + ) + + +def compute_energy_hiddens(h: Tensor, params: BGRBM) -> Tensor: + """Energy marginalized over v""" + return _compute_energy_hiddens( + h=h, + vbias=params.vbias, + hbias=params.hbias, + weight_matrix=params.weight_matrix, + ) + + +def compute_gradient( + data: dict[str, Tensor], + chains: dict[str, Tensor], + params: BGRBM, + centered: bool, + lambda_l1: float = 0.0, + lambda_l2: float = 0.0, +) -> None: + _compute_gradient( + v_data=data["visible"], + mh_data=data["hidden_mag"], # use conditional mean for positive phase + w_data=data["weights"], + v_chain=chains["visible"], + h_chain=chains["hidden"], # negative phase from chain samples + w_chain=chains["weights"], + vbias=params.vbias, + hbias=params.hbias, + weight_matrix=params.weight_matrix, + centered=centered, + lambda_l1=lambda_l1, + lambda_l2=lambda_l2, + ) + + +def init_chains( + num_samples: int, + params: BGRBM, + weights: Optional[Tensor] = None, + start_v: Optional[Tensor] = None, +) -> dict[str, Tensor]: + visible, hidden, mean_visible, mean_hidden = _init_chains( + num_samples=num_samples, + weight_matrix=params.weight_matrix, + hbias=params.hbias, + start_v=start_v, + ) + if weights is None: + weights = torch.ones(visible.shape[0], device=visible.device, dtype=visible.dtype) + return dict(visible=visible, hidden=hidden, visible_mag=mean_visible, hidden_mag=mean_hidden, weights=weights) + + +def init_parameters( + num_hiddens: int, + dataset: RBMDataset, + device: torch.device, + dtype: torch.dtype, + var_init: float = 1e-4, +) -> BGRBM: + data = dataset.data + if isinstance(data, np.ndarray): data = torch.from_numpy(dataset.data).to(device=device, dtype=dtype) + vbias, hbias, weight_matrix = _init_parameters( + num_hiddens=num_hiddens, data=data, device=device, dtype=dtype, var_init=var_init + ) + return BGRBM(weight_matrix=weight_matrix, vbias=vbias, hbias=hbias, device=device, dtype=dtype) diff --git a/rbms/bernoulli_gaussian/implement.py b/rbms/bernoulli_gaussian/implement.py new file mode 100644 index 0000000..b59e42b --- /dev/null +++ b/rbms/bernoulli_gaussian/implement.py @@ -0,0 +1,195 @@ +from typing import Optional, Tuple + +import torch +from torch import Tensor +from torch.nn.functional import softmax + + +@torch.jit.script +def _sample_hiddens( + v: Tensor, weight_matrix: Tensor, hbias: Tensor, beta: float = 1.0 +) -> Tuple[Tensor, Tensor]: + mh = (hbias + (v @ weight_matrix)) + h = torch.normal(mean=mh, std=torch.tensor(1.0).to(weight_matrix.device)) + return h, mh + + +@torch.jit.script +def _sample_visibles( + h: Tensor, weight_matrix: Tensor, vbias: Tensor, beta: float = 1.0 +) -> Tuple[Tensor, Tensor]: + mv = torch.sigmoid(vbias + h @ weight_matrix.T) + v = torch.bernoulli(mv) + return v, mv + + +@torch.jit.script +def _compute_energy( + v: Tensor, + h: Tensor, + vbias: Tensor, + hbias: Tensor, + weight_matrix: Tensor, +) -> Tensor: + fields = torch.tensordot(vbias, v, dims=[[0], [1]]) + torch.tensordot( + hbias, h, dims=[[0], [1]] + ) + interaction = torch.multiply( + v, torch.tensordot(h, weight_matrix, dims=[[1], [1]]) + ).sum(1) + Nv = weight_matrix.shape[0] + gamma = float(Nv) + quad = 0.5 * gamma * (h * h).sum(1) + return -fields - interaction + quad + + +@torch.jit.script +def _compute_energy_visibles( + v: Tensor, vbias: Tensor, hbias: Tensor, weight_matrix: Tensor +) -> Tensor: + field = v @ vbias # (B,) + t = hbias + (v @ weight_matrix) # (B,K) + Nv = weight_matrix.shape[0] + K = weight_matrix.shape[1] + inv_gamma = 1.0 / float(Nv) + quad_term = 0.5 * inv_gamma * (t * t).sum(1) # (B,) + dtype = v.dtype + device = v.device + log_two_pi = torch.log(torch.tensor(2.0 * torch.pi, dtype=dtype, device=device)) + const = 0.5 * float(K) * (torch.log(torch.tensor(float(Nv), dtype=dtype, device=device)) - log_two_pi) + + return -field - quad_term + const # (B,) + + +@torch.jit.script +def _compute_energy_hiddens( + h: Tensor, vbias: Tensor, hbias: Tensor, weight_matrix: Tensor +) -> Tensor: + field = h @ hbias # (B,) + exponent = vbias + (h @ weight_matrix.T) # (B,V) + log_term = torch.where(exponent < 10, torch.log1p(torch.exp(exponent)), exponent) + Nv = weight_matrix.shape[0] + gamma = float(Nv) + quad = 0.5 * gamma * (h * h).sum(1) # (B,) + return -field - log_term.sum(1) + quad + + +@torch.jit.script +def _compute_gradient( + v_data: Tensor, + h_data: Tensor, + w_data: Tensor, + v_chain: Tensor, + h_chain: Tensor, + w_chain: Tensor, + vbias: Tensor, + hbias: Tensor, + weight_matrix: Tensor, + centered: bool, + lambda_l1: float = 0.0, + lambda_l2: float = 0.0, +) -> None: + + w_data = w_data.view(-1, 1) + w_chain = w_chain.view(-1, 1) + chain_weights = softmax(-w_chain, dim=0) + w_data_norm = w_data.sum() + + v_data_mean = (v_data * w_data).sum(0) / w_data_norm + torch.clamp_(v_data_mean, min=1e-4, max=(1.0 - 1e-4)) + h_data_mean = (h_data * w_data).sum(0) / w_data_norm + v_gen_mean = (v_chain * chain_weights).sum(0) + torch.clamp_(v_gen_mean, min=1e-4, max=(1.0 - 1e-4)) + h_gen_mean = (h_chain * chain_weights).sum(0) + + if centered: + # Centered variables + v_data_centered = v_data - v_data_mean + h_data_centered = h_data - h_data_mean + v_gen_centered = v_chain - v_data_mean + h_gen_centered = h_chain - h_data_mean + + # Gradient + grad_weight_matrix = ( + (v_data_centered * w_data).T @ h_data_centered + ) / w_data_norm - ((v_gen_centered * chain_weights).T @ h_gen_centered) + grad_vbias = v_data_mean - v_gen_mean - (grad_weight_matrix @ h_data_mean) + grad_hbias = h_data_mean - h_gen_mean - (v_data_mean @ grad_weight_matrix) + else: + v_data_centered = v_data + h_data_centered = h_data + v_gen_centered = v_chain + h_gen_centered = h_chain + + # Gradient: h_data instead of mh_data + grad_weight_matrix = ((v_data * w_data).T @ h_data) / w_data_norm - ( + (v_chain * chain_weights).T @ h_chain + ) + grad_vbias = v_data_mean - v_gen_mean + grad_hbias = h_data_mean - h_gen_mean + + if lambda_l1 > 0: + grad_weight_matrix -= lambda_l1 * torch.sign(weight_matrix) + grad_vbias -= lambda_l1 * torch.sign(vbias) + grad_hbias -= lambda_l1 * torch.sign(hbias) + + if lambda_l2 > 0: + grad_weight_matrix -= 2 * lambda_l2 * weight_matrix + grad_vbias -= 2 * lambda_l2 * vbias + grad_hbias -= 2 * lambda_l2 * hbias + + # Attach to the parameters + weight_matrix.grad.set_(grad_weight_matrix) + vbias.grad.set_(grad_vbias) + hbias.grad.set_(grad_hbias) + + +@torch.jit.script +def _init_chains( + num_samples: int, + weight_matrix: Tensor, + hbias: Tensor, + start_v: Optional[Tensor] = None, +): + num_visibles, num_hiddens = weight_matrix.shape + device = weight_matrix.device + dtype = weight_matrix.dtype + if num_samples <= 0: + if start_v is not None: + num_samples = start_v.shape[0] + else: + raise ValueError(f"Got negative num_samples arg: {num_samples}") + + if start_v is None: + mv = ( + torch.ones(size=(num_samples, num_visibles), device=device, dtype=dtype) / 2 + ) + v = torch.bernoulli(mv) + else: + mv = torch.zeros_like(start_v, device=device, dtype=dtype) + v = start_v.to(device=device, dtype=dtype) + + h, mh = _sample_hiddens(v=v, weight_matrix=weight_matrix, hbias=hbias) + return v, h, mv, mh + + +def _init_parameters( + num_hiddens: int, + data: Tensor, + device: torch.device, + dtype: torch.dtype, + var_init: float = 1e-6, +): + _, num_visibles = data.shape + eps = 1e-4 + weight_matrix = ( + torch.randn(size=(num_visibles, num_hiddens), device=device, dtype=dtype) + * var_init + ) + frequencies = data.mean(0) + frequencies = torch.clamp(frequencies, min=eps, max=(1.0 - eps)) + vbias = (torch.log(frequencies) - torch.log(1.0 - frequencies)).to( + device=device, dtype=dtype + ) + hbias = torch.zeros(num_hiddens, device=device, dtype=dtype) + return vbias, hbias, weight_matrix From 673d7eded63142658f17f215aca042ea4e6a62df Mon Sep 17 00:00:00 2001 From: Francill66 <129966239+Francill66@users.noreply.github.com> Date: Wed, 17 Dec 2025 19:24:59 +0100 Subject: [PATCH 2/7] Update after revision requested --- rbms/bernoulli_gaussian/classes.py | 77 +++++++++++++++++++-------- rbms/bernoulli_gaussian/functional.py | 28 +++++++--- rbms/bernoulli_gaussian/implement.py | 44 ++++++--------- 3 files changed, 92 insertions(+), 57 deletions(-) diff --git a/rbms/bernoulli_gaussian/classes.py b/rbms/bernoulli_gaussian/classes.py index cf688e0..ae0b584 100644 --- a/rbms/bernoulli_gaussian/classes.py +++ b/rbms/bernoulli_gaussian/classes.py @@ -19,9 +19,9 @@ class BGRBM(RBM): - """Bernoulli-Gaussian RBM with fixed hidden variance = 1/Nv (precision γ = Nv).""" + """Bernoulli-Gaussian RBM with fixed hidden variance = 1/Nv, 0-1 visibles, hidden and visible biases""" - def __init__( + def __init__( self, weight_matrix: Tensor, vbias: Tensor, @@ -29,16 +29,21 @@ def __init__( device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): - if device is None: device = weight_matrix.device - if dtype is None: dtype = weight_matrix.dtype + if device is None: + device = weight_matrix.device + if dtype is None: + dtype = weight_matrix.dtype self.device, self.dtype = device, dtype self.weight_matrix = weight_matrix.to(device=self.device, dtype=self.dtype) self.vbias = vbias.to(device=self.device, dtype=self.dtype) self.hbias = hbias.to(device=self.device, dtype=self.dtype) - - Nv = int(self.vbias.numel()) - Nh = int(self.hbias.numel()) + log_two_pi = torch.log(2.0 * torch.pi, dtype=vbias.dtype, device=vbias.device) + self.const = ( + 0.5 + * float(weight_matrix.shape[1]) + * (torch.log(torch.tensor(float(weight_matrix.shape[0]), dtype=vbias.dtype, device=vbias.device)) - log_two_pi) + ) self.name = "BGRBM" @@ -67,8 +72,10 @@ def __mul__(self, other): def clone( self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None ): - if device is None: device = self.device - if dtype is None: dtype = self.dtype + if device is None: + device = self.device + if dtype is None: + dtype = self.dtype return BGRBM( weight_matrix=self.weight_matrix.clone(), vbias=self.vbias.clone(), @@ -79,7 +86,12 @@ def clone( def compute_energy(self, v: Tensor, h: Tensor) -> Tensor: return _compute_energy( - v=v, h=h, vbias=self.vbias, hbias=self.hbias, weight_matrix=self.weight_matrix + v=v, + h=h, + vbias=self.vbias, + hbias=self.hbias, + weight_matrix=self.weight_matrix, + const=self.const, ) def compute_energy_hiddens(self, h: Tensor) -> Tensor: @@ -92,16 +104,14 @@ def compute_energy_visibles(self, v: Tensor) -> Tensor: v=v, vbias=self.vbias, hbias=self.hbias, weight_matrix=self.weight_matrix ) - def compute_gradient( - self, data, chains, centered=True, lambda_l1=0.0, lambda_l2=0.0 - ): + def compute_gradient(self, data, chains, centered=True, lambda_l1=0.0, lambda_l2=0.0): # backend should ignore grads on eta or treat it as const; we pass it for conditionals _compute_gradient( v_data=data["visible"], - h_data=data["hidden"], + h_data=data["hidden_magn"], w_data=data["weights"], v_chain=chains["visible"], - h_chain=chains["hidden"], + h_chain=chains["hidden_mag"], w_chain=chains["weights"], vbias=self.vbias, hbias=self.hbias, @@ -145,12 +155,26 @@ def init_parameters(num_hiddens, dataset, device, dtype, var_init=0.0001): if isinstance(data, np.ndarray): data = torch.from_numpy(dataset.data).to(device=device, dtype=dtype) vbias, hbias, weight_matrix = _init_parameters( - num_hiddens=num_hiddens, data=data, device=device, dtype=dtype, var_init=var_init + num_hiddens=num_hiddens, + data=data, + device=device, + dtype=dtype, + var_init=var_init, + ) + return BGRBM( + weight_matrix=weight_matrix, + vbias=vbias, + hbias=hbias, + device=device, + dtype=dtype, ) - return BGRBM(weight_matrix=weight_matrix, vbias=vbias, hbias=hbias, device=device, dtype=dtype) def named_parameters(self): - return {"weight_matrix": self.weight_matrix, "vbias": self.vbias, "hbias": self.hbias} + return { + "weight_matrix": self.weight_matrix, + "vbias": self.vbias, + "hbias": self.hbias, + } def num_hiddens(self): return self.hbias.shape[0] @@ -175,14 +199,17 @@ def sample_hiddens(self, chains: dict[str, Tensor], beta=1) -> dict[str, Tensor] chains["hidden"], chains["hidden_mag"] = _sample_hiddens( v=chains["visible"], weight_matrix=self.weight_matrix, - hbias=self.hbias, + hbias=self.hbias, beta=beta, ) return chains def sample_visibles(self, chains: dict[str, Tensor], beta=1) -> dict[str, Tensor]: chains["visible"], chains["visible_mag"] = _sample_visibles( - h=chains["hidden"], weight_matrix=self.weight_matrix, vbias=self.vbias, beta=beta + h=chains["hidden"], + weight_matrix=self.weight_matrix, + vbias=self.vbias, + beta=beta, ) return chains @@ -200,14 +227,18 @@ def set_named_parameters(named_params: dict[str, Tensor]) -> "BGRBM": hbias=named_params.pop("hbias"), ) if len(named_params) > 0: - raise ValueError(f"Too many keys in params dictionary. Remaining keys: {named_params.keys()}") + raise ValueError( + f"Too many keys in params dictionary. Remaining keys: {named_params.keys()}" + ) return params def to( self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None ) -> "BGRBM": - if device is not None: self.device = device - if dtype is not None: self.dtype = dtype + if device is not None: + self.device = device + if dtype is not None: + self.dtype = dtype self.weight_matrix = self.weight_matrix.to(device=self.device, dtype=self.dtype) self.vbias = self.vbias.to(device=self.device, dtype=self.dtype) self.hbias = self.hbias.to(device=self.device, dtype=self.dtype) diff --git a/rbms/bernoulli_gaussian/functional.py b/rbms/bernoulli_gaussian/functional.py index 0310277..836a560 100644 --- a/rbms/bernoulli_gaussian/functional.py +++ b/rbms/bernoulli_gaussian/functional.py @@ -19,7 +19,9 @@ from rbms.dataset.dataset_class import RBMDataset -def sample_hiddens(chains: dict[str, Tensor], params: BGRBM, beta: float = 1.0) -> dict[str, Tensor]: +def sample_hiddens( + chains: dict[str, Tensor], params: BGRBM, beta: float = 1.0 +) -> dict[str, Tensor]: """Sample h|v(Gaussian hidden with fixed var = 1/Nv)""" chains["hidden"], chains["hidden_mag"] = _sample_hiddens( v=chains["visible"], @@ -30,7 +32,9 @@ def sample_hiddens(chains: dict[str, Tensor], params: BGRBM, beta: float = 1.0) return chains -def sample_visibles(chains: dict[str, Tensor], params: BGRBM, beta: float = 1.0) -> dict[str, Tensor]: +def sample_visibles( + chains: dict[str, Tensor], params: BGRBM, beta: float = 1.0 +) -> dict[str, Tensor]: """Sample v|h Bernoulli""" chains["visible"], chains["visible_mag"] = _sample_visibles( h=chains["hidden"], @@ -58,6 +62,7 @@ def compute_energy_visibles(v: Tensor, params: BGRBM) -> Tensor: vbias=params.vbias, hbias=params.hbias, weight_matrix=params.weight_matrix, + const=params.const ) @@ -81,10 +86,10 @@ def compute_gradient( ) -> None: _compute_gradient( v_data=data["visible"], - mh_data=data["hidden_mag"], # use conditional mean for positive phase + mh_data=data["hidden_mag"], # use conditional mean for positive phase w_data=data["weights"], v_chain=chains["visible"], - h_chain=chains["hidden"], # negative phase from chain samples + h_chain=chains["hidden_mag"], # negative phase from chain samples w_chain=chains["weights"], vbias=params.vbias, hbias=params.hbias, @@ -109,7 +114,13 @@ def init_chains( ) if weights is None: weights = torch.ones(visible.shape[0], device=visible.device, dtype=visible.dtype) - return dict(visible=visible, hidden=hidden, visible_mag=mean_visible, hidden_mag=mean_hidden, weights=weights) + return dict( + visible=visible, + hidden=hidden, + visible_mag=mean_visible, + hidden_mag=mean_hidden, + weights=weights, + ) def init_parameters( @@ -120,8 +131,11 @@ def init_parameters( var_init: float = 1e-4, ) -> BGRBM: data = dataset.data - if isinstance(data, np.ndarray): data = torch.from_numpy(dataset.data).to(device=device, dtype=dtype) + if isinstance(data, np.ndarray): + data = torch.from_numpy(dataset.data).to(device=device, dtype=dtype) vbias, hbias, weight_matrix = _init_parameters( num_hiddens=num_hiddens, data=data, device=device, dtype=dtype, var_init=var_init ) - return BGRBM(weight_matrix=weight_matrix, vbias=vbias, hbias=hbias, device=device, dtype=dtype) + return BGRBM( + weight_matrix=weight_matrix, vbias=vbias, hbias=hbias, device=device, dtype=dtype + ) diff --git a/rbms/bernoulli_gaussian/implement.py b/rbms/bernoulli_gaussian/implement.py index b59e42b..fbf97f4 100644 --- a/rbms/bernoulli_gaussian/implement.py +++ b/rbms/bernoulli_gaussian/implement.py @@ -9,9 +9,9 @@ def _sample_hiddens( v: Tensor, weight_matrix: Tensor, hbias: Tensor, beta: float = 1.0 ) -> Tuple[Tensor, Tensor]: - mh = (hbias + (v @ weight_matrix)) - h = torch.normal(mean=mh, std=torch.tensor(1.0).to(weight_matrix.device)) - return h, mh + mh = hbias + (v @ weight_matrix) + h = torch.randn_like(mh) + mh + return h, mh @torch.jit.script @@ -38,49 +38,42 @@ def _compute_energy( v, torch.tensordot(h, weight_matrix, dims=[[1], [1]]) ).sum(1) Nv = weight_matrix.shape[0] - gamma = float(Nv) - quad = 0.5 * gamma * (h * h).sum(1) + quad = 0.5 * Nv * (h * h).sum(1) return -fields - interaction + quad @torch.jit.script def _compute_energy_visibles( - v: Tensor, vbias: Tensor, hbias: Tensor, weight_matrix: Tensor + v: Tensor, vbias: Tensor, hbias: Tensor, weight_matrix: Tensor, const: Tensor, ) -> Tensor: - field = v @ vbias # (B,) - t = hbias + (v @ weight_matrix) # (B,K) + field = v @ vbias + t = hbias + (v @ weight_matrix) Nv = weight_matrix.shape[0] - K = weight_matrix.shape[1] - inv_gamma = 1.0 / float(Nv) - quad_term = 0.5 * inv_gamma * (t * t).sum(1) # (B,) - dtype = v.dtype - device = v.device - log_two_pi = torch.log(torch.tensor(2.0 * torch.pi, dtype=dtype, device=device)) - const = 0.5 * float(K) * (torch.log(torch.tensor(float(Nv), dtype=dtype, device=device)) - log_two_pi) - - return -field - quad_term + const # (B,) + inv_gamma = 1.0 / Nv + quad_term = 0.5 * inv_gamma * (t * t).sum(1) + return -field - quad_term + const @torch.jit.script def _compute_energy_hiddens( h: Tensor, vbias: Tensor, hbias: Tensor, weight_matrix: Tensor ) -> Tensor: - field = h @ hbias # (B,) - exponent = vbias + (h @ weight_matrix.T) # (B,V) + field = h @ hbias + exponent = vbias + (h @ weight_matrix.T) log_term = torch.where(exponent < 10, torch.log1p(torch.exp(exponent)), exponent) Nv = weight_matrix.shape[0] gamma = float(Nv) - quad = 0.5 * gamma * (h * h).sum(1) # (B,) + quad = 0.5 * gamma * (h * h).sum(1) return -field - log_term.sum(1) + quad @torch.jit.script def _compute_gradient( v_data: Tensor, - h_data: Tensor, + h_data: Tensor, w_data: Tensor, v_chain: Tensor, - h_chain: Tensor, + h_chain: Tensor, w_chain: Tensor, vbias: Tensor, hbias: Tensor, @@ -89,7 +82,6 @@ def _compute_gradient( lambda_l1: float = 0.0, lambda_l2: float = 0.0, ) -> None: - w_data = w_data.view(-1, 1) w_chain = w_chain.view(-1, 1) chain_weights = softmax(-w_chain, dim=0) @@ -132,7 +124,7 @@ def _compute_gradient( grad_weight_matrix -= lambda_l1 * torch.sign(weight_matrix) grad_vbias -= lambda_l1 * torch.sign(vbias) grad_hbias -= lambda_l1 * torch.sign(hbias) - + if lambda_l2 > 0: grad_weight_matrix -= 2 * lambda_l2 * weight_matrix grad_vbias -= 2 * lambda_l2 * vbias @@ -161,9 +153,7 @@ def _init_chains( raise ValueError(f"Got negative num_samples arg: {num_samples}") if start_v is None: - mv = ( - torch.ones(size=(num_samples, num_visibles), device=device, dtype=dtype) / 2 - ) + mv = torch.ones(size=(num_samples, num_visibles), device=device, dtype=dtype) / 2 v = torch.bernoulli(mv) else: mv = torch.zeros_like(start_v, device=device, dtype=dtype) From 2836feeca9d20ec3c59f3052819cec596ce19117 Mon Sep 17 00:00:00 2001 From: Francill66 <129966239+Francill66@users.noreply.github.com> Date: Fri, 19 Dec 2025 09:34:33 +0100 Subject: [PATCH 3/7] Notation corrected --- rbms/bernoulli_gaussian/classes.py | 4 ++-- rbms/bernoulli_gaussian/implement.py | 14 ++++++-------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/rbms/bernoulli_gaussian/classes.py b/rbms/bernoulli_gaussian/classes.py index ae0b584..3f58b03 100644 --- a/rbms/bernoulli_gaussian/classes.py +++ b/rbms/bernoulli_gaussian/classes.py @@ -101,14 +101,14 @@ def compute_energy_hiddens(self, h: Tensor) -> Tensor: def compute_energy_visibles(self, v: Tensor) -> Tensor: return _compute_energy_visibles( - v=v, vbias=self.vbias, hbias=self.hbias, weight_matrix=self.weight_matrix + v=v, vbias=self.vbias, hbias=self.hbias, weight_matrix=self.weight_matrix, const=self.const ) def compute_gradient(self, data, chains, centered=True, lambda_l1=0.0, lambda_l2=0.0): # backend should ignore grads on eta or treat it as const; we pass it for conditionals _compute_gradient( v_data=data["visible"], - h_data=data["hidden_magn"], + h_data=data["hidden_mag"], w_data=data["weights"], v_chain=chains["visible"], h_chain=chains["hidden_mag"], diff --git a/rbms/bernoulli_gaussian/implement.py b/rbms/bernoulli_gaussian/implement.py index fbf97f4..d9afb99 100644 --- a/rbms/bernoulli_gaussian/implement.py +++ b/rbms/bernoulli_gaussian/implement.py @@ -37,8 +37,8 @@ def _compute_energy( interaction = torch.multiply( v, torch.tensordot(h, weight_matrix, dims=[[1], [1]]) ).sum(1) - Nv = weight_matrix.shape[0] - quad = 0.5 * Nv * (h * h).sum(1) + num_visibles = weight_matrix.shape[0] + quad = 0.5 * num_visibles * (h * h).sum(1) return -fields - interaction + quad @@ -48,9 +48,8 @@ def _compute_energy_visibles( ) -> Tensor: field = v @ vbias t = hbias + (v @ weight_matrix) - Nv = weight_matrix.shape[0] - inv_gamma = 1.0 / Nv - quad_term = 0.5 * inv_gamma * (t * t).sum(1) + num_visibles = weight_matrix.shape[0] + quad_term = 0.5 * (t * t).sum(1) / num_visibles return -field - quad_term + const @@ -61,9 +60,8 @@ def _compute_energy_hiddens( field = h @ hbias exponent = vbias + (h @ weight_matrix.T) log_term = torch.where(exponent < 10, torch.log1p(torch.exp(exponent)), exponent) - Nv = weight_matrix.shape[0] - gamma = float(Nv) - quad = 0.5 * gamma * (h * h).sum(1) + num_visibles = weight_matrix.shape[0] + quad = 0.5 * (h * h).sum(1) * num_visibles return -field - log_term.sum(1) + quad From 3b820235f50a67f75efcf0605258850ed8f7fc43 Mon Sep 17 00:00:00 2001 From: Francill66 <129966239+Francill66@users.noreply.github.com> Date: Fri, 19 Dec 2025 09:44:19 +0100 Subject: [PATCH 4/7] Added contributor name --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ba608a2..3974149 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ version = "0.5.0" authors = [ {name="Nicolas Béreux", email="nicolas.bereux@gmail.com"}, {name="Aurélien Decelle"}, + {name="Francesco D'Amico"}, {name="Cyril Furtlehner"}, {name="Alfonso Navas"}, {name="Lorenzo Rosset"}, @@ -84,4 +85,4 @@ docstring-code-format = false [dependency-groups] dev = [ "pytest>=8.4.1", -] +] \ No newline at end of file From 3488dbfe53c1d2a74ddd43360bc38d6a78c91170 Mon Sep 17 00:00:00 2001 From: Francill66 <129966239+Francill66@users.noreply.github.com> Date: Fri, 30 Jan 2026 17:55:14 +0100 Subject: [PATCH 5/7] Added Ising Gaussian RBM No biases for default --- rbms/ising_gaussian/classes.py | 239 ++++++++++++++++++++++++++++++ rbms/ising_gaussian/functional.py | 135 +++++++++++++++++ rbms/ising_gaussian/implement.py | 178 ++++++++++++++++++++++ rbms/map_model.py | 4 +- 4 files changed, 555 insertions(+), 1 deletion(-) create mode 100644 rbms/ising_gaussian/classes.py create mode 100644 rbms/ising_gaussian/functional.py create mode 100644 rbms/ising_gaussian/implement.py diff --git a/rbms/ising_gaussian/classes.py b/rbms/ising_gaussian/classes.py new file mode 100644 index 0000000..c95a917 --- /dev/null +++ b/rbms/ising_gaussian/classes.py @@ -0,0 +1,239 @@ +from typing import List, Optional + +import numpy as np +import torch +from torch import Tensor + +from rbms.ising_gaussian.implement import ( + _compute_energy, + _compute_energy_hiddens, + _compute_energy_visibles, + _compute_gradient, + _init_chains, + _init_parameters, + _sample_hiddens, + _sample_visibles, +) +from rbms.classes import RBM + + +class IGRBM(RBM): + """Ising-Gaussian RBM with fixed hidden variance = 1/Nv, \pm 1 visibles, without any bias""" + + def __init__( + self, + weight_matrix: Tensor, + vbias: Tensor, + hbias: Tensor, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + if device is None: + device = weight_matrix.device + if dtype is None: + dtype = weight_matrix.dtype + self.device, self.dtype = device, dtype + + self.weight_matrix = weight_matrix.to(device=self.device, dtype=self.dtype) + self.vbias = vbias.to(device=self.device, dtype=self.dtype) + self.hbias = hbias.to(device=self.device, dtype=self.dtype) + + log_two_pi = torch.log(torch.tensor(2.0 * torch.pi, dtype=dtype, device=device)) + const = ( + 0.5 + * float(self.weight_matrix[1]) + * ( + torch.log( + torch.tensor(float(self.weight_matrix[0]), dtype=dtype, device=device) + ) + - log_two_pi + ) + ) + self.const = const + self.name = "IGRBM" + + def __add__(self, other): + out = IGRBM( + weight_matrix=self.weight_matrix + other.weight_matrix, + vbias=self.vbias + other.vbias, + hbias=self.hbias + other.hbias, + device=self.device, + dtype=self.dtype, + ) + return out + + def __mul__(self, other): + out = IGRBM( + weight_matrix=self.weight_matrix * other, + vbias=self.vbias * other, + hbias=self.hbias * other, + device=self.device, + dtype=self.dtype, + ) + return out + + def clone( + self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None + ): + if device is None: + device = self.device + if dtype is None: + dtype = self.dtype + return IGRBM( + weight_matrix=self.weight_matrix.clone(), + vbias=self.vbias.clone(), + hbias=self.hbias.clone(), + device=device, + dtype=dtype, + ) + + def compute_energy(self, v: Tensor, h: Tensor) -> Tensor: + return _compute_energy( + v=v, h=h, vbias=self.vbias, hbias=self.hbias, weight_matrix=self.weight_matrix + ) + + def compute_energy_hiddens(self, h: Tensor) -> Tensor: + return _compute_energy_hiddens( + h=h, vbias=self.vbias, hbias=self.hbias, weight_matrix=self.weight_matrix + ) + + def compute_energy_visibles(self, v: Tensor) -> Tensor: + return _compute_energy_visibles( + v=v, vbias=self.vbias, hbias=self.hbias, weight_matrix=self.weight_matrix + ) + + def compute_gradient(self, data, chains, centered=True, lambda_l1=0.0, lambda_l2=0.0): + _compute_gradient( + v_data=data["visible"], + h_data=data["hidden"], + w_data=data["weights"], + v_chain=chains["visible"], + h_chain=chains["hidden"], + w_chain=chains["weights"], + vbias=self.vbias, + hbias=self.hbias, + weight_matrix=self.weight_matrix, + centered=centered, + lambda_l1=lambda_l1, + lambda_l2=lambda_l2, + ) + + def independent_model(self): + return IGRBM( + weight_matrix=torch.zeros_like(self.weight_matrix), + vbias=self.vbias, + hbias=torch.zeros_like(self.hbias), + device=self.device, + dtype=self.dtype, + ) + + def init_chains(self, num_samples, weights=None, start_v=None): + visible, hidden, mean_visible, mean_hidden = _init_chains( + num_samples=num_samples, + weight_matrix=self.weight_matrix, + hbias=self.hbias, + start_v=start_v, + ) + if weights is None: + weights = torch.ones( + visible.shape[0], device=visible.device, dtype=visible.dtype + ) + return dict( + visible=visible, + hidden=hidden, + visible_mag=mean_visible, + hidden_mag=mean_hidden, + weights=weights, + ) + + @staticmethod + def init_parameters(num_hiddens, dataset, device, dtype, var_init=0.0001): + data = dataset.data + if isinstance(data, np.ndarray): + data = torch.from_numpy(dataset.data).to(device=device, dtype=dtype) + vbias, hbias, weight_matrix = _init_parameters( + num_hiddens=num_hiddens, + data=data, + device=device, + dtype=dtype, + var_init=var_init, + ) + return IGRBM( + weight_matrix=weight_matrix, + vbias=vbias, + hbias=hbias, + device=device, + dtype=dtype, + ) + + def named_parameters(self): + return { + "weight_matrix": self.weight_matrix, + "vbias": self.vbias, + "hbias": self.hbias, + } + + def num_hiddens(self): + return self.hbias.shape[0] + + def num_visibles(self): + return self.vbias.shape[0] + + def parameters(self) -> List[Tensor]: + return [self.weight_matrix, self.vbias, self.hbias] + + def ref_log_z(self): + K = self.num_hiddens() + logZ_v = torch.log1p(torch.exp(self.vbias)).sum() + quad = 0.5 * torch.dot(self.hbias, self.hbias) / float(self.num_visibles()) + log_norm = 0.5 * K * np.log(2.0 * np.pi) - 0.5 * K * np.log(float(self.num_visibles())) + return (logZ_v + quad + log_norm).item() + + def sample_hiddens(self, chains: dict[str, Tensor], beta=1) -> dict[str, Tensor]: + chains["hidden"], chains["hidden_mag"] = _sample_hiddens( + v=chains["visible"], + weight_matrix=self.weight_matrix, + hbias=self.hbias, + beta=beta, + ) + return chains + + def sample_visibles(self, chains: dict[str, Tensor], beta=1) -> dict[str, Tensor]: + chains["visible"], chains["visible_mag"] = _sample_visibles( + h=chains["hidden"], + weight_matrix=self.weight_matrix, + vbias=self.vbias, + beta=beta, + ) + return chains + + @staticmethod + def set_named_parameters(named_params: dict[str, Tensor]) -> "IGRBM": + names = ["vbias", "hbias", "weight_matrix"] + for k in names: + if k not in named_params: + raise ValueError( + f"""Dictionary params missing key '{k}'\n Provided keys : {named_params.keys()}\n Expected keys: {names}""" + ) + params = IGRBM( + weight_matrix=named_params.pop("weight_matrix"), + vbias=named_params.pop("vbias"), + hbias=named_params.pop("hbias"), + ) + if len(named_params) > 0: + raise ValueError( + f"Too many keys in params dictionary. Remaining keys: {named_params.keys()}" + ) + return params + + def to( + self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None + ) -> "IGRBM": + if device is not None: + self.device = device + if dtype is not None: + self.dtype = dtype + self.weight_matrix = self.weight_matrix.to(device=self.device, dtype=self.dtype) + self.vbias = self.vbias.to(device=self.device, dtype=self.dtype) + self.hbias = self.hbias.to(device=self.device, dtype=self.dtype) + return self diff --git a/rbms/ising_gaussian/functional.py b/rbms/ising_gaussian/functional.py new file mode 100644 index 0000000..5d837ff --- /dev/null +++ b/rbms/ising_gaussian/functional.py @@ -0,0 +1,135 @@ +from typing import Optional + +import numpy as np +import torch +from torch import Tensor + +from rbms.ising_gaussian.classes import IGRBM +from rbms.ising_gaussian.implement import ( + _compute_energy, + _compute_energy_hiddens, + _compute_energy_visibles, + _compute_gradient, + _init_chains, + _init_parameters, + _sample_hiddens, + _sample_visibles, +) +from rbms.dataset.dataset_class import RBMDataset + + +def sample_hiddens( + chains: dict[str, Tensor], params: IGRBM, beta: float = 1.0 +) -> dict[str, Tensor]: + chains["hidden"], chains["hidden_mag"] = _sample_hiddens( + v=chains["visible"], + weight_matrix=params.weight_matrix, + hbias=params.hbias, + beta=beta, + ) + return chains + + +def sample_visibles( + chains: dict[str, Tensor], params: IGRBM, beta: float = 1.0 +) -> dict[str, Tensor]: + chains["visible"], chains["visible_mag"] = _sample_visibles( + h=chains["hidden"], + weight_matrix=params.weight_matrix, + vbias=params.vbias, + beta=beta, + ) + return chains + + +def compute_energy(v: Tensor, h: Tensor, params: IGRBM) -> Tensor: + return _compute_energy( + v=v, + h=h, + vbias=params.vbias, + hbias=params.hbias, + weight_matrix=params.weight_matrix, + ) + + +def compute_energy_visibles(v: Tensor, params: IGRBM) -> Tensor: + return _compute_energy_visibles( + v=v, + vbias=params.vbias, + hbias=params.hbias, + weight_matrix=params.weight_matrix, + ) + + +def compute_energy_hiddens(h: Tensor, params: IGRBM) -> Tensor: + return _compute_energy_hiddens( + h=h, + vbias=params.vbias, + hbias=params.hbias, + weight_matrix=params.weight_matrix, + ) + + +def compute_gradient( + data: dict[str, Tensor], + chains: dict[str, Tensor], + params: IGRBM, + centered: bool, + lambda_l1: float = 0.0, + lambda_l2: float = 0.0, +) -> None: + _compute_gradient( + v_data=data["visible"], + mh_data=data["hidden_mag"], + w_data=data["weights"], + v_chain=chains["visible"], + h_chain=chains["hidden"], + w_chain=chains["weights"], + vbias=params.vbias, + hbias=params.hbias, + weight_matrix=params.weight_matrix, + centered=centered, + lambda_l1=lambda_l1, + lambda_l2=lambda_l2, + ) + + +def init_chains( + num_samples: int, + params: IGRBM, + weights: Optional[Tensor] = None, + start_v: Optional[Tensor] = None, +) -> dict[str, Tensor]: + visible, hidden, mean_visible, mean_hidden = _init_chains( + num_samples=num_samples, + weight_matrix=params.weight_matrix, + hbias=params.hbias, + start_v=start_v, + ) + if weights is None: + weights = torch.ones(visible.shape[0], device=visible.device, dtype=visible.dtype) + return dict( + visible=visible, + hidden=hidden, + visible_mag=mean_visible, + hidden_mag=mean_hidden, + weights=weights, + ) + + +def init_parameters( + num_hiddens: int, + dataset: RBMDataset, + device: torch.device, + dtype: torch.dtype, + var_init: float = 1e-4, +) -> IGRBM: + data = dataset.data + if isinstance(data, np.ndarray): + data = torch.from_numpy(dataset.data).to(device=device, dtype=dtype) + vbias, hbias, weight_matrix = _init_parameters( + num_hiddens=num_hiddens, data=data, device=device, dtype=dtype, var_init=var_init + ) + return IGRBM( + weight_matrix=weight_matrix, vbias=vbias, hbias=hbias, device=device, dtype=dtype + ) diff --git a/rbms/ising_gaussian/implement.py b/rbms/ising_gaussian/implement.py new file mode 100644 index 0000000..6c8905d --- /dev/null +++ b/rbms/ising_gaussian/implement.py @@ -0,0 +1,178 @@ +from typing import Optional, Tuple + +import torch +from torch import Tensor +from torch.nn.functional import softmax + + +@torch.jit.script +def _sample_hiddens( + v: Tensor, weight_matrix: Tensor, hbias: Tensor, beta: float = 1.0 +) -> Tuple[Tensor, Tensor]: + mh = hbias + (v @ weight_matrix) + h = torch.randn_like(mh) + mh + return h, mh + + +@torch.jit.script +def _sample_visibles( + h: Tensor, weight_matrix: Tensor, vbias: Tensor, beta: float = 1.0 +) -> Tuple[Tensor, Tensor]: + mv = torch.tanh(vbias + h @ weight_matrix.T) + v = torch.bernoulli(0.5 * (1 + mv)) * 2 - 1 + return v, mv + + +@torch.jit.script +def _compute_energy( + v: Tensor, + h: Tensor, + vbias: Tensor, + hbias: Tensor, + weight_matrix: Tensor, +) -> Tensor: + fields = torch.tensordot(vbias, v, dims=[[0], [1]]) + torch.tensordot( + hbias, h, dims=[[0], [1]] + ) + interaction = torch.multiply( + v, torch.tensordot(h, weight_matrix, dims=[[1], [1]]) + ).sum(1) + quad = 0.5 * float(weight_matrix.shape[0]) * (h * h).sum(1) + return -fields - interaction + quad + + +@torch.jit.script +def _compute_energy_visibles( + v: Tensor, vbias: Tensor, hbias: Tensor, weight_matrix: Tensor, const: Tensor +) -> Tensor: + field = v @ vbias + t = hbias + (v @ weight_matrix) + quad_term = 0.5 * (t * t).sum(1) / float(weight_matrix.shape[0]) + return -field - quad_term + const + + +@torch.jit.script +def _compute_energy_hiddens( + h: Tensor, vbias: Tensor, hbias: Tensor, weight_matrix: Tensor +) -> Tensor: + field = h @ hbias + exponent = vbias + (h @ weight_matrix.T) + log_term = torch.where(exponent < 10, torch.log1p(torch.exp(exponent)), exponent) + quad = 0.5 * float(weight_matrix.shape[0]) * (h * h).sum(1) + return -field - log_term.sum(1) + quad + + +@torch.jit.script +def _compute_gradient( + v_data: Tensor, + h_data: Tensor, + w_data: Tensor, + v_chain: Tensor, + h_chain: Tensor, + w_chain: Tensor, + vbias: Tensor, + hbias: Tensor, + weight_matrix: Tensor, + centered: bool, + lambda_l1: float = 0.0, + lambda_l2: float = 0.0, +) -> None: + w_data = w_data.view(-1, 1) + w_chain = w_chain.view(-1, 1) + chain_weights = softmax(-w_chain, dim=0) + w_data_norm = w_data.sum() + + v_data_mean = (v_data * w_data).sum(0) / w_data_norm + torch.clamp_(v_data_mean, min=1e-4, max=(1.0 - 1e-4)) + h_data_mean = (h_data * w_data).sum(0) / w_data_norm + v_gen_mean = v_chain.mean(0) + torch.clamp_(v_gen_mean, min=1e-4, max=(1.0 - 1e-4)) + + if centered: + v_data_centered = v_data - v_data_mean + h_data_centered = h_data - h_data_mean + v_gen_centered = v_chain - v_data_mean + h_gen_centered = h_chain - h_data_mean + + grad_weight_matrix = ( + (v_data_centered * w_data).T @ h_data_centered + ) / w_data_norm - ((v_gen_centered * chain_weights).T @ h_gen_centered) + grad_vbias = torch.zeros( + vbias.shape[0], device=vbias.device, dtype=vbias.dtype + ) # No training on biases + grad_hbias = torch.zeros( + hbias.shape[0], device=hbias.device, dtype=hbias.dtype + ) # No training on biases + else: + v_data_centered = v_data + h_data_centered = h_data + v_gen_centered = v_chain + h_gen_centered = h_chain + + grad_weight_matrix = ((v_data * w_data).T @ h_data) / w_data_norm - ( + (v_chain * chain_weights).T @ h_chain + ) + + grad_vbias = torch.zeros( + vbias.shape[0], device=vbias.device, dtype=vbias.dtype + ) # No training on biases + grad_hbias = torch.zeros( + hbias.shape[0], device=hbias.device, dtype=hbias.dtype + ) # No training on biases + + if lambda_l1 > 0: + grad_weight_matrix -= lambda_l1 * torch.sign(weight_matrix) + grad_vbias -= lambda_l1 * torch.sign(vbias) + grad_hbias -= lambda_l1 * torch.sign(hbias) + + if lambda_l2 > 0: + grad_weight_matrix -= 2 * lambda_l2 * weight_matrix + grad_vbias -= 2 * lambda_l2 * vbias + grad_hbias -= 2 * lambda_l2 * hbias + + weight_matrix.grad.set_(grad_weight_matrix) + vbias.grad.set_(grad_vbias) + hbias.grad.set_(grad_hbias) + + +@torch.jit.script +def _init_chains( + num_samples: int, + weight_matrix: Tensor, + hbias: Tensor, + start_v: Optional[Tensor] = None, +): + device = weight_matrix.device + dtype = weight_matrix.dtype + if num_samples <= 0: + if start_v is not None: + num_samples = start_v.shape[0] + else: + raise ValueError(f"Got negative num_samples arg: {num_samples}") + + if start_v is None: + mv = torch.ones(size=(num_samples, weight_matrix.shape[0]), device=device, dtype=dtype) / 2 + v = torch.bernoulli(mv) * 2 - 1 + else: + mv = torch.zeros_like(start_v, device=device, dtype=dtype) + v = start_v.to(device=device, dtype=dtype) + + h, mh = _sample_hiddens(v=v, weight_matrix=weight_matrix, hbias=hbias) + return v, h, mv, mh + + +def _init_parameters( + num_hiddens: int, + data: Tensor, + device: torch.device, + dtype: torch.dtype, + var_init: float = 1e-6, +): + _, num_visibles = data.shape + weight_matrix = ( + torch.randn(size=(num_visibles, num_hiddens), device=device, dtype=dtype) + * var_init + ) + vbias = torch.zeros(num_visibles, device=device, dtype=dtype) + hbias = torch.zeros(num_hiddens, device=device, dtype=dtype) + return vbias, hbias, weight_matrix diff --git a/rbms/map_model.py b/rbms/map_model.py index f9d5d2f..c2a236b 100644 --- a/rbms/map_model.py +++ b/rbms/map_model.py @@ -1,5 +1,7 @@ from rbms.bernoulli_bernoulli.classes import BBRBM from rbms.classes import EBM from rbms.potts_bernoulli.classes import PBRBM +from rbms.ising_gaussian.classes import IGRBM +from rbms.bernoulli_gaussian.classes import BGRBM -map_model: dict[str, EBM] = {"BBRBM": BBRBM, "PBRBM": PBRBM} +map_model: dict[str, EBM] = {"BBRBM": BBRBM, "PBRBM": PBRBM, "BGRBM": BGRBM, "IGRBM": IGRBM} From 032db12f4ea7958fdada840776cb662cb90966f0 Mon Sep 17 00:00:00 2001 From: Francill66 <129966239+Francill66@users.noreply.github.com> Date: Mon, 2 Feb 2026 14:36:33 +0100 Subject: [PATCH 6/7] Create __init__.py --- rbms/ising_gaussian/__init__.py | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 rbms/ising_gaussian/__init__.py diff --git a/rbms/ising_gaussian/__init__.py b/rbms/ising_gaussian/__init__.py new file mode 100644 index 0000000..b20d9a5 --- /dev/null +++ b/rbms/ising_gaussian/__init__.py @@ -0,0 +1,3 @@ +# ruff: noqa +from rbms.ising_gaussian.classes import IGRBM +from rbms.ising_gaussian.functional import * \ No newline at end of file From f0c394c8c883d7c8d5a5517fadc4b573b3de0585 Mon Sep 17 00:00:00 2001 From: Francill66 <129966239+Francill66@users.noreply.github.com> Date: Tue, 10 Feb 2026 17:17:25 +0100 Subject: [PATCH 7/7] Update requested --- rbms/ising_gaussian/classes.py | 4 ++-- rbms/ising_gaussian/implement.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/rbms/ising_gaussian/classes.py b/rbms/ising_gaussian/classes.py index c95a917..3ba5042 100644 --- a/rbms/ising_gaussian/classes.py +++ b/rbms/ising_gaussian/classes.py @@ -105,10 +105,10 @@ def compute_energy_visibles(self, v: Tensor) -> Tensor: def compute_gradient(self, data, chains, centered=True, lambda_l1=0.0, lambda_l2=0.0): _compute_gradient( v_data=data["visible"], - h_data=data["hidden"], + h_data=data["hidden_mag"], w_data=data["weights"], v_chain=chains["visible"], - h_chain=chains["hidden"], + h_chain=chains["hidden_mag"], w_chain=chains["weights"], vbias=self.vbias, hbias=self.hbias, diff --git a/rbms/ising_gaussian/implement.py b/rbms/ising_gaussian/implement.py index 6c8905d..0d440b1 100644 --- a/rbms/ising_gaussian/implement.py +++ b/rbms/ising_gaussian/implement.py @@ -18,8 +18,9 @@ def _sample_hiddens( def _sample_visibles( h: Tensor, weight_matrix: Tensor, vbias: Tensor, beta: float = 1.0 ) -> Tuple[Tensor, Tensor]: - mv = torch.tanh(vbias + h @ weight_matrix.T) - v = torch.bernoulli(0.5 * (1 + mv)) * 2 - 1 + effective_field = beta * (vbias + (h @ weight_matrix.T)) + mv = torch.tanh(effective_field) + v = 2 * torch.bernoulli(torch.sigmoid(2 * effective_field)) - 1 return v, mv