Stochastic Batch Augmentation with An Effective Distilled Dynamic Soft Label Regularizer

06/27/2020 ∙ by Qian Li, et al. ∙ Xi'an Jiaotong University NetEase, Inc The Hong Kong University of Science and Technology 6

Data augmentation have been intensively used in training deep neural network to improve the generalization, whether in original space (e.g., image space) or representation space. Although being successful, the connection between the synthesized data and the original data is largely ignored in training, without considering the distribution information that the synthesized samples are surrounding the original sample in training. Hence, the behavior of the network is not optimized for this. However, that behavior is crucially important for generalization, even in the adversarial setting, for the safety of the deep learning system. In this work, we propose a framework called Stochastic Batch Augmentation (SBA) to address these problems. SBA stochastically decides whether to augment at iterations controlled by the batch scheduler and in which a ”distilled” dynamic soft label regularization is introduced by incorporating the similarity in the vicinity distribution respect to raw samples. The proposed regularization provides direct supervision by the KL-Divergence between the output soft-max distributions of original and virtual data. Our experiments on CIFAR-10, CIFAR-100, and ImageNet show that SBA can improve the generalization of the neural networks and speed up the convergence of network training.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

For quite a few years, deep learning systems have persistently enabled significant improvements in many application domains, such as object recognition from vision, speech, and language and are now widely used both in research and industry. However, these systems perform well only when evaluated on instances very similar to those from the training set. When evaluated on slightly different distributions, neural networks often provide incorrect predictions with strikingly high confidence. This is a worrying prospect since deep learning systems are increasingly being deployed in settings where the environment is noisy, subject to domain shifts, or even adversarial attacks.

Recent research to address these issues has focused on [19, 15, 16], which involves producing new valid instances in input or latent space. Although being successful, we have observed one limitation: it largely ignores the connection between the synthesized data and the original data, since they are used for training without interaction. Therefore, for the distribution information that the synthesized samples are surrounding the original sample, the model behavior is not explicitly optimized.

To this end, we propose a dynamic soft label regularization in augmented batch to achieve this optimization that speeds up the convergence of network training, leveraging the intermediate model before final iteration and improves the generalization ability of the ultimate model. Further to reduce the redundant and high computational cost of Batch Augmentation (BA) [7]

and allow the model to acquire higher the capability of generalization in a short time, we induce the stochasticity on batch augmentation inspired by the epsilon-greedy exploration in reinforcement learning. Finally, we propose a framework dubbed Stochastic Batch Augmentation (SBA) in which batch augmentation is performed in latent space at randomly selected iterations instead of all iterations. Moreover, so as to incorporate the prior knowledge that the derived virtual samples have a strong similarity relative to original samples, here the similarity is characterized by the KL-Divergence between their predicted distributions of neural network.

The main contributions of the paper include the following:

  • It proposes a general framework named Stochastic Batch Augmentation (SBA), which is composed of two major ingredients: stochastic batch scheduler and distilled dynamic soft label regularization.

  • For the first time, it proposes to employ the output distribution of reference data as the soft label to guide the neural network to fit and recognize the vicinity of the reference data. In this way, we can speed up the learning process and acquire better generalization over other methods.

  • Experimental results demonstrate that SBA can significantly improve the generalization of the neural network. In fact, it often can achieve better prediction performance due to the majority vote among the virtual and original data predictions in test time.

2 Related Work

Figure 1: Overview of our framework Stochastic Batch Augmentation. (1) Stochastic Batch Scheduler is a Bernoulli process whose observed value decides whether to perform augmentation at iteration ; (2) Distilled dynamic soft label is used to guide the model to fit the vicinity of raw sample so that model can achieve better generalization quickly.

A common practice in training modern neural networks is to use data augmentations – multiple instances of input samples, each with a different transformation applied to it. Common forms of augmentation include random crops, horizontal flipping, and color augmentation for image [1], enhancing generalization to translation, reflection, and illumination, respectively. Closely related to the topic of our work is batch augmentation (BA) [7]: replicating instances of samples within the same batch with different data augmentations.

