Knowledge distillation (KD) is one of the most potent model compression techniques in which knowledge is transferred from a cumbersome model (teacher) to a single small model (student) [hinton2015distilling]. In general, the objective of training a smaller student network in the KD framework is formed as a linear summation of two losses: cross-entropy (CE) loss with “hard” targets, which are one-hot ground-truth vectors of the samples, and Kullback-Leibler (KL) divergence loss with the teacher’s predictions. Specifically, KL divergence loss has achieved considerable success by controlling the softness of “soft” targets via the temperature-scaling hyperparameter . Utilizing a larger value for this hyperparameter makes the softmax vectors smooth over classes. Such a re-scaled output probability vector by is called the softened probability distribution, or the softened softmax [DBLP:journals/corr/MaddisonMT16, jang2016categorical]. Recent KD has evolved to give more importance to the KL divergence loss to improve performance when balancing the objective between CE loss and KL divergence loss [hinton2015distilling, tian2019contrastive]. Hence, we focus on training a student network based solely on the KL divergence loss.
Recently, there has been an increasing demand for investigating the reasons for the superiority of KD. [yuan2020revisit, tang2020understanding]
empirically showed that the facilitation of KD is attributed to not only the privileged information on similarities among classes but also the label smoothing regularization. In some cases, the theoretical reasoning for using “soft” targets is clear. For example, in deep linear neural networks, KD not only accelerates the training convergence but also helps in reliable training[phuong2019towards]. In the case of self-distillation (SD), where the teacher model and the student model are the same, such an approach progressively restricts the number of basis functions to represent the solution [mobahi2020self]. However, there is still a lack of understanding of how the degree of softness affects the performance.
In this paper, we first investigate the characteristics of a student trained with KL divergence loss with various , both theoretically and empirically. We find that the student’s logit (i.e., an input of the softened softmax function) more closely resembles the teacher’s logit as increases, but not completely. Therefore, we design a direct logit learning scheme by replacing the KL divergence loss between the softened probability distributions of a teacher network and a student network (Figure 1(a)) with the mean squared error (MSE) loss between the student’s logit and the teacher’s logit (Figure 1(b)). Our contributions are summarized as follows:
We investigate the role of the softening hyperparameter theoretically and empirically. A large , that is, strong softening, leads to logit matching, whereas a small results in training label matching. In general, logit matching has a better generalization capacity than label matching.
We propose a direct logit matching scheme with the MSE loss and show that the KL divergence loss with any value of cannot achieve complete logit matching as much as the MSE loss. Direct training results in the best performance in our experiments.
We theoretically show that the KL divergence loss makes the model’s penultimate layer representations elongated than those of the teacher, while the MSE loss does not. We visualize the representations using the method proposed by [muller2019does].
We show that sequential distillation, the MSE loss after the KL divergence loss, can be a better strategy than direct distillation when the capacity gap between the teacher and the student is large, which contrasts [cho2019efficacy].
We observe that the KL divergence loss, with low in particular, is more efficient than the MSE loss when the data have incorrect labels (noisy label). In this situation, extreme logit matching provokes bad training, whereas the KL divergence loss mitigates this problem.
2 Related Work
2.1 Knowledge Distillation
KD has been extended to a wide range of methods. One attempted to distill not only the softened probabilities of the teacher network but also the hidden feature vector so that the student could be trained with rich information from the teacher [romero2014fitnets, zagoruyko2016paying, srinivas2018knowledge, kim2018paraphrasing, heo2019knowledge, heo2019comprehensive]. The KD approach can be leveraged to reduce the generalization errors in teacher models (i.e., self-distillation; SD) [Zhang_2019_ICCV, park2019relational] as well as model compressions. In the generative models, a generator can be compressed by distilling the latent features from a cumbersome generator [DBLP:journals/corr/abs-1902-00159].
To explain the efficacy of KD, [furlanello2018born] asserted that the maximum value of a teacher’s softmax probability was similar to weighted importance by showing that permuting all of the non-argmax elements could also improve performance. [yuan2020revisit] argued that “soft” targets served as a label smoothing regularizer rather than as a transfer of class similarity by showing that a poorly trained or smaller-size teacher model can boost performance. Recently, [tang2020understanding] modified the conjecture in [furlanello2018born] and showed that the sample was positively re-weighted by the prediction of the teacher’s logit vector.
2.2 Label Smoothing
Smoothing the label is a common method for improving the performance of deep neural networks by preventing the overconfident predictions [szegedy2016rethinking]. Label smoothing is a technique that facilitates the generalization by replacing a ground-truth one-hot vector with a weighted mixture of hard targets :
where indicates the index, and is a constant. This implicitly ensures that the model is well-calibrated [muller2019does]. Despite its improvements, [muller2019does] observed that the teacher model trained with LS improved its performance, whereas it could hurt the student’s performance. [yuan2020revisit] demonstrated that KD might be a category of LS by using the adaptive noise, i.e., KD is a label regularization method.
3 Preliminaries: KD
|Ground-truth one-hot vector|
|Number of classes in the dataset|
|Logit vector of a sample through a network|
|Logit value corresponding the -th class label, i.e.,|
|the -th value of|
|Hyperparameter of the linear combination|
|Softened probability distribution with of a sample|
|for a network|
|The -th value of a softened probability distribution,|
|Kullback-Leibler divergence loss|
|Mean squared error loss|
We denote the softened probability vector with a temperature-scaling hyperparameter for a network as , given a sample . The -th value of the softened probability vector is denoted by , where is the -th value of the logit vector , is the number of classes, and is the natural exponential function. Then, given a sample , the typical loss for a student network is a linear combination of the cross-entropy loss and the Kullback-Leibler divergence loss :
where indicates the student network, indicates the teacher network, is a one-hot label vector of a sample , and is a hyperparameter of the linear combination. For simplicity of notation, and are denoted by and , respectively. The standard choices are and [hinton2015distilling, zagoruyko2016paying].
In [hinton2015distilling], given a single sample , the gradient of with respect to is as follows:
When goes to , this gradient is simplified with the approximation, i.e., :
Here, the authors assumed the zero-mean teacher and student logit, i.e., and , and hence . This indicates that minimizing is equivalent to minimizing the mean squared error , that is, , under a sufficiently large temperature and the zero-mean logit assumption for both the teacher and the student.
However, we observe that this assumption does not seem appropriate and hinders complete understanding by ignoring the hidden term in when increases. Figure 2 describes the histograms for the magnitude of logit summations on the training dataset. The logit summation histogram from the teacher network trained with is almost zero (Figure 2(a)), whereas that from the student network trained with using the teacher’s knowledge goes far from zero as increases (Figure 2(b)). This is discussed in detail in Section 4.2.
3.1 Experimental Setup
In this paper, we used an experimental setup similar to that in [heo2019comprehensive, cho2019efficacy]: image classification on CIFAR-100 with a family of Wide-ResNet (WRN) [zagoruyko2016wide]
and ImageNet with a family of of ResNet (RN)[he2016deep]
. We used a standard PyTorch SGD optimizer with a momentum of 0.9, weight decay, and apply standard data augmentation. Other than those mentioned, the training settings from the original papers[heo2019comprehensive, cho2019efficacy] were used.
4 Relationship between and
In this section, we conduct extensive experiments and systematically break down the effects of in based on theoretical and empirical results. Then, we highlight the relationship between and . Then, we compare the models trained with and in terms of performance and penultimate layer representations. Finally, we investigate the effects of a noisy teacher on the performance according to the objective.
4.1 Hyperparameter in
We investigate the training and test accuracies according to the change in in and in (Figure 3). First, we empirically observe that the generalization error of a student model decreases as in increases. This means that “soft” targets are more efficient than “hard” targets in training a student if “soft” targets are extracted from a well-trained teacher. This result is consistent with prior studies that addressed the efficacy of “soft” targets [furlanello2018born, tang2020understanding]. Therefore, we focus on the situation where “soft” targets are used to train a student model solely, that is, , in the remainder of this paper.
When , the generalization error of the student model decreases as in increases. These consistent tendencies according to the two hyperparameters, and , are the same across various teacher-student pairs. To explain this phenomenon, we extend the gradient analysis in Section 3 without the assumption that the mean of the logit vector is zero.
Let be the number of classes in the dataset, and be the indicator function, which is 1 when the statement inside the brackets is true and 0 otherwise. Then,
Proposition 1 explains the consistent trends as follows. In the course of regularizing with sufficiently large , the student model attempts to imitate the logit distribution of the teacher model. Specifically, a larger is linked to a larger , making the logit vector of the student similar to that of the teacher (i.e., logit matching). Hence, “soft” targets are being fully used as increases. This is implemented using a handcrafted gradient (top row of Figure 3). On the other hand, when is close to 0, the gradient of does not consider the logit distributions and only identifies whether the student and the teacher share the same output (i.e., label matching), which transfers limited information. In addition, there is a scaling issue when approaches 0. As decreases, increasingly loses its quality and eventually becomes less involved in learning. The scaling problem can be easily fixed by multiplying by when is close to zero.
From this proposition, it is recommended to modify the original in Eq. (2), considering as follows:
4.2 Extensions from to
In this subsection, we focus on Eq. (5) to investigate the reason as to why the efficacy of KD is observed when is greater than 1 in the KD environment, as shown in Figure 3. Eq. (5) can be understood as a biased regression of the vector expression as follows:
where is a vector whose elements are equal to one. Furthermore, we can derive the relationship between and as follows:
In Figure 2(a), the sum of the logit values of the teacher model is almost zero. With the teacher’s logit value, is approximated as . Therefore, can make the logit mean of the student trained with depart from zero. From this analysis, it is unreasonable to assume that the student’s logit mean is zero. We empirically find that the student’s logit mean breaks the existing assumption as increases (Figure 2(b)). In summary, hinders complete logit matching by shifting the mean of the elements in the logit. In other words, as derived from Eq. (8), optimizing with sufficiently large is equivalent to optimizing with the additional regularization term , and it seems to rather hinder logit matching.
Therefore, we propose the direct logit learning objective for enhanced logit matching as follows:
Although this direct logit learning was used in [ba2013deep, urban2016deep], they did not investigate the wide range of temperature scaling and the effects of MSE in the latent space. In this respect, our work differs.
4.3 Comparison of and
|Student||Baseline||SKD hinton2015distilling||FitNets romero2014fitnets||AT zagoruyko2016paying||Jacobian srinivas2018knowledge||FT kim2018paraphrasing||AB heo2019knowledge||Overhaul heo2019comprehensive||MSE|
We empirically compared the objectives and in terms of performance gains and measured the distance between the logit distributions. Following the previous analysis, we also focused on “soft” targets in . Table 2 presents the top-1 test accuracies on CIFAR-100 according to the student learning scheme for various teacher-student pairs. The students trained with are vanilla models without a teacher. The students trained with or are trained following the KD framework without using the “hard” targets, i.e., in and , respectively. It is shown that distillation with , that is, direct logit distillation without hindering term , is the best training scheme for various teacher-student pairs. We also found the consistent improvements in ensemble distillation [hinton2015distilling]. For the ensemble distillation using MSE loss, an ensemble of logit predictions (i.e., an average of logit predictions) are used by multiple teachers. We obtained the test accuracy of WRN16-2 () when the WRN16-4, WRN-28-4, and WRN-40-6 models were used as ensemble teachers in this manner. Moreover, the model trained with has similar or better performance when compared to other existing KD methods, as described in Table 3.111We excluded the additional experiments for the replacement with MSE loss in feature-based distillation methods. It is difficult to add the MSE loss or replace the KL loss with MSE loss in the existing works because of the sensitivity to hyperparameter optimization. Their methods included various types of hyperparameters that need to be optimized for their settings.
Furthermore, to measure the distance between the student’s logit and the teacher’s logit sample by sample, we describe the probabilistic density function (pdf) from the histogram for on the CIFAR-100 training dataset (Figure 4(a)). The logit distribution of the student with a large is closer to that of the teacher than with a small when is used. Moreover, is more efficient in transferring the teacher’s information to a student than . Optimizing aligns the student’s logit with the teacher’s logit. On the other hand, when becomes significantly large, has the , and optimizing makes the student’s logit mean deviate from that of the teacher’s logit mean.
We further investigate the effect of on the penultimate layer representations (i.e., pre-logits). Based on , we can reformulate Eq. (8). Let be the penultimate representation of student from an instance , and be the weight matrix of the student’s fully connected layer. Then,
, we first find an orthonormal basis constructed from the templates (i.e., the mean of the representations of the samples within the same class) of the three selected classes (apple, aquarium fish, and baby in our experiments). Then, the penultimate layer representations are projected onto the hyperplane based on the identified orthonormal basis. WRN-28-4 () is used as a teacher, and WRN-16-2 () is used as a student on the CIFAR-100 training dataset. As shown in the first row of Figure 5, when WRN-like architectures are trained with based on ground-truth hard targets, clusters are tightened as the model’s complexity increases. As shown in the second row of Figure 5, when the student is trained with with infinite or with , both representations attempt to follow the shape of the teacher’s representations but differ in the degree of cohesion. This is because makes the pre-logits become much more widely clustered. Therefore, can shrink the representations more than along with the teacher.
4.4 Effects of a Noisy Teacher
We investigate the effects of a noisy teacher (i.e., a model poorly fitted to the training dataset) according to the objective. It is believed that the label matching ( with a small ) is more appropriate than the logit matching ( with a large or the ) under a noisy teacher. This is because label matching neglects the negative information of the outputs of an untrained teacher. Table 4
describes top-1 test accuracies on CIFAR-100, where the used teacher network (WRN-28-4) has a training accuracy of 53.77%, which is achieved in 10 epochs. When poor knowledge is distilled, the students following thelabel matching scheme performed better than the students following the logit matching scheme, and the extreme logit matching through has the worst performance. Similarly, it seems that logit matching is not suitable for large-scale tasks. Table 5
presents top-1 test accuracies on ImageNet, where the used teacher network (ResNet-152) has a training accuracy of 81.16%, which is provided in PyTorch. Even in this case, the extremelogit matching exhibits the worst performance. The utility of negative logits (i.e., negligible aspect when is small) was discussed in [hinton2015distilling].
5 Sequential Distillation
In [cho2019efficacy], the authors showed that more extensive teachers do not mean better teachers, insisting that the capacity gap between the teacher and the student is a more important factor than the teacher itself. In their results, using a medium-sized network instead of a large-scale network as a teacher can improve the performance of a small network by reducing the capacity gap between the teacher and the student. They also showed that sequential KD (large network medium network small network) is not conducive to generalization when in Eq. (2). In other words, the best approach is a direct distillation from the medium model to the small model.
Table 6 describes the test accuracies of sequential KD, where the largest model is WRN-28-4, the intermediate model is WRN-16-4, and the smallest model is WRN-16-2. Similar to the previous study, when with is used to train the small network iteratively, the direct distillation from the intermediate network to the small network is better (i.e., WRN-16-4 WRN-16-2, 74.84%) than the sequential distillation (i.e., WRN-28-4 WRN-16-4 WRN-16-2, 74.52%) and direct distillation from a large network to a small network (i.e., WRN-28-4 WRN-16-2, 74.24%). The same trend occurs in iterations.
On the other hand, we find that the medium-sized teacher can improve the performance of a smaller-scale student when and are used sequentially (the last fourth row) despite the large capacity gap between the teacher and the student. KD iterations with such a strategy might compress the model size more effectively, and hence should also be considered in future work. Furthermore, our work is the first study on the sequential distillation at the objective level, not at the architecture level such as [cho2019efficacy, mirzadeh2020improved].
|(78.88%)||() (78.76%)||()||74.52 %|
6 Robustness to Noisy Labels
In this section, we investigate how noisy labels, samples annotated with incorrect labels in the training dataset, affect the distillation ability when training a teacher network. This setting is related to the capacity for memorization and generalization. Modern deep neural networks even attempt to memorize samples perfectly [zhang2016understanding]; hence, the teacher might transfer corrupted knowledge to the student in this situation. Therefore, it is thought that logit matching might not be the best strategy when the teacher is trained using a noisy label dataset.
From this insight, we simulate the noisy label setting to evaluate the robustness on CIFAR-100 by randomly flipping a certain fraction of the labels in the training dataset following a symmetric uniform distribution.Figure 6 shows the test accuracy graphs as the loss function changes. First, we observe that a small network (WRN-16-2 (), orange dotted line) has a better generalization performance than an extensive network (WRN-28-4 (), purple dotted line) when models are trained with . This implies that a complex model can memorize the training dataset better than a simple model, but cannot generalize to the test dataset. Next, WRN-28-4 (purple dotted line) is used as the teacher model. When the noise is less than 50%, extreme logit matching (, green dotted line) and logit matching with (, blue dotted line) can mitigate the label noise problem compared with the model trained with . However, when the noise is more than 50%, these training cannot mitigate this problem because it follows corrupted knowledge more often than correct knowledge.
Interestingly, the best generalization performance is achieved when we use with . In Figure 6, the blue solid line represents the test accuracy using the rescaled loss function from the black dotted line when . As expected, logit matching might transfer the teacher’s overconfidence, even for incorrect predictions. However, the proper objective derived from both logit matching and label matching enables similar effects of label smoothing, as studied in [pmlr-v119-lukasik20a, yuan2020revisit]. Therefore, with appears to significantly mitigate the problem of noisy labels.
In this paper, we first showed the characteristics of a student trained with according to the temperature-scaling hyperparameter . As goes to 0, the trained student has the label matching property. In contrast, as goes to , the trained student has the logit matching property. Nevertheless, with a sufficiently large cannot achieve complete logit matching owing to . To achieve this goal, we proposed a direct logit learning framework using and improved the performance based on this loss function. In addition, we showed that the model trained with followed the teacher’s penultimate layer representations more than that with . We observed that sequential distillation can be a better strategy when the capacity gap between the teacher and the student is large. Furthermore, we empirically observed that, in the noisy label setting, using with near 1 mitigates the performance degradation rather than extreme logit matching, such as with or .
This work was supported by Institute of Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government (MSIT) [No.2019-0-00075, Artificial Intelligence Graduate School Program (KAIST)] and [No. 2021-0-00907, Development of Adaptive and Lightweight Edge-Collaborative Analysis Technology for Enabling Proactively Immediate Response and Rapid Learning].
Appendix A Details of formulas
a.1 Gradient of KD loss
a.2 Proof of Proposition 1
a.3 Proof of Equation 8
To prove Eq. (8), we use the bounded convergence theorem (BCT) to interchange of limit and integral. Namely, it is sufficient to prove that is bounded, where is each partial derivative. ,
Since Eq. (14) is bounded, we can utilize the BCT, i.e., . Thus,
In other ways, similar to the preliminary analysis, the authors showed that minimizing with sufficiently large is equivalent to minimizing from under zero-meaned logit assumption on both teacher and student. Therefore, from , it is easily derived that minimizing with sufficiently large is equivalent to minimizing , where