Margin-Based Regularization and Selective Sampling in Deep Neural Networks

09/13/2020 ∙ by Berry Weinstein, et al. ∙ IDC Herzliya 0

We derive a new margin-based regularization formulation, termed multi-margin regularization (MMR), for deep neural networks (DNNs). The MMR is inspired by principles that were applied in margin analysis of shallow linear classifiers, e.g., support vector machine (SVM). Unlike SVM, MMR is continuously scaled by the radius of the bounding sphere (i.e., the maximal norm of the feature vector in the data), which is constantly changing during training. We empirically demonstrate that by a simple supplement to the loss function, our method achieves better results on various classification tasks across domains. Using the same concept, we also derive a selective sampling scheme and demonstrate accelerated training of DNNs by selecting samples according to a minimal margin score (MMS). This score measures the minimal amount of displacement an input should undergo until its predicted classification is switched. We evaluate our proposed methods on three image classification tasks and six language text classification tasks. Specifically, we show improved empirical results on CIFAR10, CIFAR100 and ImageNet using state-of-the-art convolutional neural networks (CNNs) and BERT-BASE architecture for the MNLI, QQP, QNLI, MRPC, SST-2 and RTE benchmarks.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Over the last decade, deep neural networks (DNNs) have become the machine learning method of choice in a variety of applications, demonstrating outstanding performance, often close to or above the human-level. Despite their success, some researchers have shown that neural networks can generalize poorly even with small data transformations azulay2018deep as well as overfit to arbitrarily corrupted data zhang2017multi. Additionally, problems such as adversarial examples szegedy2013intriguing; goodfellow2014explaining

, which cause neural networks to misclassify slightly perturbed input data, can be a source of concern in real-world deployment of models. These challenges raise the question as to whether properties that enabled classical machine learning algorithms to overcome these problems can be useful in helping DNNs resolve similar problems. Specifically,

schapire1998boosting introduced margin theory to explain boosting resistance to over-fitting. Furthermore, the large margin principle, i.e., maximizing the smallest distance from the instances to the classification boundary in the feature space, has played an important role in theoretical analysis of generalization, and helped to achieve remarkable practical results cortes1995support, as well as robustness for input perturbations bousquet2001algorithmic on unseen data. Can the application of the large margin principle in DNNs lead to similar results?

Although computation of the actual margin in the input space of DNNs is intractable, studies show that the widely used cross-entropy loss function is by itself a proxy for converging to the maximal margin soudry2018implicit. To date, this was only demonstrated for linear models that, similarly to SVM, have a theoretical guarantee for maximal margin convergence rosset2004margin. No such assurance for non-linear DNNs, their being being highly non-convex, has been offered.

Recently, jiang2019predicting developed a measure for predicting the generalization gap 111The difference in accuracy between training and testing performance. in DNNs that leverages margin distribution 222The distribution of distances to the decision boundaries. garg2002generalization as a more robust assessment for the margin notion in DNNs. In their work, they also point out that this measure can be used as an auxiliary loss function to achieve better generalization.

In the present study, we extend the aforementioned ideas and present a novel regularization term, which we denote as Multi-Margin Regularization (MMR), which can be added to any existing loss function in DNNs. We derive the regularization term starting from the binary case of large margin classification and generalize it to the multi-class case. This regularization term aims at increasing the margin induced by classifiers attained from the true class and its most competitive class. By summing over the margin distribution we compensate for class imbalance in the regularization term. Furthermore, due to the dynamic nature of feature space representation when training neural networks, we scale our formulation by the ever-changing maximal norm of the samples in the feature space, .

We empirically show that applying this regulizer on the output layer of various DNNs, in different classification tasks and from different domains, is sufficient to obtain a substantial improvement in accuracy. In particular, we achieve valuable accuracy improvement in numerous image and text classification tasks, including CIFAR10, CIFAR100, ImageNet, MNLI, QQP and more.

In fact, our contribution is twofold. Alongside improving generalization performance using a new regularization scheme, we leverage the large margin principle to improve convergence during training using a selective-sampling scheme and assisted by a measure that we call the Minimal Margin Score (MMS). Essentially, MMS measures the distance to the decision boundary of the two most competitive predicted labels. This measure, in turn, is used to select, at the back-propagation pass, only those instances that accelerate convergence, thus speeding up the entire training process. It is worth noting that our selection criterion is based on computations that are an integral part of the forward pass, thus taking advantage of the ”cheaper” inference computations. Lastly, we empirically show that using the MMS selection scheme with a faster learning-rate regime can improve, to a large extent, the convergence process.

