Skip to content
Open

Wip #38

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions MNIST.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import gzip
import pickle

import h5py
import numpy as np


def preprocess_MNIST(
filename="mnist.pkl.gz",
binary_threshold=0.3,
out_dir="dataset",
):
with gzip.open(filename, "rb") as f:
training_data, validation_data, test_data = pickle.load(f,
encoding="latin1")

names = ["MNIST_train.h5", "MNIST_val.h5", "MNIST_test.h5"]
datasets = [training_data, validation_data, test_data]
for dataset, name in zip(datasets, names):
curr_data = np.array(dataset[0])
curr_data = (curr_data > binary_threshold).astype("float")
curr_labels = np.array(dataset[1])

with h5py.File(name, "w") as f:
f["samples"] = curr_data
f["labels"] = curr_labels


if __name__ == "__main__":
preprocess_MNIST()
1 change: 1 addition & 0 deletions clusterbm
Submodule clusterbm added at c4375a
Binary file added mnist.pkl.gz
Binary file not shown.
80 changes: 77 additions & 3 deletions rbms/bernoulli_bernoulli/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ def __init__(
weight_matrix: Tensor,
vbias: Tensor,
hbias: Tensor,
K1: Tensor,
K2: Tensor,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
Expand All @@ -46,8 +48,13 @@ def __init__(
self.device = device
self.dtype = dtype
self.weight_matrix = weight_matrix.to(device=self.device, dtype=self.dtype)
self.w_norm_0 = torch.norm(weight_matrix)
self.vbias = vbias.to(device=self.device, dtype=self.dtype)
self.v_norm_0 = torch.norm(vbias)
self.hbias = hbias.to(device=self.device, dtype=self.dtype)
self.K1 = K1.to(device=self.device, dtype=self.dtype)
self.K2 = K2.to(device=self.device, dtype=self.dtype)
self.K2_norm_0 = torch.norm(K2)
self.name = "BBRBM"

def __add__(self, other):
Expand Down Expand Up @@ -104,7 +111,7 @@ def compute_energy_visibles(self, v: Tensor) -> Tensor:
weight_matrix=self.weight_matrix,
)

def compute_gradient(self, data, chains, centered=True):
def compute_gradient(self, data, chains, use_fields, centered=True):
_compute_gradient(
v_data=data["visible"],
mh_data=data["hidden_mag"],
Expand Down Expand Up @@ -145,7 +152,7 @@ def init_chains(self, num_samples, weights=None, start_v=None):
)

@staticmethod
def init_parameters(num_hiddens, dataset, device, dtype, var_init=0.0001):
def init_parameters(num_hiddens, dataset, device, dtype, beta, use_fields, var_init=0.0001):
data = dataset.data
# Convert to torch Tensor if necessary
if isinstance(data, np.ndarray):
Expand All @@ -156,14 +163,21 @@ def init_parameters(num_hiddens, dataset, device, dtype, var_init=0.0001):
device=device,
dtype=dtype,
var_init=var_init,
beta=beta,
)
return BBRBM(weight_matrix=weight_matrix, vbias=vbias, hbias=hbias)
num_visible = len(data[0,:])
K1 = torch.randn_like(weight_matrix, device=device, dtype=dtype)/np.sqrt(float(num_hiddens))
K2 = torch.randn_like(weight_matrix, device=device, dtype=dtype)/np.sqrt(float(num_hiddens))

return BBRBM(weight_matrix=weight_matrix, vbias=vbias, hbias=hbias, K1=K1, K2=K2)

def named_parameters(self):
return {
"weight_matrix": self.weight_matrix,
"vbias": self.vbias,
"hbias": self.hbias,
"K1": self.K1,
"K2": self.K2,
}

