This project consists of an utility wrapper around PyTorch and Brevitas to specify the network, dataset and parameters to train with.
You can run the training of LeNet on the MNIST dataset as follows:
$ PYTORCH_JIT=1 python nn_benchmark/main.py --network LeNet --dataset MNIST --epochs 3Or its quantized version QuantLeNet on the CIFAR-10 dataset with bit-width of (4,4,8) (corresponding to activation, weight and input) :
$ PYTORCH_JIT=1 python nn_benchmark/main.py --network QuantLeNet --dataset CIFAR10 --epochs 3 \
$ --acq 4 --weq 4 --inq 8The results can be observed under the experiments folder.
You can then evaluate your network with the following command:
$ python nn_benchmark/main.py --network LeNet --dataset MNIST --evaluate --resume ./experiments/<your_folder>/checkpoints/best.tarThe following networks are supported:
- LeNet5
- VGG (11/13/16/19)
- MobilenetV1
Their quantized counterparts are available as well:
- QuantLeNet5
- QuantVGG (11/13/16/19)
- QuantMobilenetV1
The following datasets can be used:
- MNIST
- FASHION-MNIST
- CIFAR10
- GTSRB
I worked on the project through a virtual environment with virtualenvwrapper
and I highly recommend to do so as well. However, whether or not you are in a
virtual environment, the installation proceeds as follows:
-
For downloading and installing the source code of the project:
$ cd <directory you want to install to> $ git clone https://github.com/QDucasse/nn_benchmark $ python setup.py install
-
For downloading and installing the source code of the project in a new virtual environment:
Download of the source code & Creation of the virtual environment
$ cd <directory you want to install to> $ git clone https://github.com/QDucasse/nn_benchmark $ cd nn_benchmark $ mkvirtualenv -a . -r requirements.txt VIRTUALENV_NAME
Launch of the environment & installation of the project
$ workon VIRTUALENV_NAME $ pip install -e .
Finally, whether you chose the first or second option, you will need brevitas if you want to use quantized networks. The installation is better performed from source and can be done as follows (in your native or virtual environment):
$ git clone https://github.com/Xilinx/brevitas.git
$ cd brevitas
$ pip install -e .Quick presentation of the different modules of the project:
- Core: Core functionalities of the project such as the
CLI,Logger,PlotterandTrainer. - Extensions: Extended functionalities of the
PyTorchmodules. This package contains specific modules (e.g.TensorNorm), dataset (e.g.GTSRB) or loss functions (e.g.SquaredHinge) that can be used as drop-in replacements for theirPyTorchhomologues. - Networks: Network architectures implemented in the project that can be used with the CLI
--networkflag. If the suffixQuantis present, this means the network is quantized and the three following precisions can be specified:weight_bit_width, precision of the weights ;act_bit_widthprecision of the activation functions ;in_bit_width, input precision (this is useful to keep a higher precision at the beginning).
This project uses the following external libraries:
If installed as specified above, the requirements are stated in the requirements.txt file
and therefore automatically installed.
However, you can install each of them separately with the command (except for brevitas, please follow the installation from source provided at the end of the installation paragraph):
$ pip install <library>- Basic project structure
-
Trainer/Loggerlogic -
LeNetnetwork training onMNIST -
LeNet5,MobilenetV1,VGG11,VGG13,VGG16,VGG19network architectures - Quantized counterparts for the networks
-
PyTorchextensions withGTSRBdataset and another loss function (SqrHinge) - Extend trainer with
Plotter
All tests are written to work with nose and/or pytest. Just type pytest or
nosetests as a command line in the project. Every test file can still be launched
by executing the testfile itself.
$ python nn_benchmark/tests/chosentest.py
$ pytest