1.1 Previous Approaches

The large margin principle has proven to be fundamentally important in the history of machine learning. While most of the efforts revolved around binary classification, extensions to multi-class classification were also suggested, e.g., multi-class perceptron (see Kesler’s construction, 

duda1973pattern), multi-class SVM vapnik1998statistical and multi-class margin distribution zhang2017multi. Margin analysis have also been shown to correlate with better generalization properties schapire1998boosting. Of particular interest to our study is the mistake-bound for multi-class linear separability that scales with , where is the maximal norm of the samples in the feature space, and is the margin crammer2003ultraconservative.

Computing the actual margin in DNNs, though, is intractable. soudry2018implicit proved that cross-entropy loss in linear DNNs, together with stochastic gradient descent (SGD) optimization, converges to a maximal margin solution, but it cannot ensure a maximal margin solution in nonlinear DNNs. Sun2015LargeMD affirmed that cross-entropy alone is not enough to achieve the maximal margin in DNNs and that an additional regularization term is needed.

Several works addressed the large margin principle in DNNs. elsayed2018large presented a multi-class linear approximation of the margin as an alternative loss function. They applied their margin-based loss at each and every layer of the neural network. Moreover, their method required a second order derivative computation due to the presence of first order gradients in the loss function itself. Explicit computation of the second order gradients for each layer of the neural network, however, can be quite expensive, especially when DNNs are getting wider and deeper. To address this limitation, they used a first order linear approximation to deploy their loss function more effectively. Later, jiang2019predicting presented a margin-based measure that strongly correlates with the generalization gap in DNNs. Essentially, they measured the difference between the training and the test performances of a neural network using statistics of the marginal distribution garg2002generalization. sokolic2017robust used the input layer to approximate the margin via the Jacobian matrix of the network and showed that maximizing their approximations leads to a better generalization. In contrast, we show that applying our margin-based regularization to the output layer alone achieves substantial improvements.

In addition to better generalization, we show that the large margin principle can also be used to accelerate the training of DNNs. Accelerating the training process is a long-standing challenge that has already been addressed by quite a few authors bengio2008adaptive; salimans2016weight; goyal2017accurate

. Specifically, we seek to highlight faster convergence via selective sampling. To date, the most notable sample selection approach is probably hard negative mining

schroff2015facenet, where samples are selected by their loss values. The underlying assumption is that samples with higher losses have a significant impact on the model. Recent works employ selection schemes that examine the importance of the samples  alain2015variance; loshchilov2015online

. During training, the samples are selected based on their gradient norm, which in turn leads to a variance reduction in the stochastic gradients; see also 

katharopoulos2018not. Our selection method, though, utilizes uncertainty sampling, where the selection criterion is the proximity to the decision boundary, and we use the MMS measure to score the examples.

2 Margin Analysis for Binary and multi-class Classification

Consider a classification problem with two classes . We denote by the input space. Let be a linear classifier, where and

The classifier is trained using a set of examples where each example is sampled identically and independently from an unknown distribution over . The goal is to classify correctly new samples drawn from .

Denote by the (linear) decision boundary of the classifier

(1)

The geometric distance of a point from is given by

(2)

For a linearly separable training set, there exist numerous consistent classifiers, i.e., classifiers that classify all examples correctly. Better generalization, however, is achieved by selecting the classifier that maximizes the margin ,

This optimization is redundant with the length of and . Imposing removes this redundancy and results in the following equivalent minimization problem cortes1995support:

To handle noisy and linearly inseparable data, the set of linear constraints can be relaxed and substituted by the hinge loss,

(3)

The left term in Formula 3 is the regularization component and it promotes increasing of the margin between the data points and the decision boundary. The right term of the formula is the empirical risk component, imposing correct classifications on the training samples. The two terms employ two complementary forces; the former improves the generalization capability while the latter ensures the classification will be carried out correctly.

Next, we extend the large margin principle to the multi-class case. Let us assume we have a classification problem with classes, , and a set of training samples: . We now assign a score to each class: . For a linear classification, the score of point is:

The predicted class is chosen by the maximal score attained over all classes,

