It seems like the procedure is the following:
- Train teacher network (usually big and slow, for example CNN) on a training dataset. Let the number of classes be
N. - Select a subset of training examples, transfer dataset, (or use the full training dataset) and run it through teacher model. Save its logits (outputs before softmax),
logits_t^i, for each example,i._tstands for teacher,dim(logits_t^i) = N. - Modify transfer dataset labels such that
y_d = [y, softmax(logits_t/T)],_dstands for distilled. - Define student model (usually small and fast, for example MLP). The number of outputs should be the same as the number of classes for teacher model,
dim(logits_s) = N, subscript_smeans student. - Modify student model by adding one more layer, which will generate additional output to match logits of teacher model. Now the output of student model is
output_d = [softmax(logits_1), softmax(logits_2/T)], whereTis a free parameter called temperature. Note thatdim(output_d) = 2N,logits_2will correponds tologits_t. - Define modified loss function as
L_d = lambda * l(y_true, y_pred) + l(y_soft, y_pred_soft), wherel()is a cross entropy function. - Train distilled model on the modified transfer dataset.
- Predictions made by student model are extracted as the first half of its outputs.
Dependencies: Keras, Tensorflow, Numpy
- Train teacher model.
CNN:
python train.py --file data/matlab/emnist-letters.mat --model cnn
or perceptron:
python train.py --file data/matlab/emnist-letters.mat --model mlp
- Train student network with knowledge distillation:
python train.py --file data/matlab/emnist-letters.mat --model student --teacher bin/10cnn_32_128_1model.h5
EMNIST-letters dataset was used for experiments (26 classes of hand-written letters of english alphabet)
As a teacher network a simple cnn with 3378970 parameters (2 conv layers with 64 and 128 filters each, 1024 neurons on fully-connected layer) was trained for 26 epochs and was early stopped on plateau. Its validation accuracy was 94.4%
As a student network a 1-layer perceptron with 512 hidden units and 415258 total parameters was used (8 times smaller than teacher network). First it was trained alone for 50 epochs, val acc was 91.6%.
Knowledge distillation approach was used with different combinations of temperature and lambda parameters. Best performance was achieved with temp=10, lambda=0.5. Student network trained that way for 50 epochs got val acc of 92.2%.
So, the accuracy increase is less than 1% comparing to classicaly trained perceptron. But still we got some improvement. Actually all reports that people did, show similar results on different tasks: 1-2% quality increase. So we may say that reported results were reproduced on emnist-letters dataset.
Knowledge distillation parameters (temperature and lambda) must be tuned for each specific task. To get better accuracy gain additional similar techniques may be tested, e.g. deep mutual leraning or fitnets.
