Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
6f42ce7
save code
Oct 17, 2025
36d2aa1
remove warmup
Oct 17, 2025
87c8e69
compile normalize grad + figsize factor on plot
Nov 4, 2025
d86808f
make remove duplicates optional
Nov 17, 2025
11b952b
fix prev_size dataset in split_data script
Dec 1, 2025
e0c9be4
format code
Dec 15, 2025
455c2d5
add map model as option
Dec 16, 2025
fa9baa4
Merge branch 'develop' of github.com:DsysDML/rbms into papier_ptt_train
Dec 19, 2025
1ee0b5c
Merge branch 'develop' of github.com:DsysDML/rbms into papier_ptt_train
Dec 19, 2025
4605d21
Merge branch 'develop' of github.com:DsysDML/rbms into papier_ptt_train
Dec 19, 2025
2cbd4e4
fix merge
Dec 19, 2025
9d9a915
remove 3.14 as torch compile is not supported yet
Dec 19, 2025
716da6f
fix merge
Dec 19, 2025
b6dcefd
add missing keys to args dict in tests
Dec 19, 2025
0bcc1cc
batch function dataset
Jan 9, 2026
1e06197
Merge branch 'develop' of github.com:DsysDML/rbms into papier_ptt_train
Jan 9, 2026
efe31de
use batch method
Jan 9, 2026
4695ab1
add visible_type to EBM class
Jan 13, 2026
0cae04c
change variable_type after conversion
Jan 13, 2026
10af4f2
change variable_type from binary to bernoulli
Jan 13, 2026
7607287
add visible_type
Jan 13, 2026
152803e
add categorical_to_bernoulli implementation
Jan 13, 2026
878f14e
fix variable_type
Jan 13, 2026
3520807
match dataset variable type with model visible type
Jan 13, 2026
1c75c62
sample bernoulli when variable_type is bernoulli
Jan 14, 2026
fc19aef
add log_scale option to PCA plot
Jan 14, 2026
2ca6df6
removed unused variable in non centered gradient
Jan 27, 2026
1f4d543
add conversion print + astype to dataset class
Jan 27, 2026
60c42f6
add __eq__ to class for easier comparison
Jan 27, 2026
852725d
add IIRBM and BGRBM to map_model
Jan 27, 2026
1cfd16e
add model_type and normalize_grad option to parser
Jan 27, 2026
8c69c83
add dataset weights arg
Jan 27, 2026
da1035c
fix binary to bernoulli and add ising to model match
Jan 27, 2026
c654f8f
make normalize_grad optional
Jan 27, 2026
8365a70
save result from get_eigenvalues_history in file to avoid repeating c…
Jan 27, 2026
1a3682b
change version number
Jan 27, 2026
ecbfb08
simplify imports
Jan 27, 2026
9da903c
fix: add __init__ to bernoulli_gaussian
Jan 29, 2026
bdf553e
clip grad
Feb 4, 2026
dc54a16
rework the main loop and add rbms restore script allowing to change m…
Feb 4, 2026
be952e4
new parser, keep the old fucntions for compatibility
Feb 4, 2026
8559d4e
save learning rate during training and remove the hyperparameters loa…
Feb 4, 2026
feb4e0a
util to handle optimizer declaration
Feb 4, 2026
9544f24
remove test for removed function
Feb 4, 2026
34ee17a
remove weights from init_parameters
Feb 4, 2026
b363fff
add learning_rate
Feb 4, 2026
7ed3c2b
margaret update
Feb 11, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/codecov.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: 3.14
python-version: 3.13