def num_hiddens(self):
Expand Down Expand Up @@ -210,6 +224,8 @@ def set_named_parameters(named_params: dict[str, Tensor]) -> Self:
weight_matrix=named_params.pop("weight_matrix"),
vbias=named_params.pop("vbias"),
hbias=named_params.pop("hbias"),
K1=named_params.pop("K1"),
K2=named_params.pop("K2")
)
if len(named_params.keys()) > 0:
raise ValueError(
Expand All @@ -227,4 +243,62 @@ def to(
self.weight_matrix = self.weight_matrix.to(device=self.device, dtype=self.dtype)
self.vbias = self.vbias.to(device=self.device, dtype=self.dtype)
self.hbias = self.hbias.to(device=self.device, dtype=self.dtype)
self.K1 = self.K1.to(device=self.device, dtype=self.dtype)
self.K2 = self.K2.to(device=self.device, dtype=self.dtype)
return self


# ───────────────────────── 1st-order PL (single visible site) ─────────────────────────
def compute_loss_PL1(self, data, l, use_fields=True, use_hfield=True):
x = data # [M,N] entries ∈{0,1}

# mean-field hidden expectation ⟨h_a⟩ ≃ tanh(λ⋅pre)
h_pre = torch.einsum("ia,mi->ma", self.K1, x) # Wᵀx
if use_hfield:
h_pre = h_pre + self.hbias
h = torch.tanh(l * h_pre) # ±1 hidden → tanh

F = torch.einsum("ja,ma->mj", self.K1, h) # local field on each visible
if use_fields:
F = F + self.vbias

logZ = F.softplus(l * F) if hasattr(F, "softplus") else F # F.softplus → log(1+e^{λF})
e_i = -x * F + (1. / l) * logZ # −log P(x_i|rest)/λ
return e_i.mean()


# ───────────────────────── 2nd-order PL (visible–hidden pair) ─────────────────────────
def compute_loss_PL2(self, data, l, use_fields=True, use_hfield=True):
x = data # [M,N]
# hidden mean-field
h_pre = torch.einsum("ia,mi->ma", self.K2, x)
if use_fields and use_hfield:
h_pre = h_pre + self.hbias
h = torch.tanh(l * h_pre) # [M,A]

# leave-one-out effective fields
b = torch.einsum("ja,ma->mj", self.K2, h) # vis-field from h
c = torch.einsum("ja,mj->ma", self.K2, x) # hid-field from x
if use_fields:
b = b + self.vbias
if use_hfield:
c = c + self.hbias

j_term = torch.einsum("ja,ma->mja", self.K2, h) # W_ja h_a
a_term = torch.einsum("ja,mj->mja", self.K2, x) # W_ja x_j
b_i_eff = b.unsqueeze(2) - j_term # b̂_j|¬a [M,N,1]
c_a_eff = c.unsqueeze(1) - a_term # ĉ_a|¬j [M,1,A]

# observed energy E_obs = −(W_ja x_j h_a + b̂_j x_j + ĉ_a h_a)
w_ai = torch.einsum("ma,ja,mj->mja", h, self.K2, x) # W_ja x_j h_a
h_ai = b_i_eff * x.unsqueeze(2) + c_a_eff * h.unsqueeze(1)

# partition Z_{ja} over x_j∈{0,1}, h_a∈{±1}
z0 = torch.exp(-l * c_a_eff) # (x=0,h=-1)
z1 = torch.exp( l * c_a_eff) # (x=0,h=+1)
z2 = torch.exp( l * (b_i_eff - self.K2 - c_a_eff)) # (x=1,h=-1)
z3 = torch.exp( l * (b_i_eff + self.K2 + c_a_eff)) # (x=1,h=+1)
Z_ai = z0 + z1 + z2 + z3

e_ij = -w_ai - h_ai + (1. / l) * torch.log(Z_ai + 1e-9) # −log P(x_j,h_a|rest)/λ
return e_ij.mean()
3 changes: 2 additions & 1 deletion rbms/bernoulli_bernoulli/implement.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def _init_parameters(
device: torch.device,
dtype: torch.dtype,
var_init: float = 1e-4,
beta: float=1.,
) -> Tuple[Tensor, Tensor, Tensor]:
_, num_visibles = data.shape
eps = 1e-4
Expand All @@ -168,7 +169,7 @@ def _init_parameters(
)
frequencies = data.mean(0)
frequencies = torch.clamp(frequencies, min=eps, max=(1.0 - eps))
vbias = (torch.log(frequencies) - torch.log(1.0 - frequencies)).to(
vbias = 1/beta*(torch.log(frequencies) - torch.log(1.0 - frequencies)).to(
device=device, dtype=dtype
)
hbias = torch.zeros(num_hiddens, device=device, dtype=dtype)
Expand Down
Empty file.
Loading