Data augmentations were repeatedly found to provide efficient and useful regularization, often accounting for a significant portion of the final generalization performance. One of the most common approaches of regularization includes dropout [13], and the information bottleneck [17]

, which involves regularizing deep networks by perturbing their hidden representations. Another related regularization technique called Mixup was introduced by

[19]. Mixup uses a mixed input from two separate samples with different classes and uses as target their labels mixed by the same amount.

3 Methodology

In this section, we present the framework named Stochastic Batch Augmentation (SBA), where the mini-batch is augmented at randomly selected epochs. The framework is illustrated in Figure

1

. The network is equipped with a scheduler whose status is stochastically updated at the beginning of each epoch. And that scheduler decides whether to perform batch augmentation (BA) in the new epoch. During training, the intermediate output distribution of the raw sample is used to construct a supervision signal via the Kullback-Leibler divergence respect to predicted distributions of the virtual point, which takes advantage of the informative prediction of the raw point.

3.1 Stochastic Batch Augmentation

Suppose we have dataset with features and label represents the label for

-class classification problem. Denote a classifier by

, with being the shared parameter. The standard training goal is to obtain the classifier

by minimizing the average loss function

over the underlying data distribution , following the principle of expected risk minimization:

(1)

Equation (1) is often be approximated by (ERM) on the collected samples due to the distribution being unknown, as follow.

(2)

While efficient to compute, the empirical risk monitors the behavior of only at a finite set of examples, causing the over-fitting and sample memorization in the large neural network [14]. Hence, we follow the principle of (VRM) [2] during the training process, aiming to improve the adversarial robustness, i.e., to preserve the label consistency in small perturbation neighborhood. The VRM principle targets to minimize the on the virtual data pair sampled from a vicinity distribution generated from the original training set distribution and consequently, the VRM-based training objective can be described as:

(3)

However, Equation (3) still can’t directly come into effect in our case, due to the batch scheduler. In general, the role of scheduler should be included in our optimization objective. Before that, we describe the workflow of the framework. First, select a random layer from a set of eligible layers in the neural network, then process the input batch until reaching that layer. Whether to perform batch augmentation operation at layer depends on the status of the inside scheduler, whose status evolvement can be described by the Bernoulli discrete-time stochastic process

with fixed probability

, which takes only two values, canonically and corresponding non-augmentation and augmentation respectively. Finally, continue processing from the hidden state to the output.

More formally, we can redefine our neural network function in terms of , where denotes the mapping from an input sample to its internal DNN representation at layer , and denotes the part mapping such hidden representation at layer to the output . Let denote input batch, and be the activitions of -th layer. The augmented batch of is . Assuming that each sample in and is i.i.d, thus our optimization objective can be expressed as :

(4)

where represents the specific status of the scheduler at the iteration . And it is worthwhile to realize that the whole process of stochastic optimization for Equation (4) is biased towards the vicinity distribution or original distribution determined by the probability .

3.2 Stochastic Batch Scheduler

The essential of the batch scheduler is the generation of random variable

, which decides whether to perform augmentation at each iteration. For simplicity, we adopt Bernoulli process , , ,

is used to control the skewness of that distribution. For the Bernoulli process, the probability of performing augmentation for each iteration is fixed and predefined, which is the simplest case in comparison to constructing a delicate changeable probability with respect to iteration in the way of learning rate scheduler.

The idea is motivated by the epsilon-greedy, a simple heuristics for exploration in reinforcement learning (RL). In RL, exploration, and exploitation is a well-known tradeoff. The exploitation only uses the learned policy (maybe sub-optimal) to take action at

-th step. In neural network training, that could correspond to the learned weights and biases up to -th epoch. Instead, exploration encourages the agent to explore the faced environment. In our case, we model the data manifold at -th layer as an environment with uncertainty. Thus, it is immediately obvious that the current trained neural network model is the agent, and the exploration strategy is our stochastic batch scheduler.

