Code for NeurIPS 2024 submission #2003
/!\ To install PyTorch with GPU support, see this link.
To install the package, you need to
git clone https://github.com/nbereux/fast-RBM.git
cd fast-RBM
bash download_dataset.sh
pip install -r requirements.txt
pip install -e .All scripts can be called with --help options to get a description of the arguments
Compute the mesh on the intrinsic space
python scripts/compute_mesh.py -d MICKEY --variable_type Ising --dimension 0 1 --border_length 0.04 --n_pts_dim 100 --device cuda -o ./mesh.h5 -dis the name of the dataset to load ("MICKEY", "MNIST" or "GENE")--variable_typeshould be set to Ising for the mesh. Can be ("Ising", "Bernoulli", "Continuous" or "Potts"). Currently, only "Ising" works for the RCM.--dimensionis the index of the dimensions of the intrinsic space--border_lengthshould be set as 2/50 or less--n_pts_dimis the number of points of the mesh for each dimension. The total number of points will ben_pts_dim**n_dim--deviceis the pytorch device you want to use. On lower dimensions, the CPU and GPU have similar performance.-oThe filename for your mesh.
Train the RCM and compute the corresponding Ising RBM:
python scripts/train_rcm.py -d MICKEY --variable_type Ising --num_hidden 200 --max_iter 100000 --adapt --stop_ll 0.001 --decimation --dimension 0 1 --mesh_file mesh.h5 -o RCM.h5--num_hiddenis the maximum number of hidden nodes for the final RBM.--max_iteris the maximum training iterations allowed before stopping.--adaptallows to use an adaptive learning rate strategy.--stop_llis the threshold for the exponential moving average on the test log likelihood fluctuations.--decimationallows for removal of unimportant features to improve the mapping from RCM to Ising RBM. If not set the final RBM will have exactly--num_hiddenhidden nodes.--mesh_filepath to the mesh computed at step 1.
python scripts/rcm_to_rbm.py -i RCM.h5 -o RBM.h5 --num_hiddens 100 -d MICKEY --gibbs_steps 100 --batch_size 2000 --num_chains 2000 --min_eps 0.7 --dtype float --device cuda-iis the filename for the RCM obtained at step 2-ois the filename for the new RBM initialized with the RCM.--num_hiddensThe target number of hidden nodes for the RBM. If below the final number of hidden nodes of the RCM will do nothing. Otherwise the new nodes are initialized with 0 bias and random weights.--gibbs_stepsThe number of gibbs steps performed at each gradient updates--num_chainsThe number of parallel chains used for the gradient estimation.--dtypethe dtype for the weights of the RBM.--min_epsMinimum effective population size for the Jar-RBM.
python scripts/train_rbm.py -d MICKEY --variable_type Bernoulli --use_torch --model BernoulliBernoulliJarJarRBM --filename RBM.h5 --epochs 1000 --log --dtype float --restore--use_torchloads the dataset entirely on the GPU allowing for faster processing in exchange for higher VRAM footprint--modelThe algorithm to train the RBM. Can beBernoulliBernoulliJarJarRBMorBernoulliBernoulliPCDRBM--filenamepath to the RBM you want to continue training.--epochstotal number of training epochs for the model.--logwrite metrics in a log file.--restoreMust be put to restore training. Otherwise will start a new training.
python scripts/train_rbm.py -d MICKEY --variable_type Bernoulli --use_torch --filename RBM.h5 --epochs 1000 --log --dtype float --model BernoulliBernoulliJarJarRBM --learning_rate 0.01 --num_hiddens 100 --gibbs_steps 100 --batch_size 2000 --num_chains 2000 --min_eps 0.7python scripts/ptt_sampling.py -i RBM.h5 -o sample_RBM_mickey.h5 --num_samples 2000 --target_acc_rate 0.9 --it_mcmc 1000-iis the filename of the RBM obtained at step 4 or 5.-ois the file in which to save the samples.--filename_rcmthe name of the file used to initialize the RBM at step 3. Do not set if the RBM was trained from scratch--target_acc_rateThe target acceptance rate between two consecutive machines--it_mcmcThe number of gibbs steps performed by each machine.
See rcm_analysis.ipynb for an analysis of the file obtained at step 2 and rbm_analysis.ipynb for an analysis of the files obtained at step 4 or 5, as well as the results for the PTT.