For any two classes, , the decision boundary between these classes is given by (see Figure 1):

Denoting and , the decision boundary can be rewritten as:

which is similar to the binary case in Equation 1 where replaces and replaces . Similarly to Equation 2, the geometric distance of a point from is

(4)
Figure 1: Illustrative example of a bi-class decision boundary.

For point , denote by the score for the true class and by the maximal score attained for the non-true classes, i.e., . Class is the competitive class vis-à-vis . The boundary decision between and its competitive class is whose geometric distance to is

(5)

Note that is non-negative if the classification is correct ( ) and negative otherwise.

For the multi-class case, Equation 3 can be generalized to the following optimization problem,

(6)

where the optimization is over , and . Here too, the left-hand term is the regularization penalty while the right-hand term represents the empirical risk with a hinge loss. The regularization term aims to increase the margin between the true class and its competitive class. Note, though, that the summation is over the margin distribution ( is the instance index). If the instances are evenly distributed over the classes, then this is equivalent to summation over the classes. Otherwise, this summation compensates for class imbalance in the regularization term.

3 Large Margin in DNNs

Applying the above scheme directly to DNNs poses several problems. First, these networks employ a non-linear mapping from the input space into a representation space: , where are the network parameters. The vector can be interpreted as a feature vector based on which the last layer in a DNN calculates the scores for each class via a fully-connected layer, . Maximizing the margin in the input space , as suggested in sokolic2017robust, requires back-propagating derivatives downstream the network up to the input layer, and calculating distances to the boundary up to the first order of approximation. In highly non-linear mappings, this approximation loses accuracy very fast as we move away from the decision boundary. Therefore, we apply the large margin principle in the last layer, where the distances to the decision boundary are Euclidean in the feature space :

(7)

The second problem stems from the fact that in Equation 5 the input space is fixed along the course of training while the feature space in Equation 7 is constantly changing. Accordingly, maximizing the margins in Equation 7 can be trivially attained by scaling up the space . Therefore, the feature space must be constrained. In our scheme, we divide Equation 7 by , the maximal norm of the samples in the feature space, of the current batch. This ensures that scaling up the feature space will not increase the distance in a free manner. The proposed formulation is translated, similarly to Equation 6, into the following optimization problem

(8)

where

denotes the margin regularization term, and is the empirical risk term. While for SVM, hinge loss is commonly used, in DNNs the common practice is to use cross-entropy

where is the probability of the true label

obtained from the network after the softmax layer:

Similarly to hinge loss, cross-entropy will strive for correct classification while the regularization term will maximize the margin. For the rest of this paper we denote as the multi-margin megularization (MMR).

Note that the regularization term in this scheme is different from the weight decay commonly applied in deep networks. First, here, the minimization is applied over the differences of: . Additionally, the regularization term is multiplied by the . Lastly, the regularization term is implemented only at the last layer.

4 Accelerating Training Using Minimal Margin Score Selection

We continue to leverage the principle of large margin in neural networks to address the computational limitations in real-world applications, specifically in the selective sampling scheme. We show that by selecting samples that are closer to the margin in the multi-class setting during the forward pass, we achieve better convergence and speed-up during the training process. To this end, we evaluate our suggested method on CIFAR10 and CIFAR100 using ResNet-44 he2016deep and WRN-28-10 zagoruyko2016wide architectures. To demonstrate the effectiveness of our selection more vigorously, we apply a faster learning-rate (LR) regime than those suggested in the original papers.

In principle, our selection method is based on the evaluation of the minimal amount of displacement a training sample should undergo until its predicted classification is switched. We call this measure the minimal margin score (MMS). This measure depends on the highest and the second highest scores achieved per sample. Similarly to our margin-based regularization, we apply our measure only to the output layer, and calculate it linearly with respect to the input of the last layer. Additionally, unlike  jiang2019predicting, we do not take into consideration the true label, i.e., our measure is calculated based solely on the highest and the second highest neural network scores.

As shown in Figure 1, a multi-class classification problem is composed of three classes, Green, Red, and Blue, along with three linear projections, and , respectively. The query point is marked by an empty black circle. The highest scores of the query point are and (assuming all biases are 0’s), where and are negative (not marked). Since the two highest scores are for the Green and Red classes, the distance of the query point to the decision boundary between these two classes is . The magnitude of is the MMS of this query point.

