Created by Yongheng Zhao, Tolga Birdal, Haowen Deng, Federico Tombari from TUM.
See this link for original README documentation
Custom functions:
- generate capsule dataset for transfer learning,
- train beta-vae with capsules,
- decode and visualize capsules using default capsnet checkpoint
Since the default CD package is extremely buggy, we switched to a new CD package provided by chrdiller. Link: https://github.com/chrdiller/pyTorchChamferDistance
The code is based on PyTorch. It has been tested with Python 3.8, PyTorch 1.6.0, CUDA 11.0(or higher) on Ubuntu 20.04.
Install h5py for Python:
sudo apt-get install libhdf5-dev
sudo pip install h5pyTo visualize the training process in PyTorch, consider installing TensorBoard.
If you have GUI enabled, to visualize the reconstructed point cloud, consider installing Open3D.
pip3 install open3d cd dataset
bash download_shapenet_part16_catagories.shShapeNet Core with 13 categories (refered from AtlasNet.)
cd dataset
bash download_shapenet_core13_catagories.shShapeNet Core with 55 categories (refered from FoldingNet.)
cd dataset
bash download_shapenet_core55_catagories.shYou can download the pre-trained models here.
We provide an example demonstrating the basic usage in the folder 'mini_example'.
To visualize the reconstruction from latent capsules with our pre-trained model:
cd mini_example/AE
python viz_reconstruction.py --model ../../checkpoints/shapenet_part_dataset_ae_200.pthTo train a point capsule auto encoder with ShapeNetPart dataset by yourself:
cd mini_example/AE
python train_ae.pyTo train a point capsule auto encoder with another dataset:
cd apps/AE
python train_ae.py --dataset < shapenet_part, shapenet_core13, shapenet_core55 >To monitor the training process, use TensorBoard by specifying the log directory:
tensorboard --logdir logTo test the reconstruction accuracy:
python test_ae.py --dataset < > --model < >
e.g.
python test_ae.py --dataset shapenet_core13 --model ../../checkpoints/shapenet_core13_dataset_ae_230.pthTo visualize the reconstructed points:
python viz_reconstruction.py --dataset < > --model < >
e.g.
python viz_reconstruction.py --dataset shapenet_core13 --model ../../checkpoints/shapenet_core13_dataset_ae_230.pth