Scaling Ensemble Distribution Distillation to Many Classes with Proxy Targets

by   Max Ryabinin, et al.

Ensembles of machine learning models yield improved system performance as well as robust and interpretable uncertainty estimates; however, their inference costs may often be prohibitively high. Ensemble Distribution Distillation is an approach that allows a single model to efficiently capture both the predictive performance and uncertainty estimates of an ensemble. For classification, this is achieved by training a Dirichlet distribution over the ensemble members' output distributions via the maximum likelihood criterion. Although theoretically principled, this criterion exhibits poor convergence when applied to large-scale tasks where the number of classes is very high. In our work, we analyze this effect and show that the Dirichlet log-likelihood criterion classes with low probability induce larger gradients than high-probability classes. This forces the model to focus on the distribution of the ensemble tail-class probabilities. We propose a new training objective that minimizes the reverse KL-divergence to a Proxy-Dirichlet target derived from the ensemble. This loss resolves the gradient issues of Ensemble Distribution Distillation, as we demonstrate both theoretically and empirically on the ImageNet and WMT17 En-De datasets containing 1000 and 40,000 classes, respectively.


page 1

page 2

page 3

page 4


Reverse KL-Divergence Training of Prior Networks: Improved Uncertainty and Adversarial Robustness

Ensemble approaches for uncertainty estimation have recently been applie...

Ensemble Distribution Distillation

Ensemble of Neural Network (NN) models are known to yield improvements i...

A general framework for ensemble distribution distillation

Ensembles of neural networks have been shown to give better performance ...

Regression Prior Networks

Prior Networks are a recently developed class of models which yield inte...

Self-Distribution Distillation: Efficient Uncertainty Estimation

Deep learning is increasingly being applied in safety-critical domains. ...

Efficient Evaluation-Time Uncertainty Estimation by Improved Distillation

In this work we aim to obtain computationally-efficient uncertainty esti...

Using Small Proxy Datasets to Accelerate Hyperparameter Search

One of the biggest bottlenecks in a machine learning workflow is waiting...