Formally, let be a large set of samples and be the input to the last layer of the neural network. Assume we have a classification problem with classes. At the last layer, the classifier consists of linear functions: for where is a linear mapping . Denote the sorted scores of by , where and . The classifier gives the highest score and , the second highest score. The decision boundary between classes and is defined as:

Using this definition, the confidence of the predicted label of point is determined by the distance of to the decision boundary . Following Equation 7, it is easy to show that

(9)

The distance is the MMS of point . The larger , the more confident we are about the predicted label. Conversely, the smaller , the less confident we are about the predicted label . Therefore, can serve as a confidence measure for the predicted labels. Accordingly, the best points to select for the back-propagation step are the points whose MMS are the smallest. Note that in contrast to the MMR, in this case we do not have to normalize the distance with because this normalization will not change the order of , and thus the set of selected points will remain unchanged.

Our implementation consists of a generic, yet simple, online selective sampling method, applied at the beginning of each training step. Specifically, at each training step, we first apply a forward pass on a batch of points of size , and obtain their respective scores.We then calculate their respective MMS measures, and select the samples () whose MMS measures are the smallest. The resulting batch of size , in turn, is used for training the network. The selection process is repeated every training step, thus potentially selecting a new batch of points for training. The MMS-based training procedure is summarized in Algorithm 1.

0:  Inputs
1:  
2:  repeat
3:       forward pass a batch of size B
4:                 calc. MMS
5:            store smallest scores
6:              subset of of size b
7:        back prop. batch of size b
8:     
9:  until reaching final model accuracy
Algorithm 1 MMS-based training

5 Experiments

In this section333

All experiments were conducted using PyTorch; the code will be released on github upon acceptance of the paper.

, we report on the series of experiments we designed to evaluate the MMR’s ability to achieve a higher accuracy score, and the MMS selection method’s ability to achieve a faster convergence than the original training algorithms (the baseline) and data augmentation. The experiments were conducted on commonly used datasets and neural network models, in the vision and natural language processing (NLP) realms.

 Model Dataset Baseline Our MMR Change
ResNet-44 he2016deep CIFAR10 93.22% 93.83% 9.00%
VGG simonyan2014very CIFAR10 93.19% 93.34% 2.20%
WRN-28-10 + auto-augment + cutout zagoruyko2016wide CIFAR100 82.51% 83.52% 5.77%
VGG + auto-augment + cutout CIFAR100 73.93% 74.19% 1.00%
MobileNet howard2017mobilenets ImageNet 71.17% 71.44% 0.94%
QNLI 91.06% 91.48% 4.70%
SST-2 92.08% 92.43% 4.42%
BERTBASE devlin2018bert MRPC 90.68% 91.43% 8.05%
RTE 68.23% 69.67% 4.53%
QQP 87.9% 88.04% 1.16%
MNLI 84.5% 84.70% 1.29%
Table 1: Test accuracy results. Top1 for CIFAR10/100 datasets. Any relative change in error over the baseline is listed in percentage, and improvements higher than 4% are marked in bold. F1 scores are reported for QQP and MRPC. For MNLI, we report the average of the matched (with ) and miss-matched (with ) for both the baseline and our MMR.

Our experimental workbench is composed of CIFAR10, CIFAR100 krizhevsky2009learning and ImageNet imagenet_cvpr09 for image classification; Question NLI (QNLI) wang2018glue

, MultiNLI (MNLI)

williams2017broad and Recognizing Textual Entailment (RTE) bentivogli2009fifth for natural language inference; MSR Paraphrase Corpus (MRPC) dolan2005automatically

and Quora Question Pairs (QQP)

chen2018quora for sentence similarity; Stanford Sentiment Treebank-2 (SST-2) socher2013recursive for text classification.

5.1 Image Classification

CIFAR10 and CIFAR100.

These are image classification datasets that consist of color images from 10 or 100 classes, consisting of 50k training examples and 10k test examples. The last 5k images of the training set are used as a held-out validation set, as suggested in common practice. For our experiments, we used ResNet-44 he2016deep and WRN-28-10 zagoruyko2016wide architectures. We applied the original hyper parameters and training regime using a batch-size of 64. In addition, we used the original augmentation policy as described in he2016deep for ResNet-44, while adding cutout devries2017improved and auto-augment cubuk2018autoaugment

