From 6f42ce715d06e0ae5a2d45fc19ef3f2bdde55c4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Fri, 17 Oct 2025 18:17:51 +0200 Subject: [PATCH 01/43] save code --- rbms/correlations.py | 8 +++----- rbms/dataset/__init__.py | 2 ++ rbms/dataset/dataset_class.py | 4 ++-- rbms/plot.py | 2 ++ rbms/scripts/train_rbm.py | 5 +++-- rbms/training/pcd.py | 30 ++++++++++++++++++++++++++---- rbms/training/utils.py | 8 +++++++- rbms/utils.py | 7 +++++-- 8 files changed, 50 insertions(+), 16 deletions(-) diff --git a/rbms/correlations.py b/rbms/correlations.py index 98ab4f2..3229ea2 100644 --- a/rbms/correlations.py +++ b/rbms/correlations.py @@ -50,10 +50,8 @@ def compute_2b_correlations( ) if full_mat: res = torch.triu(res, 1) + torch.tril(res).T - return res / torch.sqrt( - torch.diag(res).unsqueeze(1) @ torch.diag(res).unsqueeze(0) - ) - return torch.corrcoef(data) + return res #/ torch.sqrt(torch.diag(res).unsqueeze(1) @ torch.diag(res).unsqueeze(0)) + return torch.corrcoef(data.T) @torch.jit.script @@ -104,7 +102,7 @@ def compute_3b_correlations( res = _3b_batched( centered_data=centered_data, weights=weights.unsqueeze(1), - batcu_size=batch_size, + batch_size=batch_size, ) if full_mat: res = _3b_full_mat(res) diff --git a/rbms/dataset/__init__.py b/rbms/dataset/__init__.py index d469469..4e11394 100644 --- a/rbms/dataset/__init__.py +++ b/rbms/dataset/__init__.py @@ -72,6 +72,8 @@ def load_dataset( unique_ind = get_unique_indices(torch.from_numpy(data)).cpu().numpy() idx = torch.randperm(unique_ind.shape[0]) + if unique_ind.shape[0] < data.shape[0]: + print(f"N_samples: {data.shape[0]} -> {unique_ind.shape[0]}") data = data[unique_ind[idx]] labels = labels[unique_ind[idx]] weights = weights[unique_ind[idx]] diff --git a/rbms/dataset/dataset_class.py b/rbms/dataset/dataset_class.py index 003ffe9..78b5c05 100644 --- a/rbms/dataset/dataset_class.py +++ b/rbms/dataset/dataset_class.py @@ -1,11 +1,11 @@ import gzip import textwrap -from typing import Dict, Union, Self, Tuple, Optional +from typing import Dict, Optional, Self, Tuple, Union import numpy as np import torch from torch.utils.data import Dataset -from tqdm import tqdm +from tqdm.autonotebook import tqdm class RBMDataset(Dataset): diff --git a/rbms/plot.py b/rbms/plot.py index 663920c..877739a 100644 --- a/rbms/plot.py +++ b/rbms/plot.py @@ -183,6 +183,7 @@ def plot_one_PCA( s=size_scat, zorder=0, alpha=0.3, + rasterized=True, ) _, bins_x, _ = ax_hist_x.hist( data1[:, dir1], @@ -216,6 +217,7 @@ def plot_one_PCA( marker="o", alpha=1, linewidth=0.4, + rasterized=True ) ax_hist_x.hist( data2[:, dir1], diff --git a/rbms/scripts/train_rbm.py b/rbms/scripts/train_rbm.py index 4c0ee04..bb0c49f 100644 --- a/rbms/scripts/train_rbm.py +++ b/rbms/scripts/train_rbm.py @@ -8,11 +8,11 @@ from rbms.parser import ( add_args_pytorch, add_args_rbm, - add_args_saves, add_args_regularization, + add_args_saves, + default_args, match_args_dtype, remove_argument, - default_args, ) from rbms.training.pcd import train from rbms.training.utils import get_checkpoints @@ -25,6 +25,7 @@ def create_parser(): parser = add_args_regularization(parser) parser = add_args_saves(parser) parser = add_args_pytorch(parser) + parser.add_argument("--optim", default="sgd") remove_argument(parser, "use_torch") return parser diff --git a/rbms/training/pcd.py b/rbms/training/pcd.py index 05c0617..9ca1df7 100644 --- a/rbms/training/pcd.py +++ b/rbms/training/pcd.py @@ -118,19 +118,38 @@ def train( optimizer = optim(params.parameters(), lr=args["learning_rate"], maximize=True) + update_lr = False + warmup = True + from rbms.classes import RBM # Continue the training with torch.no_grad(): for idx in range(num_updates + 1, args["num_updates"] + 1): rand_idx = torch.randperm(len(train_dataset))[: args["batch_size"]] batch = (train_dataset.data[rand_idx], train_dataset.weights[rand_idx]) if args["training_type"] == "rdm": - parallel_chains = params.init_chains(parallel_chains["visible"].shape[0]) elif args["training_type"] == "cd": - parallel_chains = params.init_chains(batch[0].shape[0],weights=batch[1], start_v=batch[0]) + parallel_chains = params.init_chains( + batch[0].shape[0], weights=batch[1], start_v=batch[0] + ) + + if warmup and isinstance(params, RBM): + if params.weight_matrix.norm() > 10 and args["optim"] == "nag": + # optimizer = SGD_cossim( + # params.updated_params.parameters(), + # lr=args["learning_rate"], + # maximize=True, + # ) + optimizer = SGD( + params.parameters(), + lr=args["learning_rate"], + maximize=True, + momentum=0.9, + nesterov=True + ) + warmup = False optimizer.zero_grad(set_to_none=False) - parallel_chains, logs = fit_batch_pcd( batch=batch, parallel_chains=parallel_chains, @@ -141,7 +160,10 @@ def train( lambda_l1=args["L1"], lambda_l2=args["L2"], ) - optimizer.step() + if update_lr: + optimizer.step(update_lr=update_lr) + else: + optimizer.step() if isinstance(params, PBRBM): ensure_zero_sum_gauge(params) diff --git a/rbms/training/utils.py b/rbms/training/utils.py index 9780497..86edf47 100644 --- a/rbms/training/utils.py +++ b/rbms/training/utils.py @@ -6,7 +6,7 @@ import numpy as np import torch from torch import Tensor -from tqdm import tqdm +from tqdm.autonotebook import tqdm from rbms.classes import EBM from rbms.const import LOG_FILE_HEADER @@ -59,11 +59,17 @@ def setup_training( args[k] = v if test_dataset is None: + print("Splitting dataset") train_dataset, test_dataset = train_dataset.split_train_test( rng=np.random.default_rng(args["seed"]), train_size=args["train_size"], test_size=args["test_size"], ) + print("Train dataset:") + print(train_dataset) + print("Test dataset:") + print(test_dataset) + # Open the log file if it exists log_filename = pathlib.Path(args["filename"]).parent / pathlib.Path( diff --git a/rbms/utils.py b/rbms/utils.py index e519727..b001b77 100644 --- a/rbms/utils.py +++ b/rbms/utils.py @@ -12,7 +12,7 @@ from rbms.const import LOG_FILE_HEADER -def get_eigenvalues_history(filename: str): +def get_eigenvalues_history(filename: str, backend="cpu"): """ Extracts the history of eigenvalues of the RBM's weight matrix. @@ -31,7 +31,10 @@ def get_eigenvalues_history(filename: str): if "update_" in key: weight_matrix = f[key]["params"]["weight_matrix"][()] weight_matrix = weight_matrix.reshape(-1, weight_matrix.shape[-1]) - eig = np.linalg.svd(weight_matrix, compute_uv=False) + if backend == "gpu": + eig = torch.svd(torch.from_numpy(weight_matrix).to(device='cuda'), compute_uv=False).S.cpu().numpy() + else: + eig = np.linalg.svd(weight_matrix, compute_uv=False) eigenvalues.append(eig.reshape(*eig.shape, 1)) gradient_updates.append(int(key.split("_")[1])) From 36d2aa16eeddf993622bf830e3573b3be39a58a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Fri, 17 Oct 2025 23:53:12 +0200 Subject: [PATCH 02/43] remove warmup --- rbms/training/pcd.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/rbms/training/pcd.py b/rbms/training/pcd.py index 9ca1df7..788f3bd 100644 --- a/rbms/training/pcd.py +++ b/rbms/training/pcd.py @@ -117,7 +117,15 @@ def train( args = set_args_default(args=args, default_args=default_args) optimizer = optim(params.parameters(), lr=args["learning_rate"], maximize=True) - + + if args["optim"] == "nag": + optimizer = SGD( + params.parameters(), + lr=args["learning_rate"], + maximize=True, + momentum=0.9, + nesterov=True + ) update_lr = False warmup = True from rbms.classes import RBM From 87c8e699e19d486312789a2094b4e699ef8fc2fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Tue, 4 Nov 2025 16:44:15 +0100 Subject: [PATCH 03/43] compile normalize grad + figsize factor on plot --- rbms/classes.py | 2 ++ rbms/dataset/__init__.py | 2 +- rbms/plot.py | 3 ++- rbms/potts_bernoulli/classes.py | 18 ++++++++++++++---- 4 files changed, 19 insertions(+), 6 deletions(-) diff --git a/rbms/classes.py b/rbms/classes.py index 447b12c..ae5abef 100644 --- a/rbms/classes.py +++ b/rbms/classes.py @@ -1,4 +1,5 @@ from __future__ import annotations + from abc import ABC, abstractmethod from typing import List, Optional, Self @@ -208,6 +209,7 @@ def init_grad(self) -> None: for p in self.parameters(): p.grad = torch.zeros_like(p) + @torch.compile def normalize_grad(self) -> None: norm_grad = torch.sqrt( torch.sum(torch.tensor([p.grad.square().sum() for p in self.parameters()])) diff --git a/rbms/dataset/__init__.py b/rbms/dataset/__init__.py index 4e11394..2d85b25 100644 --- a/rbms/dataset/__init__.py +++ b/rbms/dataset/__init__.py @@ -85,7 +85,7 @@ def load_dataset( labels=labels, weights=weights, names=names, - dataset_name=dataset_name, + dataset_name=dset_name, is_binary=is_binary, device=device, dtype=dtype, diff --git a/rbms/plot.py b/rbms/plot.py index 877739a..e82ba03 100644 --- a/rbms/plot.py +++ b/rbms/plot.py @@ -249,6 +249,7 @@ def plot_mult_PCA( data2: Optional[np.ndarray] = None, labels: Optional[List[str]] = None, n_dir: int = 2, + figsize_factor=4 ): if data2 is not None: if data2.shape[1] < data1.shape[1]: @@ -271,7 +272,7 @@ def plot_mult_PCA( else ((data1.shape[1] // 2) // max_cols) + 1 ) - fig, ax = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows)) + fig, ax = plt.subplots(n_rows, n_cols, figsize=(figsize_factor * n_cols, figsize_factor * n_rows)) for i in range(n_rows): for j in range(n_cols): diff --git a/rbms/potts_bernoulli/classes.py b/rbms/potts_bernoulli/classes.py index edb691c..583bcbb 100644 --- a/rbms/potts_bernoulli/classes.py +++ b/rbms/potts_bernoulli/classes.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, override import numpy as np import torch @@ -105,9 +105,7 @@ def compute_energy_visibles(self, v): weight_matrix=self.weight_matrix, ) - def compute_gradient( - self, data, chains, centered=True, lambda_l1=0.0, lambda_l2=0.0 - ): + def compute_gradient(self, data, chains, centered=True, lambda_l1=0.0, lambda_l2=0.0): _compute_gradient( v_data=data["visible"], mh_data=data["hidden_mag"], @@ -233,3 +231,15 @@ def to( self.vbias = self.vbias.to(device=self.device, dtype=self.dtype) self.hbias = self.hbias.to(device=self.device, dtype=self.dtype) return self + + @override + @torch.compile + def normalize_grad(self) -> None: + norm_factor = torch.sqrt( + self.weight_matrix.square().sum() + + self.vbias.square().sum() + + self.hbias.square().sum() + ) + self.weight_matrix.grad /= norm_factor + self.vbias.grad /= norm_factor + self.hbias.grad /= norm_factor From d86808fe156c66196cee6c66d964d613611545db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Mon, 17 Nov 2025 12:37:55 +0100 Subject: [PATCH 04/43] make remove duplicates optional --- rbms/dataset/__init__.py | 8 ++++++-- rbms/dataset/parser.py | 6 ++++++ rbms/scripts/split_data.py | 36 +++++++++++++++++++++++++----------- rbms/scripts/train_rbm.py | 1 + 4 files changed, 38 insertions(+), 13 deletions(-) diff --git a/rbms/dataset/__init__.py b/rbms/dataset/__init__.py index 2d85b25..0289f79 100644 --- a/rbms/dataset/__init__.py +++ b/rbms/dataset/__init__.py @@ -17,6 +17,7 @@ def load_dataset( use_weights: bool = False, binarize: bool = False, alphabet="protein", + remove_duplicates: bool = False, device: str = "cpu", dtype: torch.dtype = torch.float32, ) -> Tuple[RBMDataset, RBMDataset | None]: @@ -68,8 +69,11 @@ def load_dataset( if labels is None: labels = -np.ones(data.shape[0]) - # Remove duplicates and internally shuffle the dataset - unique_ind = get_unique_indices(torch.from_numpy(data)).cpu().numpy() + if remove_duplicates: + # Remove duplicates and internally shuffle the dataset + unique_ind = get_unique_indices(torch.from_numpy(data)).cpu().numpy() + else: + unique_ind = np.arange(data.shape[0]) idx = torch.randperm(unique_ind.shape[0]) if unique_ind.shape[0] < data.shape[0]: diff --git a/rbms/dataset/parser.py b/rbms/dataset/parser.py index 168eaea..c828ac6 100644 --- a/rbms/dataset/parser.py +++ b/rbms/dataset/parser.py @@ -57,6 +57,12 @@ def add_args_dataset(parser: argparse.ArgumentParser) -> argparse.ArgumentParser action="store_true", help="(Defaults to False). Binarize the dataset.", ) + dataset_args.add_argument( + "--remove_duplicates", + default=False, + action="store_true", + help="Remove duplicates from the dataset before splitting.", + ) dataset_args.add_argument( "--seed", default=None, diff --git a/rbms/scripts/split_data.py b/rbms/scripts/split_data.py index aa2ad81..637482c 100644 --- a/rbms/scripts/split_data.py +++ b/rbms/scripts/split_data.py @@ -43,6 +43,12 @@ def create_parser(): default="protein", help="(Defaults to protein). Type of encoding for the sequences. Choose among ['protein', 'rna', 'dna'] or a user-defined string of tokens.", ) + parser.add_argument( + "--remove_duplicates", + action="store_true", + default=False, + help="Remove duplicates from the dataset before splitting.", + ) return parser @@ -51,6 +57,7 @@ def split_data_train_test( output_train_file: Optional[str] = None, output_test_file: Optional[str] = None, train_size=0.6, + remove_duplicates: bool = False, seed: int = None, alphabet: str = "protein", ): @@ -59,13 +66,18 @@ def split_data_train_test( dataset, _ = load_dataset(input_file, None, alphabet=alphabet) - print("Removing duplicates...") - prev_size = dataset.data.shape[0] - unique_ind = get_unique_indices(dataset.data) - data = dataset.data[unique_ind] - names = dataset.names[unique_ind] - labels = dataset.labels[unique_ind] - + if remove_duplicates: + print("Removing duplicates...") + prev_size = dataset.data.shape[0] + unique_ind = get_unique_indices(dataset.data) + data = dataset.data[unique_ind] + names = dataset.names[unique_ind] + labels = dataset.labels[unique_ind] + else: + data = dataset.data + names = dataset.names + labels = dataset.labels + curr_size = data.shape[0] print(f" Dataset size: {prev_size} -> {curr_size} samples") print(f" Removed {prev_size - curr_size} samples.") @@ -87,7 +99,6 @@ def split_data_train_test( names_test = names[permutation_index[n_sample_train:]] labels_test = labels[permutation_index[n_sample_train:]].int().cpu().numpy() - print( f" train_size = {data_train.shape[0]} ({100 * data_train.shape[0] / data.shape[0]}%)" ) @@ -100,11 +111,13 @@ def split_data_train_test( if output_train_file is None: output_train_file = ( - ".".join(str(dset_name).split(".")[:-1]) + f"_train={train_size}.{file_format}" + ".".join(str(dset_name).split(".")[:-1]) + + f"_train={train_size}.{file_format}" ) if output_test_file is None: output_test_file = ( - ".".join(str(dset_name).split(".")[:-1]) + f"_test={1 - train_size}.{file_format}" + ".".join(str(dset_name).split(".")[:-1]) + + f"_test={1 - train_size}.{file_format}" ) match file_format: @@ -118,7 +131,7 @@ def split_data_train_test( with h5py.File(output_test_file, "w") as f: f["samples"] = data_test f["labels"] = labels_test - print(" Done") + print(" Done") case "fasta": print(f"Writing train dataset to '{output_train_file}'...") @@ -141,6 +154,7 @@ def main(): output_train_file=args["out_train"], output_test_file=args["out_test"], train_size=args["train_size"], + remove_duplicates=args["remove_duplicates"], seed=args["seed"], alphabet=args["alphabet"], ) diff --git a/rbms/scripts/train_rbm.py b/rbms/scripts/train_rbm.py index bb0c49f..2085508 100644 --- a/rbms/scripts/train_rbm.py +++ b/rbms/scripts/train_rbm.py @@ -43,6 +43,7 @@ def train_rbm(args: dict): use_weights=args["use_weights"], alphabet=args["alphabet"], binarize=args["binarize"], + remove_duplicates=args["remove_duplicates"], device=args["device"], dtype=args["dtype"], ) From 11b952b8e2e6520149892a874a8c5e6deff30c15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Mon, 1 Dec 2025 01:24:27 +0100 Subject: [PATCH 05/43] fix prev_size dataset in split_data script --- rbms/scripts/split_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rbms/scripts/split_data.py b/rbms/scripts/split_data.py index 637482c..1399967 100644 --- a/rbms/scripts/split_data.py +++ b/rbms/scripts/split_data.py @@ -66,9 +66,9 @@ def split_data_train_test( dataset, _ = load_dataset(input_file, None, alphabet=alphabet) + prev_size = dataset.data.shape[0] if remove_duplicates: print("Removing duplicates...") - prev_size = dataset.data.shape[0] unique_ind = get_unique_indices(dataset.data) data = dataset.data[unique_ind] names = dataset.names[unique_ind] From e0c9be4a9c140ca7ce919e8beef60fcf704e4f74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Mon, 15 Dec 2025 17:35:31 +0100 Subject: [PATCH 06/43] format code --- rbms/utils.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/rbms/utils.py b/rbms/utils.py index b001b77..ba34ac7 100644 --- a/rbms/utils.py +++ b/rbms/utils.py @@ -32,7 +32,14 @@ def get_eigenvalues_history(filename: str, backend="cpu"): weight_matrix = f[key]["params"]["weight_matrix"][()] weight_matrix = weight_matrix.reshape(-1, weight_matrix.shape[-1]) if backend == "gpu": - eig = torch.svd(torch.from_numpy(weight_matrix).to(device='cuda'), compute_uv=False).S.cpu().numpy() + eig = ( + torch.svd( + torch.from_numpy(weight_matrix).to(device="cuda"), + compute_uv=False, + ) + .S.cpu() + .numpy() + ) else: eig = np.linalg.svd(weight_matrix, compute_uv=False) eigenvalues.append(eig.reshape(*eig.shape, 1)) @@ -127,7 +134,7 @@ def query_yes_no(question: str, default: str = "yes") -> bool: elif choice in valid: return valid[choice] else: - sys.stdout.write("Please respond with 'yes' or 'no' " "(or 'y' or 'n').\n") + sys.stdout.write("Please respond with 'yes' or 'no' (or 'y' or 'n').\n") def check_file_existence(filename: str): @@ -257,12 +264,8 @@ def swap_chains( idx_vis_mean = idx_vis idx_hid = idx.unsqueeze(1).repeat(1, chain_1["hidden"].shape[1]) - new_chain_1["visible"] = torch.where( - idx_vis, chain_2["visible"], chain_1["visible"] - ) - new_chain_2["visible"] = torch.where( - idx_vis, chain_1["visible"], chain_2["visible"] - ) + new_chain_1["visible"] = torch.where(idx_vis, chain_2["visible"], chain_1["visible"]) + new_chain_2["visible"] = torch.where(idx_vis, chain_1["visible"], chain_2["visible"]) new_chain_1["visible_mag"] = torch.where( idx_vis_mean, chain_2["visible_mag"], chain_1["visible_mag"] @@ -303,8 +306,5 @@ def get_flagged_updates(filename: str, flag: str) -> np.ndarray: if flag in f[key]["flags"]: if f[key]["flags"][flag][()]: flagged_updates.append(update) - flagged_updates = np.sort(np.array(flagged_updates)) + flagged_updates = np.sort(np.array(flagged_updates, dtype=int)) return flagged_updates - - - From 455c2d5aebc18c82b65c63fefecf06f99e5d23ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Tue, 16 Dec 2025 18:15:26 +0100 Subject: [PATCH 07/43] add map model as option --- rbms/training/pcd.py | 1 + rbms/training/utils.py | 1 + 2 files changed, 2 insertions(+) diff --git a/rbms/training/pcd.py b/rbms/training/pcd.py index 788f3bd..4d7c961 100644 --- a/rbms/training/pcd.py +++ b/rbms/training/pcd.py @@ -96,6 +96,7 @@ def train( train_dataset=train_dataset, test_dataset=test_dataset, dtype=dtype, + map_model=map_model, ) ( params, diff --git a/rbms/training/utils.py b/rbms/training/utils.py index 86edf47..0179722 100644 --- a/rbms/training/utils.py +++ b/rbms/training/utils.py @@ -206,6 +206,7 @@ def initialize_model_archive( test_dataset: Optional[RBMDataset], dtype: torch.dtype, flags: List[str] = ["checkpoint"], + map_model: dict[str, EBM] = map_model, ): num_visibles = train_dataset.get_num_visibles() args = set_args_default(args=args, default_args=default_args) From 2cbd4e4647bb20d38a73cc3e9a807240eb593cdf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Fri, 19 Dec 2025 12:58:25 +0100 Subject: [PATCH 08/43] fix merge --- rbms/dataset/dataset_class.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/rbms/dataset/dataset_class.py b/rbms/dataset/dataset_class.py index 574119d..832e3f4 100644 --- a/rbms/dataset/dataset_class.py +++ b/rbms/dataset/dataset_class.py @@ -1,10 +1,6 @@ import gzip import textwrap -<<<<<<< HEAD -from typing import Dict, Optional, Self, Tuple, Union -======= from typing import Self, Union ->>>>>>> 0dfca34d6e47839074bbadc53719500dff7a19c7 import numpy as np import torch From 9d9a915c07fe1577ac5303c15421b81c6e8a8e98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Fri, 19 Dec 2025 13:07:09 +0100 Subject: [PATCH 09/43] remove 3.14 as torch compile is not supported yet --- .github/workflows/codecov.yaml | 2 +- .github/workflows/test.yaml | 2 +- pyproject.toml | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/codecov.yaml b/.github/workflows/codecov.yaml index 445267a..4477207 100644 --- a/.github/workflows/codecov.yaml +++ b/.github/workflows/codecov.yaml @@ -15,7 +15,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: 3.14 + python-version: 3.13 - name: Install test dependencies run: pip install pytest pytest-cov diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 2128877..756ea63 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -8,7 +8,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.12, 3.13, 3.14] + python-version: [3.12, 3.13] steps: - name: Checkout uses: actions/checkout@v4 diff --git a/pyproject.toml b/pyproject.toml index 391cb99..1dffc4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,12 +19,12 @@ maintainers = [ ] description = "Training and analyzing Restricted Boltzmann Machines in PyTorch" readme = "README.md" -requires-python = ">=3.12" +requires-python = ">=3.12, <3.14" dependencies = [ "h5py>=3.12.0", "numpy>=2.0.0", "matplotlib>=3.8.0", - "torch>=2.5.0", + "torch>=2.6.0", "tqdm>=4.65.0", ] From 716da6f16a3eabf60db191fb4710444f06220e24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Fri, 19 Dec 2025 13:14:36 +0100 Subject: [PATCH 10/43] fix merge --- rbms/training/pcd.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/rbms/training/pcd.py b/rbms/training/pcd.py index bc52708..3511e0a 100644 --- a/rbms/training/pcd.py +++ b/rbms/training/pcd.py @@ -117,18 +117,19 @@ def train( args = set_args_default(args=args, default_args=default_args) optimizer = optim(params.parameters(), lr=args["learning_rate"], maximize=True) - + if args["optim"] == "nag": optimizer = SGD( - params.parameters(), - lr=args["learning_rate"], - maximize=True, + params.parameters(), + lr=args["learning_rate"], + maximize=True, momentum=0.9, - nesterov=True + nesterov=True, ) update_lr = False warmup = True from rbms.classes import RBM + # Continue the training with torch.no_grad(): for idx in range(num_updates + 1, args["num_updates"] + 1): @@ -140,7 +141,6 @@ def train( parallel_chains = params.init_chains( batch[0].shape[0], weights=batch[1], start_v=batch[0] ) -<<<<<<< HEAD if warmup and isinstance(params, RBM): if params.weight_matrix.norm() > 10 and args["optim"] == "nag": @@ -150,15 +150,13 @@ def train( # maximize=True, # ) optimizer = SGD( - params.parameters(), - lr=args["learning_rate"], - maximize=True, + params.parameters(), + lr=args["learning_rate"], + maximize=True, momentum=0.9, - nesterov=True + nesterov=True, ) warmup = False -======= ->>>>>>> 0dfca34d6e47839074bbadc53719500dff7a19c7 optimizer.zero_grad(set_to_none=False) parallel_chains, logs = fit_batch_pcd( From b6dcefd2ec46541c103fa17ce386e4ec394291d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Fri, 19 Dec 2025 13:21:33 +0100 Subject: [PATCH 11/43] add missing keys to args dict in tests --- tests/conftest.py | 2 ++ tests/use_cases/test_bbrbm.py | 2 ++ tests/use_cases/test_pbrbm.py | 4 ++++ 3 files changed, 8 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 5d0710b..4c5def2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -156,6 +156,8 @@ def sample_args(tmp_path): "L1": 0.0, "L2": 1.0, "training_type": "pcd", + "optim": "sgd", + "remove_duplicates": False, } diff --git a/tests/use_cases/test_bbrbm.py b/tests/use_cases/test_bbrbm.py index ca4687d..b6d28d6 100644 --- a/tests/use_cases/test_bbrbm.py +++ b/tests/use_cases/test_bbrbm.py @@ -62,6 +62,8 @@ def test_use_case_train_bbrbm(): "L1": 0.0, "L2": 1.0, "training_type": "pcd", + "optim": "sgd", + "remove_duplicates": False, } train_rbm(args) diff --git a/tests/use_cases/test_pbrbm.py b/tests/use_cases/test_pbrbm.py index 1507f53..0d4be2b 100644 --- a/tests/use_cases/test_pbrbm.py +++ b/tests/use_cases/test_pbrbm.py @@ -62,6 +62,8 @@ def test_use_case_train_pbrbm_no_weights(): "L1": 0.0, "L2": 1.0, "training_type": "pcd", + "optim": "sgd", + "remove_duplicates": False, } train_rbm(args) @@ -140,6 +142,8 @@ def test_use_case_train_pbrbm_weights(): "L1": 1.0, "L2": 0.0, "training_type": "pcd", + "optim": "sgd", + "remove_duplicates": False, } train_rbm(args) From 0bcc1cc5f772a23d9d1e941a84933bc1c63f1438 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Fri, 9 Jan 2026 17:01:17 +0100 Subject: [PATCH 12/43] batch function dataset --- rbms/dataset/dataset_class.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/rbms/dataset/dataset_class.py b/rbms/dataset/dataset_class.py index 832e3f4..2768d5a 100644 --- a/rbms/dataset/dataset_class.py +++ b/rbms/dataset/dataset_class.py @@ -28,6 +28,10 @@ def __init__( self.device = device self.dtype = dtype self.is_binary = is_binary + if self.is_binary: + self.visible_type = "binary" + else: + self.visible_type = "categorical" self.data = torch.from_numpy(data).to(device=self.device, dtype=self.dtype) # Weights should have shape n_visibles self.weights = ( @@ -169,3 +173,7 @@ def split_train_test( dtype=self.dtype, ) return train_dataset, test_dataset + + def batch(self, batch_size: int) -> dict[str, Union[np.ndarray, torch.Tensor]]: + rand_idx = torch.randperm(len(self)) + return self[rand_idx[:batch_size]] \ No newline at end of file From efe31de66dc162a20bfde5d6bc1aa62a58c5fe1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Fri, 9 Jan 2026 18:01:51 +0100 Subject: [PATCH 13/43] use batch method --- rbms/training/pcd.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/rbms/training/pcd.py b/rbms/training/pcd.py index 3511e0a..a88a3ea 100644 --- a/rbms/training/pcd.py +++ b/rbms/training/pcd.py @@ -133,22 +133,19 @@ def train( # Continue the training with torch.no_grad(): for idx in range(num_updates + 1, args["num_updates"] + 1): - rand_idx = torch.randperm(len(train_dataset))[: args["batch_size"]] - batch = (train_dataset.data[rand_idx], train_dataset.weights[rand_idx]) + # rand_idx = torch.randperm(len(train_dataset))[: args["batch_size"]] + # batch = (train_dataset.data[rand_idx], train_dataset.weights[rand_idx]) + batch = train_dataset.batch(args["batch_size"]) + if args["training_type"] == "rdm": parallel_chains = params.init_chains(parallel_chains["visible"].shape[0]) elif args["training_type"] == "cd": parallel_chains = params.init_chains( - batch[0].shape[0], weights=batch[1], start_v=batch[0] + batch["data"].shape[0], weights=batch["weights"], start_v=batch["data"] ) if warmup and isinstance(params, RBM): if params.weight_matrix.norm() > 10 and args["optim"] == "nag": - # optimizer = SGD_cossim( - # params.updated_params.parameters(), - # lr=args["learning_rate"], - # maximize=True, - # ) optimizer = SGD( params.parameters(), lr=args["learning_rate"], @@ -160,7 +157,7 @@ def train( optimizer.zero_grad(set_to_none=False) parallel_chains, logs = fit_batch_pcd( - batch=batch, + batch=(batch["data"], batch["weights"]), parallel_chains=parallel_chains, params=params, gibbs_steps=args["gibbs_steps"], From 4695ab1e8f3cf04572c69f57830cd38bc467fba3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Tue, 13 Jan 2026 17:58:06 +0100 Subject: [PATCH 14/43] add visible_type to EBM class --- rbms/bernoulli_bernoulli/classes.py | 2 ++ rbms/bernoulli_gaussian/classes.py | 2 ++ rbms/classes.py | 1 + rbms/potts_bernoulli/classes.py | 2 ++ 4 files changed, 7 insertions(+) diff --git a/rbms/bernoulli_bernoulli/classes.py b/rbms/bernoulli_bernoulli/classes.py index 5ce9c42..0e1f76a 100644 --- a/rbms/bernoulli_bernoulli/classes.py +++ b/rbms/bernoulli_bernoulli/classes.py @@ -19,6 +19,8 @@ class BBRBM(RBM): """Parameters of the Bernoulli-Bernoulli RBM""" + + visible_type: str = "bernoulli" def __init__( self, diff --git a/rbms/bernoulli_gaussian/classes.py b/rbms/bernoulli_gaussian/classes.py index 1c4a0b3..88c6848 100644 --- a/rbms/bernoulli_gaussian/classes.py +++ b/rbms/bernoulli_gaussian/classes.py @@ -19,6 +19,8 @@ class BGRBM(RBM): """Bernoulli-Gaussian RBM with fixed hidden variance = 1/Nv, 0-1 visibles, hidden and visible biases""" + visible_type: str = "bernoulli" + def __init__( self, weight_matrix: Tensor, diff --git a/rbms/classes.py b/rbms/classes.py index 04b79bb..0e1776f 100644 --- a/rbms/classes.py +++ b/rbms/classes.py @@ -14,6 +14,7 @@ class EBM(ABC): name: str device: torch.device + visible_type: str @abstractmethod def __init__(self): ... diff --git a/rbms/potts_bernoulli/classes.py b/rbms/potts_bernoulli/classes.py index 6a8d795..a767dae 100644 --- a/rbms/potts_bernoulli/classes.py +++ b/rbms/potts_bernoulli/classes.py @@ -20,6 +20,8 @@ class PBRBM(RBM): """Parameters of the Potts-Bernoulli RBM""" + visible_type: str = "categorical" + def __init__( self, weight_matrix: Tensor, From 0cae04c9447b5c4b3ffd1a5a84df4629fed4a71d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Tue, 13 Jan 2026 17:58:58 +0100 Subject: [PATCH 15/43] change variable_type after conversion --- rbms/dataset/dataset_class.py | 1 + 1 file changed, 1 insertion(+) diff --git a/rbms/dataset/dataset_class.py b/rbms/dataset/dataset_class.py index 4728c27..d4fb9d0 100644 --- a/rbms/dataset/dataset_class.py +++ b/rbms/dataset/dataset_class.py @@ -126,6 +126,7 @@ def get_gzip_entropy(self, mean_size: int = 50, num_samples: int = 100): def match_model_variable_type(self, visible_type: str): self.data = convert_data[self.variable_type][visible_type](self.data) + self.variable_type = visible_type def split_train_test( self, From 10af4f246e1354a839ce98b87247a9644544a446 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Tue, 13 Jan 2026 17:59:27 +0100 Subject: [PATCH 16/43] change variable_type from binary to bernoulli --- rbms/dataset/load_h5.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rbms/dataset/load_h5.py b/rbms/dataset/load_h5.py index 954edd6..09606f5 100644 --- a/rbms/dataset/load_h5.py +++ b/rbms/dataset/load_h5.py @@ -19,7 +19,7 @@ def load_HDF5( Tuple[np.ndarray, np.ndarray]: The dataset and labels. """ labels = None - variable_type = "binary" + variable_type = "bernoulli" with h5py.File(filename, "r") as f: if "samples" not in f.keys(): raise ValueError( @@ -28,10 +28,10 @@ def load_HDF5( dataset = np.array(f["samples"][()]) if "variable_type" not in f.keys(): print( - f"No variable_type found in the hdf5 file keys: {f.keys()}. Assuming 'binary'." + f"No variable_type found in the hdf5 file keys: {f.keys()}. Assuming 'bernoulli'." ) print( - "Set a 'variable_type' with value 'binary', 'categorical' or 'continuous' in the hdf5 archive to remove this message" + "Set a 'variable_type' with value 'bernoulli', 'ising', 'categorical' or 'continuous' in the hdf5 archive to remove this message" ) else: variable_type = f["variable_type"][()].decode() From 7607287163c5edd8173b7fad1f2d4708aa8ee7ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Tue, 13 Jan 2026 17:59:43 +0100 Subject: [PATCH 17/43] add visible_type --- rbms/ising_ising/classes.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/rbms/ising_ising/classes.py b/rbms/ising_ising/classes.py index 068ceb0..f2d6419 100644 --- a/rbms/ising_ising/classes.py +++ b/rbms/ising_ising/classes.py @@ -21,6 +21,8 @@ class IIRBM(RBM): """Parameters of the Ising-Ising RBM""" + visible_type: str = "ising" + def __init__( self, weight_matrix: Tensor, From 152803e727c7bfb7b05ea6de3415b92f631edc51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Tue, 13 Jan 2026 18:00:28 +0100 Subject: [PATCH 18/43] add categorical_to_bernoulli implementation --- rbms/dataset/utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/rbms/dataset/utils.py b/rbms/dataset/utils.py index e8d7f64..ea7463b 100644 --- a/rbms/dataset/utils.py +++ b/rbms/dataset/utils.py @@ -1,6 +1,7 @@ import numpy as np import torch from torch import Tensor +from rbms.custom_fn import one_hot def get_subset_labels( @@ -47,25 +48,24 @@ def ising_to_bernoulli(x): return (x + 1) / 2 -def bernoulli_to_categorical(x): - pass def categorical_to_bernoulli(x): - pass + return one_hot(x.long()).reshape(x.shape[0], -1) + convert_data = { "bernoulli": { "bernoulli": (lambda x: x), "ising": (lambda x: bernoulli_to_ising(x)), - "categorical": (lambda x: bernoulli_to_categorical(x)), + "categorical": (lambda x: x), # "continuous": lambda x: raise ValueError("Cannot convert from 'bernoulli' to 'continuous' data.") }, "ising": { "bernoulli": (lambda x: ising_to_bernoulli(x)), "ising": (lambda x: x), - "categorical": (lambda x: bernoulli_to_categorical(ising_to_bernoulli(x))), + "categorical": (lambda x: ising_to_bernoulli(x)), }, "categorical": { "bernoulli": (lambda x: categorical_to_bernoulli(x)), From 878f14e4dab3467f01061d486187c1a374386cf2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Tue, 13 Jan 2026 18:01:01 +0100 Subject: [PATCH 19/43] fix variable_type --- rbms/scripts/train_rbm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rbms/scripts/train_rbm.py b/rbms/scripts/train_rbm.py index f2ac590..01449c9 100644 --- a/rbms/scripts/train_rbm.py +++ b/rbms/scripts/train_rbm.py @@ -55,7 +55,7 @@ def train_rbm(args: dict): else: model_type = args["model_type"] if model_type is None: - match train_dataset.visible_type: + match train_dataset.variable_type: case "binary": model_type = "BBRBM" case "categorical": From 3520807279a13bedf0b6aa49905fe85e549f5e3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Tue, 13 Jan 2026 18:01:50 +0100 Subject: [PATCH 20/43] match dataset variable type with model visible type --- rbms/training/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/rbms/training/utils.py b/rbms/training/utils.py index 1b459b2..beb61ca 100644 --- a/rbms/training/utils.py +++ b/rbms/training/utils.py @@ -92,6 +92,8 @@ def setup_training( # Start recording training time start = time.time() + train_dataset.match_model_variable_type(params.visible_type) + test_dataset.match_model_variable_type(params.visible_type) return ( params, parallel_chains, @@ -212,6 +214,7 @@ def initialize_model_archive( train_dataset, _ = train_dataset.split_train_test( rng, args["train_size"], args["test_size"] ) + train_dataset.match_model_variable_type(visible_type=map_model[model_type].visible_type) params = map_model[model_type].init_parameters( num_hiddens=args["num_hiddens"], dataset=train_dataset, From 1c75c62c9dce0097b21d50d650ab1dad74833cea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Wed, 14 Jan 2026 14:59:58 +0100 Subject: [PATCH 21/43] sample bernoulli when variable_type is bernoulli --- rbms/dataset/dataset_class.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/rbms/dataset/dataset_class.py b/rbms/dataset/dataset_class.py index d4fb9d0..c46bda8 100644 --- a/rbms/dataset/dataset_class.py +++ b/rbms/dataset/dataset_class.py @@ -177,4 +177,10 @@ def split_train_test( def batch(self, batch_size: int) -> dict[str, Union[np.ndarray, torch.Tensor]]: rand_idx = torch.randperm(len(self)) - return self[rand_idx[:batch_size]] \ No newline at end of file + sampled_batch = self[rand_idx[:batch_size]] + match self.variable_type: + case "bernoulli": + sampled_batch["data"] = torch.bernoulli(sampled_batch["data"]) + case _: + pass + return sampled_batch \ No newline at end of file From fc19aef30b0844c6f7262a07584d41244113ca7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Wed, 14 Jan 2026 15:00:12 +0100 Subject: [PATCH 22/43] add log_scale option to PCA plot --- rbms/plot.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/rbms/plot.py b/rbms/plot.py index f6d4c0c..4d1402e 100644 --- a/rbms/plot.py +++ b/rbms/plot.py @@ -151,6 +151,7 @@ def plot_one_PCA( labels: list[str] | None = None, dir1: int = 0, dir2: int = 1, + log_scale: bool = False, ): label_1 = None label_2 = None @@ -238,6 +239,9 @@ def plot_one_PCA( orientation="horizontal", lw=1, ) + if log_scale: + ax_hist_x.semilogy() + ax_hist_y.semilogx() if labels is not None: ax_hist_x.legend(fontsize=12, bbox_to_anchor=(1, 1)) @@ -247,7 +251,8 @@ def plot_mult_PCA( data2: np.ndarray | None = None, labels: list[str] | None = None, n_dir: int = 2, - figsize_factor=4 + figsize_factor=4, + log_scale: bool = False, ): if data2 is not None: if data2.shape[1] < data1.shape[1]: @@ -284,6 +289,7 @@ def plot_mult_PCA( labels=labels if curr_plot_idx == 0 else None, dir1=curr_plot_idx * 2, dir2=curr_plot_idx * 2 + 1, + log_scale=log_scale, ) else: ax[*indexes].set_axis_off() From 2ca6df6fbabd97fbfece5597e43822d6323aac91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Tue, 27 Jan 2026 16:41:34 +0100 Subject: [PATCH 23/43] removed unused variable in non centered gradient --- rbms/bernoulli_bernoulli/implement.py | 8 +------ rbms/bernoulli_gaussian/implement.py | 8 +------ rbms/ising_ising/implement.py | 8 +------ rbms/potts_bernoulli/implement.py | 34 +++++++++++++++++++-------- 4 files changed, 27 insertions(+), 31 deletions(-) diff --git a/rbms/bernoulli_bernoulli/implement.py b/rbms/bernoulli_bernoulli/implement.py index a2f4d2a..f9b5ca1 100644 --- a/rbms/bernoulli_bernoulli/implement.py +++ b/rbms/bernoulli_bernoulli/implement.py @@ -1,6 +1,5 @@ import torch from torch import Tensor -from torch.nn.functional import softmax @torch.jit.script @@ -77,7 +76,7 @@ def _compute_gradient( w_data = w_data.view(-1, 1) w_chain = w_chain.view(-1, 1) # Turn the weights of the chains into normalized weights - chain_weights = softmax(-w_chain, dim=0) + chain_weights = w_chain / w_chain.sum() w_data_norm = w_data.sum() # Averages over data and generated samples @@ -102,11 +101,6 @@ def _compute_gradient( grad_vbias = v_data_mean - v_gen_mean - (grad_weight_matrix @ h_data_mean) grad_hbias = h_data_mean - h_gen_mean - (v_data_mean @ grad_weight_matrix) else: - v_data_centered = v_data - h_data_centered = mh_data - v_gen_centered = v_chain - h_gen_centered = h_chain - # Gradient grad_weight_matrix = ((v_data * w_data).T @ mh_data) / w_data_norm - ( (v_chain * chain_weights).T @ h_chain diff --git a/rbms/bernoulli_gaussian/implement.py b/rbms/bernoulli_gaussian/implement.py index a6c81d3..b12e472 100644 --- a/rbms/bernoulli_gaussian/implement.py +++ b/rbms/bernoulli_gaussian/implement.py @@ -1,6 +1,5 @@ import torch from torch import Tensor -from torch.nn.functional import softmax @torch.jit.script @@ -84,7 +83,7 @@ def _compute_gradient( ) -> None: w_data = w_data.view(-1, 1) w_chain = w_chain.view(-1, 1) - chain_weights = softmax(-w_chain, dim=0) + chain_weights = w_chain / w_chain.sum() w_data_norm = w_data.sum() v_data_mean = (v_data * w_data).sum(0) / w_data_norm @@ -108,11 +107,6 @@ def _compute_gradient( grad_vbias = v_data_mean - v_gen_mean - (grad_weight_matrix @ h_data_mean) grad_hbias = h_data_mean - h_gen_mean - (v_data_mean @ grad_weight_matrix) else: - v_data_centered = v_data - h_data_centered = h_data - v_gen_centered = v_chain - h_gen_centered = h_chain - # Gradient: h_data instead of mh_data grad_weight_matrix = ((v_data * w_data).T @ h_data) / w_data_norm - ( (v_chain * chain_weights).T @ h_chain diff --git a/rbms/ising_ising/implement.py b/rbms/ising_ising/implement.py index 9e0c5e4..3bc3187 100644 --- a/rbms/ising_ising/implement.py +++ b/rbms/ising_ising/implement.py @@ -1,6 +1,5 @@ import torch from torch import Tensor -from torch.nn.functional import softmax from rbms.custom_fn import log2cosh @@ -81,7 +80,7 @@ def _compute_gradient( w_data = w_data.view(-1, 1) w_chain = w_chain.view(-1, 1) # Turn the weights of the chains into normalized weights - chain_weights = softmax(-w_chain, dim=0) + chain_weights = w_chain / w_chain.sum() w_data_norm = w_data.sum() # Averages over data and generated samples @@ -106,11 +105,6 @@ def _compute_gradient( grad_vbias = v_data_mean - v_gen_mean - (grad_weight_matrix @ h_data_mean) grad_hbias = h_data_mean - h_gen_mean - (v_data_mean @ grad_weight_matrix) else: - v_data_centered = v_data - h_data_centered = mh_data - v_gen_centered = v_chain - h_gen_centered = h_chain - # Gradient grad_weight_matrix = ((v_data * w_data).T @ mh_data) / w_data_norm - ( (v_chain * chain_weights).T @ h_chain diff --git a/rbms/potts_bernoulli/implement.py b/rbms/potts_bernoulli/implement.py index 6b641aa..7191f8f 100644 --- a/rbms/potts_bernoulli/implement.py +++ b/rbms/potts_bernoulli/implement.py @@ -150,22 +150,17 @@ def _compute_gradient( - torch.tensordot(v_data_mean, grad_weight_matrix, dims=[[0, 1], [0, 1]]) ) else: - v_data_centered = v_data_one_hot - h_data_centered = mh_data - v_gen_centered = v_gen_one_hot - h_gen_centered = h_chain - # Gradient grad_weight_matrix = ( torch.tensordot( - v_data_centered, - h_data_centered, + v_data_one_hot, + mh_data, dims=[[0], [0]], ) / v_data.shape[0] - torch.tensordot( - v_gen_centered, - h_gen_centered, + v_gen_one_hot, + h_chain, dims=[[0], [0]], ) / v_chain.shape[0] @@ -218,6 +213,7 @@ def _init_chains( def _init_parameters( num_hiddens: int, data: Tensor, + weights: Tensor, device: torch.device, dtype: torch.dtype, var_init: float = 1e-4, @@ -240,5 +236,23 @@ def _init_parameters( ) * var_init ) + U, S, V = torch.svd(weight_matrix.reshape(num_visibles * num_states, num_hiddens)) + # print(S.shape) + from rbms.potts_bernoulli.tools import get_covariance_matrix + + data_oh = ( + torch.eye(num_states, device=device)[data.long()] + .float() + .reshape(-1, num_states * num_visibles) + ) + cov_data = torch.tensor( + get_covariance_matrix(data_oh, weights, device=device), device=device + ).float() + U_data, S_data, V_data = torch.svd(cov_data) + weight_matrix = ( + V_data.T[:, : min(num_hiddens, num_visibles * num_states)] @ torch.diag(S) @ V + ).reshape(num_visibles, num_states, num_hiddens) + # print(torch.svd(weight_matrix.reshape(-1, weight_matrix.shape[-1])).S) - return vbias, hbias, weight_matrix + beta = 1.0 + return beta * vbias, beta * hbias, beta * weight_matrix From 1f4d5438469902e8a8c2b4bab9e207f02c245659 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Tue, 27 Jan 2026 16:42:22 +0100 Subject: [PATCH 24/43] add conversion print + astype to dataset class --- rbms/dataset/dataset_class.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/rbms/dataset/dataset_class.py b/rbms/dataset/dataset_class.py index c46bda8..2b87843 100644 --- a/rbms/dataset/dataset_class.py +++ b/rbms/dataset/dataset_class.py @@ -126,8 +126,14 @@ def get_gzip_entropy(self, mean_size: int = 50, num_samples: int = 100): def match_model_variable_type(self, visible_type: str): self.data = convert_data[self.variable_type][visible_type](self.data) + if self.variable_type != visible_type: + print(f"Converting from '{self.variable_type}' to '{visible_type}'") + print(self.data) self.variable_type = visible_type + def astype(self, target_variable_type: str): + return convert_data[self.variable_type][target_variable_type](self.data) + def split_train_test( self, rng: np.random.Generator, From 60c42f62104977f91d994443e05f775b4bfac7a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Tue, 27 Jan 2026 16:43:06 +0100 Subject: [PATCH 25/43] add __eq__ to class for easier comparison --- rbms/classes.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/rbms/classes.py b/rbms/classes.py index 0e1776f..59ae374 100644 --- a/rbms/classes.py +++ b/rbms/classes.py @@ -29,6 +29,14 @@ def __mul__(self, other: float) -> EBM: """Multiplies the parameters of the RBM by a float.""" ... + def __eq__(self, other: EBM): + other_params = other.named_parameters() + for k, v in self.named_parameters().items(): + if not torch.equal(other_params[k], v): + return False + return True + + @abstractmethod def sample_visibles( self, chains: dict[str, Tensor], beta: float = 1.0 @@ -217,7 +225,8 @@ def normalize_grad(self) -> None: ) for p in self.parameters(): p.grad /= norm_grad - + # for p in self.parameters(): + # p.grad /= p.grad.norm() class RBM(EBM): """An abstract class representing the parameters of a RBM.""" From 852725d2242cf18867aeb62c944096936385fbf5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Tue, 27 Jan 2026 16:43:28 +0100 Subject: [PATCH 26/43] add IIRBM and BGRBM to map_model --- rbms/map_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rbms/map_model.py b/rbms/map_model.py index f9d5d2f..d2f9f8f 100644 --- a/rbms/map_model.py +++ b/rbms/map_model.py @@ -1,5 +1,7 @@ from rbms.bernoulli_bernoulli.classes import BBRBM from rbms.classes import EBM from rbms.potts_bernoulli.classes import PBRBM +from rbms.ising_ising.classes import IIRBM +from rbms.bernoulli_gaussian.classes import BGRBM -map_model: dict[str, EBM] = {"BBRBM": BBRBM, "PBRBM": PBRBM} +map_model: dict[str, EBM] = {"BBRBM": BBRBM, "PBRBM": PBRBM, "IIRBM": IIRBM, "BGRBM": BGRBM} From 1cfd16e8a75fead99a64308a74fff541d0b50c87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Tue, 27 Jan 2026 16:43:55 +0100 Subject: [PATCH 27/43] add model_type and normalize_grad option to parser --- rbms/parser.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/rbms/parser.py b/rbms/parser.py index 320df5d..101df5b 100644 --- a/rbms/parser.py +++ b/rbms/parser.py @@ -143,10 +143,21 @@ def add_args_rbm(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: rbm_args.add_argument( "--training_type", type=str, - default = "pcd", - help="(Defaults to 'pcd'). Type of the training, should be one of {'pcd', 'cd', 'rdm'}." + default="pcd", + help="(Defaults to 'pcd'). Type of the training, should be one of {'pcd', 'cd', 'rdm'}.", + ) + rbm_args.add_argument( + "--normalize_grad", + default=False, + action="store_true", + help="(Defaults to False). Normalize the gradient before update.", + ) + rbm_args.add_argument( + "--model_type", + type=str, + default=None, + help="(Defaults to None). Model to use. If None is provided, will be a RBM with the same visible type as the dataset and binary hiddens. If restore, this argument is ignored.", ) - rbm_args.add_argument("--model_type", type=str, default=None, help="(Defaults to None). Model to use. If None is provided, will be a RBM with the same visible type as the dataset and binary hiddens. If restore, this argument is ignored.") return parser From 8c69c837a9264175bb4e6f70b0b9c1705703e7f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Tue, 27 Jan 2026 16:44:40 +0100 Subject: [PATCH 28/43] add dataset weights arg --- rbms/potts_bernoulli/classes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/rbms/potts_bernoulli/classes.py b/rbms/potts_bernoulli/classes.py index a767dae..36dbaf9 100644 --- a/rbms/potts_bernoulli/classes.py +++ b/rbms/potts_bernoulli/classes.py @@ -156,6 +156,7 @@ def init_parameters(num_hiddens, dataset, device, dtype, var_init=0.0001): vbias, hbias, weight_matrix = _init_parameters( num_hiddens=num_hiddens, data=data, + weights=dataset.weights, device=device, dtype=dtype, var_init=var_init, From da1035cbc4b5fe58fbdc3b11670910b552250e9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Tue, 27 Jan 2026 16:45:30 +0100 Subject: [PATCH 29/43] fix binary to bernoulli and add ising to model match --- rbms/scripts/train_rbm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rbms/scripts/train_rbm.py b/rbms/scripts/train_rbm.py index 01449c9..43e4d5a 100644 --- a/rbms/scripts/train_rbm.py +++ b/rbms/scripts/train_rbm.py @@ -56,10 +56,12 @@ def train_rbm(args: dict): model_type = args["model_type"] if model_type is None: match train_dataset.variable_type: - case "binary": + case "bernoulli": model_type = "BBRBM" case "categorical": model_type = "PBRBM" + case "ising": + model_type = "IIRBM" case _: raise NotImplementedError() print(model_type) From c654f8f06aff49fd2b94e3c0ae6f5b03a3970e92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Tue, 27 Jan 2026 16:45:56 +0100 Subject: [PATCH 30/43] make normalize_grad optional --- rbms/training/pcd.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/rbms/training/pcd.py b/rbms/training/pcd.py index a88a3ea..c23c8ec 100644 --- a/rbms/training/pcd.py +++ b/rbms/training/pcd.py @@ -6,7 +6,7 @@ from torch import Tensor from torch.optim import SGD, Optimizer -from rbms.classes import EBM +from rbms.classes import EBM, RBM from rbms.dataset.dataset_class import RBMDataset from rbms.io import save_model from rbms.map_model import map_model @@ -26,6 +26,7 @@ def fit_batch_pcd( centered: bool = True, lambda_l1: float = 0.0, lambda_l2: float = 0.0, + normalize_grad: bool = True, ) -> tuple[dict[str, Tensor], dict]: """Sample the EBM and compute the gradient. @@ -57,7 +58,8 @@ def fit_batch_pcd( lambda_l1=lambda_l1, lambda_l2=lambda_l2, ) - params.normalize_grad() + if normalize_grad: + params.normalize_grad() logs = {} return parallel_chains, logs @@ -128,7 +130,6 @@ def train( ) update_lr = False warmup = True - from rbms.classes import RBM # Continue the training with torch.no_grad(): @@ -141,7 +142,9 @@ def train( parallel_chains = params.init_chains(parallel_chains["visible"].shape[0]) elif args["training_type"] == "cd": parallel_chains = params.init_chains( - batch["data"].shape[0], weights=batch["weights"], start_v=batch["data"] + batch["data"].shape[0], + weights=batch["weights"], + start_v=batch["data"], ) if warmup and isinstance(params, RBM): @@ -165,6 +168,7 @@ def train( centered=not (args["no_center"]), lambda_l1=args["L1"], lambda_l2=args["L2"], + normalize_grad=args["normalize_grad"], ) if update_lr: optimizer.step(update_lr=update_lr) From 8365a700cefc1b86a9815c3a80079dd797b5e5df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Tue, 27 Jan 2026 16:46:38 +0100 Subject: [PATCH 31/43] save result from get_eigenvalues_history in file to avoid repeating computations --- rbms/utils.py | 55 +++++++++++++++++++++++++++------------------------ 1 file changed, 29 insertions(+), 26 deletions(-) diff --git a/rbms/utils.py b/rbms/utils.py index 58cc912..01e5cad 100644 --- a/rbms/utils.py +++ b/rbms/utils.py @@ -25,33 +25,36 @@ def get_eigenvalues_history(filename: str, backend="cpu"): - gradient_updates (np.ndarray): Array of gradient update steps. - eigenvalues (np.ndarray): Eigenvalues along training. """ - with h5py.File(filename, "r") as f: - gradient_updates = [] - eigenvalues = [] - for key in f.keys(): - if "update_" in key: - weight_matrix = f[key]["params"]["weight_matrix"][()] - weight_matrix = weight_matrix.reshape(-1, weight_matrix.shape[-1]) - if backend == "gpu": - eig = ( - torch.svd( - torch.from_numpy(weight_matrix).to(device="cuda"), - compute_uv=False, - ) - .S.cpu() - .numpy() + saved_updates = get_saved_updates(filename) + eigenvalues = [] + for upd in saved_updates: + compute = False + with h5py.File(filename, "a") as f: + if "singular_values" not in f[f"update_{upd}"]: + compute = True + weight_matrix = f[f"update_{upd}"]["params"]["weight_matrix"][()] + + if compute: + weight_matrix = weight_matrix.reshape(-1, weight_matrix.shape[-1]) + if backend == "gpu": + eig = ( + torch.svd( + torch.from_numpy(weight_matrix).to(device="cuda"), + compute_uv=False, ) - else: - eig = np.linalg.svd(weight_matrix, compute_uv=False) - eigenvalues.append(eig.reshape(*eig.shape, 1)) - gradient_updates.append(int(key.split("_")[1])) - - # Sort the results - sorting = np.argsort(gradient_updates) - gradient_updates = np.array(gradient_updates)[sorting] - eigenvalues = np.array(np.hstack(eigenvalues).T)[sorting] - - return gradient_updates, eigenvalues + .S.cpu() + .numpy() + ) + else: + eig = np.linalg.svd(weight_matrix, compute_uv=False) + with h5py.File(filename, "a") as f: + f[f"update_{upd}"]["singular_values"] = eig + + with h5py.File(filename, "a") as f: + eig = f[f"update_{upd}"]["singular_values"][()] + eigenvalues.append(eig.reshape(*eig.shape, 1)) + eigenvalues = np.array(np.hstack(eigenvalues).T) + return saved_updates, eigenvalues def get_saved_updates(filename: str) -> np.ndarray: From 1a3682ba2faae7497b0e4f858120621078b8c81c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Tue, 27 Jan 2026 16:46:49 +0100 Subject: [PATCH 32/43] change version number --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1dffc4c..d0df24e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "rbms" -version = "0.5.0" +version = "0.5.1" authors = [ {name="Nicolas Béreux", email="nicolas.bereux@gmail.com"}, {name="Aurélien Decelle"}, From ecbfb089d05f9da83d766408034659ea7e6234b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Tue, 27 Jan 2026 17:47:20 +0100 Subject: [PATCH 33/43] simplify imports --- rbms/__init__.py | 42 ++++++++++++++++++++++ rbms/bernoulli_bernoulli/__init__.py | 11 +++++- rbms/dataset/utils.py | 51 ++++++++++++++++++++++++-- rbms/ising_ising/__init__.py | 11 +++++- rbms/potts_bernoulli/__init__.py | 13 +++++++ rbms/potts_bernoulli/implement.py | 30 ++++++++-------- rbms/potts_bernoulli/tools.py | 53 ---------------------------- 7 files changed, 138 insertions(+), 73 deletions(-) delete mode 100644 rbms/potts_bernoulli/tools.py diff --git a/rbms/__init__.py b/rbms/__init__.py index e69de29..0169025 100644 --- a/rbms/__init__.py +++ b/rbms/__init__.py @@ -0,0 +1,42 @@ +from rbms.bernoulli_bernoulli.classes import BBRBM +from rbms.bernoulli_gaussian.classes import BGRBM +from rbms.dataset import load_dataset +from rbms.dataset.utils import convert_data +from rbms.io import load_model, load_params +from rbms.ising_ising.classes import IIRBM +from rbms.map_model import map_model +from rbms.plot import plot_image, plot_mult_PCA +from rbms.potts_bernoulli.classes import PBRBM +from rbms.utils import ( + bernoulli_to_ising, + compute_log_likelihood, + get_categorical_configurations, + get_eigenvalues_history, + get_flagged_updates, + get_saved_updates, + ising_to_bernoulli, +) + +__all__ = [ + BBRBM, + BGRBM, + IIRBM, + PBRBM, + map_model, + bernoulli_to_ising, + ising_to_bernoulli, + compute_log_likelihood, + get_eigenvalues_history, + get_saved_updates, + get_flagged_updates, + get_categorical_configurations, + plot_mult_PCA, + plot_image, + load_params, + load_model, + load_dataset, + convert_data, +] + + +__version__ = "0.5.1" diff --git a/rbms/bernoulli_bernoulli/__init__.py b/rbms/bernoulli_bernoulli/__init__.py index 25d1da7..b2248e0 100644 --- a/rbms/bernoulli_bernoulli/__init__.py +++ b/rbms/bernoulli_bernoulli/__init__.py @@ -1,3 +1,12 @@ # ruff: noqa from rbms.bernoulli_bernoulli.classes import BBRBM -from rbms.bernoulli_bernoulli.functional import * +from rbms.bernoulli_bernoulli.functional import ( + compute_energy, + compute_energy_hiddens, + compute_energy_visibles, + compute_gradient, + init_chains, + init_parameters, + sample_hiddens, + sample_visibles, +) diff --git a/rbms/dataset/utils.py b/rbms/dataset/utils.py index ea7463b..d8f8bf8 100644 --- a/rbms/dataset/utils.py +++ b/rbms/dataset/utils.py @@ -1,6 +1,7 @@ import numpy as np import torch from torch import Tensor + from rbms.custom_fn import one_hot @@ -48,13 +49,10 @@ def ising_to_bernoulli(x): return (x + 1) / 2 - - def categorical_to_bernoulli(x): return one_hot(x.long()).reshape(x.shape[0], -1) - convert_data = { "bernoulli": { "bernoulli": (lambda x: x), @@ -73,3 +71,50 @@ def categorical_to_bernoulli(x): "categorical": (lambda x: x), }, } + + +def get_covariance_matrix( + data: Tensor, + weights: Tensor | None = None, + num_extract: int | None = None, + center: bool = True, + device: torch.device = torch.device("cpu"), +) -> Tensor: + """Returns the covariance matrix of the data. If weights is specified, the weighted covariance matrix is computed. + + Args: + data (Tensor): Data. + weights (Tensor, optional): Weights of the data. Defaults to None. + num_extract (int, optional): Number of data to extract to compute the covariance matrix. Defaults to None. + center (bool): Center the data. Defaults to True. + device (torch.device): Device. Defaults to 'cpu'. + dtype (torch.dtype): DType. Defaults to torch.float32. + + Returns: + Tensor: Covariance matrix of the dataset. + """ + num_data = len(data) + num_classes = int(data.max().item() + 1) + + if weights is None: + weights = torch.ones(num_data) + weights = weights.to(device=device, dtype=torch.float32) + + if num_extract is not None: + idxs = np.random.choice(a=np.arange(num_data), size=(num_extract,), replace=False) + data = data[idxs] + weights = weights[idxs] + num_data = num_extract + + if num_classes != 2: + data = data.to(device=device, dtype=torch.int32) + data_oh = one_hot(data, num_classes=num_classes).reshape(num_data, -1) + else: + data_oh = data.to(device=device, dtype=torch.float32) + + norm_weights = weights.reshape(-1, 1) / weights.sum() + data_mean = (data_oh * norm_weights).sum(0, keepdim=True) + cov_matrix = ((data_oh * norm_weights).mT @ data_oh) - int(center) * ( + data_mean.mT @ data_mean + ) + return cov_matrix diff --git a/rbms/ising_ising/__init__.py b/rbms/ising_ising/__init__.py index d467f56..39fa13e 100644 --- a/rbms/ising_ising/__init__.py +++ b/rbms/ising_ising/__init__.py @@ -1,3 +1,12 @@ # ruff: noqa from rbms.ising_ising.classes import IIRBM -from rbms.ising_ising.functional import * +from rbms.ising_ising.functional import ( + compute_energy, + compute_energy_hiddens, + compute_energy_visibles, + compute_gradient, + init_chains, + init_parameters, + sample_hiddens, + sample_visibles, +) diff --git a/rbms/potts_bernoulli/__init__.py b/rbms/potts_bernoulli/__init__.py index e69de29..3fd8ebb 100644 --- a/rbms/potts_bernoulli/__init__.py +++ b/rbms/potts_bernoulli/__init__.py @@ -0,0 +1,13 @@ +# ruff: noqa +from rbms.potts_bernoulli.classes import PBRBM +from rbms.potts_bernoulli.functional import ( + compute_energy, + compute_energy_hiddens, + compute_energy_visibles, + compute_gradient, + init_chains, + init_parameters, + sample_hiddens, + sample_visibles, +) +from rbms.potts_bernoulli.utils import ensure_zero_sum_gauge diff --git a/rbms/potts_bernoulli/implement.py b/rbms/potts_bernoulli/implement.py index 7191f8f..9c491e1 100644 --- a/rbms/potts_bernoulli/implement.py +++ b/rbms/potts_bernoulli/implement.py @@ -236,22 +236,22 @@ def _init_parameters( ) * var_init ) - U, S, V = torch.svd(weight_matrix.reshape(num_visibles * num_states, num_hiddens)) - # print(S.shape) - from rbms.potts_bernoulli.tools import get_covariance_matrix + # U, S, V = torch.svd(weight_matrix.reshape(num_visibles * num_states, num_hiddens)) + # # print(S.shape) + # from rbms.potts_bernoulli.tools import get_covariance_matrix - data_oh = ( - torch.eye(num_states, device=device)[data.long()] - .float() - .reshape(-1, num_states * num_visibles) - ) - cov_data = torch.tensor( - get_covariance_matrix(data_oh, weights, device=device), device=device - ).float() - U_data, S_data, V_data = torch.svd(cov_data) - weight_matrix = ( - V_data.T[:, : min(num_hiddens, num_visibles * num_states)] @ torch.diag(S) @ V - ).reshape(num_visibles, num_states, num_hiddens) + # data_oh = ( + # torch.eye(num_states, device=device)[data.long()] + # .float() + # .reshape(-1, num_states * num_visibles) + # ) + # cov_data = torch.tensor( + # get_covariance_matrix(data_oh, weights, device=device), device=device + # ).float() + # U_data, S_data, V_data = torch.svd(cov_data) + # weight_matrix = ( + # V_data.T[:, : min(num_hiddens, num_visibles * num_states)] @ torch.diag(S) @ V + # ).reshape(num_visibles, num_states, num_hiddens) # print(torch.svd(weight_matrix.reshape(-1, weight_matrix.shape[-1])).S) beta = 1.0 diff --git a/rbms/potts_bernoulli/tools.py b/rbms/potts_bernoulli/tools.py deleted file mode 100644 index 482da43..0000000 --- a/rbms/potts_bernoulli/tools.py +++ /dev/null @@ -1,53 +0,0 @@ -import numpy as np -import torch -from torch import Tensor - -from rbms.custom_fn import one_hot - - -def get_covariance_matrix( - data: Tensor, - weights: Tensor | None = None, - num_extract: int | None = None, - center: bool = True, - device: torch.device = torch.device("cpu"), - dtype: torch.dtype = torch.float32, -) -> Tensor: - """Returns the covariance matrix of the data. If weights is specified, the weighted covariance matrix is computed. - - Args: - data (Tensor): Data. - weights (Tensor, optional): Weights of the data. Defaults to None. - num_extract (int, optional): Number of data to extract to compute the covariance matrix. Defaults to None. - center (bool): Center the data. Defaults to True. - device (torch.device): Device. Defaults to 'cpu'. - dtype (torch.dtype): DType. Defaults to torch.float32. - - Returns: - Tensor: Covariance matrix of the dataset. - """ - num_data = len(data) - num_classes = int(data.max().item() + 1) - - if weights is None: - weights = torch.ones(num_data) - weights = weights.to(device=device, dtype=torch.float32) - - if num_extract is not None: - idxs = np.random.choice(a=np.arange(num_data), size=(num_extract,), replace=False) - data = data[idxs] - weights = weights[idxs] - num_data = num_extract - - if num_classes != 2: - data = data.to(device=device, dtype=torch.int32) - data_oh = one_hot(data, num_classes=num_classes).reshape(num_data, -1) - else: - data_oh = data.to(device=device, dtype=torch.float32) - - norm_weights = weights.reshape(-1, 1) / weights.sum() - data_mean = (data_oh * norm_weights).sum(0, keepdim=True) - cov_matrix = ((data_oh * norm_weights).mT @ data_oh) - int(center) * ( - data_mean.mT @ data_mean - ) - return cov_matrix From 9da903c546b677036e30979646313b9d2d17816b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Thu, 29 Jan 2026 13:09:43 +0100 Subject: [PATCH 34/43] fix: add __init__ to bernoulli_gaussian --- rbms/bernoulli_gaussian/__init__.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 rbms/bernoulli_gaussian/__init__.py diff --git a/rbms/bernoulli_gaussian/__init__.py b/rbms/bernoulli_gaussian/__init__.py new file mode 100644 index 0000000..3743484 --- /dev/null +++ b/rbms/bernoulli_gaussian/__init__.py @@ -0,0 +1,12 @@ +# ruff: noqa +from rbms.bernoulli_gaussian.classes import BGRBM +from rbms.bernoulli_gaussian.functional import ( + compute_energy, + compute_energy_hiddens, + compute_energy_visibles, + compute_gradient, + init_chains, + init_parameters, + sample_hiddens, + sample_visibles, +) From bdf553e5574a053af358a02d675be01efd31ca06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Wed, 4 Feb 2026 23:53:43 +0100 Subject: [PATCH 35/43] clip grad --- rbms/classes.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/rbms/classes.py b/rbms/classes.py index 59ae374..ec44d83 100644 --- a/rbms/classes.py +++ b/rbms/classes.py @@ -34,8 +34,7 @@ def __eq__(self, other: EBM): for k, v in self.named_parameters().items(): if not torch.equal(other_params[k], v): return False - return True - + return True @abstractmethod def sample_visibles( @@ -228,6 +227,14 @@ def normalize_grad(self) -> None: # for p in self.parameters(): # p.grad /= p.grad.norm() + def clip_grad(self, max_norm=5): + for p in self.parameters(): + grad_norm = p.grad.norm() + if grad_norm > max_norm: + p.grad /= grad_norm + p.grad *= max_norm + + class RBM(EBM): """An abstract class representing the parameters of a RBM.""" From dc54a162cba766c95dcead4cc08d650f54d2d6a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Thu, 5 Feb 2026 00:28:13 +0100 Subject: [PATCH 36/43] rework the main loop and add rbms restore script allowing to change many hyperparameters --- rbms/scripts/entrypoint.py | 11 +- rbms/scripts/restore.py | 211 +++++++++++++++++++ rbms/scripts/train_rbm.py | 173 +++++++++++----- rbms/training/pcd.py | 232 +++++++++------------ rbms/training/utils.py | 402 ++++++++++++++++++++----------------- 5 files changed, 651 insertions(+), 378 deletions(-) create mode 100644 rbms/scripts/restore.py diff --git a/rbms/scripts/entrypoint.py b/rbms/scripts/entrypoint.py index e3fea49..df685cb 100644 --- a/rbms/scripts/entrypoint.py +++ b/rbms/scripts/entrypoint.py @@ -4,13 +4,12 @@ def main(): - # Get the directory of the current script SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) # Check if the first positional argument is provided if len(sys.argv) < 2: - print("Error: No command provided. Use 'train' or 'pt_sampling'.") + print("Error: No command provided. Use 'train', 'restore', 'split' or 'pt_sampling'.") sys.exit(1) # Assign the first positional argument to a variable @@ -23,9 +22,13 @@ def main(): case "pt_sampling": SCRIPT = "pt_sampling.py" case "split": - SCRIPT = "split_data.py" + SCRIPT = "split_data.py" + case "restore": + SCRIPT = "restore.py" case _: - print(f"Error: Invalid command '{COMMAND}'. Use 'train', 'split' or 'pt_sampling'.") + print( + f"Error: Invalid command '{COMMAND}'. Use 'train', 'restore', 'split' or 'pt_sampling'." + ) sys.exit(1) # Run the corresponding Python script with the remaining optional arguments diff --git a/rbms/scripts/restore.py b/rbms/scripts/restore.py new file mode 100644 index 0000000..7cfd1ec --- /dev/null +++ b/rbms/scripts/restore.py @@ -0,0 +1,211 @@ +import argparse + +import h5py +import torch + +from rbms import get_saved_updates +from rbms.dataset import load_dataset +from rbms.map_model import map_model +from rbms.optim import setup_optim +from rbms.parser import ( + add_args_pytorch, + add_args_saves, + add_args_train, + add_grad_args, + add_sampling_args, + match_args_dtype, + remove_argument, +) +from rbms.training.pcd import train +from rbms.training.utils import get_checkpoints, restore_training + + +def create_parser_restore(): + parser = argparse.ArgumentParser( + description="Restore the training of a Restricted Boltzmann Machine" + ) + dataset_args = parser.add_argument_group("Dataset") + dataset_args.add_argument( + "-d", + "--dataset", + type=str, + required=True, + help="Path to a data file (type should be .h5 or .fasta)", + ) + dataset_args.add_argument( + "--test_dataset", + type=str, + required=False, + default=None, + help="Path to test dataset file (type should be .h5 or .fasta)", + ) + parser = add_args_train(parser) + parser = add_grad_args(parser) + parser.add_argument( + "--update", + default=None, + type=int, + help="(Defaults to None). Which update to restore from. If None, the last update is used.", + ) + remove_argument(parser, "no_center") + remove_argument(parser, "normalize_grad") + + parser = add_sampling_args(parser) + parser = add_args_saves(parser) + parser = add_args_pytorch(parser) + remove_argument(parser, "use_torch") + return parser + + +def recover_args( + args: dict, +) -> tuple[ + dict[str, str], + dict[str, int | float], + dict[str, bool | float], + dict[str, int | float], + dict[str, str | torch.dtype], +]: + with h5py.File(args["filename"], "r") as f: + # dataset + args_dataset = { + "dataset_name": args["dataset"], + "test_dataset_name": args["test_dataset"], + } + dataset = f["dataset_args"] + if "subset_labels" in dataset.keys(): + args_dataset["subset_labels"] = dataset["subset_labels"][()] + else: + args_dataset["subset_labels"] = None + args_dataset["train_size"] = dataset["train_size"][()].item() + args_dataset["test_size"] = dataset["test_size"][()].item() + + args_dataset["use_weights"] = dataset["use_weights"][()].item() + args_dataset["alphabet"] = dataset["alphabet"][()].decode() + args_dataset["remove_duplicates"] = dataset["remove_duplicates"][()].item() + args_dataset["seed"] = dataset["seed"][()].item() + + # grad + args_grad = {} + grad = f["grad_args"] + ## Default args + args_grad["no_center"] = grad["no_center"][()].item() + args_grad["normalize_grad"] = grad["normalize_grad"][()].item() + ## Can be overriden + args_grad["max_norm_grad"] = args["max_norm_grad"] + if args_grad["max_norm_grad"] is None: + args_grad["max_norm_grad"] = grad["max_norm_grad"][()].item() + args_grad["L1"] = args["L1"] + if args_grad["L1"] is None: + args_grad["L1"] = grad["L1"][()].item() + args_grad["L2"] = args["L2"] + if args_grad["L2"] is None: + args_grad["L2"] = grad["L2"][()].item() + + # sampling + args_sampling = {} + sampling = f["sampling_args"] + args_sampling["gibbs_steps"] = args["gibbs_steps"] + if args_sampling["gibbs_steps"] is None: + args_sampling["gibbs_steps"] = sampling["gibbs_steps"][()].item() + args_sampling["beta"] = args["beta"] + if args_sampling["beta"] is None: + args_sampling["beta"] = sampling["beta"][()].item() + + # train + args_train = {} + train_args = f["train_args"] + args_train["optim"] = args["optim"] + args_train["num_updates"] = args["num_updates"] + if args_train["optim"] is None: + args_train["optim"] = train_args["optim"][()].decode() + args_train["learning_rate"] = args["learning_rate"] + if args_train["learning_rate"] is None: + args_train["learning_rate"] = train_args["learning_rate"][()] + args_train["batch_size"] = args["batch_size"] + if args_train["batch_size"] is None: + args_train["batch_size"] = train_args["batch_size"][()].item() + args_train["update"] = args["update"] + if args_train["update"] is None: + args_train["update"] = get_saved_updates(args["filename"])[-1] + args_train["mult_optim"] = args["mult_optim"] + args_train["training_type"] = args["training_type"] + if args_train["training_type"] is None: + args_train["training_type"] = train_args["training_type"][()].decode() + + # Torch + args_torch = {} + args_torch["device"] = args["device"] + args_torch["dtype"] = args["dtype"] + + # save + args_save = {} + args_save["filename"] = args["filename"] + save = f["save_args"] + args_save["n_save"] = args["n_save"] + if args_save["n_save"] is None: + args_save["n_save"] = save["n_save"][()].item() + args_save["spacing"] = args["spacing"] + if args_save["spacing"] is None: + args_save["spacing"] = save["spacing"][()] + return (args_dataset, args_save, args_train, args_grad, args_sampling, args_torch) + + +def main(): + torch.backends.cudnn.benchmark = True + parser = create_parser_restore() + args = parser.parse_args() + args = vars(args) + args = match_args_dtype(args) + args_dataset, args_save, args_train, args_grad, args_sampling, args_torch = ( + recover_args(args) + ) + checkpoints = get_checkpoints( + num_updates=args_train["num_updates"], + n_save=args_save["n_save"], + spacing=args_save["spacing"], + ) + train_dataset, test_dataset = load_dataset( + dataset_name=args_dataset["dataset_name"], + test_dataset_name=args_dataset["test_dataset_name"], + subset_labels=args_dataset["subset_labels"], + use_weights=args_dataset["use_weights"], + alphabet=args_dataset["alphabet"], + remove_duplicates=args_dataset["remove_duplicates"], + **args_torch, + ) + ( + params, + parallel_chains, + target_update, + elapsed_time, + train_dataset, + test_dataset, + ) = restore_training( + train_dataset=train_dataset, + test_dataset=test_dataset, + args_save=args_save, + args_train=args_train, + args_dataset=args_dataset, + args_torch=args_torch, + map_model=map_model, + ) + optimizer = setup_optim(args_train["optim"], args_train, params) + train( + train_dataset=train_dataset, + test_dataset=test_dataset, + params=params, + parallel_chains=parallel_chains, + optimizer=optimizer, + curr_update=target_update, + elapsed_time=elapsed_time, + checkpoints=checkpoints, + args_save=args_save, + args_train=args_train, + args_grad=args_grad, + args_sampling=args_sampling, + ) + + +if __name__ == "__main__": + main() diff --git a/rbms/scripts/train_rbm.py b/rbms/scripts/train_rbm.py index 43e4d5a..023465a 100644 --- a/rbms/scripts/train_rbm.py +++ b/rbms/scripts/train_rbm.py @@ -1,90 +1,163 @@ import argparse - -import h5py import torch - from rbms.dataset import load_dataset from rbms.dataset.parser import add_args_dataset from rbms.map_model import map_model +from rbms.optim import setup_optim from rbms.parser import ( + add_args_init_rbm, add_args_pytorch, - add_args_rbm, - add_args_regularization, add_args_saves, + add_args_train, + add_grad_args, + add_sampling_args, default_args, match_args_dtype, remove_argument, + set_args_default, ) from rbms.training.pcd import train -from rbms.training.utils import get_checkpoints +from rbms.training.utils import get_checkpoints, init_training, restore_training def create_parser(): parser = argparse.ArgumentParser(description="Train a Restricted Boltzmann Machine") parser = add_args_dataset(parser) - parser = add_args_rbm(parser) - parser = add_args_regularization(parser) + parser = add_args_init_rbm(parser) + parser = add_args_train(parser) + parser = add_sampling_args(parser) + parser = add_grad_args(parser) parser = add_args_saves(parser) parser = add_args_pytorch(parser) - parser.add_argument("--optim", default="sgd") remove_argument(parser, "use_torch") return parser -def train_rbm(args: dict): - if args["num_updates"] is None: - args["num_updates"] = default_args["num_updates"] +def process_args(args: dict): + args_torch = {"device": args["device"], "dtype": args["dtype"]} + args_dataset = { + "dataset_name": args["dataset"], + "test_dataset_name": args["test_dataset"], + "train_size": args["train_size"], + "test_size": args["test_size"], + "subset_labels": args["subset_labels"], + "use_weights": args["use_weights"], + "alphabet": args["alphabet"], + "remove_duplicates": args["remove_duplicates"], + "seed": args["seed"], + } + args_grad = { + "no_center": args["no_center"], + "normalize_grad": args["normalize_grad"], + "max_norm_grad": args["max_norm_grad"], + "L1": args["L1"], + "L2": args["L2"], + } + args_sampling = {"gibbs_steps": args["gibbs_steps"], "beta": args["beta"]} + args_train = { + "optim": args["optim"], + "learning_rate": args["learning_rate"], + "batch_size": args["batch_size"], + "num_updates": args["num_updates"], + "mult_optim": args["mult_optim"], + "training_type": args["training_type"], + } + args_save = { + "filename": args["filename"], + "n_save": args["n_save"], + "spacing": args["spacing"], + } + args_init = { + "num_chains": args["num_chains"], + "num_hiddens": args["num_hiddens"], + "model_type": args["model_type"], + } + return ( + args_dataset, + args_save, + args_train, + args_grad, + args_sampling, + args_torch, + args_init, + ) + + +def main(): + torch.backends.cudnn.benchmark = True + parser = create_parser() + args = parser.parse_args() + args = vars(args) + args = set_args_default(args, default_args=default_args) + args = match_args_dtype(args) + ( + args_dataset, + args_save, + args_train, + args_grad, + args_sampling, + args_torch, + args_init, + ) = process_args(args) checkpoints = get_checkpoints( - num_updates=args["num_updates"], n_save=args["n_save"], spacing=args["spacing"] + num_updates=args_train["num_updates"], + n_save=args_save["n_save"], + spacing=args_save["spacing"], ) train_dataset, test_dataset = load_dataset( - dataset_name=args["dataset"], - test_dataset_name=args["test_dataset"], - subset_labels=args["subset_labels"], + dataset_name=args_dataset["dataset_name"], + test_dataset_name=args_dataset["test_dataset_name"], + subset_labels=args_dataset["subset_labels"], use_weights=args["use_weights"], alphabet=args["alphabet"], remove_duplicates=args["remove_duplicates"], - device=args["device"], - dtype=args["dtype"], - + **args_torch, + ) + flags = ["checkpoint"] + init_training( + args_save, + args_train, + args_grad, + args_sampling, + args_init, + args_dataset, + args_torch, + train_dataset, + flags, ) - print(train_dataset) - if args["restore"]: - with h5py.File(args["filename"], "r") as f: - model_type = f["model_type"][()].decode() - else: - model_type = args["model_type"] - if model_type is None: - match train_dataset.variable_type: - case "bernoulli": - model_type = "BBRBM" - case "categorical": - model_type = "PBRBM" - case "ising": - model_type = "IIRBM" - case _: - raise NotImplementedError() - print(model_type) + args_train["update"] = 1 + ( + params, + parallel_chains, + target_update, + elapsed_time, + train_dataset, + test_dataset, + ) = restore_training( + train_dataset=train_dataset, + test_dataset=test_dataset, + args_save=args_save, + args_train=args_train, + args_dataset=args_dataset, + args_torch=args_torch, + map_model=map_model, + ) + optimizer = setup_optim(args_train["optim"], args_train, params) train( train_dataset=train_dataset, test_dataset=test_dataset, - model_type=model_type, - args=args, - dtype=args["dtype"], + params=params, + parallel_chains=parallel_chains, + optimizer=optimizer, + curr_update=target_update, + elapsed_time=elapsed_time, checkpoints=checkpoints, - map_model=map_model, - default_args=default_args, + args_save=args_save, + args_train=args_train, + args_grad=args_grad, + args_sampling=args_sampling, ) -def main(): - torch.backends.cudnn.benchmark = True - parser = create_parser() - args = parser.parse_args() - args = vars(args) - args = match_args_dtype(args) - train_rbm(args=args) - - if __name__ == "__main__": main() diff --git a/rbms/training/pcd.py b/rbms/training/pcd.py index c23c8ec..0f1ff31 100644 --- a/rbms/training/pcd.py +++ b/rbms/training/pcd.py @@ -1,20 +1,15 @@ import time -import h5py import numpy as np import torch from torch import Tensor -from torch.optim import SGD, Optimizer +from tqdm.autonotebook import tqdm -from rbms.classes import EBM, RBM +from rbms.classes import EBM from rbms.dataset.dataset_class import RBMDataset from rbms.io import save_model -from rbms.map_model import map_model -from rbms.parser import default_args, set_args_default from rbms.potts_bernoulli.classes import PBRBM from rbms.potts_bernoulli.utils import ensure_zero_sum_gauge -from rbms.training.utils import initialize_model_archive, setup_training -from rbms.utils import check_file_existence, log_to_csv def fit_batch_pcd( @@ -27,6 +22,7 @@ def fit_batch_pcd( lambda_l1: float = 0.0, lambda_l2: float = 0.0, normalize_grad: bool = True, + max_norm_grad: float = -1, ) -> tuple[dict[str, Tensor], dict]: """Sample the EBM and compute the gradient. @@ -60,21 +56,26 @@ def fit_batch_pcd( ) if normalize_grad: params.normalize_grad() - logs = {} - return parallel_chains, logs + if max_norm_grad > 0: + params.clip_grad(max_norm=max_norm_grad) + return parallel_chains +@torch.no_grad def train( train_dataset: RBMDataset, - test_dataset: RBMDataset | None, - model_type: str, - args: dict, - dtype: torch.dtype, + test_dataset: RBMDataset, + params: EBM, + parallel_chains: dict[str, Tensor], + optimizer: torch.optim.Optimizer, + curr_update: int, + elapsed_time: float, checkpoints: np.ndarray, - optim: Optimizer = SGD, - map_model: dict[str, EBM] = map_model, - default_args: dict = default_args, -) -> None: + args_save: dict[str, str], + args_train: dict[str, int | float], + args_grad: dict[str, bool | float], + args_sampling: dict[str, int | float], +): """Train an EBM. Args: @@ -85,128 +86,83 @@ def train( dtype (torch.dtype): The data type for the parameters. checkpoints (np.ndarray): An array of checkpoints for saving model states. """ - - if not (args["overwrite"]): - check_file_existence(args["filename"]) - - # Create a first archive with the initialized model - if not (args["restore"]): - initialize_model_archive( - args=args, - model_type=model_type, - train_dataset=train_dataset, - test_dataset=test_dataset, - dtype=dtype, - map_model=map_model, - ) - ( - params, - parallel_chains, - args, - num_updates, - start, - elapsed_time, - log_filename, - pbar, - train_dataset, - test_dataset, - ) = setup_training( - args, - map_model=map_model, - train_dataset=train_dataset, - test_dataset=test_dataset, + # Sampling + gibbs_steps: int = args_sampling["gibbs_steps"] + beta: float = args_sampling["beta"] + + # Grad + centered: bool = not (args_grad["no_center"]) + L1: float = args_grad["L1"] + L2: float = args_grad["L2"] + normalize_grad: bool = args_grad["normalize_grad"] + max_norm_grad: float = args_grad["max_norm_grad"] + + # train + batch_size: int = args_train["batch_size"] + num_updates: int = args_train["num_updates"] + training_type: str = args_train["training_type"] + + # save + filename: str = args_save["filename"] + + # pbar + pbar = tqdm( + initial=curr_update, + total=num_updates, + colour="red", + dynamic_ncols=True, + ascii="-#", ) - args = set_args_default(args=args, default_args=default_args) - - optimizer = optim(params.parameters(), lr=args["learning_rate"], maximize=True) - - if args["optim"] == "nag": - optimizer = SGD( - params.parameters(), - lr=args["learning_rate"], - maximize=True, - momentum=0.9, - nesterov=True, + pbar.set_description(f"Training {params.name}") + + start = time.time() + + for idx in range(curr_update + 1, num_updates + 1): + batch = train_dataset.batch(batch_size) + data, weights = batch["data"], batch["weights"] + if training_type == "rdm": + parallel_chains = params.init_chains(parallel_chains["visible"].shape[0]) + elif training_type == "cd": + parallel_chains = params.init_chains( + data.shape[0], + weights=weights, + start_v=data, + ) + for opt in optimizer: + opt.zero_grad(set_to_none=False) + + parallel_chains = fit_batch_pcd( + batch=(data, weights), + parallel_chains=parallel_chains, + params=params, + gibbs_steps=gibbs_steps, + beta=beta, + centered=centered, + lambda_l1=L1, + lambda_l2=L2, + normalize_grad=normalize_grad, + max_norm_grad=max_norm_grad, ) - update_lr = False - warmup = True - - # Continue the training - with torch.no_grad(): - for idx in range(num_updates + 1, args["num_updates"] + 1): - # rand_idx = torch.randperm(len(train_dataset))[: args["batch_size"]] - # batch = (train_dataset.data[rand_idx], train_dataset.weights[rand_idx]) - batch = train_dataset.batch(args["batch_size"]) - - if args["training_type"] == "rdm": - parallel_chains = params.init_chains(parallel_chains["visible"].shape[0]) - elif args["training_type"] == "cd": - parallel_chains = params.init_chains( - batch["data"].shape[0], - weights=batch["weights"], - start_v=batch["data"], - ) - - if warmup and isinstance(params, RBM): - if params.weight_matrix.norm() > 10 and args["optim"] == "nag": - optimizer = SGD( - params.parameters(), - lr=args["learning_rate"], - maximize=True, - momentum=0.9, - nesterov=True, - ) - warmup = False - optimizer.zero_grad(set_to_none=False) - - parallel_chains, logs = fit_batch_pcd( - batch=(batch["data"], batch["weights"]), - parallel_chains=parallel_chains, + for opt in optimizer: + opt.step() + + if isinstance(params, PBRBM): + ensure_zero_sum_gauge(params) + + # Save current model if necessary + if idx in checkpoints or idx == num_updates: + curr_time = time.time() - start + learning_rate = torch.tensor([opt.param_groups[0]["lr"] for opt in optimizer]) + save_model( + filename=filename, params=params, - gibbs_steps=args["gibbs_steps"], - beta=args["beta"], - centered=not (args["no_center"]), - lambda_l1=args["L1"], - lambda_l2=args["L2"], - normalize_grad=args["normalize_grad"], + chains=parallel_chains, + num_updates=idx, + time=curr_time + elapsed_time, + learning_rate=learning_rate, + flags=["checkpoint"], ) - if update_lr: - optimizer.step(update_lr=update_lr) - else: - optimizer.step() - if isinstance(params, PBRBM): - ensure_zero_sum_gauge(params) - - # Save current model if necessary - if idx in checkpoints: - curr_time = time.time() - start - save_model( - filename=args["filename"], - params=params, - chains=parallel_chains, - num_updates=idx, - time=curr_time + elapsed_time, - flags=["checkpoint"], - ) - - # Save some logs - learning_rates = np.array([optimizer.param_groups[0]["lr"]]) - with h5py.File(args["filename"], "a") as f: - if "learning_rate" in f.keys(): - learning_rates = np.append(f["learning_rate"][()], learning_rates) - del f["learning_rate"] - f["learning_rate"] = learning_rates - if hasattr(optimizer, "cosine_similarity"): - if "cosine_similarities" in f.keys(): - cosine_similarities = np.append( - f["cosine_similarities"][()], - optimizer.cosine_similarity, - ) - del f["cosine_similarities"] - f["cosine_similarities"] = cosine_similarities - - if args["log"]: - log_to_csv(logs, log_file=log_filename) - pbar.set_postfix_str(f"lr: {optimizer.param_groups[0]['lr']:.6f}") - # Update progress bar - pbar.update(1) + + pbar.set_postfix_str(f"lr: {optimizer[0].param_groups[0]['lr']:.6f}") + # Update progress bar + pbar.update(1) diff --git a/rbms/training/utils.py b/rbms/training/utils.py index beb61ca..c16232d 100644 --- a/rbms/training/utils.py +++ b/rbms/training/utils.py @@ -1,7 +1,3 @@ -import pathlib -import time -from typing import Any - import h5py import numpy as np import torch @@ -9,234 +5,268 @@ from tqdm.autonotebook import tqdm from rbms.classes import EBM -from rbms.const import LOG_FILE_HEADER from rbms.dataset.dataset_class import RBMDataset from rbms.io import load_model, save_model from rbms.map_model import map_model -from rbms.parser import default_args, set_args_default from rbms.potts_bernoulli.classes import PBRBM from rbms.potts_bernoulli.utils import ensure_zero_sum_gauge from rbms.utils import get_saved_updates -def setup_training( - args: dict, +def get_checkpoints(num_updates: int, n_save: int, spacing: str = "exp") -> np.ndarray: + """Select the list of training times (ages) at which to save the model. + + Args: + num_updates (int): Number of gradient updates to perform during training. + n_save (int): Number of models to save. + spacing (str, optional): Spacing method, either "linear" ("lin") or "exponential" ("exp"). Defaults to "exp". + + Returns: + np.ndarray: Array of checkpoint indices. + """ + match spacing: + case "exp": + checkpoints = [] + xi = num_updates + for _ in range(n_save): + checkpoints.append(xi) + xi = xi / num_updates ** (1 / n_save) + checkpoints = np.unique(np.array(checkpoints, dtype=np.int32)) + case "linear": + checkpoints = np.linspace(1, num_updates, n_save).astype(np.int32) + case _: + raise ValueError(f"spacing should be one of ('exp', 'linear'), got {spacing}") + checkpoints = np.unique(np.append(checkpoints, num_updates)) + return checkpoints + + +def init_training( + args_save: dict[str, str | int], + args_train: dict, + args_grad: dict, + args_sampling: dict, + args_init: dict, + args_dataset: dict, + args_torch: dict[str, str | torch.dtype], + train_dataset: RBMDataset, + flags: list[str] = ["checkpoint"], +): + # Torch + device: str = args_torch["device"] + dtype: torch.dtype = args_torch["dtype"] + + # Sampling + gibbs_steps: int = args_sampling["gibbs_steps"] + beta: float = args_sampling["beta"] + + # Grad + centered: bool = not (args_grad["no_center"]) + L1: float = args_grad["L1"] + L2: float = args_grad["L2"] + normalize_grad: bool = args_grad["normalize_grad"] + max_norm_grad: float = args_grad["max_norm_grad"] + + # train + batch_size: int = args_train["batch_size"] + num_updates: int = args_train["num_updates"] + optim: str = args_train["optim"] + mult_optim: bool = args_train["mult_optim"] + training_type: str = args_train["training_type"] + learning_rate: float = args_train["learning_rate"] + + # save + filename: str = args_save["filename"] + n_save: int = args_save["n_save"] + spacing: str = args_save["spacing"] + + # dataset + seed: int = args_dataset["seed"] + train_size: float = args_dataset["train_size"] + test_size: float = args_dataset["test_size"] + if test_size is None: + test_size = 1 - train_size + subset_labels: list = args_dataset["subset_labels"] + use_weights: bool = args_dataset["use_weights"] + alphabet: str = args_dataset["alphabet"] + remove_duplicates: bool = args_dataset["remove_duplicates"] + + # init + num_hiddens: int = args_init["num_hiddens"] + num_chains: int = args_init["num_chains"] + model_type: str = args_init["model_type"] + if model_type is None: + match train_dataset.variable_type: + case "bernoulli": + model_type = "BBRBM" + case "categorical": + model_type = "PBRBM" + case "ising": + model_type = "IIRBM" + case _: + raise NotImplementedError() + + # Setup dataset + num_visibles = train_dataset.get_num_visibles() + + # Setup RBM + params = map_model[model_type].init_parameters( + num_hiddens=num_hiddens, + dataset=train_dataset, + device=device, + dtype=dtype, + ) + if isinstance(params, PBRBM): + ensure_zero_sum_gauge(params) + + # Permanent chains + parallel_chains = params.init_chains(num_samples=num_chains) + parallel_chains = params.sample_state(chains=parallel_chains, n_steps=gibbs_steps) + + # Save hyperparameters + if mult_optim: + learning_rate = torch.tensor([learning_rate] * len(params.parameters())) + else: + learning_rate = torch.tensor([learning_rate]) + + with h5py.File(filename, "w") as file_model: + hyperparameters = file_model.create_group("hyperparameters") + hyperparameters["num_visibles"] = num_visibles + hyperparameters["num_hiddens"] = num_hiddens + hyperparameters["num_chains"] = num_chains + hyperparameters["filename"] = str(filename) + + save_model( + filename=filename, + params=params, + chains=parallel_chains, + num_updates=1, + time=0.0, + flags=flags, + learning_rate=learning_rate, + ) + + with h5py.File(filename, "a") as f: + dataset = f.create_group("dataset_args") + if subset_labels is not None: + dataset["subset_labels"] = subset_labels + dataset["use_weights"] = use_weights + dataset["train_size"] = train_size + dataset["test_size"] = test_size + dataset["alphabet"] = alphabet + dataset["remove_duplicates"] = remove_duplicates + dataset["seed"] = seed + + grad = f.create_group("grad_args") + grad["no_center"] = not (centered) + grad["normalize_grad"] = normalize_grad + grad["max_norm_grad"] = max_norm_grad + grad["L1"] = L1 + grad["L2"] = L2 + + sampling = f.create_group("sampling_args") + sampling["gibbs_steps"] = gibbs_steps + sampling["beta"] = beta + + train_args = f.create_group("train_args") + train_args["optim"] = optim + train_args["batch_size"] = batch_size + train_args["learning_rate"] = learning_rate + train_args["training_type"] = training_type + + save_args = f.create_group("save_args") + save_args["n_save"] = n_save + save_args["spacing"] = spacing + + +def restore_training( train_dataset: RBMDataset, - test_dataset: RBMDataset | None = None, - map_model: dict[str, EBM] = map_model, + test_dataset: RBMDataset, + args_save: dict[str, str], + args_train: dict[str, int | float], + args_dataset, + args_torch: dict[str, str | torch.dtype], + map_model: dict[str, EBM], ) -> tuple[ EBM, dict[str, Tensor], - dict[str, Any], int, float, - float, - pathlib.Path, tqdm, RBMDataset, RBMDataset, ]: + target_update = args_train["update"] + filename = args_save["filename"] + num_updates: int = args_train["num_updates"] + + # Torch + device: str = args_torch["device"] + dtype: torch.dtype = args_torch["dtype"] + + # dataset + seed: int = args_dataset["seed"] + train_size: float = args_dataset["train_size"] + test_size: float = args_dataset["test_size"] + # Retrieve the the number of training updates already performed on the model - updates = get_saved_updates(filename=args["filename"]) - num_updates = updates[-1] - if args["num_updates"] <= num_updates: + print(f"Restoring training from update {target_update}") + + if num_updates <= target_update: raise RuntimeError( - f"The parameter /'num_updates/' ({args['num_updates']}) must be greater than the previous number of updates ({num_updates})." + f"The parameter /'num_updates/' ({num_updates}) must be greater than the previous number of updates ({target_update})." ) - params, parallel_chains, elapsed_time, hyperparameters = load_model( - args["filename"], - num_updates, - device=args["device"], - dtype=args["dtype"], + params, parallel_chains, elapsed_time = load_model( + filename, + target_update, + device=device, + dtype=dtype, restore=True, map_model=map_model, ) - # Hyperparameters - for k, v in hyperparameters.items(): - if args[k] is None: - args[k] = v + # Delete all updates after the current one + saved_updates = get_saved_updates(filename) + if saved_updates[-1] > target_update: + to_delete = saved_updates[saved_updates > target_update] + with h5py.File(filename, "a") as f: + print("Deleting:") + for upd in to_delete: + print(f" - {upd}") + del f[f"update_{upd}"] if test_dataset is None: print("Splitting dataset") train_dataset, test_dataset = train_dataset.split_train_test( - rng=np.random.default_rng(args["seed"]), - train_size=args["train_size"], - test_size=args["test_size"], + rng=np.random.default_rng(seed), + train_size=train_size, + test_size=test_size, ) print("Train dataset:") print(train_dataset) print("Test dataset:") print(test_dataset) - # Open the log file if it exists - log_filename = pathlib.Path(args["filename"]).parent / pathlib.Path( - f"log-{pathlib.Path(args['filename']).stem}.csv" - ) - args["log"] = log_filename.exists() - - # Progress bar - pbar = tqdm( - initial=num_updates, - total=args["num_updates"], - colour="red", - dynamic_ncols=True, - ascii="-#", - ) - pbar.set_description(f"Training {params.name}") + # # Progress bar + # pbar = tqdm( + # initial=target_update, + # total=num_updates, + # colour="red", + # dynamic_ncols=True, + # ascii="-#", + # ) + # pbar.set_description(f"Training {params.name}") # Initialize gradients for the parameters params.init_grad() - # Start recording training time - start = time.time() - train_dataset.match_model_variable_type(params.visible_type) test_dataset.match_model_variable_type(params.visible_type) return ( params, parallel_chains, - args, - num_updates, - start, + target_update, elapsed_time, - log_filename, - pbar, train_dataset, test_dataset, ) - - -def create_machine( - filename: str, - params: EBM, - num_visibles: int, - num_hiddens: int, - num_chains: int, - batch_size: int, - gibbs_steps: int, - learning_rate: float, - train_size: float, - log: bool, - flags: list[str], - seed: int, - L1: float, - L2: float, -) -> None: - """Create a RBM and save it to a new file. - - Args: - filename (str): The name of the file to save the RBM. - params (RBM): Initialized parameters. - num_visibles (int): Number of visible units. - num_hiddens (int): Number of hidden units. - num_chains (int): Number of parallel chains for gradient computation. - batch_size (int): Size of the data batch. - gibbs_steps (int): Number of Gibbs steps to perform. - learning_rate (float): Learning rate for training. - log (bool): Whether to enable logging. - L1 (float): Lambda parameter for L1 regularization. - L2 (float): Lambda parameter for L2 regularization. - """ - # Permanent chains - parallel_chains = params.init_chains(num_samples=num_chains) - parallel_chains = params.sample_state(chains=parallel_chains, n_steps=gibbs_steps) - with h5py.File(filename, "w") as file_model: - hyperparameters = file_model.create_group("hyperparameters") - hyperparameters["num_hiddens"] = num_hiddens - hyperparameters["num_visibles"] = num_visibles - hyperparameters["num_chains"] = num_chains - hyperparameters["batch_size"] = batch_size - hyperparameters["gibbs_steps"] = gibbs_steps - hyperparameters["filename"] = str(filename) - hyperparameters["learning_rate"] = learning_rate - hyperparameters["train_size"] = train_size - hyperparameters["seed"] = seed - hyperparameters["L1"] = L1 - hyperparameters["L2"] = L2 - - save_model( - filename=filename, - params=params, - chains=parallel_chains, - num_updates=1, - time=0.0, - flags=flags, - ) - if log: - filename = pathlib.Path(filename) - log_filename = filename.parent / pathlib.Path(f"log-{filename.stem}.csv") - with open(log_filename, "w", encoding="utf-8") as log_file: - log_file.write(",".join(LOG_FILE_HEADER) + "\n") - - -def get_checkpoints(num_updates: int, n_save: int, spacing: str = "exp") -> np.ndarray: - """Select the list of training times (ages) at which to save the model. - - Args: - num_updates (int): Number of gradient updates to perform during training. - n_save (int): Number of models to save. - spacing (str, optional): Spacing method, either "linear" ("lin") or "exponential" ("exp"). Defaults to "exp". - - Returns: - np.ndarray: Array of checkpoint indices. - """ - match spacing: - case "exp": - checkpoints = [] - xi = num_updates - for _ in range(n_save): - checkpoints.append(xi) - xi = xi / num_updates ** (1 / n_save) - checkpoints = np.unique(np.array(checkpoints, dtype=np.int32)) - case "linear": - checkpoints = np.linspace(1, num_updates, n_save).astype(np.int32) - case _: - raise ValueError(f"spacing should be one of ('exp', 'linear'), got {spacing}") - checkpoints = np.unique(np.append(checkpoints, num_updates)) - return checkpoints - - -def initialize_model_archive( - args: dict, - model_type: str, - train_dataset: RBMDataset, - test_dataset: RBMDataset | None, - dtype: torch.dtype, - flags: list[str] = ["checkpoint"], - map_model: dict[str, EBM] = map_model, -): - num_visibles = train_dataset.get_num_visibles() - args = set_args_default(args=args, default_args=default_args) - rng = np.random.default_rng(args["seed"]) - if test_dataset is None: - train_dataset, _ = train_dataset.split_train_test( - rng, args["train_size"], args["test_size"] - ) - train_dataset.match_model_variable_type(visible_type=map_model[model_type].visible_type) - params = map_model[model_type].init_parameters( - num_hiddens=args["num_hiddens"], - dataset=train_dataset, - device=args["device"], - dtype=dtype, - ) - - if isinstance(params, PBRBM): - ensure_zero_sum_gauge(params) - create_machine( - filename=args["filename"], - params=params, - num_visibles=num_visibles, - num_hiddens=args["num_hiddens"], - num_chains=args["num_chains"], - batch_size=args["batch_size"], - gibbs_steps=args["gibbs_steps"], - learning_rate=args["learning_rate"], - train_size=args["train_size"], - log=args["log"], - flags=flags, - seed=args["seed"], - L1=args["L1"], - L2=args["L2"], - ) From be952e47abd8c2f7021f39bf9db91dede8a2846e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Thu, 5 Feb 2026 00:28:41 +0100 Subject: [PATCH 37/43] new parser, keep the old fucntions for compatibility --- rbms/parser.py | 141 ++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 129 insertions(+), 12 deletions(-) diff --git a/rbms/parser.py b/rbms/parser.py index 101df5b..ca678dc 100644 --- a/rbms/parser.py +++ b/rbms/parser.py @@ -48,18 +48,7 @@ def add_args_saves(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: default=50, help="(Defaults to 50). Number of models to save during the training.", ) - save_args.add_argument( - "--acc_ptt", - type=float, - default=None, - help="(Defaults to 0.25). Minimum PTT acceptance to save configurations for ptt file.", - ) - save_args.add_argument( - "--acc_ll", - type=float, - default=None, - help="(Defaults to 0.7). Minimum PTT acceptance to save configurations for ll file.", - ) + save_args.add_argument( "--spacing", type=str, @@ -79,6 +68,120 @@ def add_args_saves(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: return parser +def add_args_init_rbm(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + rbm_args = parser.add_argument_group("RBM") + rbm_args.add_argument( + "--num_hiddens", + type=int, + default=None, + help="(Defaults to 100). Number of hidden units.", + ) + rbm_args.add_argument( + "--num_chains", + type=int, + default=None, + help="(Defaults to 2000). Number of parallel chains.", + ) + rbm_args.add_argument( + "--model_type", + type=str, + default=None, + help="(Defaults to None). Model to use. If None is provided, will be a RBM with the same visible type as the dataset and binary hiddens. If restore, this argument is ignored.", + ) + return parser + + +def add_sampling_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + sampling_args = parser.add_argument_group("Sampling") + sampling_args.add_argument( + "--gibbs_steps", + type=int, + default=None, + help="(Defaults to 100). Number of gibbs steps to perform for each gradient update.", + ) + sampling_args.add_argument( + "--beta", + default=None, + type=float, + help="(Defaults to 1.0). The inverse temperature of the RBM", + ) + return parser + + +def add_grad_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + grad_args = parser.add_argument_group("Gradient") + grad_args.add_argument( + "--L1", + default=None, + type=float, + help="(Defaults to 0.0). Lambda parameter for the L1 regularization.", + ) + grad_args.add_argument( + "--L2", + default=None, + type=float, + help="(Defaults to 0.0). Lambda parameter for the L2 regularization.", + ) + grad_args.add_argument( + "--no_center", + default=False, + action="store_true", + help="(Defaults to False). Use the non-centered gradient.", + ) + grad_args.add_argument( + "--max_norm_grad", + default=None, + type=float, + help="(Defaults to None). Maximum norm of the gradient before update.", + ) + grad_args.add_argument( + "--normalize_grad", + default=False, + action="store_true", + help="(Defaults to False). Normalize the gradient before update.", + ) + return parser + + +def add_args_train(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + train_args = parser.add_argument_group("Train") + train_args.add_argument( + "--batch_size", + type=int, + default=None, + help="(Defaults to 2000). Minibatch size.", + ) + train_args.add_argument( + "--learning_rate", + type=float, + default=None, + help="(Defaults to 0.01). Learning rate.", + ) + train_args.add_argument( + "--num_updates", + default=None, + type=int, + help="(Defaults to 10 000). Number of gradient updates to perform.", + ) + train_args.add_argument( + "--optim", default=None, type=str, help="(Defaults to sgd). Optimizer to use." + ) + train_args.add_argument( + "--mult_optim", + action="store_true", + default=False, + help="(Defaults to False). Use a different optimizer for each param group.", + ) + train_args.add_argument( + "--training_type", + type=str, + default="pcd", + help="(Defaults to 'pcd'). Type of the training, should be one of {'pcd', 'cd', 'rdm'}.", + ) + + return parser + + def add_args_rbm(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """Add an argument group to the parser for the general hyperparameters of a RBM @@ -158,6 +261,18 @@ def add_args_rbm(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: default=None, help="(Defaults to None). Model to use. If None is provided, will be a RBM with the same visible type as the dataset and binary hiddens. If restore, this argument is ignored.", ) + rbm_args.add_argument( + "--update", + type=int, + default=None, + help="(Defaults to None). The update to restore from. If set to None or to an update not in the archive, the last one will be selected instead.", + ) + rbm_args.add_argument( + "--max_norm_grad", + default=None, + type=float, + help="(Defaults to None). Maximum norm of the gradient before update.", + ) return parser @@ -230,6 +345,8 @@ def match_args_dtype(args: dict[str, Any]) -> dict[str, Any]: "no_center": False, "L1": 0.0, "L2": 0.0, + "max_norm_grad": -1, + "optim": "sgd", } From 8559d4e02158b1166972cfa0c967353c1a0c339e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Thu, 5 Feb 2026 00:29:50 +0100 Subject: [PATCH 38/43] save learning rate during training and remove the hyperparameters loading from load_model --- rbms/io.py | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/rbms/io.py b/rbms/io.py index b92957f..2dba454 100644 --- a/rbms/io.py +++ b/rbms/io.py @@ -14,6 +14,7 @@ def save_model( chains: dict[str, Tensor], num_updates: int, time: float, + learning_rate: Tensor, flags: list[str] = [], ) -> None: """Save the current state of the model. @@ -46,7 +47,7 @@ def save_model( checkpoint["numpy_rng_arg3"] = np.random.get_state()[3] checkpoint["numpy_rng_arg4"] = np.random.get_state()[4] checkpoint["time"] = time - + checkpoint["learning_rate"] = learning_rate.cpu().numpy() # Update the parallel chains to resume training if "parallel_chains" in f.keys(): f["parallel_chains"][...] = chains["visible"].cpu().numpy() @@ -98,7 +99,7 @@ def load_model( dtype: torch.dtype, restore: bool = False, map_model: dict[str, EBM] = map_model, -) -> tuple[EBM, dict[str, Tensor], float, dict]: +) -> tuple[EBM, dict[str, Tensor], float]: """Load a RBM from a h5 archive. Args: @@ -111,10 +112,9 @@ def load_model( Returns: Tuple[EBM, dict[str, Tensor], float, dict]: A tuple containing the loaded RBM parameters, - the parallel chains, the time taken, and the model's hyperparameters. + the parallel chains and the time taken """ last_file_key = f"update_{index}" - hyperparameters = dict() with h5py.File(filename, "r") as f: visible = torch.from_numpy(f["parallel_chains"][()]).to( device=device, dtype=dtype @@ -122,21 +122,6 @@ def load_model( # Elapsed time start = np.array(f[last_file_key]["time"]).item() - # Hyperparameters - if "hyperparameters" in f.keys(): - hyperparameters["batch_size"] = int(f["hyperparameters"]["batch_size"][()]) - hyperparameters["gibbs_steps"] = int(f["hyperparameters"]["gibbs_steps"][()]) - hyperparameters["learning_rate"] = float( - f["hyperparameters"]["learning_rate"][()] - ) - hyperparameters["L1"] = float(f["hyperparameters"]["L1"][()]) - hyperparameters["L2"] = float(f["hyperparameters"]["L2"][()]) - if "seed" in f["hyperparameters"].keys(): - hyperparameters["seed"] = int(f["hyperparameters"]["seed"][()]) - if "train_size" in f["hyperparameters"].keys(): - hyperparameters["train_size"] = float( - f["hyperparameters"]["train_size"][()] - ) params = load_params( filename=filename, index=index, device=device, dtype=dtype, map_model=map_model ) @@ -144,4 +129,4 @@ def load_model( if restore: restore_rng_state(filename=filename, index=index) - return (params, perm_chains, start, hyperparameters) + return (params, perm_chains, start) From feb4e0a3f475f1f365b2fecfb86c67ce05f29101 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Thu, 5 Feb 2026 00:30:08 +0100 Subject: [PATCH 39/43] util to handle optimizer declaration --- rbms/optim.py | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 rbms/optim.py diff --git a/rbms/optim.py b/rbms/optim.py new file mode 100644 index 0000000..b1ad59e --- /dev/null +++ b/rbms/optim.py @@ -0,0 +1,48 @@ +from rbms.classes import EBM +from torch.optim import SGD, Optimizer +from torch import Tensor +import torch + +def setup_optim(optim: str, args: dict, params: EBM) -> list[Optimizer]: + match args["optim"]: + case "sgd": + optim = SGD + case _: + print(f"Unrecognized optimizer {args['optim']}, falling back to SGD.") + optim = SGD + learning_rate = args["learning_rate"] + if args["mult_optim"]: + if not isinstance(learning_rate, Tensor): + learning_rate = torch.tensor([learning_rate] * len(params.parameters())) + optimizer = [ + optim( + [p], + lr=learning_rate[i], + maximize=True, + ) + for i, p in enumerate(params.parameters()) + ] + else: + if not isinstance(learning_rate, Tensor): + learning_rate = torch.tensor([learning_rate]) + optimizer = [ + optim( + params.parameters(), + lr=learning_rate[0], + maximize=True, + ) + ] + + if args["optim"] == "nag": + optimizer = [ + SGD( + opt.param_groups[0]["params"], + lr=opt.param_groups[0]["lr"], + maximize=True, + momentum=0.9, + nesterov=True, + ) + for opt in optimizer + ] + + return optimizer From 9544f24b6d9bafe990a28e9c5cfd1a9584d163e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Thu, 5 Feb 2026 00:34:37 +0100 Subject: [PATCH 40/43] remove test for removed function --- .../test_bernoulli_utils_bbrbm.py | 124 ------------------ .../potts_bernoulli/test_utils_pbrbm.py | 118 ----------------- 2 files changed, 242 deletions(-) diff --git a/tests/unit_test/bernoulli_bernoulli/test_bernoulli_utils_bbrbm.py b/tests/unit_test/bernoulli_bernoulli/test_bernoulli_utils_bbrbm.py index dd580f6..e69de29 100644 --- a/tests/unit_test/bernoulli_bernoulli/test_bernoulli_utils_bbrbm.py +++ b/tests/unit_test/bernoulli_bernoulli/test_bernoulli_utils_bbrbm.py @@ -1,124 +0,0 @@ -from pathlib import Path - -import pytest -import torch - -from rbms.bernoulli_bernoulli.classes import BBRBM -from rbms.const import LOG_FILE_HEADER -from rbms.io import load_model -from rbms.training.utils import create_machine - - -# Helper function to create a temporary HDF5 file for testing -def create_temp_hdf5_file(tmp_path, sample_params_class_bbrbm, sample_chains_bbrbm): - filename = tmp_path / "test_model.h5" - create_machine( - filename, - params=sample_params_class_bbrbm, - chains=sample_chains_bbrbm, - num_updates=1, - time=0.0, - ) - return filename - - -def test_create_load_machine(tmp_path, sample_params_class_bbrbm): - filename = tmp_path / "test_model.h5" - device = torch.device("cpu") - dtype = torch.float32 - create_machine( - filename=str(filename), - params=sample_params_class_bbrbm, - num_visibles=pytest.NUM_VISIBLES, - num_hiddens=pytest.NUM_HIDDENS, - num_chains=pytest.NUM_CHAINS, - batch_size=pytest.BATCH_SIZE, - gibbs_steps=pytest.GIBBS_STEPS, - learning_rate=pytest.LEARNING_RATE, - train_size=pytest.TRAIN_SIZE, - log=True, - flags=["test"], - seed=pytest.SEED, - L1=1.0, - L2=0.0, - ) - - # Check if the file was created - assert filename.exists() - - # Check if the log file was created - log_filename = filename.parent / Path(f"log-{filename.stem}.csv") - assert log_filename.exists() - - # Check the contents of the log file - with open(log_filename, "r", encoding="utf-8") as log_file: - header = log_file.readline().strip() - assert header == ",".join(LOG_FILE_HEADER) - - params, chains, start, hyperparameters = load_model( - filename=str(filename), - index=1, - device=device, - dtype=dtype, - restore=False, - ) - assert isinstance(params, BBRBM) - assert isinstance(chains, dict) - assert isinstance(start, float) - assert isinstance(hyperparameters, dict) - assert hyperparameters["batch_size"] == pytest.BATCH_SIZE - assert hyperparameters["gibbs_steps"] == pytest.GIBBS_STEPS - assert hyperparameters["learning_rate"] == pytest.LEARNING_RATE - assert hyperparameters["seed"] == pytest.SEED - - -def test_create_load_machine_dtype(tmp_path, sample_params_class_bbrbm): - filename = tmp_path / "test_model.h5" - device = torch.device("cpu") - dtype = torch.float64 - create_machine( - filename=str(filename), - params=sample_params_class_bbrbm, - num_visibles=pytest.NUM_VISIBLES, - num_hiddens=pytest.NUM_HIDDENS, - num_chains=pytest.NUM_CHAINS, - batch_size=pytest.BATCH_SIZE, - gibbs_steps=pytest.GIBBS_STEPS, - learning_rate=pytest.LEARNING_RATE, - train_size=pytest.TRAIN_SIZE, - log=True, - flags=["test"], - seed=pytest.SEED, - L1=0.0, - L2=1.0, - ) - - # Check if the file was created - assert filename.exists() - - # Check if the log file was created - log_filename = filename.parent / Path(f"log-{filename.stem}.csv") - assert log_filename.exists() - - # Check the contents of the log file - with open(log_filename, "r", encoding="utf-8") as log_file: - header = log_file.readline().strip() - assert header == ",".join(LOG_FILE_HEADER) - - params, chains, start, hyperparameters = load_model( - filename=str(filename), - index=1, - device=device, - dtype=dtype, - restore=False, - ) - assert isinstance(params, BBRBM) - assert isinstance(chains, dict) - assert isinstance(start, float) - assert isinstance(hyperparameters, dict) - assert hyperparameters["batch_size"] == pytest.BATCH_SIZE - assert hyperparameters["gibbs_steps"] == pytest.GIBBS_STEPS - assert hyperparameters["learning_rate"] == pytest.LEARNING_RATE - assert hyperparameters["seed"] == pytest.SEED - assert chains["weights"].shape == (pytest.NUM_CHAINS,) - assert chains["visible"].shape == (pytest.NUM_CHAINS, pytest.NUM_VISIBLES) diff --git a/tests/unit_test/potts_bernoulli/test_utils_pbrbm.py b/tests/unit_test/potts_bernoulli/test_utils_pbrbm.py index 77ae445..e69de29 100644 --- a/tests/unit_test/potts_bernoulli/test_utils_pbrbm.py +++ b/tests/unit_test/potts_bernoulli/test_utils_pbrbm.py @@ -1,118 +0,0 @@ -from pathlib import Path - -import pytest -import torch - -from rbms.const import LOG_FILE_HEADER -from rbms.io import load_model -from rbms.potts_bernoulli.classes import PBRBM -from rbms.training.utils import create_machine - - -# Helper function to create a temporary HDF5 file for testing -def create_temp_hdf5_file(tmp_path, sample_params, sample_chains): - filename = tmp_path / "test_model.h5" - create_machine( - filename, params=sample_params, chains=sample_chains, num_updates=1, time=0 - ) - return filename - - -def test_create_load_machine(tmp_path, sample_params_class_pbrbm): - filename = tmp_path / "test_model.h5" - device = torch.device("cpu") - dtype = torch.float32 - create_machine( - filename=str(filename), - params=sample_params_class_pbrbm, - num_visibles=pytest.NUM_VISIBLES, - num_hiddens=pytest.NUM_HIDDENS, - num_chains=pytest.NUM_CHAINS, - batch_size=pytest.BATCH_SIZE, - gibbs_steps=pytest.GIBBS_STEPS, - learning_rate=pytest.LEARNING_RATE, - train_size=pytest.TRAIN_SIZE, - log=True, - flags=["test"], - seed=pytest.SEED, - L1=1.0, - L2=0.0, - ) - - # Check if the file was created - assert filename.exists() - - # Check if the log file was created - log_filename = filename.parent / Path(f"log-{filename.stem}.csv") - assert log_filename.exists() - - # Check the contents of the log file - with open(log_filename, "r", encoding="utf-8") as log_file: - header = log_file.readline().strip() - assert header == ",".join(LOG_FILE_HEADER) - - params, chains, start, hyperparameters = load_model( - filename=str(filename), - index=1, - device=device, - dtype=dtype, - restore=False, - ) - assert isinstance(params, PBRBM) - assert isinstance(chains, dict) - assert isinstance(start, float) - assert isinstance(hyperparameters, dict) - assert hyperparameters["batch_size"] == pytest.BATCH_SIZE - assert hyperparameters["gibbs_steps"] == pytest.GIBBS_STEPS - assert hyperparameters["learning_rate"] == pytest.LEARNING_RATE - assert hyperparameters["seed"] == pytest.SEED - - -def test_create_load_machine_dtype(tmp_path, sample_params_class_pbrbm): - filename = tmp_path / "test_model.h5" - device = torch.device("cpu") - dtype = torch.float64 - create_machine( - filename=str(filename), - params=sample_params_class_pbrbm, - num_visibles=pytest.NUM_VISIBLES, - num_hiddens=pytest.NUM_HIDDENS, - num_chains=pytest.NUM_CHAINS, - batch_size=pytest.BATCH_SIZE, - gibbs_steps=pytest.GIBBS_STEPS, - learning_rate=pytest.LEARNING_RATE, - train_size=pytest.TRAIN_SIZE, - log=True, - flags=["test"], - seed=pytest.SEED, - L1=0.0, - L2=1.0, - ) - - # Check if the file was created - assert filename.exists() - - # Check if the log file was created - log_filename = filename.parent / Path(f"log-{filename.stem}.csv") - assert log_filename.exists() - - # Check the contents of the log file - with open(log_filename, "r", encoding="utf-8") as log_file: - header = log_file.readline().strip() - assert header == ",".join(LOG_FILE_HEADER) - - params, chains, start, hyperparameters = load_model( - filename=str(filename), - index=1, - device=device, - dtype=dtype, - restore=False, - ) - assert isinstance(params, PBRBM) - assert isinstance(chains, dict) - assert isinstance(start, float) - assert isinstance(hyperparameters, dict) - assert hyperparameters["batch_size"] == pytest.BATCH_SIZE - assert hyperparameters["gibbs_steps"] == pytest.GIBBS_STEPS - assert hyperparameters["learning_rate"] == pytest.LEARNING_RATE - assert hyperparameters["seed"] == pytest.SEED From 34ee17a333b8b4fb7687102889d3e77916efac50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Thu, 5 Feb 2026 00:40:24 +0100 Subject: [PATCH 41/43] remove weights from init_parameters --- rbms/potts_bernoulli/classes.py | 1 - rbms/potts_bernoulli/implement.py | 1 - 2 files changed, 2 deletions(-) diff --git a/rbms/potts_bernoulli/classes.py b/rbms/potts_bernoulli/classes.py index 36dbaf9..a767dae 100644 --- a/rbms/potts_bernoulli/classes.py +++ b/rbms/potts_bernoulli/classes.py @@ -156,7 +156,6 @@ def init_parameters(num_hiddens, dataset, device, dtype, var_init=0.0001): vbias, hbias, weight_matrix = _init_parameters( num_hiddens=num_hiddens, data=data, - weights=dataset.weights, device=device, dtype=dtype, var_init=var_init, diff --git a/rbms/potts_bernoulli/implement.py b/rbms/potts_bernoulli/implement.py index 9c491e1..77cfc44 100644 --- a/rbms/potts_bernoulli/implement.py +++ b/rbms/potts_bernoulli/implement.py @@ -213,7 +213,6 @@ def _init_chains( def _init_parameters( num_hiddens: int, data: Tensor, - weights: Tensor, device: torch.device, dtype: torch.dtype, var_init: float = 1e-4, From b363fff379648d4f65fe01895cd10f0d741ad664 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Thu, 5 Feb 2026 00:40:45 +0100 Subject: [PATCH 42/43] add learning_rate --- tests/unit_test/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_test/test_utils.py b/tests/unit_test/test_utils.py index d4fc94c..b042302 100644 --- a/tests/unit_test/test_utils.py +++ b/tests/unit_test/test_utils.py @@ -229,7 +229,7 @@ def test_save_model(tmp_path, sample_params_class_bbrbm, sample_chains_bbrbm): num_updates = 1 time = 0.0 - save_model(str(filename), params, chains, num_updates, time, ["flag_1", "flag_2"]) + save_model(str(filename), params, chains, num_updates, time, torch.tensor([0.01, 0.01, 0.01]), ["flag_1", "flag_2"]) with h5py.File(filename, "r") as f: assert "update_1" in f.keys() From 7ed3c2bcc1d4ae0bbecc95968ca5cfecd939bc49 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicolas=20B=C3=A9reux?= Date: Wed, 11 Feb 2026 16:45:27 +0100 Subject: [PATCH 43/43] margaret update --- pyproject.toml | 2 +- rbms/io.py | 1 + rbms/ising_ising/implement.py | 1 + rbms/optim.py | 20 ++- rbms/parser.py | 14 +- rbms/scripts/restore.py | 4 + rbms/scripts/train_rbm.py | 79 ++++++--- rbms/training/implement.py | 284 ++++++++++++++++++++++++++++++ rbms/training/pcd.py | 25 ++- rbms/training/utils.py | 317 +++++++++++++++++++--------------- 10 files changed, 573 insertions(+), 174 deletions(-) create mode 100644 rbms/training/implement.py diff --git a/pyproject.toml b/pyproject.toml index d0df24e..0bd9315 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "rbms" -version = "0.5.1" +version = "0.6.0" authors = [ {name="Nicolas Béreux", email="nicolas.bereux@gmail.com"}, {name="Aurélien Decelle"}, diff --git a/rbms/io.py b/rbms/io.py index 2dba454..06e5780 100644 --- a/rbms/io.py +++ b/rbms/io.py @@ -8,6 +8,7 @@ from rbms.utils import restore_rng_state +@torch.compiler.disable def save_model( filename: str, params: EBM, diff --git a/rbms/ising_ising/implement.py b/rbms/ising_ising/implement.py index 3bc3187..6c5b60c 100644 --- a/rbms/ising_ising/implement.py +++ b/rbms/ising_ising/implement.py @@ -173,6 +173,7 @@ def _init_parameters( weight_matrix = ( torch.randn(size=(num_visibles, num_hiddens), device=device, dtype=dtype) * var_init + * 0.1 ) frequencies = data.mean(0) frequencies = torch.clamp(frequencies, min=-(1.0 - eps), max=(1.0 - eps)) diff --git a/rbms/optim.py b/rbms/optim.py index b1ad59e..48768fc 100644 --- a/rbms/optim.py +++ b/rbms/optim.py @@ -1,16 +1,27 @@ -from rbms.classes import EBM -from torch.optim import SGD, Optimizer -from torch import Tensor +import numpy as np import torch +from ptt.optim.cossim import SGD_cossim +from torch import Tensor +from torch.optim import SGD, Optimizer + +from rbms.classes import EBM + def setup_optim(optim: str, args: dict, params: EBM) -> list[Optimizer]: match args["optim"]: case "sgd": optim = SGD + case "cossim": + optim = SGD_cossim case _: print(f"Unrecognized optimizer {args['optim']}, falling back to SGD.") optim = SGD learning_rate = args["learning_rate"] + max_lr = args["max_lr"] + if args["scale_lr"]: + learning_rate /= np.sqrt(np.sqrt(params.num_visibles() * params.num_hiddens())) + max_lr /= np.sqrt(np.sqrt(params.num_visibles() * params.num_hiddens())) + if args["mult_optim"]: if not isinstance(learning_rate, Tensor): learning_rate = torch.tensor([learning_rate] * len(params.parameters())) @@ -32,6 +43,9 @@ def setup_optim(optim: str, args: dict, params: EBM) -> list[Optimizer]: maximize=True, ) ] + for opt in optimizer: + if isinstance(opt, SGD_cossim): + opt.max_lr = max_lr if args["optim"] == "nag": optimizer = [ diff --git a/rbms/parser.py b/rbms/parser.py index ca678dc..672095e 100644 --- a/rbms/parser.py +++ b/rbms/parser.py @@ -178,7 +178,18 @@ def add_args_train(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: default="pcd", help="(Defaults to 'pcd'). Type of the training, should be one of {'pcd', 'cd', 'rdm'}.", ) - + train_args.add_argument( + "--max_lr", + type=float, + default=None, + help="(Defaults to 10). Maximum learning rate when adaptative learning rate is used.", + ) + train_args.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Set it to scale learning rate with the number of variables of the system", + ) return parser @@ -347,6 +358,7 @@ def match_args_dtype(args: dict[str, Any]) -> dict[str, Any]: "L2": 0.0, "max_norm_grad": -1, "optim": "sgd", + "max_lr": 10, } diff --git a/rbms/scripts/restore.py b/rbms/scripts/restore.py index 7cfd1ec..b455c0e 100644 --- a/rbms/scripts/restore.py +++ b/rbms/scripts/restore.py @@ -132,6 +132,9 @@ def recover_args( args_train["training_type"] = args["training_type"] if args_train["training_type"] is None: args_train["training_type"] = train_args["training_type"][()].decode() + if args_train["max_lr"] is None: + args_train["max_lr"] = train_args["max_lr"][()].item() + args_train["scale_lr"] = args["scale_lr"] # Torch args_torch = {} @@ -152,6 +155,7 @@ def recover_args( def main(): + torch.set_float32_matmul_precision("high") torch.backends.cudnn.benchmark = True parser = create_parser_restore() args = parser.parse_args() diff --git a/rbms/scripts/train_rbm.py b/rbms/scripts/train_rbm.py index 023465a..db75254 100644 --- a/rbms/scripts/train_rbm.py +++ b/rbms/scripts/train_rbm.py @@ -1,5 +1,7 @@ import argparse + import torch + from rbms.dataset import load_dataset from rbms.dataset.parser import add_args_dataset from rbms.map_model import map_model @@ -46,32 +48,56 @@ def process_args(args: dict): "remove_duplicates": args["remove_duplicates"], "seed": args["seed"], } - args_grad = { - "no_center": args["no_center"], - "normalize_grad": args["normalize_grad"], - "max_norm_grad": args["max_norm_grad"], - "L1": args["L1"], - "L2": args["L2"], - } - args_sampling = {"gibbs_steps": args["gibbs_steps"], "beta": args["beta"]} - args_train = { - "optim": args["optim"], - "learning_rate": args["learning_rate"], - "batch_size": args["batch_size"], - "num_updates": args["num_updates"], - "mult_optim": args["mult_optim"], - "training_type": args["training_type"], - } - args_save = { - "filename": args["filename"], - "n_save": args["n_save"], - "spacing": args["spacing"], - } - args_init = { - "num_chains": args["num_chains"], - "num_hiddens": args["num_hiddens"], - "model_type": args["model_type"], - } + key_args_grad = [ + "no_center", + "normalize_grad", + "max_norm_grad", + "L1", + "L2", + ] + key_args_sampling = [ + "gibbs_steps", + "beta", + ] + key_args_train = [ + "optim", + "learning_rate", + "batch_size", + "num_updates", + "mult_optim", + "training_type", + "max_lr", + "scale_lr", + ] + key_args_save = [ + "filename", + "n_save", + "spacing", + "overwrite", + ] + key_args_init = [ + "num_chains", + "num_hiddens", + "model_type", + ] + + args_grad = {} + args_sampling = {} + args_train = {} + args_save = {} + args_init = {} + all_target = [args_grad, args_sampling, args_train, args_save, args_init] + all_keys = [ + key_args_grad, + key_args_sampling, + key_args_train, + key_args_save, + key_args_init, + ] + for target, keys in zip(all_target, all_keys): + for k in keys: + target[k] = args[k] + return ( args_dataset, args_save, @@ -84,6 +110,7 @@ def process_args(args: dict): def main(): + torch.set_float32_matmul_precision("high") torch.backends.cudnn.benchmark = True parser = create_parser() args = parser.parse_args() diff --git a/rbms/training/implement.py b/rbms/training/implement.py new file mode 100644 index 0000000..a87f850 --- /dev/null +++ b/rbms/training/implement.py @@ -0,0 +1,284 @@ +import time + +import h5py +import numpy as np +import torch +from torch import Tensor +from torch.optim import Optimizer +from tqdm.autonotebook import tqdm + +from rbms.classes import EBM +from rbms.dataset.dataset_class import RBMDataset +from rbms.io import load_model, save_model +from rbms.map_model import map_model +from rbms.potts_bernoulli.classes import PBRBM +from rbms.potts_bernoulli.utils import ensure_zero_sum_gauge +from rbms.training.pcd import fit_batch_pcd +from rbms.utils import get_saved_updates + + +def _init_training( + train_dataset: RBMDataset, + seed: int, + train_size: float, + test_size: float, + num_hiddens: int, + num_chains: int, + model_type: str, + filename: str, + n_save: int, + spacing: str, + batch_size: int, + optim: str, + mult_optim: bool, + training_type: str, + learning_rate: float, + max_lr: float, + gibbs_steps: int, + beta: float, + centered: bool, + L1: float, + L2: float, + normalize_grad: bool, + max_norm_grad: float, + subset_labels: list, + use_weights: bool, + alphabet: str, + remove_duplicates: bool, + dtype: torch.dtype, + device: str, + flags: list[str], + map_model: dict[str, EBM] = map_model, +): + if model_type is None: + match train_dataset.variable_type: + case "bernoulli": + model_type = "BBRBM" + case "categorical": + model_type = "PBRBM" + case "ising": + model_type = "IIRBM" + case _: + raise NotImplementedError() + + # Setup dataset + num_visibles = train_dataset.get_num_visibles() + + # Setup RBM + params = map_model[model_type].init_parameters( + num_hiddens=num_hiddens, + dataset=train_dataset, + device=device, + dtype=dtype, + ) + if isinstance(params, PBRBM): + ensure_zero_sum_gauge(params) + + # Permanent chains + parallel_chains = params.init_chains(num_samples=num_chains) + parallel_chains = params.sample_state(chains=parallel_chains, n_steps=gibbs_steps) + + # Save hyperparameters + if mult_optim: + learning_rate = torch.tensor([learning_rate] * len(params.parameters())) + else: + learning_rate = torch.tensor([learning_rate]) + + with h5py.File(filename, "w") as file_model: + hyperparameters = file_model.create_group("hyperparameters") + hyperparameters["num_visibles"] = num_visibles + hyperparameters["num_hiddens"] = num_hiddens + hyperparameters["num_chains"] = num_chains + hyperparameters["filename"] = str(filename) + + save_model( + filename=filename, + params=params, + chains=parallel_chains, + num_updates=1, + time=0.0, + flags=flags, + learning_rate=learning_rate, + ) + + with h5py.File(filename, "a") as f: + dataset = f.create_group("dataset_args") + if subset_labels is not None: + dataset["subset_labels"] = subset_labels + dataset["use_weights"] = use_weights + dataset["train_size"] = train_size + dataset["test_size"] = test_size + dataset["alphabet"] = alphabet + dataset["remove_duplicates"] = remove_duplicates + dataset["seed"] = seed + + grad = f.create_group("grad_args") + grad["no_center"] = not (centered) + grad["normalize_grad"] = normalize_grad + grad["max_norm_grad"] = max_norm_grad + grad["L1"] = L1 + grad["L2"] = L2 + + sampling = f.create_group("sampling_args") + sampling["gibbs_steps"] = gibbs_steps + sampling["beta"] = beta + + train_args = f.create_group("train_args") + train_args["optim"] = optim + train_args["batch_size"] = batch_size + train_args["learning_rate"] = learning_rate + train_args["training_type"] = training_type + train_args["max_lr"] = max_lr + + save_args = f.create_group("save_args") + save_args["n_save"] = n_save + save_args["spacing"] = spacing + + +def _restore_training( + filename: str, + train_dataset: RBMDataset, + test_dataset: RBMDataset | None, + num_updates: int, + target_update: int, + seed: int, + train_size: float, + test_size: float, + device: str, + dtype: torch.dtype, +): + # Retrieve the the number of training updates already performed on the model + print(f"Restoring training from update {target_update}") + + if num_updates <= target_update: + raise RuntimeError( + f"The parameter /'num_updates/' ({num_updates}) must be greater than the previous number of updates ({target_update})." + ) + + params, parallel_chains, elapsed_time = load_model( + filename, + target_update, + device=device, + dtype=dtype, + restore=True, + map_model=map_model, + ) + + # Delete all updates after the current one + saved_updates = get_saved_updates(filename) + if saved_updates[-1] > target_update: + to_delete = saved_updates[saved_updates > target_update] + with h5py.File(filename, "a") as f: + print("Deleting:") + for upd in to_delete: + print(f" - {upd}") + del f[f"update_{upd}"] + + if test_dataset is None: + print("Splitting dataset") + train_dataset, test_dataset = train_dataset.split_train_test( + rng=np.random.default_rng(seed), + train_size=train_size, + test_size=test_size, + ) + print("Train dataset:") + print(train_dataset) + print("Test dataset:") + print(test_dataset) + + # Initialize gradients for the parameters + params.init_grad() + + train_dataset.match_model_variable_type(params.visible_type) + test_dataset.match_model_variable_type(params.visible_type) + return ( + params, + parallel_chains, + target_update, + elapsed_time, + train_dataset, + test_dataset, + ) + + +def _train( + params: EBM, + parallel_chains: dict[str, Tensor], + optimizer: Optimizer, + train_dataset: RBMDataset, + checkpoints: np.ndarray, + curr_update: int, + num_updates: int, + batch_size: int, + training_type: str, + gibbs_steps: int, + beta: float, + centered: bool, + L1: float, + L2: float, + normalize_grad: bool, + max_norm_grad: float, + filename: str, + elapsed_time: float, +): + # pbar + pbar = tqdm( + initial=curr_update, + total=num_updates, + colour="red", + dynamic_ncols=True, + ascii="-#", + ) + pbar.set_description(f"Training {params.name}") + + start = time.perf_counter() + + for idx in range(curr_update + 1, num_updates + 1): + batch = train_dataset.batch(batch_size) + data, weights = batch["data"], batch["weights"] + if training_type == "rdm": + parallel_chains = params.init_chains(parallel_chains["visible"].shape[0]) + elif training_type == "cd": + parallel_chains = params.init_chains( + data.shape[0], + weights=weights, + start_v=data, + ) + for opt in optimizer: + opt.zero_grad(set_to_none=False) + + parallel_chains = fit_batch_pcd( + batch=(data, weights), + parallel_chains=parallel_chains, + params=params, + gibbs_steps=gibbs_steps, + beta=beta, + centered=centered, + lambda_l1=L1, + lambda_l2=L2, + normalize_grad=normalize_grad, + max_norm_grad=max_norm_grad, + ) + for opt in optimizer: + opt.step() + + if isinstance(params, PBRBM): + ensure_zero_sum_gauge(params) + + # Save current model if necessary + if idx in checkpoints or idx == num_updates: + curr_time = time.perf_counter() - start + learning_rate = torch.tensor([opt.param_groups[0]["lr"] for opt in optimizer]) + save_model( + filename=filename, + params=params, + chains=parallel_chains, + num_updates=idx, + time=curr_time + elapsed_time, + learning_rate=learning_rate, + flags=["checkpoint"], + ) + + pbar.set_postfix_str(f"lr: {optimizer[0].param_groups[0]['lr']:.6f}") + # Update progress bar + pbar.update(1) diff --git a/rbms/training/pcd.py b/rbms/training/pcd.py index 0f1ff31..c57319c 100644 --- a/rbms/training/pcd.py +++ b/rbms/training/pcd.py @@ -61,6 +61,7 @@ def fit_batch_pcd( return parallel_chains +@torch.compile @torch.no_grad def train( train_dataset: RBMDataset, @@ -105,6 +106,26 @@ def train( # save filename: str = args_save["filename"] + # _train( + # params=params, + # parallel_chains=parallel_chains, + # optimizer=optimizer, + # train_dataset=train_dataset, + # checkpoints=checkpoints, + # curr_update=curr_update, + # num_updates=num_updates, + # batch_size=batch_size, + # training_type=training_type, + # gibbs_steps=gibbs_steps, + # beta=beta, + # centered=centered, + # L1=L1, + # L2=L2, + # normalize_grad=normalize_grad, + # max_norm_grad=max_norm_grad, + # filename=filename, + # elapsed_time=elapsed_time, + # ) # pbar pbar = tqdm( initial=curr_update, @@ -115,7 +136,7 @@ def train( ) pbar.set_description(f"Training {params.name}") - start = time.time() + start = time.perf_counter() for idx in range(curr_update + 1, num_updates + 1): batch = train_dataset.batch(batch_size) @@ -151,7 +172,7 @@ def train( # Save current model if necessary if idx in checkpoints or idx == num_updates: - curr_time = time.time() - start + curr_time = time.perf_counter() - start learning_rate = torch.tensor([opt.param_groups[0]["lr"] for opt in optimizer]) save_model( filename=filename, diff --git a/rbms/training/utils.py b/rbms/training/utils.py index c16232d..854021c 100644 --- a/rbms/training/utils.py +++ b/rbms/training/utils.py @@ -1,16 +1,11 @@ -import h5py import numpy as np import torch from torch import Tensor -from tqdm.autonotebook import tqdm from rbms.classes import EBM from rbms.dataset.dataset_class import RBMDataset -from rbms.io import load_model, save_model from rbms.map_model import map_model -from rbms.potts_bernoulli.classes import PBRBM -from rbms.potts_bernoulli.utils import ensure_zero_sum_gauge -from rbms.utils import get_saved_updates +from rbms.training.implement import _init_training, _restore_training def get_checkpoints(num_updates: int, n_save: int, spacing: str = "exp") -> np.ndarray: @@ -50,6 +45,7 @@ def init_training( args_torch: dict[str, str | torch.dtype], train_dataset: RBMDataset, flags: list[str] = ["checkpoint"], + map_model: dict[str, EBM] = map_model, ): # Torch device: str = args_torch["device"] @@ -73,6 +69,7 @@ def init_training( mult_optim: bool = args_train["mult_optim"] training_type: str = args_train["training_type"] learning_rate: float = args_train["learning_rate"] + max_lr: float = args_train["max_lr"] # save filename: str = args_save["filename"] @@ -94,88 +91,124 @@ def init_training( num_hiddens: int = args_init["num_hiddens"] num_chains: int = args_init["num_chains"] model_type: str = args_init["model_type"] - if model_type is None: - match train_dataset.variable_type: - case "bernoulli": - model_type = "BBRBM" - case "categorical": - model_type = "PBRBM" - case "ising": - model_type = "IIRBM" - case _: - raise NotImplementedError() - - # Setup dataset - num_visibles = train_dataset.get_num_visibles() - - # Setup RBM - params = map_model[model_type].init_parameters( + + _init_training( + train_dataset=train_dataset, + seed=seed, + train_size=train_size, + test_size=test_size, num_hiddens=num_hiddens, - dataset=train_dataset, - device=device, - dtype=dtype, - ) - if isinstance(params, PBRBM): - ensure_zero_sum_gauge(params) - - # Permanent chains - parallel_chains = params.init_chains(num_samples=num_chains) - parallel_chains = params.sample_state(chains=parallel_chains, n_steps=gibbs_steps) - - # Save hyperparameters - if mult_optim: - learning_rate = torch.tensor([learning_rate] * len(params.parameters())) - else: - learning_rate = torch.tensor([learning_rate]) - - with h5py.File(filename, "w") as file_model: - hyperparameters = file_model.create_group("hyperparameters") - hyperparameters["num_visibles"] = num_visibles - hyperparameters["num_hiddens"] = num_hiddens - hyperparameters["num_chains"] = num_chains - hyperparameters["filename"] = str(filename) - - save_model( + num_chains=num_chains, + model_type=model_type, filename=filename, - params=params, - chains=parallel_chains, - num_updates=1, - time=0.0, - flags=flags, + n_save=n_save, + spacing=spacing, + batch_size=batch_size, + optim=optim, + mult_optim=mult_optim, + training_type=training_type, learning_rate=learning_rate, + max_lr=max_lr, + gibbs_steps=gibbs_steps, + beta=beta, + centered=centered, + L1=L1, + L2=L2, + normalize_grad=normalize_grad, + max_norm_grad=max_norm_grad, + subset_labels=subset_labels, + use_weights=use_weights, + alphabet=alphabet, + remove_duplicates=remove_duplicates, + dtype=dtype, + device=device, + flags=flags, + map_model=map_model, ) - with h5py.File(filename, "a") as f: - dataset = f.create_group("dataset_args") - if subset_labels is not None: - dataset["subset_labels"] = subset_labels - dataset["use_weights"] = use_weights - dataset["train_size"] = train_size - dataset["test_size"] = test_size - dataset["alphabet"] = alphabet - dataset["remove_duplicates"] = remove_duplicates - dataset["seed"] = seed - - grad = f.create_group("grad_args") - grad["no_center"] = not (centered) - grad["normalize_grad"] = normalize_grad - grad["max_norm_grad"] = max_norm_grad - grad["L1"] = L1 - grad["L2"] = L2 - - sampling = f.create_group("sampling_args") - sampling["gibbs_steps"] = gibbs_steps - sampling["beta"] = beta - - train_args = f.create_group("train_args") - train_args["optim"] = optim - train_args["batch_size"] = batch_size - train_args["learning_rate"] = learning_rate - train_args["training_type"] = training_type - - save_args = f.create_group("save_args") - save_args["n_save"] = n_save - save_args["spacing"] = spacing + # if model_type is None: + # match train_dataset.variable_type: + # case "bernoulli": + # model_type = "BBRBM" + # case "categorical": + # model_type = "PBRBM" + # case "ising": + # model_type = "IIRBM" + # case _: + # raise NotImplementedError() + + # # Setup dataset + # num_visibles = train_dataset.get_num_visibles() + + # # Setup RBM + # params = map_model[model_type].init_parameters( + # num_hiddens=num_hiddens, + # dataset=train_dataset, + # device=device, + # dtype=dtype, + # ) + # if isinstance(params, PBRBM): + # ensure_zero_sum_gauge(params) + + # # Permanent chains + # parallel_chains = params.init_chains(num_samples=num_chains) + # parallel_chains = params.sample_state(chains=parallel_chains, n_steps=gibbs_steps) + + # # Save hyperparameters + # if mult_optim: + # learning_rate = torch.tensor([learning_rate] * len(params.parameters())) + # else: + # learning_rate = torch.tensor([learning_rate]) + + # with h5py.File(filename, "w") as file_model: + # hyperparameters = file_model.create_group("hyperparameters") + # hyperparameters["num_visibles"] = num_visibles + # hyperparameters["num_hiddens"] = num_hiddens + # hyperparameters["num_chains"] = num_chains + # hyperparameters["filename"] = str(filename) + + # save_model( + # filename=filename, + # params=params, + # chains=parallel_chains, + # num_updates=1, + # time=0.0, + # flags=flags, + # learning_rate=learning_rate, + # ) + + # with h5py.File(filename, "a") as f: + # dataset = f.create_group("dataset_args") + # if subset_labels is not None: + # dataset["subset_labels"] = subset_labels + # dataset["use_weights"] = use_weights + # dataset["train_size"] = train_size + # dataset["test_size"] = test_size + # dataset["alphabet"] = alphabet + # dataset["remove_duplicates"] = remove_duplicates + # dataset["seed"] = seed + + # grad = f.create_group("grad_args") + # grad["no_center"] = not (centered) + # grad["normalize_grad"] = normalize_grad + # grad["max_norm_grad"] = max_norm_grad + # grad["L1"] = L1 + # grad["L2"] = L2 + + # sampling = f.create_group("sampling_args") + # sampling["gibbs_steps"] = gibbs_steps + # sampling["beta"] = beta + + # train_args = f.create_group("train_args") + # train_args["optim"] = optim + # train_args["batch_size"] = batch_size + # train_args["learning_rate"] = learning_rate + # train_args["training_type"] = training_type + # train_args["max_lr"] = max_lr + + # save_args = f.create_group("save_args") + # save_args["n_save"] = n_save + # save_args["spacing"] = spacing def restore_training( @@ -191,7 +224,6 @@ def restore_training( dict[str, Tensor], int, float, - tqdm, RBMDataset, RBMDataset, ]: @@ -208,65 +240,68 @@ def restore_training( train_size: float = args_dataset["train_size"] test_size: float = args_dataset["test_size"] - # Retrieve the the number of training updates already performed on the model - print(f"Restoring training from update {target_update}") - - if num_updates <= target_update: - raise RuntimeError( - f"The parameter /'num_updates/' ({num_updates}) must be greater than the previous number of updates ({target_update})." - ) - - params, parallel_chains, elapsed_time = load_model( - filename, - target_update, + return _restore_training( + filename=filename, + train_dataset=train_dataset, + test_dataset=test_dataset, + num_updates=num_updates, + target_update=target_update, + seed=seed, + train_size=train_size, + test_size=test_size, device=device, dtype=dtype, - restore=True, - map_model=map_model, ) - # Delete all updates after the current one - saved_updates = get_saved_updates(filename) - if saved_updates[-1] > target_update: - to_delete = saved_updates[saved_updates > target_update] - with h5py.File(filename, "a") as f: - print("Deleting:") - for upd in to_delete: - print(f" - {upd}") - del f[f"update_{upd}"] - - if test_dataset is None: - print("Splitting dataset") - train_dataset, test_dataset = train_dataset.split_train_test( - rng=np.random.default_rng(seed), - train_size=train_size, - test_size=test_size, - ) - print("Train dataset:") - print(train_dataset) - print("Test dataset:") - print(test_dataset) - - # # Progress bar - # pbar = tqdm( - # initial=target_update, - # total=num_updates, - # colour="red", - # dynamic_ncols=True, - # ascii="-#", + # # Retrieve the the number of training updates already performed on the model + # print(f"Restoring training from update {target_update}") + + # if num_updates <= target_update: + # raise RuntimeError( + # f"The parameter /'num_updates/' ({num_updates}) must be greater than the previous number of updates ({target_update})." + # ) + + # params, parallel_chains, elapsed_time = load_model( + # filename, + # target_update, + # device=device, + # dtype=dtype, + # restore=True, + # map_model=map_model, + # ) + + # # Delete all updates after the current one + # saved_updates = get_saved_updates(filename) + # if saved_updates[-1] > target_update: + # to_delete = saved_updates[saved_updates > target_update] + # with h5py.File(filename, "a") as f: + # print("Deleting:") + # for upd in to_delete: + # print(f" - {upd}") + # del f[f"update_{upd}"] + + # if test_dataset is None: + # print("Splitting dataset") + # train_dataset, test_dataset = train_dataset.split_train_test( + # rng=np.random.default_rng(seed), + # train_size=train_size, + # test_size=test_size, + # ) + # print("Train dataset:") + # print(train_dataset) + # print("Test dataset:") + # print(test_dataset) + + # # Initialize gradients for the parameters + # params.init_grad() + + # train_dataset.match_model_variable_type(params.visible_type) + # test_dataset.match_model_variable_type(params.visible_type) + # return ( + # params, + # parallel_chains, + # target_update, + # elapsed_time, + # train_dataset, + # test_dataset, # ) - # pbar.set_description(f"Training {params.name}") - - # Initialize gradients for the parameters - params.init_grad() - - train_dataset.match_model_variable_type(params.visible_type) - test_dataset.match_model_variable_type(params.visible_type) - return ( - params, - parallel_chains, - target_update, - elapsed_time, - train_dataset, - test_dataset, - )