Deep neural networks have achieved impressive performance, however, they tend to make over-confident predictions and poorly quantify uncertainty (Lakshminarayanan et al., 2017). It has been demonstrated that ensembles of models improve predictive performance and offer higher quality uncertainty quantification (Dietterich, 2000; Lakshminarayanan et al., 2017; Ovadia et al., 2019). A fundamental limitation of ensembles is the cost of computation and memory at evaluation time. A popular solution is to distill an ensemble of models into a single compact network by attempting to match the average predictions of the original ensemble. This idea goes back to the foundational work of Hinton et al. (2015), itself inspired by earlier ideas developed by Buciluǎ et al. (2006). While this process has led to simple and well-performing algorithms, it fails to take into account the intrinsic diversity of the predictions of the ensemble, as represented by the individual predictions of each of its members. In particular, this diversity is all the more important in tasks that hinge on the uncertainty output of the ensemble, e.g., in out-of-distribution scenarios (Lakshminarayanan et al., 2017; Ovadia et al., 2019)
. Similarly, by losing the diversity of the ensemble, this simple form of distillation makes it impossible to estimate measures of uncertainty such as model uncertainty(Depeweg et al., 2017; Malinin et al., 2019). Proper uncertainty quantification is especially crucial for safety-related tasks and applications. To overcome this limitation, Malinin et al. (2019) proposed to model the entire distribution of an ensemble using a Dirichlet distribution parametrized by a neural network, referred to as a prior network (Malinin and Gales, 2018). However, this imposes a strong parametric assumption on the distillation process.
Inspired by multi-headed architectures already widely applied in various applications (Szegedy et al., 2015; Sercu et al., 2016; Osband et al., 2016; Song and Chai, 2018), we propose a multi-headed model to distill ensembles. Our multi-headed approach—which we name Hydra
—can be seen as an interpolation between the full ensemble of models and the knowledge distillation proposed byHinton et al. (2015). Our distillation model is comprised of (1) a single body and (2) as many heads as there are members in the original ensemble. Each head is assigned to an ensemble member and tries to mimic the individual predictions of this ensemble member, as illustrated in Figure 1. The heads share the same body network whose role is to provide a common feature representation. The design of the body and the heads makes it possible to trade off the computational and memory efficiency against the fidelity with which the diversity of the ensemble is retained. An illustration of the common knowledge distillation and ensemble distillation as well as Hydra is shown in Figure 1 and a detailed methodology description is found in Section 2
. While the choices of the body and head architectures may appear like complex new hyperparameters we introduce, we will see in the experiments that we get good results by simply taking the, , layers of the original ensemble members for the body and duplicating the layer for the heads.
Summary of contributions. Firstly, we present a multi-headed approach for the ensemble knowledge distillation. The shared component keeps the model computationally and memory efficient while diversity is captured through the heads matching the individual ensemble members. Secondly, we show through experimental evaluation that Hydra outperforms existing distillation methods for both classification and regression tasks with respect to predictive test performance. Lastly, we investigate Hydra’s behaviour in terms of in-domain and out-of-distribution data and demonstrate that Hydra comes closest to the ensemble behaviour in comparison to existing distillation methods.
Novelty and significance. Ensembles of models have successfully improved predictive performance and yielded robust measures of uncertainty. However, existing distillation methods do not retain the diversity of the ensemble (beyond its average predictive behavior), or need to make strong parametric assumptions that are not applicable in regression settings. To the best of our knowledge, our approach is the first to employ a multi-headed architecture in the context of ensemble distillation. It is simple to implement, does not make the strong parametric assumptions, requires few modifications to the distilled ensemble model and works well in practice, thereby making it attractive to apply to a wide range of ensemble models and tasks.
2 Hydra: A Multi-Headed Approach
With a focus on offline distillation, our goal is to train a student network to match the predictive distribution of the teacher models, which is an ensemble of (deep) neural networks. Formally, given a dataset , we consider an ensemble of M models’ parameters and prediction outputs . For simplicity, a single data instance pair will be referred to as .
In (Hinton et al., 2015; Balan et al., 2015) distilling an ensemble of models into a single neural network is achieved by minimizing the Kullback-Leibler (KL) divergence between the student’s predictive distribution and the expected predictive distribution of an ensemble:
Hydra builds upon the approach of knowledge distillation and extends it to a multi-headed student model. Hydra is defined as a (deep) neural network with a single body and heads. For distillation, Hydra has as many heads as there are ensemble members. The distillation model is parametrized by , i.e. the body, , is shared among all heads .
In terms of number of parameters, we assume the heads to be much lighter than the shared part, so that the distillation is still meaningful. In practice, we use , layers of the original ensemble member architecture and the original final layer(s) as head. The objective is to minimize the average KL divergence between each head and corresponding ensemble member . We differentiate between two tasks, classification and regression.
For classification tasks, the ensemble of models has access to during training, with each x belonging to one of classes, i.e., . Assuming, the categorical distribution over class labels for a sample over a class is computed as:
where is a temperature re-scaling the logits. As discussed in (Hinton et al., 2015; Malinin et al., 2019) the distribution of the teacher network is often “sharp", which can limit the common support between the output distribution of the model and the target empirical distribution. Mimimizing KL-divergence between distributions with limited non-zero common support is known to be particularly difficult. To alleviate this issue, we follow the common practice (Hinton et al., 2015; Song and Chai, 2018; Lan et al., 2018) to use temperature to “heat up" both distributions and increase common support during training. At evaluation, is set to
. The soft probability distributions at a temperature of T are used to match the teacher ensemble of models by minimizing the average KL divergence between each headand ensemble model :
Compared to the objective of knowledge distillation (1), we can observe that the average over the ensemble members is pulled out of the KL. Ignoring the constant entropy terms, this objective is reduced to standard cross entropy loss:
We scale our objective by as the gradient magnitudes produced by the soft targets are scaled by . By multiplying the loss term by a factor of we ensure that the relative contributions to additional regularization losses remain roughly unchanged (Song and Chai, 2018; Lan et al., 2018).
We focus on heteroscedastic regression tasks where each ensemble memberoutputs a mean and given an input x.111In our concrete implementation, our neural network outputs the mean
and log standard deviationwhich we thereafter exponentiate. The output is modeled as for a given head and the ensemble of models are trained by minimizing the negative log-likelihood. Traditional knowledge distillation matches a single Gaussian (“student") outputting and to a mixture of Gaussians (a “teacher" ensemble):
With Hydra, each head outputs a mean
and varianceand optimizes the KL divergence between each head output and corresponding ensemble member output:
where the final line uses the fact that each KL term has an analytical solution.
Training with multi-head growth.
Hydra is trained in two phases. In the first phase, Hydra mimics knowledge distillation in that it is trained until convergence with a single head—the “Hinton head"—to match the average predictions of the ensemble. Hydra is then extended by heads, all of which initialized with the parameter values of the “Hinton head". The resulting heads are finally further trained to match the individual predictions of the ensemble members (according to objective (3)). In practice, we sometimes experienced difficulties for Hydra to converge in absence of this initialization scheme, and for the cases where different initialization worked, this two-phase training scheme typically led to overall quicker convergence.
In this section, we demonstrate that Hydra not only best matches the behavior of the teacher ensemble in terms of uncertainty quantification but also improves the predictive performance compared to existing distillation approaches, over both classification and regression tasks.
Datasets. For visualizing and explaining model uncertainty in Subsection 3.1 we used a spiral toy dataset. For classification, we used two datasets: MNIST and CIFAR-10. For evaluating MNIST, we use its test set as well as increasingly shifted data (increasingly rotated or horizontally translated images) and Fashion-MNIST. For CIFAR-10, we report performance of the test set, the cyclic translated test set and 80 different corrupted test sets as well as SVHN. The pre-processing for MNIST and CIFAR-10, as well as the generation schemes of the corrupted images, are taken from Ovadia et al. (2019). For regression, we conducted experiments on the standard regression datasets from the UCI repository (Asuncion and Newman, 2007), following the protocol of Bui et al. (2016).
For the toy dataset, we trained 10 ensemble models, each of which consists of a multi-layer perceptron (MLP) with two hidden layers of 100 units each. For MNIST and CIFAR, we each trained ensembles of 50 models. For ensemble training we used different architectures for CIFAR-10 and MNIST. We used an MLP architecture for MNIST and the ResNet-20 V1 architecture for CIFAR-10. For additional details for MNIST and CIFAR-10 models, we refer to(Ovadia et al., 2019)
and their open-source code222https://github.com/google-research/google-research/tree/master/uq_benchmark_2019. For all regression tasks, the same model is optimized: an MLP with a single hidden layer of 50 units with softplus activation for each dataset, except for case of the larger protein structure dataset, prot where 100 units were used (following Bui et al. (2016)).
Distillation setup. We compare our work with two core distillation approaches, Knowledge Distillation (Hinton et al., 2015) and Prior Networks (Malinin et al., 2019; Malinin and Gales, 2018). All baseline models have the same architecture as the ensemble for distillation. For MNIST, Hydra uses the original ensemble member architecture and adds an MLP with two hidden layers of 100 units each as head. For CIFAR-10, the original Resnet20 V1 model without the last residual block was used as body. The final distillation model reported here has one residual block per head.
We report all evaluation metrics on the test set based on the best validation loss from training. For both classification and regression, we evaluate the negative log likelihood (NLL) as well as Brier score which depends on the predictive uncertainty and model uncertainty (MU). NLL is a proper scoring rule and a popular metric for evaluating predictive uncertainty(Ovadia et al., 2019). Model uncertainty, introduced by Depeweg et al. (2017); Malinin et al. (2019), is a measure of the spread or disagreement of an ensemble based on mutual information. A detailed description of MU and exemplary visualization can be found in Subsection 3.1. For classification we additionally measure classification accuracy and the Brier score. NLL has the disadvantage of over-emphasizing tail probabilities (Quinonero-Candela et al., 2005)
. In contrast, the Brier score, a proper scoring rule that takes into account both calibration and accuracy, is not as strongly skewed as NLL(Gneiting and Raftery, 2007). For a given input-output pair with the Brier score is defined as
3.1 Uncertainty Quantification
We assess Hydra’s ability to distill uncertainty metrics from an ensemble on classification tasks with MNIST and CIFAR-10. One way to quantify uncertainty is through model uncertainty (Depeweg et al., 2017; Malinin et al., 2019) which measures the spread or disagreement of an ensemble. Model uncertainty estimates the mutual information between the categorical label y and model parameters . It can be expressed as the difference of the total uncertainty and the expected data uncertainty, where total uncertainty is the entropy of the expected predictive distribution and expected data uncertainty is the expected entropy of individual predictive distribution:
|Dataset||Knowledge distillation||Prior Networks||Hydra|
|(Hinton et al., 2015)||(Malinin et al., 2019)|
|CIFAR-10 (cyclic translation)||N/A||0.1330||0.1646|
The total uncertainty will be high whenever the model is uncertain - both in regions of severe class overlap and out-of-domain. However, for out-of-distribution data the estimates of expected data uncertainty are poor, resulting in high model uncertainty. Figure 2 visualizes the model uncertainty and its decomposition for the spiral toy dataset, with Subfigures 1(a) and 1(b) showing results for the ensemble and Hydra respectively. The results show that Hydra successfully model uncertainty and its decomposition, though with a slight decrease in scale. We observe, as expected, a low model uncertainty where classes overlap due to both high total uncertainty and expected data uncertainty and high model uncertainty where at the border of in-domain and out-of-distribution data.
Further, we evaluate Hydra with respect to in-domain and out-of-distribution behavior for models trained on each MNIST and CIFAR10. We report the average absolute difference of model uncertainty as a measure how “close" the distillation method and ensemble are matched in Table 1. Hydra outperforms for MNIST test and shifted versions of MNIST, however, Hydra performs worse than Prior Network for several CIFAR10 datasets. Looking closer at the performance of one of the worse results, CIFAR10 (cyclic translation), we plotted all evaluation metrics against the intensity of skew. In order to rule out, that the worse performance is related to model capacity, we added the results from Hydra trained with a larger head of three residual blocks. As visualized in Figure 3, Hydra matches the behaviour of the ensemble the best in terms of accuracy, Brier score and NLL. With model uncertainty, Hydra with a larger head configuration seems to improve on overall performance compared to Prior Network. Inspecting the decomposition of model uncertainty, total uncertainty and expected data uncertainty, it seems like Hydra is more capable to “mimic" the behaviour the of ensemble. However, the uncertainty scales of Prior Networks is larger than the ones of both the ensemble and Hydra, leading to an overall better model uncertainty.
3.2 Test Performance on Common Classification and Regression Benchmarks
Classification Performance on MNIST and CIFAR-10.
We investigate Hydra on two real image datasets, MNIST and CIFAR-10, using the setup described in Section 3. We report all metrics for MNIST in Table 2 and for CIFAR-10 in Table 3 as well as model capacity and efficiency in Table 5. For MNIST we can see that all distillation methods match the accuracy of the target ensemble, but Hydra outperforms both knowledge distillation and prior networks in terms of capturing the ensemble uncertainty, almost matching the ensemble predictive NLL (Hydra NLL 0.0465 matching ensemble NLL 0.0439) and Brier score (Hydra -0.9776 versus ensemble -0.9780). For the more challening CIFAR-10 dataset all distillation methods approach but do not quite match the high accuracy of the target ensemble (0.9226 ensemble accuracy). Among distillation methods Hydra has the smallest gap in terms of accuracy. All distillation methods retain a gap in NLL performance compared to the ensemble, but Hydra again has a significantly smaller NLL (0.3179 NLL) compared to Prior networks (0.4392 NLL). In-distribution model uncertainty (MU) is comparable for both Prior Networks (0.0280) and Hydra (0.0074) but quite a bit smaller compared to target ensemble MU of 0.1055, meaning it is possible to improve uncertainty quantification in all distillation methods tested.
|(Malinin et al. (2019))|
|(Hinton et al. (2015))|
|(Malinin et al. (2019))|
|(Hinton et al. (2015))|
|(head last res. block)|
Regression Performance on UCI Regression Datasets. We trained both Knowledge Distillation (Hinton et al., 2015) and Hydra on standard regression UCI datasets shown in Table 4. Here Prior Networks are not applicable because for the case of probabilistic regression we cannot take averages of distributions.333Formally, the average of two Gaussian densities is no longer Gaussian in general.
For regression Hydra outperforms knowledge distillation in terms of predictive performance (NLL) because Hydra produces a more flexible output in the form of a Gaussian mixture model with one Gaussian component per head, whereas Knowledge Distillation can produce only a single Gaussian component.
|Dataset||Ensemble||Prior Network||Knowledge distillation||Hydra (head )|
|(Malinin et al., 2019)||(Hinton et al. (2015))|
|Prior Networks||199,210 [2%]||[2%]||274,442 [2%]||[2%]|
|(Malinin et al. (2019))|
|Knowledge distillation||199,210 [2%]||[2%]||274,442 [2%]||[2%]|
|(Hinton et al. (2015))|
4 Related Work
We now review related work in both distillation and multi-headed neural architectures.
Distillation. In (Hinton et al., 2015), a “student" network—i.e. the outcome of the distillation process—is trained to match the average predictions of the “teacher" network(s). This methodology has been later successfully applied to the distillation of Bayesian ensembles (Balan et al., 2015) where ensemble members correspond to a Monte Carlo approximation to the posterior distribution.444Note that, in this paper, we focus primarily on instances of distillation where the supervision occurs at the level of the predictive probabilities, although other strategies, e.g., based on the feature representations in the networks (Ba and Caruana, 2014; Romero et al., 2014), have been proposed. A parallel line of research has focused on co-distillation, also sometimes referred to as online distillation to further reduce the overall training cost (Zhang et al., 2018; Anil et al., 2018; Lan et al., 2018). In this setting, both the student and teacher networks are learned simultaneously and this form of mutual training acts as a regularization mechanism. Also, distillation has been recently the topic of theoretical analysis to better explain its empirical success (Lopez-Paz et al., 2015; Phuong and Lampert, 2019).
Closest to our approach is the work of Lan et al. (2018). Their method consists in training multiple student models whose combined predictions induce an ensemble teacher model. While we share conceptual similarities with their work, we depart from their formulations in several ways. First, we focus on the offline ensemble setting (Hinton et al., 2015) where we start from a pre-defined ensemble whose training may be difficult to replicate inside a co-distillation process. Second, our approach follows a different goal: we consider multiple branches, or heads in our terminology, to individually match the behavior of each teacher model. We do so to preserve the diversity of the teacher ensemble which is, for instance, essential in out-of-distribution tasks (Lakshminarayanan et al., 2017; Ovadia et al., 2019). Third, our methodology has a conceivably simpler design, as reflected by our single-component objective function—the average KL divergence between each head and corresponding teacher model—and the absence of a learned gating mechanism to linearly combine the logits of the student models (Lan et al., 2018).
Multi-headed architectures. This type of architecture can be motivated by several benefits, e.g., a reduction in memory footprint due to the sharing of parameters, a speed-up of the training process, a regularization effect due to the introduction of auxiliary training objectives, or the transfer of information across different tasks. As discussed in detail by Song and Chai (2018), there exist multiple strategies to define a body of shared parameters together with heads, e.g. hierarchical sharing pattern. In this work, we concentrate on simple multi-headed architectures, where the heads will correspond to either the last-layer of the original network or small extensions thereof. While multi-headed architectures were used for online distillation in (Lan et al., 2018), we reiterate that our goal is different from this previous work in that we exploit the multiple heads to match and mirror the individual members of a pre-defined teacher ensemble.
We presented Hydra, a simple and effective approach to distillation for ensemble models. Hydra preserves diversity in ensemble member predictions and we have demonstrated on standard models that capturing this information translates into improved performance and better uncertainty quantification. While Hydra improves on previous approaches we believe that we can further improve distillation performance by leveraging techniques from fields studying sets of related learning such as meta-learning and domain adaptation.
- Large scale distributed neural network training through online distillation. preprint arXiv:1804.03235. Cited by: §4.
UCI machine learning repository. Cited by: §3.
- Do deep nets really need to be deep?. In Advances in neural information processing systems, Cited by: footnote 4.
- Bayesian dark knowledge. In Advances in Neural Information Processing Systems, pp. 3438–3446. Cited by: §2, §4.
- Model compression. In Proceedings of the 12th ACM SIGKDD international conference on Knowledge discovery and data mining, pp. 535–541. Cited by: §1.
- Deep gaussian processes for regression using approximate expectation propagation. In International Conference on Machine Learning, pp. 1472–1481. Cited by: §3, §3.
- . stat 1050, pp. 11. Cited by: §1, §3.1, §3.
Ensemble methods in machine learning.
International workshop on multiple classifier systems, pp. 1–15. Cited by: §1.
- UCI machine learning repository. University of California, Irvine, School of Information and Computer Sciences. Cited by: Table 4.
- Strictly proper scoring rules, prediction, and estimation. Journal of the American Statistical Association 102 (477), pp. 359–378. Cited by: §3.
- Distilling the knowledge in a neural network. preprint arXiv:1503.02531. Cited by: Figure 1, §1, §1, §2, §2, §3.2, Table 1, Table 2, Table 3, Table 4, Table 5, §3, §4, §4.
- Simple and scalable predictive uncertainty estimation using deep ensembles. In Advances in Neural Information Processing Systems, pp. 6402–6413. Cited by: §1, §4.
- Knowledge distillation by on-the-fly native ensemble. In Proceedings of the 32nd International Conference on Neural Information Processing Systems, pp. 7528–7538. Cited by: §2, §4, §4, §4.
- Unifying distillation and privileged information. preprint arXiv:1511.03643. Cited by: §4.
- Predictive uncertainty estimation via prior networks. In Advances in Neural Information Processing Systems, pp. 7047–7058. Cited by: §1, §3.
- Ensemble distribution distillation. preprint arXiv:1905.00076. Cited by: §1, §2, §3.1, Table 1, Table 2, Table 3, Table 4, Table 5, §3, §3.
- Deep exploration via bootstrapped dqn. In Advances in neural information processing systems, pp. 4026–4034. Cited by: §1.
- Can you trust your model’s uncertainty? evaluating predictive uncertainty under dataset shift. preprint arXiv:1906.02530. Cited by: §1, §3, §3, §3, §4.
- Towards understanding knowledge distillation. In International Conference on Machine Learning, pp. 5142–5151. Cited by: §4.
- Evaluating predictive uncertainty challenge. In Machine Learning Challenges Workshop, pp. 1–27. Cited by: §3.
- Fitnets: hints for thin deep nets. arXiv preprint arXiv:1412.6550. Cited by: footnote 4.
Very deep multilingual convolutional neural networks for lvcsr. In 2016 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), pp. 4955–4959. Cited by: §1.
- Collaborative learning for deep neural networks. In Advances in Neural Information Processing Systems, pp. 1832–1841. Cited by: §1, §2, §4.
- Going deeper with convolutions. In , pp. 1–9. Cited by: §1.
- Deep mutual learning. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 4320–4328. Cited by: §4.