for WRN-28-10. Optimization was performed for 200 epochs (equivalent to

iterations) after which baseline accuracy was obtained with no apparent improvement.

ImageNet.

For large-scale evaluation, we used the ImageNet dataset imagenet_cvpr09, containing more than 1.2M images in 1k classes. In our experiments, we used MobileNet howard2017mobilenets architecture and followed the training regime established by goyal2017accurate in which an initial LR of 0.1 is decreased by a factor of 10 in epochs 30, 60, and 80, for a total of 90 epochs. We used a base batch size of over four devices and regularization over weights of convolutional layers as well as the standard data augmentation.

Improving accuracy via MMR


The MMR was added to the objective function as an additional regularization term, where is a trade-off factor between the cross-entropy loss and the regularization 444This formulation is equivalent to Equation 8, where . It is preferred because it leads to multiplying the regularization term by a small number and keeping the scaling factor of to be 1, thus avoiding gradient enlargement.:

To find the optimal , we used a grid search and found that a linear scaling of in the range of works best for CIFAR10/100 and static works best for ImageNet.

(a)
(b)
Figure 2: Training (dashed) and validation errors of CIFAR100 using the WRN28-10 neural network and comparing baseline training and our MMR approach. We use linear scale , starting with up to .

Table 1 demonstrates our final results when increasing the final model’s accuracy on CIFAR10 and CIFAR100. Specifically, we managed to improve baseline accuracy in ResNet-44 from to and from to in VGG. We also see a relative change in error of and , respectively, on the CIFAR10 dataset. Furthermore, we show a substantial decrease of in the error for CIFAR100 using the WRN-28-20 model (see Figure 2), raising its absolute accuracy by more than . Altogether, we observed a average decrease in error on all datasets.

Convergence speedup via MMS selective sampling


To test our hypothesis that using the MMS selection method we could accelerate training while preserving final model accuracy, we designed a new, more aggressive leaning-rate drop regime than the one used by the authors of the original paper. Figure 3 presents empirical evidence supporting out hypothesis. We compared the results of our MMS method against random selection 555Referred to as baseline with and without an early LR drop., and against hard-negative mining that prefers samples with low prediction scores yu2018loss; schroff2015facenet. For the latter, we used the implementation suggested by hofferinfer2train, termed ”NM-sample”, where the cross-entropy loss is used for the selection.

(a)
(b)
Figure 3: Validation errors of ResNet44, CIFAR10 (top) and WRN-28-10, CIFAR100 (bottom). We compared the baseline training, NM-sample selection (hard negative mining), and MMS (our) selection method using a faster regime. We ploted the regular regime baseline’s final errors as a dotted line for perspective. The MMS selection method achieves on par final test accuracy using fewer numbers of training steps.

For CIFAR10 and ResNet-44, we used the original LRs while decreasing them at steps , equivalent to epochs with a batch of size 64. As depicted in Figure 3 (top), we can see that our selection method indeed yields validation accuracy extremely close to the one reached by the baseline training scheme, with considerably fewer training steps. Specifically, we reached accuracy after merely steps (a minor drop of compared to the baseline). We also applied the early drop regime to the baseline configuration as well as to the NM-samples. Both failed to reach the desired model accuracy while suffering from a degradation of and , respectively.

Similarly, we applied the early LR drop scheme for CIFAR100 and WRN-28-10, using and decreasing steps equivalent to epochs , with batch of size 64. As depicted in Figure 3 (bottom), MMS accuracy reached with a drop of compared to the baseline, while almost halving the number of steps ( vs. ). On the other hand, the baseline and the NM-sample schemes failed to reach the desired accuracy after we applied a similar early drop regime. For the NM-sample approach, the degradation was the most significant, with a drop of compared to the final model accuracy, while the baseline drop was approximately .

These results are in line with the main theme of selective sampling that strives to focus training on more informative points. Training loss, however, can be a poor proxy for this concept. For example, the NM-sample selection criterion favors high loss scores, which obviously increases the training error, while our MMS approach selects uncertain points, some of which might be correctly classified. Others might be mis-classified by a small margin, but they are all close to the decision boundary, and hence useful for training.

5.2 Natural Language Classification Tasks

