Deep neural networks are increasingly used in real-world applications thanks to their impressive accuracy. Nevertheless, many of these applications involve human users which necessitate high level of transparency and trust beside the accuracy. A crucial ingredient to enable trust, is to associate the automatic decision with a calibrated uncertainty.
Different techniques have been developed to obtain uncertainty estimation from deep networks including Bayesian modeling using variational approximation (Graves, 2011), expectation propagation (Hernández-Lobato & Adams, 2015), sampling (Gong et al., 2018) and non-Bayesian methods such as bootstrapping (Lakshminarayanan et al., 2017) and classification margin (Geifman & El-Yaniv, 2017).
While many of those models achieve acceptable uncertainty estimates (Lakshminarayanan et al., 2017), they are notoriously slow to train and evaluate which makes them non-viable for real-world implementation. Methods have been developed to expedite the training process. (Gal & Ghahramani, 2016; Teye et al., 2018)
cast standard deep networks as approximate Bayesian inference,(Welling & Teh, 2011)
use Langevin dynamics in tandem with stochastic gradient descent, and(Huang et al., 2017) employ the training snapshots to avoid sampling and/or multiple training runs.
All these techniques effectively improve the computational complexity of training, but remain prohibitively expensive at evaluation, essentially due to the multiple inference required for uncertainty estimates at test time. This is despite the fact that computational and memory-footprint is of high concern for real-world applications where the evaluation model needs to be deployed in products with limited computational and memory capacity.
The focus of this work is to address this issue. We devise an algorithm to efficiently obtain uncertainty estimates at evaluation time irrespective of the modelling choice. Common deep networks for epistemic uncertainty produces either samples of the posterior or a parametric approximation of it , where is the model parameters (i.e. weights and biases of the network), the training data, and the parameters of the approximating distribution. To marginalize over the parameter uncertainty, usually, samples of the parameters are obtained e.g. through bootstrapping (Lakshminarayanan et al., 2017) or sampling of (Gal & Ghahramani, 2016):
This summation is the source of evaluation memory/time complexity. In this work, we aim to train a single deep network with parameters potentially from a different parameter space that minimizes a divergence from the mean distribution of Eq 1 for all . The teacher-student setup of (Hinton et al., 2015) is suitable for this purpose.
Our goal is to optimize the parameters of a single (student) network to produce class-posterior similar to our target (teacher) distribution for both in and out-of-distribution samples. One measure is the KL-divergence:
exclude the hyperparameters of the standard distillation(Hinton et al., 2015): temperature and mixing parameter .
2.1 Target distribution as regularization
The student needs to learn a dispersed distribution as target in contrast to the teacher’s hard label (i.e. Kronecker delta pmf, ), which we assume is a “more difficult” task that requires higher model capacity. Indeed, we have empirically observed that students, with the same architecture as the teacher, tend to converge slower and are less prone to overfitting. Furthermore, (Hinton et al., 2015) and (Balan et al., 2015) noticed a lower regularization weight is needed for the student. Based on these observations, we propose the following two modifications to standard distillation.
Higher capacity students. One way to address this phenomenon is to increase the student’s capacity (w.r.t. teacher) to account for the additional complexity. That is, we assume . This can be done, for instance, by increasing the depth or width of the student network. 111Although this results into a less efficient student, it will still be eminently more efficient than evaluating multiple teachers. Also, the student, after being fully trained, can be compressed using various existing approaches (e.g. (Zhuang et al., 2018)).
Sharper target distribution. Alternatively, for each sample, we can decrease target’s entropy H() by sharpening the teacher’s distribution to a new target distribution :
gives the formulation of knowledge distillation without temperature . This can explain the improvement (Hinton et al., 2015) observed by adding the true labels to the distribution for the correctly classified examples
correctly classified examples.
2.2 Proper class-posterior distribution
Here we pose the question of whether to follow the teacher even when it makes wrong decisions (i.e. , where
). From the perspective of predictive uncertainty, we argue that it is only reasonable for a target distribution to be as “faulty” as a uniform distribution. However, a uniform distribution loses the dark knowledge in the wrong prediction, so we propose the following alternative distribution for each misclassified sample:
at its minimum makes a new target distribution with maximum mass on the correct class while still retaining maximal mass on non-correct classes and thereby dark knowledge. This complements the previous argument by explaining the observed improvement in (Hinton et al., 2015) of mixing in the true labels for the wrongly classified examples.
2.3 Robustness to out-of-distribution (OOD) samples
Ideally we want the model to have high uncertainty when it is presented with samples that are out of the training distribution, i.e., .222Note that, marginalizing over does not give since is limited to the current task excluding a “negative” class. We posit that this set includes two important subsets.
Natural set. OOD samples can come from the support of , i.e., . For instance, an image of a car is a natural OOD sample for a cat vs dog classification task. (Li & Hoiem, 2018) uses a large unlabeled student training dataset for this purpose.
Unnatural set. Unnatural OOD samples come from the rest of the space, i.e., which are important for defending against adversarial attacks. (Lakshminarayanan et al., 2017) uses adversarial training to become robust to this set of OOD samples.
The contributions of this work can be summarized as:
we recognize the regularization effect of dispersed target distributions and accordingly suggest techniques to improve the distillation process
we provide justification for the particular target distribution of standard distillation in (Hinton et al., 2015)
we propose a simple and yet effective technique for distillation of out-of-distribution predictive uncertainty
we conduct a comprehensive set of experiments and evaluations to study the aforementioned aspects
4 Related Work
Here we briefly present the recent works that address the computational efficiency of evaluating predictive uncertainty and delineate our work with respect to them.
(Hinton et al., 2015) coined knowledge distillation to summarize an ensemble. They focused on the accuracy of the student and not the uncertainty estimates. Our work sheds light on their design choices and is more elaborately designed for the purpose of uncertainty distillation.
(Li & Hoiem, 2018; Gurau et al., 2018; Balan et al., 2015) are the closest to our work, they use (Hinton et al., 2015) to distill ensemble networks, Monte Carlo sampling of dropout networks (Gal & Ghahramani, 2016) and approximate posterior samples of SGLD (Welling & Teh, 2011) respectively. They use distillation in its standard form, thus our observations and proposed modifications are complementary to those works. (Li & Hoiem, 2018) addresses the problem of OOD prediction with an unlabeled dataset, whereas we propose different and potentially complementary procedures.
(Anil et al., 2018) designs a technique to distill ensembles in an online fashion, focused on a distributed training scenario. Their goal is to match and improve the accuracy as opposed to predictive uncertainty.
Finally, (Wu et al., 2019) proposes a method to deterministically propagate uncertainty of model parameters and activations to the output layer. While this elegant approach circumvents the computational burden of sampling the parameter posterior, it achieves inferior results compared to the ensemble model of (Lakshminarayanan et al., 2017).
We use the state-of-the-art ensemble technique proposed in (Lakshminarayanan et al., 2017) as the teacher. We measure calibration of in-distribution predictive uncertainty through negative log-likelihood (NLL) and Brier score. We evaluate the robustness to OOD samples via entropy histograms. The experimental results of the student is reported as the mean and std of 5 runs unless stated otherwise. We use CIFAR10 as the main dataset, and report some results on MNIST and CIFAR100. See appendix Sec A for details.
5.1 Vanilla distillation
First, we show that vanilla distillation produces decent predictive uncertainty for the in-distribution samples, while it is significantly worse on the OOD samples, see Fig 1. Results for other network depths are in Fig 4 in the appendix. In Fig 1(b), we can see that the student is more over-confident in its predictions compared to the teacher. It is still interesting that this simple baseline without hyperparameters performs on-par with the ensemble teacher. We now further improve upon these results using our proposed techniques.
5.2 Target distribution as regularization
We have observed that the teachers quickly overfit to NLL after the first drop in learning rate (while accuracy still improves (Guo et al., 2017)
); this behavior, however, is not observed for the students. Furthermore, the convergence time of the students is far longer than the teachers – 2500 vs 85 epochs. These observations hint that the student learning process is more regularized than its teacher’s counterpart. In the following, we take measures based on this observation.
Higher capacity students. Fig 2(a) serves as a baseline for how the teacher performs for varying number of networks and network depths. In Fig 2(b) we consistently observe better NLL as the student’s depth is increased. The results for other teacher depths are shown in Fig 5 in the appendix. More interestingly, increasing the student’s capacity is more effective than increasing the teacher’s, see Fig 2(c). This can be due to the same regularization effect. This was also observed for students of depth 5 and 18, see appendix Fig 6.
Finally, all the figures crucially indicate that the improvement in student performance by increased depth is not merely because the original task demanded larger networks. That can be seen by comparing the improvements in the ensemble performance to student’s as the depth increases.
Sharpening the target distribution. Another way to address the regularization of dispersed target distributions is to lower the entropy as in (Hinton et al., 2015). Interestingly, we empirically observed the effect of diminishes as the capacity of the student is increased. Appendix Tab 1 shows that a student of depth 18 trained using a teacher of depth 5, does not significantly benefit from an increase in . For results of sharpened targets on MNIST and CIFAR100, see Appendix Fig 7 and Tab 3, respectively.
Entropy histograms of the predictions of models trained on CIFAR datasets and evaluated on OOD SVHN dataset. Student corresponds to training with sharper targets through interpolation with true delta distribution, Student-Aug uses transformations to traverse the natural manifold. The teacher uses 15 networks of depth 5(a) or depth 18(b) and the students uses the same depth as their teacher.
5.3 Proper class-posterior distribution
As we discussed, a way to improve the distillation process is to correct for the wrong predictions of the teacher ensemble. We proposed another interpretation of the weighted average between the true label and the teacher predictions. We move in Eq 4 in the range . We observed no significant difference when using the approach on CIFAR10, however, it gives small improvements on CIFAR100 (Tab 2
in the appendix). We hypothesize the reason for this is that the number of misclassified examples was too low for CIFAR10 for the change to show significant improvement and thus going to a more challenging task such as ImageNet would further signify the benefits.
5.4 Robustness to out-of-distribution samples
In Fig 2, we saw that the uncertainty estimates for in-distribution samples are on par with the ensemble, especially with increased student’s capacity while Fig 1 shows the robustness to OOD samples is far from ideal. We propose a simple approach for the natural OOD samples. Here we simply perturb the samples of the natural manifold by applying image transformations that do not violate the manifold including cropping and mirroring. In the standard case, the label for an augmented image is the teacher’s prediction for the corresponding unperturbed image. We instead propose to use the teacher’s prediction on the augmented image as the label, providing more information about the teacher during training. Fig 3 shows the intriguing improvement this simple technique brings. The noticeable improvements we get from this simple approach, highlights the promise of pursuing this direction further. Interestingly, we have found that more aggressive transformations which is usually harmful for standard training helps the teacher-student learning.
6 Final remarks
In this work we closely analyzed the distillation process of (Hinton et al., 2015) from an uncertainty estimation perspective. We shed light on their design choices which resulted into suggesting additional improvements. In the experimental part of this work we empirically studied the suggested aspects which led to many interesting observations.
Throughout all the experiments we tried to keep high-level experimental standard in our reports by cross-validated hyperparameter optimisation for baseline students. We also reported values as result of 5 different runs. Important future directions include theoretical analysis of our observations regarding the effects of distillation and its design choices and applying the techniques on larger datasets such as ImageNet to further highlight their effectiveness.
This work was supported by the Wallenberg AI, Autonomous Systems and Software Program (WASP-AI/MLX) funded by the Knut and Alice Wallenberg Foundation.
- Anil et al. (2018) Anil, R., Pereyra, G., Passos, A., Ormándi, R., Dahl, G. E., and Hinton, G. E. Large scale distributed neural network training through online distillation. In 6th International Conference on Learning Representations, ICLR 2018, Vancouver, BC, Canada, April 30 - May 3, 2018, Conference Track Proceedings, 2018. URL https://openreview.net/forum?id=rkr1UDeC-.
- Balan et al. (2015) Balan, A. K., Rathod, V., Murphy, K. P., and Welling, M. Bayesian dark knowledge. In Advances in Neural Information Processing Systems, pp. 3438–3446, 2015.
Gal & Ghahramani (2016)
Gal, Y. and Ghahramani, Z.
Dropout as a bayesian approximation: Representing model uncertainty in deep learning.In international conference on machine learning, pp. 1050–1059, 2016.
- Geifman & El-Yaniv (2017) Geifman, Y. and El-Yaniv, R. Selective classification for deep neural networks. In Advances in neural information processing systems, pp. 4878–4887, 2017.
- Gong et al. (2018) Gong, W., Li, Y., and Hernández-Lobato, J. M. Meta-learning for stochastic gradient mcmc. arXiv preprint arXiv:1806.04522, 2018.
- Graves (2011) Graves, A. Practical variational inference for neural networks. In Advances in neural information processing systems, pp. 2348–2356, 2011.
- Guo et al. (2017) Guo, C., Pleiss, G., Sun, Y., and Weinberger, K. Q. On calibration of modern neural networks. In Proceedings of the 34th International Conference on Machine Learning-Volume 70, pp. 1321–1330. JMLR. org, 2017.
- Gurau et al. (2018) Gurau, C., Bewley, A., and Posner, I. Dropout distillation for efficiently estimating model confidence. arXiv preprint arXiv:1809.10562, 2018.
- He et al. (2016) He, K., Zhang, X., Ren, S., and Sun, J. Identity mappings in deep residual networks. CoRR, abs/1603.05027, 2016. URL http://arxiv.org/abs/1603.05027.
Hernández-Lobato & Adams (2015)
Hernández-Lobato, J. M. and Adams, R.
Probabilistic backpropagation for scalable learning of bayesian neural networks.In International Conference on Machine Learning, pp. 1861–1869, 2015.
- Hinton et al. (2015) Hinton, G., Vinyals, O., and Dean, J. Distilling the knowledge in a neural network. In NIPS Deep Learning and Representation Learning Workshop, 2015. URL http://arxiv.org/abs/1503.02531.
- Huang et al. (2017) Huang, G., Li, Y., Pleiss, G., Liu, Z., Hopcroft, J. E., and Weinberger, K. Q. Snapshot ensembles: Train 1, get M for free. In 5th International Conference on Learning Representations, ICLR 2017, Toulon, France, April 24-26, 2017, Conference Track Proceedings, 2017. URL https://openreview.net/forum?id=BJYwwY9ll.
- Lakshminarayanan et al. (2017) Lakshminarayanan, B., Pritzel, A., and Blundell, C. Simple and scalable predictive uncertainty estimation using deep ensembles. In Advances in Neural Information Processing Systems, pp. 6402–6413, 2017.
- Langley (2000) Langley, P. Crafting papers on machine learning. In Langley, P. (ed.), Proceedings of the 17th International Conference on Machine Learning (ICML 2000), pp. 1207–1216, Stanford, CA, 2000. Morgan Kaufmann.
- Li & Hoiem (2018) Li, Z. and Hoiem, D. Reducing overconfident errors outside the known distribution. 2018.
Teye et al. (2018)
Teye, M., Azizpour, H., and Smith, K.
Bayesian uncertainty estimation for batch normalized deep networks.In International Conference on Machine Learning, pp. 4914–4923, 2018.
- Welling & Teh (2011) Welling, M. and Teh, Y. W. Bayesian learning via stochastic gradient langevin dynamics. In Proceedings of the 28th international conference on machine learning (ICML-11), pp. 681–688, 2011.
- Wu et al. (2019) Wu, A., Nowozin, S., Meeds, E., Turner, R. E., Hernández-Lobato, J. M., and Gaunt, A. L. Fixing variational bayes: Deterministic variational inference for bayesian neural networks. In 7th International Conference on Learning Representations, ICLR 2019, 2019.
- Zhuang et al. (2018) Zhuang, Z., Tan, M., Zhuang, B., Liu, J., Guo, Y., Wu, Q., Huang, J., and Zhu, J. Discrimination-aware channel pruning for deep neural networks. In Advances in Neural Information Processing Systems, pp. 875–886, 2018.
Appendix A Experimental Details
We evaluate our method on MNIST, CIFAR-10 and CIFAR-100 training a dense neural network on MNIST and ResNet variants for CIFAR. To simplify things, we only consider ensembles of networks where each network has the same capacity.
All models are trained on a train-validation split, where hyperparameters are optimized on the validation set based on NLL. We do not retrain on the entire training set before evaluating on the test set.
MNIST For all MNIST experiments, we use the same dense neural network architecture as proposed in (Lakshminarayanan et al., 2017)
. That is, three hidden layers with 200 units per layer, ReLU activations and batch normalization. Both the student and the teacher is trained using the Adam optimizer. Each network of the teacher is trained for 10 epochs with a fixed learning rate of 0.001 and a batch size of 1000. The students are trained for 600 epochs with a fixed learning rate of 0.002 and a batch size of 64.
CIFAR For all CIFAR-10 and CIFAR-100 experiments, we use the ResNet version proposed by (He et al., 2016)
. We use ResNet models of varying depth 5(ResNet32), 9(ResNet56), 18(ResNet110), etc. We train these models using the Momentum optimizer with a batch size of 128 and a learning rate of 0.1. The teacher networks overfit quickly to NLL after the first drop in learning rate. We drop the learning rate at epoch 82 and do early stopping at epoch 85 before the validation NLL degrades. The students can be trained for longer without overfitting to NLL. We use 2500 epochs and the learning rate is reduced by a factor of 10 at epoch 2000, 2100, 2300. For data augmentation we use padding, random cropping and horizontal flips. As the baseline, the label for each augmented image is the prediction of the teacher on the corresponding original image. For the improved augmentation technique, the label for each augmented image is the prediction of the teacher for that particular augmented image.
Appendix B Additional Figures and Tables
. “Ref”, refers to the baseline case where no is used.
|Student||18||0.0||22.92 0.15||0.8187 0.0060||0.00320 0.00002|
|Student||18||0.1||23.24 0.29||0.8257 0.0049||0.00322 0.00002|
|Student||27||0.0||22.16 0.15||0.7856 0.0070||0.00310 0.00003|
|Student||27||0.1||22.26 0.15||0.7923 0.0044||0.00311 0.00002|