The stochastic scheduler makes our scheme crucially different from [7, 15, 11], in which augmentation is deterministic to be performed at all iterations, ignoring the already learned capability of generalization in neural network, i.e., the intermediate learning parameter values, for example, the updated weight after 50 epochs even if the total epochs is 200, which can bring out higher generalization in comparison to the randomly initialized weight while lower than the ultimate converged weight values. So, they lead to a high computational cost relatively despite the performance improvement. In particular, when the selected -th layer is close to the input layer, the overhead of extra computational cost is vastly obvious. Note that , it is a fair schedule with respect to original and augmented mini-batch, which means that , have the same probability of being transformed by the latter part of network .

3.3 Dynamic Soft Label Regularization

The -th layer input batch is the reference of virtual batch and is sampled from the vicinity distribution . is batch size, and is width of the -th layer.

In order to cause a robust cover of underlying local regime of each instance in , we combine the truncated Gaussian noise [4] and Dropout [13] to construct a mixed vicinity distribution as . Given a reference point , the corresponding set of virtual points can be represented by and as below:

(5)
(6)

where , are the augmented number of instances respectively, is the total augmented fold. Specifically, the two kinds of virtual points can be written as:

(7)
(8)

where

is a binary vector with entry

denoting the deletion of a feature and denotes the element-wise product. The is a vector of zero-mean independent Gaussian noise and is clipped to range , is the maximum scale for each components in generated noise. Note that we adopt a random normalized basis matrix for random direction projection, in order to cause the diversity of derived points. And that basis matrix is regenerated when the random variable is set to one, aiming to avoid the unnecessary basis matrix update in those iterations with non-augmentation.

In contrast to [18]

that treat virtual points as an uniform distributions within

-radius ball of reference point, then use the one-hot encoding of the ground truth label

of reference point to compute the hard loss of derived virtual points, respect to the output distribution . More concretely, it can be expressed by

(9)

which is the standard cross entropy loss of classification for virtual point similar to reference point and is the cross entropy function.

Motivated by the paradigm of knowledge distillation, i.e., teacher-student training [6] and adversarial training in which networks should regularly behave in the neighborhood of training data, whether in the input space or latent space. Hence we propose to treat the soft-max output of reference point as “soft” label (the standard notion of soft label is the output prediction of another trained large model with high generalization) which contains more information for modeling relations between the reference and derived data, and acts as hit for the network itself as shown in Figure 1.

Different from the conventional knowledge distillation in which the soft label for a weaker student model, comes from the wiser teacher, however the soft label of our method is provided by the student itself, which also plays the role of teacher on account of the increasing generalization during the training phase as indicated by the learning curve. In this sense, we can view the updating learning parameters as the immature recognition ability of classifier resulting from the previous training iterations.

1:  for  to  do
2:     Sample from Bernoulli process ,
3:     Update parameters by primary objective -
4:     if =1 then
5:        Update parameters by the conservative constraint -
6:     end if
7:  end for
Algorithm 1 Training procedure of SBA

Exploiting it to guide the model to learn to fit and recognize the virtual points more quickly, thus leading towards faster convergence. And by doing so, the eventually trained model is to be more insensitive to adversarial perturbations. Note that the soft label is dynamically changed during training since learning parameters are continuously updated with the iterations. In another aspect, it essentially incorporates explicit similarity between the reference and virtual points in the local vicinity, where the relevant notion of similarity is based on the output distribution of the neural network. Thus, the so-called soft loss of virtual points can be represented by:

(10)

where

is Kullback–Leibler divergence, a measure of how one probability distribution is different from another reference probability distribution. Note that this method can be regarded as a conservative constraint on the predicted distributions of neural network model in the local area of each point in the transformed space.

3.4 Learning and Inference

Training is to minimize the Equation (4). Because of the induced run-time stochasticity on mini-batch augmentation, the training iteration objective and parameters update differ a lot in form and meaning, which depending on the observed value of at iteration . With the help of indicator function , the optimization problem at iteration can be written as follows:

(11)

where is the constant to balance the two terms in Equation (11) which is easy to be solved by any SGD algorithm. is the output distribution of reference point and is its corresponding output distribution of virtual point. The equation consists of two components: the first part minimizes the objective function on raw mini-batch, aiming to achieve full utilization of the mini-batch with more weight. The second component is a similarity constraint which limits the dramatic changes of predictions in the local vicinity of reference point, which is dropped when no augmentation performed (i.e., ).