To challenge our premise that we could achieve a higher accuracy score, we examined our MMR on an NLP-related model and datasets. In particular, we used the BERTBASE model devlin2018bert with 12 transformer layers, a hidden dimensional size of 768 and 12 self-attention heads. Fine-tuning was performed using the Adam optmizer as in the pre-training, with a dropout probability of 0.1 on all layers. Additionally, we used a LR of over three epochs in total for all the tasks. We used the original WordPiece embeddings wu2016google with a 30k token vocabulary. For our method, similarly to the image classification task, we also used the factor in the objective function, and found via a grid search, to be the optimal 666We applied only to evaluate our method’s accuracy with the miss-matched MNLI..

We performed experiments on a variety of supervised tasks, specifically by applying downstream task fine-tuning on natural language inference, semantic similarity, and text classification. All these tasks are available as part of the GLUE multi-task benchmark wang2018glue.

Natural Language Inference

The task of natural language inference (NLI) or recognizing textual entailment means that when a pair of sentences is given, the classifier decides whether or not they contradict each other. Although there has been a lot of progress, the task remains challenging due to the presence of a wide variety of phenomena such as lexical entailment, coreference, and lexical and syntactic ambiguity. We evaluate our scheme on three NLI datasets taken from different sources, including transcribed speech, popular fiction, and government reports (MNLI), Wikipedia articles (QNLI) and news articles (RTE).

As shown in Table 1, our scheme using the regularization term outperformed baseline results on all the three tasks. We achieved absolute improvement of up to on RTE and a relative change in error of . On QNLI and MNLI we also achieved higher scores of (accuracy) and (F1), outperforming the baseline results by and , respectively.

Semantic Similarity

This task involves predicting whether two sentences are semantically equivalent by identifying similar concepts in both sentences. It can be challenging for a language model to recognize syntactic and morphological ambiguity as well as compare the same ideas using different expressions or the other way around. We evaluated our approach on QQP and MRPC downstream tasks, outperforming baseline results as can be seen in Table 1. On MRPC in particular, we achieved a relative change of more than .

Text Classification

Lastly, we evaluated our method on the Stanford Sentiment Treebank (SST-2), which is a binary single-sentence classification task consisting of sentences extracted from movie reviews with human annotations of their sentiment. Our approach outperformed the baseline by a relative error change of .

Overall, applying MMR boosted the accuracy in all the reported tasks. This indicates that our approach works well for different tasks from various domains.

6 Discussion

We studied a multi-class margin analysis for DNNs and use it to devise a novel regularization term, the multi-margin regularization (MMR). Similarly to previous formulations, the MMR aims at increasing the margin induced by the classifiers, and it is derived directly, for each sample, from the true class and its most competitive class. The main difference between the MMR and common regularization terms is that MMR is scaled by , which is the maximal norm of the samples in the feature space. This ensures a meaningful increase in the margin that is not induced by a simple scaling of the feature space. Additionally, weight differences are minimized rather than the commonly used determinant or other norms of . Lastly, MMR in formulated and performed over the margin distribution to compensate for class imbalance in the regularization term. The MMR can be incorporated with any empirical risk loss and it is not restrictive to hinge loss or cross-entropy losses. And indeed, using MMR, we demonstrate improved accuracy over a set of experiments in images and text.

Additionally, the multi-class margin analysis enables us to propose a selective sampling method designed to accelerate the training of DNNs. Specifically, we utilized uncertainty sampling, where the criterion for selection is the distance to the decision boundary. To this end, we introduced a novel measurement, the minimal margin score (MMS), which measures the minimal amount of displacement an input should undergo until its predicted classification is switched. For multi-class linear classification, the MMS measure is a natural generalization of the margin-based selection criterion.

Our selection criterion was inspired by the active learning method, but our goal, to accelerate training, is different. Active learning is mainly concerned with labeling cost. Hence, it is common to keep on training until convergence, before turning to select additional examples to label. When the goal is merely acceleration, labeling cost is not a concern, and one can adapt a more aggressive protocol and re-select a new batch of examples at each training step.

The MMS measure does not use the labels. Thus, it can be used to select samples in an active learning setting as well. Similarly to jiang2019predicting the MMS measure can be implemented at other layers in the deep architecture. This enables selection of examples that directly impact training at all levels. The additional computation associated with such a framework makes it less appealing for the purpose of acceleration. For active learning, however, it may introduce an additional gain. The design of a novel active learning method is left for further study.

References