Skip to content

cjdsj/CNN-Model-Compression

Repository files navigation

CNN Model Compression with Pruning & Knowledge Distillation

In this project we combined weight pruning and knowledge distillation techniques to conduct model compression on ResNet50 model. The experimental task is image classification with CIFAR10 dataset. Our final result achieved a compression rate of 12.82 with 77.44% validation accuracy.

Pruning

In the first step, we pruned the pretrained ResNet50 through weight pruning. We conducted iterative pruning for 15 iterations. In each iteration we first prune the convolutional layers based on Apoz scores, then finetune the pruned model to convergence. The results of iterative pruning is shown below:
Iterative Pruning Result

Number of Neurons pruned in each layer over iterations: Number of Neurons Pruned

Knowledge Distillation

To further compress the model, we train a small student ResNet50 network through knowledge distillation to learn from the teacher network, which is the pruned model in step 1. The training loss is composed of three parts: backbone loss, intermediate layer loss, and adversarial training loss. The loss function is where are hyperparameters that can be tuned. The experimental results are shown below:
Knowledge Distillation Result

Finally, the model compression result is shown below: Final Model Compression Result

Repository Description

  • model.py: defines ResNet model architecture
  • train_base_model.py: train baseline resnet50 model
  • prune.py: contains pruning methods
  • prune_model.py: training procedure for iterative pruning
  • knowledge_distillation.py: contains KD methods and training procedures
  • model: contains best accuracy models for pruning and knowledge distillation

Example Commands

  • python train_base_model.py --save_folder=./model/base_model --model_path=./model/pretrained_resnet50.h5
  • python prune_model.py --prune_iter=15
  • python knowledge_distillation.py --root_folder=./model/pruned_model/iter15 --lambda1=0.7, --lambda2=0.3, --lambda3=0.2, --regressor_name=conv1x1
  • python evaluate_model.py --model_path=./model/base_model/model.h5
  • python plot_train_history.py --history_path=./model/prued_model/iter1/history.json

References

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •  

Languages