Figure 2: Performance result under different parameter on CIFAR-10 (above) and CIFAR-100 (below).
Figure 3: Performance comparison for SBA and BA on the CIFAR-10 (above) and CIFAR-100 (below) datasets using VGGNet-16. X-axis: -th; Y-axis: Calculated cost ratio (left), Test accuracy (right).

To understand the effect of second term further, consider the gradient of it with respect to the learning parameters :

(12)

where the upper script indexes the component of vector and the small is the difference between the and which must be nearly identical, conforming the principle of adversarial training as mentioned in previous section. Then utilize the linear approximation of :

(13)

the Equation (12) then can be simplified as:

(14)

In Equation (14), the inner term depicts the inconsistency of the network’s behaviour in the neighbourhood of sample , so in this sense, the is the average inconsistency over the raw mini-batch. And the supervisory signal leverages the complementary information in the predicted distribution [3] to minize the inconsistency, varied from the first primary objective wherein solely exploits the information from the ground-truth class for supervision as follows:

(15)

where the upper script denotes the true class of sample . The Algorithm describe the training mechanism.

In inference stage, we use the simple majority vote strategy to decide the final outcome of test sample.

Dataset Network Baseline Cutout Cutout + BA SBA Relative Impro.
CIFAR-10 ResNet44 6.92% 6.28% 4.56% () 3.23% () 53.32%
VGGNet-16 16.25% 6.17% 4.69% () 2.73% () 83.20%
AlexNet 23.88% 21.34% 20.00% () 18.03% () 24.50%
CIFAR-100 ResNet44 28.01% 27.04% 25.85% () 24.61% () 12.14%
VGGNet-16 38.67% 27.00% 24.69% () 22.19% () 42.62%
ImageNet ResNet50 23.73% 23.16% () 22.72% () 4.26%
VGGNet-16 23.78% 21.29% () 19.37% () 18.54%
AlexNet 41.69% 37.73% () 34.40% () 17.49%
Table 1: Error rates obtained by the testing methods on CIFAR-10, CIFAR-100 and ImageNet: Relative improvement in error/ppl over baseline is listed in percentage.

4 Experiments

4.1 Experimental Setup

Our experiments encompass a range of datasets CIFAR-10, CIFAR-100 [8], ImageNet [10], and different kinds of network architectures VGGNet-16 [12], AlexNet [9], ResNet44, and ResNet50. We compare the original training baseline, as well as the following methods: Cutout [5], and Batch Augmentation (BA) [7].

Throughout our experiments, for each of the models, unless explicitly stated, we tested our approach using the original training regime and data augmentation described by its authors. To support our claim, we did not change the learning rate used or the number of epochs. To reduce the unnecessary hyper-parameter tuning and make strong baselines for different datasets and networks, we adopt the reported best hyper-parameters for in [7], while it has shown that the batch augmentation can improve the models’ performance consistently for a wide range of augmentation fold. As for the regularization parameter , we optimize through a simple grid-search procedure on held-out validation data over the ranges . For all the results, the reported performance value (accuracy or error rate) is the median of the performance values obtained in the final 5 epochs.

4.2 Impact of Hyper-parameters

As mentioned above, Stochastic Batch Augmentation (SBA) has one important hyper-parameter: . We first implement a group of experiments to quantify the effect that the parameter has on model classification accuracy. More specifically, we allow the parameter to range from to with tenfold increments, both at training and testing time.

We present results on SBA using the Alexnet, VGGNet-16, and ResNet44 architecture, corresponding to different parameter . These results for CIFAR-10 and CIFAR-100 are in Figure 2. The Alexnet and ResNet44 models achieve the highest test accuracy when parameter is , and the VGGNet-16 models achieve the highest test accuracy when parameter is on CIFAR-10. The VGGNet-16 and ResNet44 models achieve the highest test accuracy when parameter is on CIFAR-100. Therefore, we adopt and as parameter expectation for experiments that follow, as a constant to balance the two terms in Equation (11). We can observe that the hyper-parameter can perform well in the range , without significant efficiency on the generalization of networks.

