focal_calibration
Code for the paper "Calibrating Deep Neural Networks using Focal Loss"
view repo
Miscalibration – a mismatch between a model's confidence and its correctness – of Deep Neural Networks (DNNs) makes their predictions hard to rely on. Ideally, we want networks to be accurate, calibrated and confident. We show that, as opposed to the standard cross-entropy loss, focal loss (Lin et al., 2017) allows us to learn models that are already very well calibrated. When combined with temperature scaling, whilst preserving accuracy, it yields state-of-the-art calibrated models. We provide a thorough analysis of the factors causing miscalibration, and use the insights we glean from this to justify the empirically excellent performance of focal loss. To facilitate the use of focal loss in practice, we also provide a principled approach to automatically select the hyperparameter involved in the loss function. We perform extensive experiments on a variety of computer vision and NLP datasets, and with a wide variety of network architectures, and show that our approach achieves state-of-the-art accuracy and calibration in almost all cases.
READ FULL TEXT VIEW PDFCode for the paper "Calibrating Deep Neural Networks using Focal Loss"
Official PyTorch implementation of "Meta-Calibration: Meta-Learning of Model Calibration Using Differentiable Expected Calibration Error"
Deep neural networks have dominated computer vision and machine learning in recent years, and this has led to their widespread deployment in real-world systems
(Cao et al., 2018; Chen et al., 2018; Kamilaris and Prenafeta-Boldú, 2018; Ker et al., 2018; Wang et al., 2018). However, many current multi-class classification networks in particular are poorly calibrated, in the sense that the probability values that they associate with the class labels they predict overestimate the likelihoods of those class labels being correct in the real world. This is a major problem, since if networks are routinely overconfident, then downstream components cannot trust their predictions. The underlying cause is hypothesised to be that these networks’ high capacity leaves them vulnerable to overfitting on the negative log-likelihood (NLL) loss they conventionally use during training
(Guo et al., 2017).Given the importance of this problem, numerous suggestions for how to address it have been proposed. Much work has been inspired by approaches that were not originally formulated in a deep learning context, such as Platt scaling
(Platt, 1999), histogram binning (Zadrozny and Elkan, 2001), isotonic regression (Zadrozny and Elkan, 2002), and Bayesian binning and averaging (Naeini et al., 2015; Naeini and Cooper, 2016). As deep learning has become more dominant, however, various works have begun to directly target the calibration of deep networks. For example, Guo et al. (Guo et al., 2017) have popularised a modern variant of Platt scaling known as temperature scaling, which works by dividing a network’s logits by a scalar
(learnt on a validation subset) prior to performing softmax. Temperature scaling has the desirable property that it can improve the calibration of a network without in any way affecting its accuracy. However, whilst its simplicity and effectiveness have made it a popular network calibration method, it does have downsides. For example, whilst it scales the logits to reduce the network’s confidence in incorrect predictions, this also slightly reduces the network’s confidence in predictions that were correct.By contrast, (Kumar et al., 2018) initially eschew temperature scaling in favour of minimising a differentiable proxy for calibration error at training time, called Maximum Mean Calibration Error (MMCE), although they do later also use temperature scaling as a post-processing step to obtain better results than cross-entropy followed by temperature scaling (Guo et al., 2017). Separately, (Müller et al., 2019) propose training models on cross-entropy loss with label smoothing instead of one-hot labels and show that label smoothing has a very favourable effect on model calibration.
In this paper, we propose a technique for improving network calibration that works by replacing the cross-entropy loss conventionally used when training classification networks with the focal loss proposed by (Lin et al., 2017)
. We observe that unlike cross-entropy, which minimises the KL divergence between the predicted (softmax) distribution and the target distribution (one-hot encoding in classification tasks) over classes, focal loss minimises a regularised KL divergence between these two distributions, which ensures minimisation of the KL divergence whilst
increasing the entropy of the predicted distribution, thereby preventing the model from becoming overconfident. Since focal loss, as shown in §4, is dependent on a hyperparameter, , that needs to be cross-validated, we also provide a method of choosing automatically for each sample, and show that it outperforms all the baseline models.The intuition behind using focal loss is to direct the network’s attention during training towards samples for which it is currently predicting a low probability for the correct class, since trying to reduce the NLL on samples for which it is already predicting a high probability for the correct class is liable to lead to NLL overfitting, and thereby miscalibration (Guo et al., 2017). More formally, we show in §4 that focal loss can be seen as implicitly regularising the weights of the network during training by causing the gradient norms for confident samples to be lower than they would have been with cross-entropy, which we would expect to reduce overfitting and improve the network’s calibration.
Overall, we make the following contributions:
[leftmargin=*,topsep=0pt,itemsep=0pt,partopsep=0pt,parsep=0pt]
In §3, we study the link that (Guo et al., 2017) observed between miscalibration and NLL overfitting in detail, and show that the overfitting is associated with the predicted distributions for misclassified test samples becoming peakier as the optimiser tries to increase the magnitude of the network’s weights to reduce the training NLL.
In §4, we propose the use of focal loss for training better-calibrated networks, and provide both theoretical and empirical justifications for this approach. In addition, we provide a principled method of automatically choosing for each sample during training.
Finally, in §5.2, we show that a network trained using focal loss is able to improve on both the BLEU and ECE scores of baseline models trained using cross-entropy (with one-hot labels and label smoothing) on the WMT 2014 English-to-German translation dataset, thereby showing the practical impact of the calibration improvements we are able to achieve for performance on downstream tasks.
Let denote a dataset consisting of
samples from a joint distribution
, where for each sample , is the input and is the ground-truth class label. Let be the probability that a neural network with model parameters predicts for a class on a given input . The class that predicts for is computed as , and the predicted confidence as . The network is said to be perfectly calibrated when, for each sample , the confidence is equal to the model accuracy , i.e. the probability that the predicted class is correct. For instance, of all the samples to which a perfectly calibrated neural network assigns a confidence of , should be correctly predicted.A popular metric used to measure model calibration is the expected calibration error (ECE) (Naeini et al., 2015), defined as the expected absolute difference between the model’s confidence and its accuracy, i.e. . Since we only have finite samples, the ECE cannot in practice be computed using this definition. Instead, we divide the interval into equispaced bins, where the bin is the interval . Let denote the set of samples with confidences belonging to the bin. The accuracy of this bin is computed as , where is the indicator function, and and are the predicted and ground-truth labels for the sample. Similarly, the confidence of the bin is computed as , i.e. is the average confidence of all samples in the bin. The ECE can be approximated as a weighted average of the absolute difference between the accuracy and confidence of each bin:
(1) |
A similar metric, the maximum calibration error (MCE) (Naeini et al., 2015), is defined as the maximum absolute difference between the confidence and accuracy of each bin:
(2) |
The confidence values for training samples at different epochs during the NLL training of a ResNet-50 on CIFAR-10 (see §
3). Top row: reliability plots using confidence bins; bottom row: % of samples in each bin. As training progresses, the model gradually shifts all training samples to the highest confidence bin. Notably, it continues to do so even after achieving 100% training accuracy by the epoch point.A common way of visually exploring the calibration of a model is to use a reliability plot (Niculescu-Mizil and Caruana, 2005), which plots the accuracies of the various confidence bins as a bar chart (e.g. see Figure 1). Reliability plots also capture whether or not a model is under-confident or over-confident in general. For a perfectly calibrated model, the accuracy for each bin will match the confidence, and hence all of the bars will lie on a diagonal. By contrast, if most of the bars lie above the diagonal, meaning that the model is more accurate than it expects, then it is said to be under-confident, and if most of the bars lie below the diagonal, then it is said to be over-confident.
AdaECE: One disadvantage of ECE is the uniform bin width. Once a model is trained, most of the samples lie within the highest confidence bins, and hence these bins dominate the value of the ECE, as the contribution of each bin is weighted by the number of samples it contains. To mitigate this, we thus also consider another metric, which we call AdaECE (Adaptive ECE), for which bin sizes are calculated so as to evenly distribute samples between bins:
(3) |
We now discuss why high-capacity neural networks, despite achieving low classification errors on well-known datasets, tend to be miscalibrated. A key empirical observation made by (Guo et al., 2017) was that poor calibration of such networks appears to be linked to overfitting on the negative log-likelihood (NLL) during training. In this section, we further inspect this observation to provide new insights.
For the analysis, we train a high-capacity ResNet-50 network on CIFAR-10 with state-of-the-art performance settings (PyTorch-CIFAR, )
. We use Stochastic Gradient Descent (SGD), with a mini-batch of size 128, momentum of 0.9, and learning rate schedule of
for the first 150, next 100, and last 100 epochs, respectively. We minimise cross-entropy loss (a.k.a. NLL) , which, in a standard classification context, is , where is the probability assigned by the network to the correct class for the i sample. Note that the NLL is minimised when for each training sample , , whereas the classification error is minimised when for all . This indicates that even when the classification error is , the NLL can be positive, and the optimisation algorithm can still try to reduce it to by further increasing the value of for each sample. This can be empirically observed in Figure 1, where we plot reliability diagrams and percentages of samples in each confidence bin for different training epochs.To study how miscalibration occurs during training, we plot the average NLL for the train and test sets at each training epoch in Figures 2(a) and 2
(b). We also plot the average NLL and the entropy of the softmax distribution produced by the network for the correctly and incorrectly classified samples. In Figure
2(c), we plot the classification errors on the train and test sets, along with the test set ECE.Curse of misclassified samples: Figures 2(a) and 2(b) show that although the average train NLL (for both correctly and incorrectly classified training samples) broadly decreases throughout training, after the epoch (where the learning rate drops by a factor of ), there is a marked rise in the average test NLL, indicating that the network starts to overfit on average NLL. Interestingly, the increase in average test NLL is caused only by the incorrectly classified samples, as the average NLL for the correctly classified samples continues to decrease even after the epoch. We also observe that after epoch , the test set ECE rises, indicating that the network is becoming miscalibrated. This corroborates the observation in (Guo et al., 2017) that miscalibration and NLL overfitting are linked.
Peak at the wrong place: We further observe that the entropies of the softmax distributions for both the correctly and incorrectly classified test samples decrease throughout training (in other words, the distributions get peakier). This observation, coupled with the one we made above, indicates that for the wrongly classified test samples, the network gradually becomes more and more confident about its incorrect predictions.
Weight magnification: The increase in confidence of the network’s predictions can happen if the network increases the norm of its weights to increase the magnitudes of the logits. In fact, cross-entropy loss is minimised when for each training sample , , which is possible only when . Cross-entropy loss thus inherently induces this tendency of weight magnification in neural network optimisation. The promising performance of weight decay (regulating the norm of weights) on the calibration of neural networks can perhaps be explained using this. We explore this further in §4. This increase in the network’s confidence during training is one of the key causes of miscalibration.
As discussed in §3, overfitting on NLL, which is observed as the network grows more confident on all of its predictions irrespective of their correctness, is strongly related to poor calibration. One cause of this is that the cross-entropy objective minimises the difference between the softmax distribution and the ground-truth one-hot encoding over an entire mini-batch, irrespective of how well a network classifies individual samples in the mini-batch. In this work, we study an alternative loss function, popularly known as focal loss (Lin et al., 2017), that tackles this by weighting loss components generated from individual samples in a mini-batch by how well the model classifies them. For classification tasks where the target distribution is a one-hot encoding, it is defined as , where is a user-defined hyperparameter.^{1}^{1}1We note in passing that unlike cross-entropy loss, focal loss in its general form is not a proper loss function, as minimising it does not always lead to the predicted distribution being equal to the target distribution (see Appendix A for the relevant definition and a longer discussion). However, when is a one-hot encoding (as in our case, and for most classification tasks), minimising focal loss does lead to being equal to .
Why might focal loss improve calibration? We know that cross-entropy forms an upper bound on the KL-divergence between the target distribution and the predicted distribution , i.e. , so minimising cross-entropy results in minimising . Interestingly, a general form of focal loss can be shown to be an upper bound on the regularised KL-divergence, where the regulariser is the negative entropy of the predicted distribution , and the regularisation parameter is , the hyperparameter of focal loss (a proof of this can be found in Appendix A):
(4) |
The most interesting property of this upper bound is that it shows that replacing cross-entropy with focal loss has the effect of adding a maximum-entropy regulariser (Pereyra et al., 2017) to the implicit minimisation that was previously being performed. In other words, trying to minimise it means trying to minimise the KL divergence between and , whilst simultaneously trying to increase the entropy of the predicted distribution . Encouraging the predicted distribution to have higher entropy can help avoid the overconfident predictions produced by modern neural networks (see the ‘Peak at the wrong place’ paragraph of Section 3), and thereby improve calibration.
Empirical observations: To analyse the behaviour of neural networks trained on focal loss, we use the same framework as mentioned above, and train four ResNet-50 networks on CIFAR-10, one using cross-entropy loss, and three using focal loss with and . Figure 3(a) shows that the test NLL for the cross-entropy model significantly increases towards the end of training (before saturating), whereas the NLLs for the focal loss models remain low. To better understand this, we analyse the behaviour of these models for correctly and incorrectly classified samples. Figure 3(b) shows that even though the NLLs for the correctly classified samples broadly-speaking decrease over the course of training for all the models, the NLLs for the focal loss models remain consistently higher than that for the cross-entropy model throughout training, implying that the focal loss models are relatively less confident than the cross-entropy model for samples that they predict correctly. This is important, as we have already discussed that it is overconfidence that normally makes deep neural networks miscalibrated. Figure 3(c) shows that in contrast to the cross-entropy model, for which the NLL for misclassified test samples increases significantly after epoch , the rise in this value for the focal loss models is much less severe. Additionally, in Figure 3(d), we notice that the entropy of the softmax distribution for misclassified test samples is consistently (if marginally) higher for focal loss than for cross-entropy (consistent with Equation 4).
As per §3, an increase in the test NLL and a decrease in the test entropy for misclassified samples, taken together with no corresponding increase in the test NLL for the correctly classified samples, can be interpreted as the network starting to predict softmax distributions for the misclassified samples that are ever more peaky in the wrong place. Notably, our results in Figures 3(b), 3(c) and 3(d) clearly show that this effect is significantly reduced when training with focal loss rather than cross-entropy, leading to a network that is less peaky in the wrong place and better calibrated.
Theoretical justification: As mentioned previously, once a model trained using cross-entropy reaches high accuracy on the training set, the optimiser may try to further reduce the training NLL by increasing the confidences for the correctly classified samples. One way it could achieve this would be to increase the weights of the network to increase the magnitudes of the logits. In fact, this hypothesis would help to explain the observation of (Guo et al., 2017) that models trained using some form of weight decay are relatively better calibrated. To verify this, we plot the norm of the weights of the last linear layer for all four networks as a function of the training epoch (see Figure 3(e)). Notably, although the norms of the weights for the models trained on focal loss are initially higher than that for the cross-entropy model, a complete reversal in the ordering of the weight norms occurs between epochs and . In other words, as the networks start to become miscalibrated, the weight norm for the cross-entropy model also starts to become greater than those for the focal loss models. In practice, this is because focal loss, by design, starts to act as a regulariser on the network’s weights once the model has gained a certain amount of confidence in its predictions. This behaviour of focal loss can be observed even on a much simpler setup like a linear model (see Appendix B). To better understand this, we start by considering the following proposition:
For focal loss and cross-entropy , the gradients , where , is the focal loss hyperparameter, and denotes the parameters of the last linear layer. Thus if .
See Appendix C. ∎
Proposition 1 shows the relationship between the norms of the gradients of the last linear layer for focal loss and cross-entropy loss, for the same network architecture. Note that this relation depends on a function , which we plot in Figure 4(a) to understand its behaviour. It is clear that for every , there exists a (different) threshold such that for all , , and for all , . (For example, for , .) We use this insight to further explain why focal loss provides implicit weight regularisation.
Implicit weight regularisation: For a network trained using focal loss with a fixed , during the initial stages of the training, when , . This implies that the confidences of the focal loss model’s predictions will initially increase faster than they would for cross-entropy. However, as soon as crosses the threshold , falls below and reduces the size of the gradient updates made to the network weights, thereby having a regularising effect on the weights. This is why, in Figure 3(e), we find that the weight norms of the models trained with focal loss are initially higher than that for the model trained using cross-entropy. However, as training progresses, we find that the ordering of the weight norms reverses, as focal loss starts regularising the network weights. Moreover, we can draw similar insights from Figures 4(b), 4(c) and 4(d), in which we plot histograms of the gradient norms of the last linear layer (over all samples in the training set) at epochs , and , respectively. At epoch , the gradient norms for cross-entropy and focal loss are similar, but as training progresses, those for cross-entropy decrease less rapidly than those for focal loss, indicating that the gradient norms for focal loss are consistently lower than those for cross-entropy throughout training.
Finally, observe in Figure 4(a) that for higher values, the fall in is steeper. We would thus expect a greater weight regularisation effect for models that use higher values of . This explains why, of the three models that we trained using focal loss, the one with outperforms (in terms of calibration) the one with , which in turn outperforms the model with . Based on this observation, one might think that, in general, a higher value of gamma would lead to a more calibrated model. However, this is not the case, as we notice from Figure 4(a) that for , reduces to nearly for a relatively low value of (around ). As a result, using values of that are too high will cause the gradients to die (i.e. reduce to nearly ) early, at a point at which the network’s predictions remain ambiguous, thereby causing the training process to fail.
How to choose : As discussed, focal loss provides implicit entropy and weight regularisation, both of which heavily depend on the value of . Finding an appropriate is normally done using cross-validation. Also, traditionally, is fixed for all samples in the dataset. However, as shown, the regularisation effect for a sample depends on , i.e. the predicted probability for the ground truth label for the sample. It thus makes sense to choose for each sample based on the value of . To this end, we provide Proposition 2, which we use to find a solution to this problem:
Given a , for , for all , where , , and is the Lambert-W function (Corless et al., 1996). Moreover, for and , the equality holds only for and .
See Appendix C. Note that there exist multiple values of for which for all . ∎
For a given , this allows us to compute s.t. (i) ; (ii) for ; and (iii) for . This allows us to control the magnitude of the gradients for a particular sample based on the current value of , and gives us a way of obtaining an informed value of for each sample. For instance, a reasonable policy might be to choose s.t. if is small (say less than ), and otherwise. Such a policy will have the effect of making the weight updates larger for samples having a low predicted probability for the correct class and smaller for samples with a relatively higher predicted probability for the correct class.
Following the aforementioned arguments, we choose a threshold of , and use Proposition 2 to obtain a policy such that is observably greater than for and for . In particular, we use the following schedule: if , then , otherwise (note that and , refer Figure 4(a)). We find it to perform consistently well across multiple classification datasets and network architectures. Having said that, one can calculate multiple such schedules for following Proposition 2, using the intuition of having a relatively high for low values of and a relatively low for high values of .
We use multiple image and document classification datasets to verify the effectiveness of focal loss for training calibrated models. For our image classification experiments, we use CIFAR-10/100 (Krizhevsky, 2009)
and Tiny-ImageNet
(Deng et al., 2009), and for document classification, we use 20 Newsgroups (Lang, 1995) and the Stanford Sentiment Treebank (SST) (Socher et al., 2013). We provide details regarding these datasets and their train/validation/test splits in Appendix D. On CIFAR-10/100, we train ResNet-50, ResNet-110 (He et al., 2016), Wide-ResNet-26-10 (Zagoruyko and Komodakis, 2016) and DenseNet-121 (Huang et al., 2017). On Tiny-ImageNet, we train a ResNet-50 network. We train a Global Pooling Convolutional Network (Lin et al., 2014) on 20 Newsgroups (Lang, 1995) and a Tree-LSTM (Tai et al., 2015) on the SST Binary dataset. We provide implementation details for training each of these models in Appendix D. We use the following loss functions for training the above models:Baselines Along with cross-entropy loss, we compare our method against the following strong baselines (with and without temperature scaling):
[leftmargin=*]
MMCE (Maximum Mean Calibration Error) (Kumar et al., 2018): a continuous and differentiable proxy for calibration error that is normally used as a regulariser alongside cross-entropy.
Label smoothing (Müller et al., 2019) (LS): given a one-hot ground-truth label distribution and a smoothing factor
(hyperparameter), the smoothed vector
is obtained as , where and denote the elements of vectors and respectively, and is the number of classes. Instead of , the smoothed vector is now treated as the ground truth. In our experiments, we train models using smoothing factors and but find to perform better. We thus report the results obtained from LS- with .Brier Score (Brier, 1950): computed as the squared error between the predicted softmax vector and the one-hot ground truth encoding. Brier score is particularly relevant baseline for calibration as it can be decomposed into calibration and refinement (DeGroot and Fienberg, 1983; Snoek et al., 2019). Moreover, it has a distinct penalty on incorrect class probabilities.
Focal Loss (Sample-Dependent ): As mentioned in §4, we use the sample-dependent schedule FLSD-: for , and for which we find to consistently perform well across all the classification datasets and network architectures we experiment on.
In addition to the aforementioned sample-dependent approach, we also train other baselines on focal loss as well. We train models on focal loss with fixed to and . As a simplification to the sample-dependent approach, we also try using a training epoch-dependent schedule for . We describe these in more detail and report the results in Appendix E.
Dataset | Model | Cross-Entropy | Brier Loss | MMCE | LS-0.05 | FLSD-53 (Ours) | |||||
Pre T | Post T | Pre T | Post T | Pre T | Post T | Pre T | Post T | Pre T | Post T | ||
CIFAR-10 | ResNet 50 | 4.35 | 1.35(2.5) | 1.82 | 1.08(1.1) | 4.56 | 1.19(2.6) | 2.96 | 1.67(0.9) | 1.55 | 0.95(1.1) |
ResNet 110 | 4.41 | 1.09(2.8) | 2.56 | 1.25(1.2) | 5.08 | 1.42(2.8) | 2.09 | 2.09(1) | 1.87 | 1.07(1.1) | |
Wide ResNet 26-10 | 3.23 | 0.92(2.2) | 1.25 | 1.25(1) | 3.29 | 0.86(2.2) | 4.26 | 1.84(0.8) | 1.56 | 0.84(0.9) | |
DenseNet 121 | 4.52 | 1.31(2.4) | 1.53 | 1.53(1) | 5.1 | 1.61(2.5) | 1.88 | 1.82(0.9) | 1.22 | 1.22(1) | |
CIFAR-100 | ResNet 50 | 17.52 | 3.42(2.1) | 6.52 | 3.64(1.1) | 15.32 | 2.38(1.8) | 7.81 | 4.01(1.1) | 4.5 | 2.0(1.1) |
ResNet 110 | 19.05 | 4.43(2.3) | 7.88 | 4.65(1.2) | 19.14 | 3.86(2.3) | 11.02 | 5.89(1.1) | 8.56 | 4.12(1.2) | |
Wide ResNet 26-10 | 15.33 | 2.88(2.2) | 4.31 | 2.7(1.1) | 13.17 | 4.37(1.9) | 4.84 | 4.84(1) | 3.03 | 1.64(1.1) | |
DenseNet 121 | 20.98 | 4.27(2.3) | 5.17 | 2.29(1.1) | 19.13 | 3.06(2.1) | 12.89 | 7.52(1.2) | 3.73 | 1.31(1.1) | |
Tiny-ImageNet | ResNet 50 | 15.32 | 5.48(1.4) | 4.44 | 4.13(0.9) | 13.01 | 5.55(1.3) | 15.23 | 6.51(0.7) | 1.76 | 1.76(1) |
20 Newsgroups | Global Pooling CNN | 17.92 | 2.39(3.4) | 13.58 | 3.22(2.3) | 15.48 | 6.78(2.2) | 4.79 | 2.54(1.1) | 6.92 | 2.19(1.5) |
SST Binary | Tree LSTM | 7.37 | 2.62(1.8) | 9.01 | 2.79(2.5) | 5.03 | 4.02(1.5) | 4.84 | 4.11(1.2) | 9.19 | 1.83(0.7) |
Dataset | Model | Cross-Entropy | Brier Loss | MMCE | LS-0.05 | FLSD-53 (Ours) | |||||
Pre T | Post T | Pre T | Post T | Pre T | Post T | Pre T | Post T | Pre T | Post T | ||
CIFAR-10 | ResNet 50 | 4.33 | 2.14(2.5) | 1.74 | 1.23(1.1) | 4.55 | 2.16(2.6) | 3.89 | 2.92(0.9) | 1.56 | 1.26(1.1) |
ResNet 110 | 4.4 | 1.99(2.8) | 2.6 | 1.7(1.2) | 5.06 | 2.52(2.8) | 4.44 | 4.44(1) | 2.07 | 1.67(1.1) | |
Wide ResNet 26-10 | 3.23 | 1.69(2.2) | 1.7 | 1.7(1) | 3.29 | 1.6(2.2) | 4.27 | 2.44(0.8) | 1.52 | 1.38(0.9) | |
DenseNet 121 | 4.51 | 2.13(2.4) | 2.03 | 2.03(1) | 5.1 | 2.29(2.5) | 4.42 | 3.33(0.9) | 1.42 | 1.42(1) | |
CIFAR-100 | ResNet 50 | 17.52 | 3.42(2.1) | 6.52 | 3.64(1.1) | 15.32 | 2.38(1.8) | 7.81 | 4.01(1.1) | 4.5 | 2.0(1.1) |
ResNet 110 | 19.05 | 5.86(2.3) | 7.73 | 4.53(1.2) | 19.14 | 4.85(2.3) | 11.12 | 8.59(1.1) | 8.55 | 3.96(1.2) | |
Wide ResNet 26-10 | 15.33 | 2.89(2.2) | 4.22 | 2.81(1.1) | 13.16 | 4.25(1.9) | 5.1 | 5.1(1) | 2.75 | 1.63(1.1) | |
DenseNet 121 | 20.98 | 5.09(2.3) | 5.04 | 2.56(1.1) | 19.13 | 3.07(2.1) | 12.83 | 8.92(1.2) | 3.55 | 1.24(1.1) | |
Tiny-ImageNet | ResNet 50 | 15.23 | 5.41(1.4) | 4.37 | 4.07(0.9) | 13.0 | 5.56(1.3) | 15.28 | 6.29(0.7) | 1.42 | 1.42(1) |
20 Newsgroups | Global Pooling CNN | 17.91 | 2.23(3.4) | 13.57 | 3.11(2.3) | 15.21 | 6.47(2.2) | 4.39 | 2.63(1.1) | 6.92 | 2.35(1.5) |
SST Binary | Tree LSTM | 7.27 | 3.39(1.8) | 8.12 | 2.84(2.5) | 5.01 | 4.32(1.5) | 5.14 | 4.23(1.2) | 9.15 | 1.92(0.7) |
Temperature Scaling: We compute the optimal temperature using two different approaches: (a) learning the optimal temperature by optimising NLL over a validation set, and (b) performing grid search over temperature values between 0 and 10, with a step of 0.1, and choosing the temperature that minimises the validation set ECE. We find the second approach to produce stronger baselines. Since we report ECE and AdaECE as the performance metrics and grid search does not require a differentiable objective function, we directly minimise ECE over the validation set during grid search.
Metrics | CE () | LS () | FL () |
---|---|---|---|
ECE Pre T / Post T / T | |||
BLEU Pre T / Post T |
Performance Gains: We report the optimal temperatures and their corresponding ECE and AdaECE (computed using 15 bins) in Tables 1 and 2. Full results (ECE, AdaECE, MCE, NLL and test error) for all approaches are reported in Appendix E.
Firstly, for all dataset-network pairs, we obtain state-of-the-art classification accuracies (shown in Table 4 in the appendix). It is clear from Tables 1 and 2 that focal loss with sample-dependent outperforms all the baselines: cross-entropy, label smoothing, Brier score and MMCE. It broadly produces the lowest ECE and AdaECE values both before and after temperature scaling. This observation is particularly encouraging, as it indicates that a principled method of obtaining values of for focal loss can work well. Furthermore, Tables 5 and 6 in the appendix show that other focal loss based approaches are also very competitive. Finally, we observe that there are cases where ECE might be low, implying that the model is well calibrated, whereas AdaECE evaluated on the same model might be high. For example, in the case of WideResNet on CIFAR-10 for cross-entropy, the best ECE obtained after temperature scaling is , whereas AdaECE on the same model at the same temperature is .
Confident and Calibrated Models: It is interesting to note that for focal loss with sample-based (see Tables 1 and 2), the optimal temperatures are very close to 1, mostly lying between 0.9 and 1.1. This property is shown by the Brier score and the label smoothing models as well. By contrast, the optimal temperatures for the baselines (cross-entropy with hard targets and MMCE) are significantly higher, with values lying between 2.0 to 2.8. An optimal temperature close to 1 indicates that the model is innately calibrated and cannot be made significantly more calibrated by temperature scaling. Furthermore, an optimal temperature that is much greater than 1 can make the network underconfident in general, as its outputs are temperature-scaled irrespective of their correctness. We provide additional experimental and qualitative results to support this claim in Appendix F.
In order to observe the performance of focal loss for a downstream task, where a calibrated model can potentially improve the performance on the task, we conduct an experiment on machine translation with beam search. Following the setup described in the work Müller et al. (2019), we train the Transformer architecture (Vaswani et al., 2017) on the standard WMT 2014 English-to-German translation dataset. The settings used for training (like optimiser, learning rate schedule, number of training iterations, etc.) are exactly the same as mentioned in the paper (Vaswani et al., 2017). The intuition behind having machine translation as the downstream task of choice lies in the fact that in translation, the softmax vectors produced by the transformer model are directly fed into the beam search algorithm, and hence softmax vectors from a more calibrated model should ideally produce better translations and a better BLEU score.
We train three transformer models, one on cross-entropy with hard target labels, the second on cross-entropy with label smoothing (with smoothing factor ) and the third on focal loss with . In order to compare these models in terms of calibration, we report the test set ECE (%) both before and after temperature scaling in the first row of Table 3. Furthermore, to evaluate their performance on the English-to-German translation task, we also report the test set BLEU score of these models in the second row of Table 3. Finally, to study the variation of test set BLEU score and validation set ECE with temperature, we plot them against temperature for all three models in Figure 5.
It is clear from Table 3 that the model trained on focal loss outperforms its competitors on both ECE and BLEU score. The focal loss model also has an optimal temperature of 1, just like the model trained on cross-entropy with label smoothing. From Figure 5, we can see that the models obtain the highest BLEU scores at around the same temperatures at which they obtain low ECEs, thereby confirming our initial notion that a more calibrated model provides better translations. However, since the optimal temperatures are tuned on the validation set, they don’t often correspond to the best BLEU scores on the test set. On the test set, the highest BLEU scores we observe are 26.33 for cross-entropy, 26.36 for cross-entropy with label smoothing, and 26.39 for focal loss. Thus, focal loss obtains both the lowest ECE and the highest BLEU.
In this paper, we have shown that training using focal loss can yield multi-class classification networks that are more naturally calibrated than those trained using the more conventional cross-entropy loss. There are sound theoretical reasons to expect this: in particular, as we show in §4, focal loss implicitly maximises entropy while minimising the KL divergence between the predicted and the target distributions. We also show that, because of its design, it naturally regularises the weights of a network during training, reducing NLL overfitting and thereby improving calibration. Extensive experiments on a variety of computer vision (CIFAR-10/100/Tiny-ImageNet) and NLP (20 Newsgroups/SST) datasets, with a wide variety of different network architectures, show that this expectation is also borne out in practice. Our results show that in almost all cases, networks trained with focal loss are more calibrated than those trained with cross-entropy loss, label smoothing, Brier score and MMCE, whilst having similar levels of accuracy, making their predictions much easier for downstream components to trust. Finally, we verify this by showing the superior performance of focal loss on the downstream task of English-to-German translation.
This work was started whilst J Mukhoti was at FiveAI, and completed after he moved to the University of Oxford. V Kulharia is wholly funded by a Toyota Research Institute grant. A Sanyal acknowledges support from The Alan Turing Institute under the Turing Doctoral Studentship grant TU/C/000023. This work was also supported by ERC grant ERC-2012-AdG 321162-HELIOS, EPSRC grant Seebibyte EP/M013774/1, EPSRC/MURI grant EP/N019474/1, and the Royal Academy of Engineering.
2009 IEEE conference on computer vision and pattern recognition
, pages 248–255. Ieee, 2009.Predicting Good Probabilities With Supervised Learning.
In ICML, 2005.Empirical Methods in Natural Language Processing (EMNLP)
, 2014.Probabilistic Outputs for Support Vector Machines and Comparisons to Regularized Likelihood Methods.
Advances in Large Margin Classifiers, 10(3):61–74, 1999.Improved semantic representations from tree-structured long short-term memory networks.
In Association for Computational Linguistics (ACL), 2015.Obtaining calibrated probability estimates from decision trees and naive Bayesian classifiers.
In ICML, 2001.In § A, we discuss the relation between focal loss and a regularised KL divergence where the regulariser is the entropy of the predicted distribution. In § B, we discuss the regularisation effect of focal loss on a simple setup, i.e., a generalised linear model trained on a simple data distribution. In § C, we show the proofs of the two propositions formulated in the main text. We then describe all the datasets and implementation details for our experiments in § D. In § E, we discuss additional approaches for training using focal loss and also the results we get from these approaches. We also provide Top-5 accuracies of several models to possibly hint at their calibration. We further provide results on evaluating our models using various metrics other than ECE and Ada-ECE (like MCE and NLL). In § F, we provide empirical and qualitative results to show that models trained using focal loss are calibrated while maintaining their confidence on correct predictions. Finally, in § G, we provide a brief extension on our discussion about Figure 3(e) in the main paper with a plot of norms of features obtained from the last ResNet block during training.
Here we show why focal loss favours accurate but relatively less confident solutions. We show that it inherently provides a trade-off between minimizing the KL-divergence and maximizing the entropy, depending on the strength of . We use and to denote the focal loss with parameter and cross entropy between and , respectively. denotes the number of classes and denotes the ground-truth probability assigned to the -th class (similarly for ). We consider the following simple extension of focal loss:
By Bernoulli’s inequality . Note, | ||||
, | ||||
By Hölder’s inequality | ||||
We know that , thus, combining this equality with the above inequality leads to:
In the case of one-hot encoding (Delta distribution for ), focal loss would maximize (let be the ground-truth class index), the component of the entropy of corresponding to the ground-truth index. Thus, would prefer learning such that is assigned a higher value (because of the KL term) but not too high
(because of the entropy term) which eventually would avoid preferring overconfident models (as opposed to the cross-entropy loss). Experimentally, we found the solution of the cross entropy and focal loss equations, i.e. the value of the predicted probability which minimizes the loss, for various values of in a binary classification problem (i.e. ) and plotted it in Figure 6. As expected, focal loss favours a more entropic solution that is closer to . In other words, as Figure 6 shows, solutions to focal loss (Eqn 5) will always have higher entropy than that of cross entropy depending on the value of .
(5) |
The behaviour of deep neural networks is generally quite different from linear models and the problem of calibration is more pronounced in the case of deep neural networks, hence we focus on analysing the calibration of deep networks in the paper. However, to understand the effect of focal loss on a simpler setup, we also conducted experiments on a generalised linear model using a simple data distribution.
We consider a binary classification problem. The data matrix
is created by assigning each class, two normally distributed clusters such that the mean of the clusters are linearly separable. The mean of the clusters are situated on the vertices of a two-dimensional hypercube of side length 4. The standard deviation for each cluster is
and the samples are randomly linearly combined within each cluster in order to add covariance. Further, for of the data points, the labels were flipped. samples are used for training andsamples are used for testing. The model consists of a simple 2-parameter logistic regression model.
. We train this model using both cross-entropy and focal loss with .We have argued that focal loss implicitly regularizes the weights of the model by providing smaller gradients as compared to cross-entropy. This helps in calibration as, if all the weights are large, the logits are large and thus the confidence of the network is large on all test points, even on the misclassified points. When the model misclassifies, it misclassifies with a high confidence. Figure 7 shows, for a generalised linear model, that the norm of the logits and the weights of a network blows for Cross Entropy as compared to Focal Loss.
Figures 8 (b) and (c) show that running gradient descent with cross-entropy (CE) and focal loss (FL) both gives the same decision regions i.e. the weight vector points in the same region for both FL and CE. However, as we have seen that the norm of the weights is much larger for CE as compared to FL, we would expect the confidence of misclassified test points to be large for CE as compared to FL. Figure 8 (a) plots a histogram of the confidence of the misclassified points and it shows that CE misclassifies almost always with greater than confidence whereas FL misclassifies with much lower confidence.
Here we provide the proofs of both the propositions presented in the main text. While Proposition 1 helps us understand the regularization effect of focal loss, Proposition 2 provides us the values in a principled way such that it is sample-dependent. Implementing the sample-dependent is very easy as implementation of the Lambert-W function (Corless et al., 1996) is available in standard libraries (e.g. python scipy).
For focal loss and cross-entropy , the gradients , where , is the focal loss hyperparameter, and denotes the parameters of the last linear layer. Thus if .
Let be the linear layer parameters connecting the feature map to the logit
. Then, using the chain rule,
. Similarly, . The derivative of the focal loss with respect to , the softmax output of the network for the true class , takes the formin which and . It is thus straightforward to verify that if , then , which itself implies that . ∎
Given a , for , for all , where , , and is the Lambert-W function (Corless et al., 1996). Moreover, for and , the equality holds only for and .
We derive the value of for which for a given . From Proposition 4.1, we already know that
(6) |
where is focal loss, is cross entropy loss, is the probability assigned by the model to the ground-truth correct class for the sample, and
(7) |
For , if we look at the function , then we can clearly see that as , and that when . To observe the behaviour of for intermediate values of , we first take its derivative with respect to :
(8) |
In Equation 8, except when (in which case ). Thus, to observe the sign of the gradient , we focus on the term
(9) |
Dividing Equation 9 by , the sign remains unchanged and we get
(10) |
We can see that as and as (using l’Hôpital’s rule). Furthermore, is monotonically decreasing for . Thus, as the gradient monotonically decreases from a positive value at to a negative value at , we can say that first monotonically increases starting from (as ) and then monotonically decreases down to (at ). Thus, if for some threshold and for some , , then , . We now want to find a such that , . First, let and . Then:
(11) |
where and . We know that the inverse of is defined as , where is the Lambert-W function (Corless et al., 1996). Furthermore, the r.h.s. of the inequality in Equation 11 is always negative, with a minimum possible value of that occurs at . Therefore, applying the Lambert-W function to the r.h.s. will yield two real solutions (corresponding to a principal branch denoted by and a negative branch denoted by ). We first consider the solution corresponding to the negative branch (which is the smaller of the two solutions):
(12) |
If we consider the principal branch, the solution is
(13) |
which yields a negative value for that we discard. Thus Equation 12 gives the values of for which if , then . In other words, , and for any , . ∎
We use the following image and document classification datasets in our experiments:
[leftmargin=*]
CIFAR-10 (Krizhevsky, 2009): This dataset has 60,000 colour images of size , divided equally into 10 classes. We use a train/validation/test split of 45,000/5,000/10,000 images.
CIFAR-100 (Krizhevsky, 2009): This dataset has 60,000 colour images of size , divided equally into 100 classes. (Note that the images in this dataset are not the same images as in CIFAR-10.) We again use a train/validation/test split of 45,000/5,000/10,000 images.
Tiny-ImageNet (Deng et al., 2009): Tiny-ImageNet is a subset of ImageNet with 64 x 64 dimensional images, 200 classes and 500 images per class in the training set and 50 images per class in the validation set. The image dimensions of Tiny ImageNet are twice that of CIFAR-10/100 images.
20 Newsgroups (Lang, 1995): This dataset contains 20,000 news articles, categorised evenly into 20 different newsgroups based on their content. It is a popular dataset for text classification. Whilst some of the newsgroups are very related (e.g. rec.motorcycles and rec.autos), others are quite unrelated (e.g. sci.space and misc.forsale). We use a train/validation/test split of 15,098/900/3,999 documents.
Stanford Sentiment Treebank (SST) (Socher et al., 2013): This dataset contains movie reviews in the form of sentence parse trees, where each node is annotated by sentiment. We use the dataset version with binary labels, for which 6,920/872/1,821 documents are used as the training/validation/test split. In the training set, each node of a parse tree is annotated as positive, neutral or negative. At test time, the evaluation is done based on the model classification at the root node, i.e. considering the whole sentence, which contains only positive or negative sentiment.
For training networks on CIFAR-10 and CIFAR-100, we use SGD with a momentum of 0.9 as our optimiser, and train the networks for 350 epochs, with a learning rate of 0.1 for the first 150 epochs, 0.01 for the next 100 epochs, and 0.001 for the last 100 epochs. We use a training batch size of 128. Furthermore, we augment the training images by applying random crops and random horizontal flips. For Tiny-ImageNet, we train for 100 epochs with a learning rate of 0.1 for the first 40 epochs, 0.01 for the next 20 epochs and 0.001 for the last 40 epochs. We use a training batch size of 64. It should be noted that for Tiny-ImageNet, we saved 50 samples per class (i.e., a total of 10000 samples) from the training set as our own validation set to fine-tune the temperature parameter (hence, we trained on 90000 images) and we use the Tiny-ImageNet validation set as our test set.
For 20 Newsgroups, we train the Global Pooling Convolutional Network (Lin et al., 2014) using the Adam optimiser, with learning rate , and betas and . The code is a PyTorch adaptation of Ng . We used Glove word embeddings (Pennington et al., 2014) to train the network. We trained all the models for 50 epochs and used the models with the best validation accuracy.
For the SST Binary dataset, we train the Tree-LSTM (Tai et al., 2015) using the AdaGrad optimiser with a learning rate of and a weight decay of , as suggested by the authors. We used the constituency model, which considers binary parse trees of the data and trains a binary tree LSTM on them. The Glove word embeddings (Pennington et al., 2014) were also tuned for best results. The code framework we used is inspired by TreeLSTM . We trained these models for 25 epochs and used the models with the best validation accuracy.
For all our models, we use the PyTorch framework, setting any hyperparameters not explicitly mentioned to the default values used in the standard models. For MMCE, we used for all the image-classification tasks, whilst we found to perform better for document classification. A calibrated model which does not generalise well to an unseen test set is not very useful. Hence, for all the experiments, we set the training parameters in a way such that we get state-of-the-art test set accuracies on all datasets for each model.
In addition to the sample-dependent approach, we try the following focal loss approaches as well:
Focal Loss (Fixed ): We trained models on focal loss with fixed to and . We found to produce the best ECE among models trained using a fixed . This corroborates the observation we made in §4 of the main paper that should produce better results than or , as the regularising effect for is higher.
Focal Loss (Scheduled ): As a simplification to the sample-dependent approach, we also tried using a schedule for during training, as we expect the value of to increase in general for all samples over time. In particular, we report results for two different schedules: (a) Focal Loss (scheduled 5,3,2): for the first 100 epochs, for the next 150 epochs, and for the last 100 epochs, and (b) Focal Loss (scheduled 5,3,1): for the first 100 epochs, for the next 150 epochs, and for the last 100 epochs. We also tried various other schedules, but found these two to produce the best results on the validation sets.
Finally, for the sample-dependent approach, we also found the policy: Focal Loss (sample-dependent 5,3,2) with for , for and for to produce competitive results.
In Table 4 we present the classification errors on the test datasets for all the major loss functions we considered. Moreover, we also report the classification errors for the different focal loss approaches in Table 7. We also report the ECE and Ada-ECE for all the focal loss approaches in Table 5 and Table 6.
Finally, calibrated models should have a higher logit score (or softmax probability) on the correct class even when they misclassify, as compared to models which are less calibrated. Thus, intuitively, such models should have a higher Top-5 accuracy. In Table 8, we report the Top-5 accuracies for all our models on datasets where the number of classes is relatively high (i.e., on CIFAR-100 with 100 classes and Tiny-ImageNet with 200 classes). We observe focal loss with sample-dependent to produce the highest top-5 accuracies on all models trained on CIFAR-100 and the second best top-5 accuracy (only marginally below the highest accuracy) on Tiny-ImageNet.
Dataset | Model | Cross-Entropy | Brier Loss | MMCE | LS-0.05 | FL-3 | FLSc-532 | FLSD-53 |
CIFAR-10 | ResNet 50 | 4.95 | 5.0 | 4.99 | 5.29 | 5.25 | 5.63 | 4.98 |
ResNet 110 | 4.89 | 5.48 | 5.4 | 5.52 | 5.08 | 5.71 | 5.42 | |
Wide ResNet 26-10 | 3.86 | 4.08 | 3.91 | 4.2 | 4.13 | 4.46 | 4.01 | |
DenseNet 121 | 5.0 | 5.11 | 5.41 | 5.09 | 5.33 | 5.65 | 5.46 | |
CIFAR-100 | ResNet 50 | 23.3 | 23.39 | 23.2 | 23.43 | 22.75 | 23.24 | 23.22 |
ResNet 110 | 22.73 | 25.1 | 23.07 | 23.43 | 22.92 | 22.96 | 22.51 | |
Wide ResNet 26-10 | 20.7 | 20.59 | 20.73 | 21.19 | 19.69 | 20.13 | 20.11 | |
DenseNet 121 | 24.52 | 23.75 | 24.0 | 24.05 | 23.25 | 23.72 | 22.67 | |
Tiny Imagenet | ResNet 50 | 49.81 | 53.2 | 51.31 | 47.12 | 49.69 | 49.83 | 49.06 |
20 Newsgroups | Global Pooling CNN | 26.68 | 27.06 | 27.23 | 26.03 | 29.26 | 28.16 | 27.98 |
SST Binary | Tree LSTM | 12.85 | 12.85 | 11.86 | 13.23 | 12.19 | 13.07 | 12.8 |
In addition to ECE and Ada-ECE, we use various other metrics to compare the proposed methods with the baselines (i.e. cross-entropy, Brier score, MMCE and Label Smoothing). We present the test NLL % before and after temperature scaling in Tables 9 and 10, respectively. We report the test set MCE % before and after temperature scaling in Tables 11 and 12, respectively.
We use the following abbreviation to report results on different varieties of Focal Loss. FL-1 refers to Focal Loss (fixed 1), FL-2 refers to Focal Loss (fixed 2), FL-3 refers to Focal Loss (fixed 3), FLSc-531 refers to Focal Loss (scheduled 5,3,1), FLSc-532 refers to Focal Loss (scheduled 5,3,2), FLSD-532 refers to Focal Loss (sample-dependent 5,3,2) and FLSD-53 refers to Focal Loss (sample-dependent 5,3).
Dataset | Model | FL-1 | FL-2 | FL-3 | FLSc-531 | FLSc-532 | FLSD-532 | FLSD-53 | |||||||
Pre T | Post T | Pre T | Post T | Pre T | Post T | Pre T | Post T | Pre T | Post T | Pre T | Post T | Pre T | Post T | ||
CIFAR-10 | ResNet 50 | 3.42 | 1.08(1.6) | 2.36 | 0.91(1.2) | 1.48 | 1.42(1.1) | 4.06 | 1.53(1.6) | 2.97 | 1.53(1.2) | 2.52 | 0.88(1.3) | 1.55 | 0.95(1.1) |
ResNet 110 | 3.46 | 1.2(1.6) | 2.7 | 0.89(1.3) | 1.55 | 1.02(1.1) | 4.92 | 1.5(1.7) | 3.33 | 1.36(1.3) | 2.82 | 0.97(1.3) | 1.87 | 1.07(1.1) | |
Wide ResNet 26-10 | 2.69 | 1.46(1.3) | 1.42 | 1.03(1.1) | 1.69 | 0.97(0.9) | 2.81 | 0.96(1.4) | 1.82 | 1.45(1.1) | 1.31 | 0.87(1.1) | 1.56 | 0.84(0.9) | |
DenseNet 121 | 3.44 | 1.63(1.4) | 1.93 | 1.04(1.1) | 1.32 | 1.26(0.9) | 4.12 | 1.65(1.5) | 2.22 | 1.34(1.1) | 2.45 | 1.31(1.2) | 1.22 | 1.22(1) | |
CIFAR-100 | ResNet 50 | 12.86 | 2.3(1.5) | 8.61 | 2.24(1.3) | 5.13 | 1.97(1.1) | 11.63 | 2.09(1.4) | 8.47 | 2.13(1.3) | 9.09 | 1.61(1.3) | 4.5 | 2.(1.1) |
ResNet 110 | 15.08 | 4.55(1.5) | 11.57 | 3.73(1.3) | 8.64 | 3.95(1.2) | 14.99 | 4.56(1.5) | 11.2 | 3.43(1.3) | 11.74 | 3.64(1.3) | 8.56 | 4.12(1.2) | |
Wide ResNet 26-10 | 8.93 | 2.53(1.4) | 4.64 | 2.93(1.2) | 2.13 | 2.13(1) | 9.36 | 2.48(1.4) | 4.98 | 1.94(1.2) | 4.98 | 2.55(1.2) | 3.03 | 1.64(1.1) | |
DenseNet 121 | 14.24 | 2.8(1.5) | 7.9 | 2.33(1.2) | 4.15 | 1.25(1.1) | 13.05 | 2.08(1.5) | 7.63 | 1.96(1.2) | 8.14 | 2.35(1.3) | 3.73 | 1.31(1.1) | |
Tiny Imagenet | ResNet 50 | 7.61 | 3.29(1.2) | 3.02 | 3.02(1) | 1.87 | 1.87(1) | 7.77 | 3.07(1.2) | 3.62 | 2.54(1.1) | 2.81 | 2.57(1.1) | 1.76 | 1.76(1) |
20 Newsgroups | Global Pooling CNN | 15.06 | 2.14(2.6) | 12.1 | 3.22(1.6) | 8.67 | 3.51(1.5) | 13.55 | 4.32(1.7) | 12.13 | 2.47(1.8) | 12.2 | 2.39(2) | 6.92 | 2.19(1.5) |
SST Binary | Tree LSTM | 6.78 | 3.29(1.6) | 3.05 | 3.05(1) | 16.05 | 1.78(0.5) | 4.66 | 3.36(1.4) | 3.91 | 2.64(0.9) | 4.47 | 2.77(0.9) | 9.19 | 1.83(0.7) |
Dataset | Model | FL-1 | FL-2 | FL-3 | FLSc-531 | FLSc-532 | FLSD-532 | FLSD-53 | |||||||
Pre T | Post T | Pre T | Post T | Pre T | Post T | Pre T | Post T | Pre T | Post T | Pre T | Post T | Pre T | Post T | ||
CIFAR-10 | ResNet 50 | 3.42 | 1.51(1.6) | 2.37 | 1.69(1.2) | 1.95 | 1.83(1.1) | 4.06 | 2.43(1.6) | 2.95 | 2.18(1.2) | 2.5 | 1.23(1.3) | 1.56 | 1.26(1.1) |
ResNet 110 | 3.42 | 1.57(1.6) | 2.69 | 1.29(1.3) | 1.62 | 1.44(1.1) | 4.91 | 2.61(1.7) | 3.32 | 1.92(1.3) | 2.78 | 1.58(1.3) | 2.07 | 1.67(1.1) | |
Wide ResNet 26-10 | 2.7 | 1.71(1.3) | 1.64 | 1.47(1.1) | 1.84 | 1.54(0.9) | 2.75 | 1.85(1.4) | 2.04 | 1.9(1.1) | 1.68 | 1.49(1.1) | 1.52 | 1.38(0.9) | |
DenseNet 121 | 3.44 | 1.85(1.4) | 1.8 | 1.39(1.1) | 1.22 | 1.48(0.9) | 4.11 | 2.2(1.5) | 2.19 | 1.64(1.1) | 2.44 | 1.6(1.2) | 1.42 | 1.42(1) | |
CIFAR-100 | ResNet 50 | 12.86 | 2.54(1.5) | 8.55 | 2.44(1.3) | 5.08 | 2.02(1.1) | 11.58 | 2.01(1.4) | 8.41 | 2.25(1.3) | 9.08 | 1.94(1.3) | 4.39 | 2.33(1.1) |
ResNet 110 | 15.08 | 5.16(1.5) | 11.57 | 4.46(1.3) | 8.64 | 4.14(1.2) | 14.98 | 4.97(1.5) | 11.18 | 3.68(1.3) | 11.74 | 4.21(1.3) | 8.55 | 3.96(1.2) | |
Wide ResNet 26-10 | 8.93 | 2.74(1.4) | 4.65 | 2.96(1.2) | 2.08 | 2.08(1) | 9.2 | 2.52(1.4) | 5 | 2.11(1.2) | 5 | 2.58(1.2) | 2.75 | 1.63(1.1) | |
DenseNet 121 | 14.24 | 2.71(1.5) | 7.9 | 2.36(1.2) | 4.15 | 1.23(1.1) | 13.01 | 2.18(1.5) | 7.61 | 2.04(1.2) | 8.04 | 2.1(1.3) | 3.55 | 1.24(1.1) | |
Tiny Imagenet | ResNet 50 | 7.56 | 2.95(1.2) | 3.15 | 3.15(1) | 1.88 | 1.88(1) | 7.7 | 2.9(1.2) | 3.76 | 2.4(1.1) | 2.81 | 2.6(1.1) | 1.42 | 1.42(1) |
20 Newsgroups | Global Pooling CNN | 15.06 | 2.22(2.6) | 12.1 | 3.33(1.6) | 8.65 | 3.78(1.5) | 13.55 | 4.58(1.7) | 12.13 | 2.49(1.8) | 12.19 | 2.37(2) | 6.92 | 2.35(1.5) |
SST Binary | Tree LSTM | 6.27 | 4.59(1.6) | 3.69 | 3.69(1) | 16.01 | 2.16(0.5) | 4.43 | 3.57(1.4) | 3.37 | 2.46(0.9) | 4.42 | 2.96(0.9) | 9.15 | 1.92(0.7) |
Dataset | Model | FL-1 | FL-2 | FL-3 | FLSc-531 | FLSc-532 | FLSD-532 | FLSD-53 |
CIFAR-10 | ResNet 50 | 4.93 | 4.98 | 5.25 | 5.66 | 5.63 | 5.24 | 4.98 |
ResNet 110 | 4.78 | 5.06 | 5.08 | 6.13 | 5.71 | 5.19 | 5.42 | |