- name: Install test dependencies
run: pip install pytest pytest-cov
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.12, 3.13, 3.14]
python-version: [3.12, 3.13]
steps:
- name: Checkout
uses: actions/checkout@v4
Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "rbms"
version = "0.5.0"
version = "0.6.0"
authors = [
{name="Nicolas Béreux", email="nicolas.bereux@gmail.com"},
{name="Aurélien Decelle"},
Expand All @@ -19,12 +19,12 @@ maintainers = [
]
description = "Training and analyzing Restricted Boltzmann Machines in PyTorch"
readme = "README.md"
requires-python = ">=3.12"
requires-python = ">=3.12, <3.14"
dependencies = [
"h5py>=3.12.0",
"numpy>=2.0.0",
"matplotlib>=3.8.0",
"torch>=2.5.0",
"torch>=2.6.0",
"tqdm>=4.65.0",
]

Expand Down
42 changes: 42 additions & 0 deletions rbms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from rbms.bernoulli_bernoulli.classes import BBRBM
from rbms.bernoulli_gaussian.classes import BGRBM
from rbms.dataset import load_dataset
from rbms.dataset.utils import convert_data
from rbms.io import load_model, load_params
from rbms.ising_ising.classes import IIRBM
from rbms.map_model import map_model
from rbms.plot import plot_image, plot_mult_PCA
from rbms.potts_bernoulli.classes import PBRBM
from rbms.utils import (
bernoulli_to_ising,
compute_log_likelihood,
get_categorical_configurations,
get_eigenvalues_history,
get_flagged_updates,
get_saved_updates,
ising_to_bernoulli,
)

__all__ = [
BBRBM,
BGRBM,
IIRBM,
PBRBM,
map_model,
bernoulli_to_ising,
ising_to_bernoulli,
compute_log_likelihood,
get_eigenvalues_history,
get_saved_updates,
get_flagged_updates,
get_categorical_configurations,
plot_mult_PCA,
plot_image,
load_params,
load_model,
load_dataset,
convert_data,
]


__version__ = "0.5.1"
11 changes: 10 additions & 1 deletion rbms/bernoulli_bernoulli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
# ruff: noqa
from rbms.bernoulli_bernoulli.classes import BBRBM
from rbms.bernoulli_bernoulli.functional import *
from rbms.bernoulli_bernoulli.functional import (
compute_energy,
compute_energy_hiddens,
compute_energy_visibles,
compute_gradient,
init_chains,
init_parameters,
sample_hiddens,
sample_visibles,
)
2 changes: 2 additions & 0 deletions rbms/bernoulli_bernoulli/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

class BBRBM(RBM):
"""Parameters of the Bernoulli-Bernoulli RBM"""

visible_type: str = "bernoulli"

def __init__(
self,
Expand Down
8 changes: 1 addition & 7 deletions rbms/bernoulli_bernoulli/implement.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
from torch import Tensor
from torch.nn.functional import softmax


@torch.jit.script
Expand Down Expand Up @@ -77,7 +76,7 @@ def _compute_gradient(
w_data = w_data.view(-1, 1)
w_chain = w_chain.view(-1, 1)
# Turn the weights of the chains into normalized weights
chain_weights = softmax(-w_chain, dim=0)
chain_weights = w_chain / w_chain.sum()
w_data_norm = w_data.sum()

# Averages over data and generated samples
Expand All @@ -102,11 +101,6 @@ def _compute_gradient(
grad_vbias = v_data_mean - v_gen_mean - (grad_weight_matrix @ h_data_mean)
grad_hbias = h_data_mean - h_gen_mean - (v_data_mean @ grad_weight_matrix)
else:
v_data_centered = v_data
h_data_centered = mh_data
v_gen_centered = v_chain
h_gen_centered = h_chain

# Gradient
grad_weight_matrix = ((v_data * w_data).T @ mh_data) / w_data_norm - (
(v_chain * chain_weights).T @ h_chain
Expand Down
12 changes: 12 additions & 0 deletions rbms/bernoulli_gaussian/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# ruff: noqa
from rbms.bernoulli_gaussian.classes import BGRBM
from rbms.bernoulli_gaussian.functional import (
compute_energy,
compute_energy_hiddens,
compute_energy_visibles,
compute_gradient,
init_chains,
init_parameters,
sample_hiddens,
sample_visibles,
)
2 changes: 2 additions & 0 deletions rbms/bernoulli_gaussian/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
class BGRBM(RBM):
"""Bernoulli-Gaussian RBM with fixed hidden variance = 1/Nv, 0-1 visibles, hidden and visible biases"""

visible_type: str = "bernoulli"

def __init__(
self,
weight_matrix: Tensor,
Expand Down
8 changes: 1 addition & 7 deletions rbms/bernoulli_gaussian/implement.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import torch
from torch import Tensor
from torch.nn.functional import softmax


@torch.jit.script
Expand Down Expand Up @@ -84,7 +83,7 @@ def _compute_gradient(
) -> None:
w_data = w_data.view(-1, 1)
w_chain = w_chain.view(-1, 1)
chain_weights = softmax(-w_chain, dim=0)
chain_weights = w_chain / w_chain.sum()
w_data_norm = w_data.sum()

v_data_mean = (v_data * w_data).sum(0) / w_data_norm
Expand All @@ -108,11 +107,6 @@ def _compute_gradient(
grad_vbias = v_data_mean - v_gen_mean - (grad_weight_matrix @ h_data_mean)
grad_hbias = h_data_mean - h_gen_mean - (v_data_mean @ grad_weight_matrix)
else:
v_data_centered = v_data
h_data_centered = h_data
v_gen_centered = v_chain
h_gen_centered = h_chain

# Gradient: h_data instead of mh_data
grad_weight_matrix = ((v_data * w_data).T @ h_data) / w_data_norm - (
(v_chain * chain_weights).T @ h_chain
Expand Down
18 changes: 18 additions & 0 deletions rbms/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class EBM(ABC):

name: str
device: torch.device
visible_type: str

@abstractmethod
def __init__(self): ...
Expand All @@ -28,6 +29,13 @@ def __mul__(self, other: float) -> EBM:
"""Multiplies the parameters of the RBM by a float."""
...

def __eq__(self, other: EBM):
other_params = other.named_parameters()
for k, v in self.named_parameters().items():
if not torch.equal(other_params[k], v):
return False
return True

@abstractmethod
def sample_visibles(
self, chains: dict[str, Tensor], beta: float = 1.0
Expand Down Expand Up @@ -209,12 +217,22 @@ def init_grad(self) -> None:
for p in self.parameters():
p.grad = torch.zeros_like(p)

@torch.compile
def normalize_grad(self) -> None:
norm_grad = torch.sqrt(
torch.sum(torch.tensor([p.grad.square().sum() for p in self.parameters()]))
)
for p in self.parameters():
p.grad /= norm_grad
# for p in self.parameters():
# p.grad /= p.grad.norm()

def clip_grad(self, max_norm=5):
for p in self.parameters():
grad_norm = p.grad.norm()
if grad_norm > max_norm:
p.grad /= grad_norm
p.grad *= max_norm


class RBM(EBM):
Expand Down
8 changes: 3 additions & 5 deletions rbms/correlations.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,8 @@ def compute_2b_correlations(
)
if full_mat:
res = torch.triu(res, 1) + torch.tril(res).T
return res / torch.sqrt(
torch.diag(res).unsqueeze(1) @ torch.diag(res).unsqueeze(0)
)
return torch.corrcoef(data)
return res #/ torch.sqrt(torch.diag(res).unsqueeze(1) @ torch.diag(res).unsqueeze(0))
return torch.corrcoef(data.T)


@torch.jit.script
Expand Down Expand Up @@ -102,7 +100,7 @@ def compute_3b_correlations(
res = _3b_batched(
centered_data=centered_data,
weights=weights.unsqueeze(1),
batcu_size=batch_size,
batch_size=batch_size,
)
if full_mat:
res = _3b_full_mat(res)
Expand Down
10 changes: 8 additions & 2 deletions rbms/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def load_dataset(
subset_labels: list[int] | None = None,
use_weights: bool = False,
alphabet="protein",
remove_duplicates: bool = False,
device: str = "cpu",
dtype: torch.dtype = torch.float32,
) -> tuple[RBMDataset, RBMDataset | None]:
Expand Down Expand Up @@ -54,10 +55,15 @@ def load_dataset(
if labels is None:
labels = -np.ones(data.shape[0])

# Remove duplicates and internally shuffle the dataset
unique_ind = get_unique_indices(torch.from_numpy(data)).cpu().numpy()
if remove_duplicates:
# Remove duplicates and internally shuffle the dataset
unique_ind = get_unique_indices(torch.from_numpy(data)).cpu().numpy()
else:
unique_ind = np.arange(data.shape[0])

idx = torch.randperm(unique_ind.shape[0])
if unique_ind.shape[0] < data.shape[0]:
print(f"N_samples: {data.shape[0]} -> {unique_ind.shape[0]}")
data = data[unique_ind[idx]]
labels = labels[unique_ind[idx]]
weights = weights[unique_ind[idx]]
Expand Down
19 changes: 18 additions & 1 deletion rbms/dataset/dataset_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from tqdm.autonotebook import tqdm


class RBMDataset(Dataset):
Expand Down Expand Up @@ -126,6 +126,13 @@ def get_gzip_entropy(self, mean_size: int = 50, num_samples: int = 100):

def match_model_variable_type(self, visible_type: str):
self.data = convert_data[self.variable_type][visible_type](self.data)
if self.variable_type != visible_type:
print(f"Converting from '{self.variable_type}' to '{visible_type}'")
print(self.data)
self.variable_type = visible_type

def astype(self, target_variable_type: str):
return convert_data[self.variable_type][target_variable_type](self.data)

def split_train_test(
self,
Expand Down Expand Up @@ -173,3 +180,13 @@ def split_train_test(
dtype=self.dtype,
)
return train_dataset, test_dataset

def batch(self, batch_size: int) -> dict[str, Union[np.ndarray, torch.Tensor]]:
rand_idx = torch.randperm(len(self))
sampled_batch = self[rand_idx[:batch_size]]
match self.variable_type:
case "bernoulli":
sampled_batch["data"] = torch.bernoulli(sampled_batch["data"])
case _:
pass
return sampled_batch
6 changes: 3 additions & 3 deletions rbms/dataset/load_h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def load_HDF5(
Tuple[np.ndarray, np.ndarray]: The dataset and labels.
"""
labels = None
variable_type = "binary"
variable_type = "bernoulli"
with h5py.File(filename, "r") as f:
if "samples" not in f.keys():
raise ValueError(
Expand All @@ -28,10 +28,10 @@ def load_HDF5(
dataset = np.array(f["samples"][()])
if "variable_type" not in f.keys():
print(
f"No variable_type found in the hdf5 file keys: {f.keys()}. Assuming 'binary'."
f"No variable_type found in the hdf5 file keys: {f.keys()}. Assuming 'bernoulli'."
)
print(
"Set a 'variable_type' with value 'binary', 'categorical' or 'continuous' in the hdf5 archive to remove this message"
"Set a 'variable_type' with value 'bernoulli', 'ising', 'categorical' or 'continuous' in the hdf5 archive to remove this message"
)
else:
variable_type = f["variable_type"][()].decode()
Expand Down
6 changes: 6 additions & 0 deletions rbms/dataset/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def add_args_dataset(parser: argparse.ArgumentParser) -> argparse.ArgumentParser
default="protein",
help="(Defaults to protein). Type of encoding for the sequences. Choose among ['protein', 'rna', 'dna'] or a user-defined string of tokens.",
)
dataset_args.add_argument(
"--remove_duplicates",
default=False,
action="store_true",
help="Remove duplicates from the dataset before splitting.",
)
dataset_args.add_argument(
"--seed",
default=None,
Expand Down
Loading
Loading