4.3 Efficiency of SBA

Data Augmentation is a general method, so the SBA and BA can be added to any set of the latent transformed feature while the metric matters for the latent space. In a manner similar to Dropout, our experiments typically apply the augmentation to fully connected layers towards the deep end of the network due to features in the last few layers, which are more likely to lie in the Euclidean space. To evaluate and compare the efficiency of representations learned with SBA and BA, we train VGGNet-16 models on the CIFAR-10 and CIFAR-100.

Figure 3 shows the experimental results. and represent the runtime required to achieve convergence (the accuracy is almost stable) for BA and SBA, respectively. Validation convergence speed of SBA has noticeably improved compared to BA (left), with a significant increase in final validation classification accuracy (right). We can observe that SBA can decrease the redundant computation. SBA not only can faster convergence be achieved with stochastic batch scheduler that can eliminate unnecessary computation within each iteration, but increases accuracy as well. Note that the test accuracy of BA and SBA both decrease. This may be due to the improper metric in the corresponding transformed space (e.g., induced noise is undesirable). However, the SBA can consistently outperform BA.

Moreover, we managed to achieve high validation accuracy much quicker with SBA. We trained a ResNet44 with SBA on CIFAR-10 for two-thirds of the iterations needed for the BA, using a larger learning rate and faster learning rate decay schedule. This indicates not only an accuracy gain but a potential runtime improvement for given hardware. We can infer that the main reason is that traditional batch augmentation is deterministic to be performed at each iteration per epoch, ignoring the already learned capability of generalization in the neural network. So, they lead to high computational cost relatively. In particular, when the selected -th layer is close to the input layer, the overhead of extra computational cost is more obvious.

4.4 Performance of SBA

To show the effectiveness of our method, we empirically investigate the performance of SBA on three datasets: CIFAR-10, CIFAR-100, and ImageNet. Our experiment compares SBA to the baseline model and other state-of-art methods. As can be seen from Table 1, the relative error reduction of SBA over the baseline is at least 4.26%, and with a large margin in some cases. For example, on the CIFAR-10 dataset, the relative error reduction achieved by SBA is more than 24%. VGGNet-16 trained with SBA achieves an error rate of 2.73% on CIFAR-10, which is even 13.52% better than the baseline. Notice that this gain is much larger than the previous gains obtained by Cutout + BA against baseline (+2.23%), and by Cutout against baseline (+3.44%).

Our proposed SBA achieves an error rate of 34.40% with AlexNet on ImageNet, which outperforms the Cutout + BA by more than 3.33%. Overall, our results justify that SBA can improve the generalization of the neural network, which originates from the explicit data augmentation on the latent space as well as the conservative constraint on the predicted distributions for virtual samples.

5 Conclusion

In this paper, we have presented a framework named Stochastic Batch Augmentation (SBA) for improving the generalization of deep neural network. In SBA, we significantly reduce the computational cost by randomized batch augmentation scheme. We also introduce a distilled dynamic soft label regularization technique for learning, which explicitly incorporates the conservative constraint on the predicted distributions. The experimental results on three standard datasets using standard network architectures show the superiority of our proposed framework.

Acknowledgments

This work is supported by National Natural Science Foundation of China under Grant No.61672421.

References

  • [1] P. Bartlett, F. C. Pereira, C. J. Burges, L. Bottou, and K. Q. Weinberger (2013) Advances in neural information processing systems 25: 26th annual conference on neural information processing systems 2012. Curran Associates, Incorporated. Cited by: §2.
  • [2] O. Chapelle, J. Weston, L. Bottou, and V. Vapnik (2001) Vicinal risk minimization. In Advances in neural information processing systems, pp. 416–422. Cited by: §3.1.
  • [3] H. Chen, P. Wang, C. Liu, S. Chang, J. Pan, Y. Chen, W. Wei, and D. Juan (2019) Complement objective training. In 7th International Conference on Learning Representations, ICLR 2019, New Orleans, Cited by: §3.4.
  • [4] L. Csató and M. Opper (2001) Sparse representation for gaussian process models. In Advances in neural information processing systems, pp. 444–450. Cited by: §3.3.
  • [5] T. DeVries and G. W. Taylor (2017)

    Improved regularization of convolutional neural networks with cutout

    .
    arXiv preprint arXiv:1708.04552. Cited by: §4.1.
  • [6] G. Hinton, O. Vinyals, and J. Dean (2015) Distilling the knowledge in a neural network. arXiv preprint arXiv:1503.02531. Cited by: §3.3.
  • [7] E. Hoffer, T. Ben-Nun, I. Hubara, N. Giladi, T. Hoefler, and D. Soudry (2019) Augment your batch: better training with larger batches. arXiv preprint arXiv:1901.09335. Cited by: §1, §2, §3.2, §4.1, §4.1.
  • [8] A. Krizhevsky, G. Hinton, et al. (2009) Learning multiple layers of features from tiny images. Technical report Citeseer. Cited by: §4.1.
  • [9] A. Krizhevsky, I. Sutskever, and G. E. Hinton (2012) Imagenet classification with deep convolutional neural networks. In Advances in neural information processing systems, pp. 1097–1105. Cited by: §4.1.
  • [10] O. Russakovsky, J. Deng, H. Su, J. Krause, S. Satheesh, S. Ma, Z. Huang, A. Karpathy, A. Khosla, M. Bernstein, et al. (2015) Imagenet large scale visual recognition challenge.

    International journal of computer vision

    115 (3), pp. 211–252.
    Cited by: §4.1.
  • [11] T. Shimada, S. Yamaguchi, K. Hayashi, and S. Kobayashi (2019)

    Data interpolating prediction: alternative interpretation of mixup

    .
    arXiv preprint arXiv:1906.08412. Cited by: §3.2.
  • [12] K. Simonyan and A. Zisserman (2015) Very deep convolutional networks for large-scale image recognition. In 3rd International Conference on Learning Representations, ICLR 2015, San Diego, CA, USA, May 7-9, 2015, Conference Track Proceedings, Cited by: §4.1.
  • [13] N. Srivastava, G. Hinton, A. Krizhevsky, I. Sutskever, and R. Salakhutdinov (2014) Dropout: a simple way to prevent neural networks from overfitting.

    The journal of machine learning research

    15 (1), pp. 1929–1958.
    Cited by: §2, §3.3.
  • [14] C. Szegedy, W. Zaremba, I. Sutskever, J. Bruna, D. Erhan, I. J. Goodfellow, and R. Fergus (2014) Intriguing properties of neural networks. In 2nd International Conference on Learning Representations, ICLR, Cited by: §3.1.
  • [15] V. Verma, A. Lamb, C. Beckham, A. Courville, I. Mitliagkis, and Y. Bengio (2018) Manifold mixup: encouraging meaningful on-manifold interpolation as a regularizer. stat 1050, pp. 13. Cited by: §1, §3.2.
  • [16] V. Verma, A. Lamb, C. Beckham, A. Najafi, I. Mitliagkas, D. Lopez-Paz, and Y. Bengio (2019) Manifold mixup: better representations by interpolating hidden states. In Proceedings of the 36th International Conference on Machine Learning, ICML, pp. 6438–6447. Cited by: §1.
  • [17] A. C. Wilson, R. Roelofs, M. Stern, N. Srebro, and B. Recht (2017) The marginal value of adaptive gradient methods in machine learning. In Advances in Neural Information Processing Systems, pp. 4148–4158. Cited by: §2.
  • [18] H. Yang, J. Zhang, H. Cheng, W. Wang, Y. Chen, and H. Li (2018) Bamboo: ball-shape data augmentation against adversarial attacks from all directions. Cited by: §3.3.
  • [19] H. Zhang, M. Cissé, Y. N. Dauphin, and D. Lopez-Paz (2018) Mixup: beyond empirical risk minimization. In 6th International Conference on Learning Representations, ICLR, Cited by: §1, §2.