Improve SGD Training via Aligning Mini-batches

02/23/2020 ∙ by Xiangrui Li, et al. ∙ Wayne State University 7

Deep neural networks (DNNs) for supervised learning can be viewed as a pipeline of a feature extractor (i.e. last hidden layer) and a linear classifier (i.e. output layer) that is trained jointly with stochastic gradient descent (SGD). In each iteration of SGD, a mini-batch from the training data is sampled and the true gradient of the loss function is estimated as the noisy gradient calculated on this mini-batch. From the feature learning perspective, the feature extractor should be updated to learn meaningful features with respect to the entire data, and reduce the accommodation to noise in the mini-batch. With this motivation, we propose In-Training Distribution Matching (ITDM) to improve DNN training and reduce overfitting. Specifically, along with the loss function, ITDM regularizes the feature extractor by matching the moments of distributions of different mini-batches in each iteration of SGD, which is fulfilled by minimizing the maximum mean discrepancy. As such, ITDM does not assume any explicit parametric form of data distribution in the latent feature space. Extensive experiments are conducted to demonstrate the effectiveness of our proposed strategy.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Recently, deep neural networks (DNNs) have achieved remarkable performance improvements in a wide range of challenging tasks in computer vision

Krizhevsky et al. (2012); He et al. (2016); Huang et al. (2019)

, natural language processing

Sutskever et al. (2014); Chorowski et al. (2015) and healthcare informatics Miotto et al. (2018)

. Modern architectures of DNNs usually have an extremely large number of model parameters, which often outnumbers the available training data. Recent studies in theoretical deep learning have shown that DNNs can achieve good generalization even with the over-parameterization

Neyshabur et al. (2017); Olson et al. (2018). Although over-parameterization may not be very damaging to DNN’s overall generalizability, DNNs can still overfit the noise within the training data (e.g. sampling noise in data collection) due to its highly expressive power. This makes DNNs sensitive to small perturbations in testing data, for example, adversarial samples Goodfellow et al. (2014). To alleviate overfitting of DNNs, many methods have been proposed. These include classic ones such as early stopping, and regularization Goodfellow et al. (2016), and more recent ones such as dropout Srivastava et al. (2014)

, batch normalization

Ioffe and Szegedy (2015) and data-augmentation types of regularization (e.g. cutout DeVries and Taylor (2017), shake-shake Gastaldi (2017)

). There are also other machine learning regimes that can achieve regularization effect such as transfer learning

Pan and Yang (2009) and multi-task learning Caruana (1997); Ruder (2017).

For supervised learning, DNNs can be viewed as a feature extractor followed by a linear classifier on the latent feature space, which is jointly trained using stochastic gradient descent (SGD). When DNNs overfit the training data and a large gap between training and testing loss (e.g. cross-entropy loss for classification) is observed, from the feature learning perspective, it implies mismatching of latent feature distributions between the training and testing data extracted by the feature extractor. Regularization methods mentioned above can reduce such mismatching and hence improve DNNs performance, as the linear classifier can accommodate itself to the latent features to achieve good performance.

In this paper, we propose a different regularization method, called In-Training Distribution Matching (ITDM), that specifically aims at reducing the fitting of noise for feature extraction during SGD training. The idea behind ITDM is motivated by a simple interpretation of the mini-batch update (in addition to the approximation perspective).

Specifically, in each iteration of SGD, a mini-batch of samples is sampled from the training data . The gradient of loss function is calculated on the mini-batch, and network parameter is updated via one step of gradient descent (learning rate ):


This update (Eq.(1)) can be interpreted from two perspectives. (1) From the conventional approximation perspective, the true gradient of the loss function (i.e. gradient on the entire training data) is approximated by the mini-batch gradient. As each mini-batch contains useful information for the learning tasks and its gradient computation is cheap, large DNNs can be efficiently and effectively trained with modern computing infrastructures. Moreover, theoretical studies in deep learning have shown that the noisiness in the estimated gradient using the randomly sampled mini-batch plays a crucial role in DNNs generalizability Ge et al. (2015); Daneshmand et al. (2018). (2) Eq. (1) can also be interpreted as an exact gradient descent update on the mini-batch. In other words, SGD updates network parameter to achieve maximum improvement in fitting the mini-batch. As each mini-batch is noisy, such exact update inevitably introduces the undesirable mini-batch-dependent noise. In terms of feature learning, the DNN feature extractor can encode the mini-batch noise into the feature representations.

These two perspectives enable us to decompose the SGD update intuitively as:

A natural question then to ask is “Can we reduce the mini-batch overfitting?” to reduce the mini-batch dependence in SGD update Eq. (1). One solution to this problem is batch normalization (BN) Ioffe and Szegedy (2015). In their seminal paper, the internal covariate shift is observed due to the distribution shift of the activation of each hidden layer from mini-batch to mini-batch. Under our decomposition, this phenomenon is closely related to the mini-batch overfitting as networks have to adjust parameter to fit the mini-batches. To reduce the distribution shift, Ioffe and Szegedy (2015)

introduces the BN layer that fixes the means and variances of activations of hidden layers.

Different from BN, the proposed ITDM directly reduces the mini-batch overfitting by matching its latent feature distribution with another mini-batch. In this paper, we only consider the feature representation from the last hidden layer. Ideally, if the distribution of latent feature is known as a prior, we could explicitly match the mini-batch feature with via maximum likelihood. However, in practice, is not known or does not even have an analytic form. To tackle this problem, we utilize the maximum mean discrepancy (MMD) Gretton et al. (2012)

from statistical hypothesis testing for the two-sample problem. Our motivation of using MMD as the matching criterion is:

if the SGD update using one mini-batch A is helpful for DNNs learning good feature representations with respect to the entire data, then for another mini-batch B, the mismatch of latent feature distributions between A and B should not be significant. In this way, we can reduce mini-batch overfitting by forcing accommodation of SGD update to B and reducing dependence of the network on A. In terms of model training, MMD has two advantages: (1) it enables us to avoid the presumption for and (2) the learning objective of MMD is differentiable which can be jointly trained with

by backpropagation. Note that ITDM is not a replacement of BN. In fact, ITDM can benefit from BN when BN helps improving the feature learning for DNNs.

We summarize our contributions as follows. (1) We propose a training strategy ITDM for training DNNs. ITDM augments conventional SGD with regularization effect by additionally forcing feature matching of different mini-batches to reduce mini-batch overfitting. ITDM can be combined with existing regularization approaches and applied on a broad range of network architectures and loss functions. (2) We conduct extensive experiments to evaluate ITDM. Results on different benchmark datasets demonstrate that training with ITDM can significantly improve DNN performances, compared with conventional training strategy (i.e. perform SGD only on the loss function).

2 Related Work

In this section, we first review regularization methods in deep learning. Our work utilizes MMD and hence is also related to the topic of distribution matching, so we also review its related works that is widely studied under the context of domain adaption and generative modeling.

With limited amount of training data, training DNNs with a large number of parameters usually requires regularization to reduce overfitting. Those regularization methods include class ones such as -norm penalties and early stopping Hastie et al. (2009); Goodfellow et al. (2016). For deep learning, many new approaches are proposed motivated by the SGD training dynamics. For example, dropout Srivastava et al. (2014) and its variants Gao et al. (2019); Ghiasi et al. (2018)

achieves regularization effect by reducing the co-adaption of hidden neurons of DNNs. In the training process, dropout randomly sets some hidden neurons’ activation to zero, resulting in an averaging effect of a number of sub-networks.

Ioffe and Szegedy (2015) proposes batch normalization (BN) to reduce the internal covariate shift caused by SGD. By maintaining the mean and variance of mini-batches, BN regularizes DNNs by discouraging the adaption to the mini-batches. Label smoothing Szegedy et al. (2016) is another regularization technique that discourage DNNs’ over-confident predictions for training data. Our proposed ITDM is partially motivated by the covariate shift observed by Ioffe and Szegedy (2015). ITDM achieves regularization by reducing DNN’s accommodation to each mini-batch in SGD as it is an exact update for that mini-batch.

To match the distribution of different mini-batches, ITDM uses MMD as its learning objective. MMD Gretton et al. (2007, 2012)

is a probability metric for testing whether two finite sets of samples are generated from the same distribution. With the kernel trick, minimizing MMD encourages to match all moments of the data empirical distributions. MMD has been widely applied in many machine learning tasks. For example,

Li et al. (2015) and Li et al. (2017) use MMD to train unsupervised generative models by matching the generated distribution with the data distribution. Another application of MMD is for the domain adaption. To learn domain-invariant feature representations, Long et al. (2015) uses MMD to explicitly match feature representations from different domains. Our goal is different from those applications. In ITDM, we do not seek exact distribution matching. Instead, we use MMD as a regularization to improve SGD training.

3 In-Training Distribution Matching

In this section, we first provide an introduction of maximum mean discrepancy for the two-sample problem from the statistical hypothesis testing. Then we present our proposed ITDM for training DNNs using SGD, along with some details in implementation.

Figure 1: Training and testing cross-entropy for QMNIST, KMNIST and FMNIST.
Figure 2: T-SNE Maaten and Hinton (2008) plot for training CNN with and without ITDM-j using cross-entropy as loss function on the standard testing data of QMNIST, KMNIST and FMNIST. Compared with the conventional training that only performs SGD on the loss function, ITDM has slightly clearer margin and boundaries in the visualization.

3.1 Maximum Mean Discrepancy

Given two finite sets of samples and , MMD Gretton et al. (2007, 2012) is constructed to test whether and are generated from the same distribution. MMD compares the sample statistics between and , and if the discrepancy is small, and are then likely to follow the same distribution.

Using the kernel trick, the empirical estimate of MMD Gretton et al. (2007) w.r.t and can be rewritten as:

where is a kernel function. Gretton et al. (2007) shows that if is a characteristic kernel, then asymptotically MMD if and only and are generated from the same distribution. A typical choice of is the Gaussian kernel with bandwidth parameter :

The computational cost of MMD is . In ITDM, this is not problematic as typically only a small number of samples in each mini-batch (e.g. 100) is used in SGD.

3.2 Proposed ITDM

The idea of ITDM, as explained in Section 1, is to reduce the DNN adaption to each mini-batch if we view the SGD iteration as an exact update for that mini-batch. In terms of feature learning, we attempt to train the feature extractor to encode less mini-batch dependent noise into the feature representation. From the distribution point of view, the latent feature distribution of the mini-batch should approximately match with, or more loosely, should not deviate much from that of the entire data. However, matching with the entire data has some disadvantages. If MMD is used as matching criterion and training data size (say ) is large, the time complexity for MMD is not desirable (i.e. ). For computational efficiency, an analytic form of the latent feature distribution can be assumed but we will be at the risk of misspecification. As such, we propose to use a different mini-batch only for latent feature matching (and not for classification loss function). As seen in the experiments, this strategy can significantly improve the performance in terms of loss values on the independent testing data.

More formally, let

be a convolutional neural network model for classification that is parameterized by

. It consists of a feature extractor and a linear classifier parameterized by and respectively. Namely, and . Without ambiguity, we drop in and for notational simplicity.

In each iteration of SGD, let be the mini-batch of samples. Then the loss function using cross-entropy (CE) on can be written as


where is the predicted probability for ’s true label . SGD performs one gradient descent step on w.r.t using Eq. (1).

To reduce ’s dependence on in this exact gradient descent update, we sample from the training data another mini-batch to match the latent feature distribution between and using MMD:


Our proposed ITDM modifies the conventional gradient descent step in SGD by augmenting the cross-entropy loss (Eq. (2)) with the matching loss, which justifies the name of ITDM:


where is the tuning parameter controlling the contribution of Match. We call the update in Eq. (4) as ITDM-j since we match the joint distribution of without differentiating each class in the mini-batches and . Note that in ITDM-j, mini-batch is not used in the calculation of cross-entropy loss .

Acc w/ ITDM 98.94 94.42 90.79
w/o ITDM 98.86 94.19 90.52
CE w/ ITDM 0.037 0.196 0.257
w/o ITDM 0.037 0.227 0.269
Table 1: Classification accuracy (in %) and CE loss trained with and without ITDM-j, on the testing data of QMNIST, KMNIST and FMNIST. Results are reported as the average of last 10 iterations.

Initial results using ITDM-j To test the effectiveness of ITDM, we performed initial experiments using three MNIST-type datasets: QMNIST Yadav and Bottou (2019), KMNIST Clanuwat et al. (2018) and FMNIST Xiao et al. (2017). We trained a simple CNN of two convolutional layers with and without ITDM under the exactly same setting. Experiment details are provided in the supplemental material. Table 1 shows the classification accuracy and cross-entropy loss (i.e. negative log-likelihood) on the standard testing data. Figure 1 plots the curve of cross-entropy loss for both training and testing data. It can be seen that ITDM-j achieves better results compared with the conventional training. Notably, ITDM-j has overall smaller loss on the testing data which implies the model makes correct predictions with larger probability. To verify this, we plot the T-SNE embedding of latent features in Figure 2. We see that with ITDM, the latent feature for each class has slightly clearer margin and boundaries. This implies that ITDM can help SGD converge to a better local minimum.

Class-conditional ITDM For classification tasks, we could utilize the label information and further refine the match loss as a sum of class-conditional match loss, termed as ITDM-c (using the notation in Eq. (3)):


where is the total number of classes and the true label of sample . The ITDM-c update in SGD is


The overall training procedure of ITDM is summarized in Algorithm 1.

Compared with ITDM-j, ITDM-c has two advantages. (1) With the implicitly utilization of label information of mini-batch , ITDM-c can help DNN learn better feature representation by focusing on the in-class distribution matching. (2) With in-class matching, the computational cost for calculating MMD is reduced from to where .

1:  Input: training data
2:  Initialization: model parameter
3:  for

 each epoch 

4:     for each mini-batch  do
5:        Sample another mini-batch
6:        Calculate using (Eq. (2))
7:        Calculate and for and (Eq. (3))
8:        Calculate Match or Match (Eq. (3) or (5))
9:        Perform ITDM update using Eq. (4) or (6)
10:     end for
11:  end for
Algorithm 1 IN-Training Distribution Matching

3.3 Implementation Considerations

Bandwidth parameter in Gaussian kernel The performance MMD as a metric of matching two samples is sensitive to the choice of the bandwidth parameter when Gaussian kernel is used. Since we generally do not have the prior knowledge about the latent feature distribution, we follow the practice in Gretton et al. (2007, 2012); Long et al. (2015)

that takes the heuristic of setting

as the median squared distance between two samples. In ITDM, is not prefixed but rather estimated in each iteration of SGD w.r.t two mini-batches.

We check the gradient of Gaussian kernel to justify this choice of :


For a fixed , if and are either close to or far from each other, is small and hence provides little information in the backpropagation. By setting as the running median squared distance between random mini-batches, the MMD loss can automatically adapt itself and maintain useful gradient information.

It is worth mentioning that MMD with Gaussian kernel may not effectively carry gradient information for hard samples, as the latter are usually close to the decision boundary and far away from the majority of samples. The reason is due to small if and is far from each other. To remedy this, we use a mixture of Gaussian kernels with different ranges of s:

Mini-batch size When used as the training objective for distribution matching, MMD usually requires large batch-size for effective learning. For example, Li et al. (2015) sets the batch size to 1000. However, our goal of ITDM is not for exact distribution matching, but rather as a regularization to reduce the mini-batch overfitting in SGD update. In our experiments, we set the batch size following the common practice (e.g 150) and it works well in practice without introducing many computational burdens.

4 Experiments

In this section, we evaluate the ITDM strategy and compare its performance with the vanilla SGD training on several benchmark datasets of image classification, i.e. w/ ITDM v.s w/o ITDM

. Specifically, ITDM-c is tested as it provides implicit label information with better supervision in the training process. We implement our codes in Pytorch

Paszke et al. (2019) and utilize Nivdia RTX 2080TI GPU for computation acceleration.

4.1 Datasets

We test ITDM on four benchmark datasets Kuzushiji-MNIST (KMNIST) Clanuwat et al. (2018), Fashion-MNIST (FMNIST) Xiao et al. (2017), CIFAR10 Krizhevsky et al. (2009) and STL10 Coates et al. (2011). KMNIST and FMNIST are two gray-scale image datasets that are intended as alternatives to MNIST. Both datasets consist of 70000 (28 28) images from 10 different classes of Japanese character and clothing respectively, among which 60000 are used for training data and the remaining 10000 for testing data. CIFAR10 is a colored image dataset of 32 32 resolution. It consists of 50000 training and 10000 testing images from 10 classes. STL10 is another colored image dataset where each image is of size 96 96. Original STL10 has 100000 unlabeled images, 5000 labeled for training and 8000 labeled for testing. In our experiment, we only use the labeled subset for evaluation.

4.2 Implementation Details

Through all experiments, the optimization algorithm is the standard stochastic gradient descent with momentum and the loss function is cross-entropy (CE) loss. In ITDM-c, CE loss is further combined with the matching loss (Eq. (6) in each iteration.

On KMNIST and FMNIST, we build a 5-layer convolutional neural network (CNN) with batch normalization applied. Detailed architecture is provided in the supplemental material. Momentum is set to 0.5, batch size 150, number of epochs 50, initial learning rate 0.01 and multiplied by 0.2 at 20th and 40th epoch. No data augmentation is applied. For CIFAR10 and STL10, we use publicly available implementation of VGG13 Simonyan and Zisserman (2014), Resnet18 He et al. (2016) and Mobilenet Howard et al. (2017). All models are trained with 150 epochs, SGD momentum is set to 0.5, initial learning rate is 0.5 and multiplied by 0.1 every 50 epochs, batch size 150. We resize STL10 to 32 32. For colored image datasets, we use random crop and horizontal flip for data augmentation.

In all experiments, networks are trained with vanilla SGD and ITDM SGD under the exactly same setting (learning rate, batch size et al.). For the tuning parameter in ITDM, we test {0.2, 0.4, 0.6, 0.8, 1} for checking ITDM’s sensitivity to it. Note that when , ITDM is equivalent to vanilla SGD training. For the bandwidth parameter in the mixture of Gaussian kernels, we use 5 kernel with . We utilize the standard train/test split given in the benchmark datasets and train all models once on the training data. Performances are evaluated on the testing data and reported as the average of the last 10 iterations.

4.3 Results111Complete results are provided in supplemental materials.

For predictive performance in the table, we report the best (B) and worst (W) Top-1 accuracy trained with ITDM, and their corresponding and cross-entropy (CE) loss values, which is equivalent to negative log-likelihood. For better comparison, the performance difference of between with and without ITDM is also reported with indicating significant improvement from ITDM and otherwise.

Acc CE
w/o ITDM - 95.57 - 0.183 -
w/ ITDM (B) 0.8 95.79 0.22 0.170 0.013
w/ ITDM (W) 0.4 95.59 0.02 0.162 0.031
Table 2: Accuracy (in %, larger is better) and CE (smaller is better) on KMNIST testing data.
Figure 3:

Accuracy and CE loss (with one standard deviation error bar) w.r.t different

values on KMNIST. is equivalent to the vanilla SGD training.

KMNIST Table 2 shows the predictive performance for KMNIST. From the table, we see that training with ITDM achieves better results in terms of accuracy and CE. Even in the worst case, ITDM has comparable accuracy with that of the vanilla SGD training. In Figure 3, we also plot the accuracy, training and testing loss (after optimization converges) against . From the figure, we have the following observations. (1) On KMNIST, training with ITDM is not very sensitive to , which at least has comparable performance in terms of accuracy, and always has smaller CE loss. As CE is equivalent to negative log-likelihood, smaller CE value implies the network makes predictions on testing data with higher confidence on average. (2) As increases, the training CE loss also increases. This is expected as in each iteration of ITDM, there is a tradeoff between the CE and match loss. Since a larger CE implies smaller likelihood, ITDM has a regularization effect by alleviating the over-confident predictions on training data.

FMNIST Table 3 shows the predictive performance for FMNIST. We plot in Figure 4 the accuracy, training and testing loss. As can be seen from the table and figure, ITDM generally does not damage the predictive accuracy. Similar to KMNIST, ITDM always has smaller CE values. However it does not necessarily lead to accuracy gain. The possible reason is that, FMNIST has a significant number of hard samples (e.g. those from pullover, coat and shirt classes). Though ITDM can always lead to prediction with stronger confidence, it still misses those hard samples as MMD may not be able to effectively capture their information in the training process (Eq. (7)).

Acc CE
w/o ITDM - 92.43 - 0.294 -
w/ ITDM (B) 0.6 92.57 0.14 0.224 0.070
w/ ITDM (W) 0.2 92.42 0.01 0.248 0.046
Table 3: Accuracy (in %, larger is better) and CE loss (smaller is better) on FMNIST testing data.
Figure 4: Accuracy and CE (with error bar) w.r.t different values on FMNIST. is equivalent to the vanilla SGD training.

CIFAR10 In Table 4, we present the performance of Resnet18, VGG13 and Mobilenet on CIFAR10. For Resnet and Mobilenet, the overall performances of training with and without ITDM in terms of accuracy are comparable across all values. In particular, when is set with a relatively large value of 0.8 or 1, ITDM can further improve the accuracy by a margin 0.71% for Resnet and 0.82% for Mobilenet. For VGG13, training with ITDM gives higher accuracy and worse when and respectively. We also plot the CE loss for different s in Figure 5 w.r.t Resnet (for illustration purpose). Comparing with vanilla SGD training, we see that training with ITDM results in significant gain in CE, regardless of network architecture: Resent , VGG and Mobilenet . This pattern also holds even if ITDM does not outperform vanilla SGD training in terms of accuracy: Resnet , VGG and Mobilenet . On the other hand, as increases, the training loss also increases. A closer gap between training and testing loss usually implies better generalization as it means a closer distribution match between train and testing data. From this perspective, ITDM can regularize DNNs to learn better feature representations with better generalizability.

Acc CE
w/o ITDM - 92.99 - 0.396 -
w/ ITDM (B) 0.8 93.70 0.71 0.267 0.129
w/ ITDM (W) 0.6 92.91 0.08 0.299 0.097

Acc CE
w/o ITDM - 92.49 - 0.473 -
w/ ITDM (B) 0.8 92.72 0.23 0.334 0.139
w/ ITDM (W) 0.2 92.34 0.15 0.351 0.122

Acc CE
w/o ITDM - 88.55 - 0.615 -
w/ ITDM (B) 1.0 89.37 0.82 0.427 0.188
w/ ITDM (W) 0.2 88.76 0.21 0.507 0.108
Table 4: Accuracy (in %, larger is better) and CE loss (smaller is better) of Resnet18, VGG13 and Mobilenet on CIFAR10.
Figure 5: Resnet18 accuracy and CE (with error bar) w.r.t different values on CIFAR10.
Figure 6: T-SNE plot for CIFAR10 testing data. Networks are trained with that achieves best accuracy in Table 4.

STL10 The results on STL10 is shown in Table 5. Similar to CIFAR10, a larger value of results in higher accuracy with significant margin for ITDM, i.e., Resnet , VGG13 and Mobilenet , whereas a smaller value leads to performance drop, i.e., Resnet and Mobilent . In terms of CE loss, ITDM always outperforms vanilla SGD training, i.e., Resnet , VGG and Mobilenet (Figure 7).

Acc CE
w/o ITDM - 70.88 - 1.630 -
w/ ITDM (B) 0.6 72.78 1.90 1.049 0.581
w/ ITDM (W) 0.8 71.29 0.41 1.048 0.582

Acc CE
w/o ITDM - 74.40 - 1.545 -
w/ ITDM (B) 0.8 75.80 1.40 0.934 0.611
w/ ITDM (W) 0.4 74.46 0.06 1.110 0.435

Acc CE
w/o ITDM - 59.09 - 2.144 -
w/ ITDM (B) 0.6 62.02 2.93 1.603 0.541
w/ ITDM (W) 0.8 58.93 0.16 1.638 0.506
Table 5: Accuracy (in %, larger is better) and CE loss (smaller is better) of Resnet18, VGG13 and Mobilenet on STL10 testing data.
Figure 7: Resnet18 accuracy and CE loss (with error bar) w.r.t different values on STL10.

4.4 Analysis

Through the extensive experimental results across a broad range of datasets, we observe that ITDM with larger values tends to have better performances when compared with smaller values, and outperforms the vanilla SGD training. Since we use ITDM-c in the experiments, a plausible reason for this phenomenon is that ITDM-c provides implicit supervision in the learning process by matching two random, noisy mini-batches from the same class. With larger s, ITDM can benefit from the stronger implicit supervision and hence improve network performance.

Another phenomenon is that ITDM can reduce the testing CE loss significantly, in particular for CIFAR10 and STL10 datasets. Given a sample , its CE loss is calculated as , where is the predicted probability for ’s true class label . A smaller CE value implies a larger probability . From the geometric perspective, samples from the same class should stay close and those from different classes are expected to stay far apart in the feature space (so that output by softmax is large). To confirm this, we visualize the distribution of CIFAR10 testing samples with T-SNE Maaten and Hinton (2008) in Figure 6. From the figure, ITDM learns feature representation that is much tighter with clearer inter-class margin than that learned by vanilla SGD training. We also can gain some insight on why ITDM achieves impressive improvement in CE loss but not as much in accuracy: For each class, ITDM effectively captures the “typical pattern” of each class and the majority of samples are hence clustered closely, but ITDM also misses some hard samples that overlap with other classes. Overall, ITDM still outperforms vanilla SGD training and can be used as a promising training prototype that is capable of learning more discriminative features.

5 Conclusion

In this paper, we propose a new training strategy with regularization effect, ITDM, as an alternative to vanilla SGD training. ITDM augments vanilla SGD with a matching loss that uses MMD as the objective function. By forcing the matching of two different mini-batches, ITDM reduces the possible mini-batch overfitting in vanilla SGD. Experimental results demonstrate its excellent performance on classification tasks, as well as its impressive feature learning capacity. There are two possible directions for our future studies. The first one is to improve ITDM that can learn hard sample more effectively. The second one is potential ITDM application in learning form poisoned datasets as ITDM tends capture the major pattern in the dataset.


  • R. Caruana (1997) Multitask learning. Machine learning 28 (1), pp. 41–75. Cited by: §1.
  • J. K. Chorowski, D. Bahdanau, D. Serdyuk, K. Cho, and Y. Bengio (2015) Attention-based models for speech recognition. In Advances in neural information processing systems, pp. 577–585. Cited by: §1.
  • T. Clanuwat, M. Bober-Irizar, A. Kitamoto, A. Lamb, K. Yamamoto, and D. Ha (2018) External Links: cs.CV/1812.01718 Cited by: §3.2, §4.1.
  • A. Coates, A. Ng, and H. Lee (2011) An analysis of single-layer networks in unsupervised feature learning. In

    Proceedings of the fourteenth international conference on artificial intelligence and statistics

    pp. 215–223. Cited by: §4.1.
  • H. Daneshmand, J. Kohler, A. Lucchi, and T. Hofmann (2018) Escaping saddles with stochastic gradients. arXiv preprint arXiv:1803.05999. Cited by: §1.
  • T. DeVries and G. W. Taylor (2017) Improved regularization of convolutional neural networks with cutout. arXiv preprint arXiv:1708.04552. Cited by: §1.
  • H. Gao, J. Pei, and H. Huang (2019) Demystifying dropout. In The 36th International Conference on Machine Learning (ICML 2019), Cited by: §2.
  • X. Gastaldi (2017) Shake-shake regularization. arXiv preprint arXiv:1705.07485. Cited by: §1.
  • R. Ge, F. Huang, C. Jin, and Y. Yuan (2015)

    Escaping from saddle points—online stochastic gradient for tensor decomposition

    In Conference on Learning Theory, pp. 797–842. Cited by: §1.
  • G. Ghiasi, T. Lin, and Q. V. Le (2018) Dropblock: a regularization method for convolutional networks. In Advances in Neural Information Processing Systems, pp. 10727–10737. Cited by: §2.
  • I. Goodfellow, Y. Bengio, and A. Courville (2016) Deep learning. MIT press. Cited by: §1, §2.
  • I. J. Goodfellow, J. Shlens, and C. Szegedy (2014) Explaining and harnessing adversarial examples. arXiv preprint arXiv:1412.6572. Cited by: §1.
  • A. Gretton, K. M. Borgwardt, M. J. Rasch, B. Schölkopf, and A. Smola (2012) A kernel two-sample test. Journal of Machine Learning Research 13 (Mar), pp. 723–773. Cited by: §1, §2, §3.1, §3.3.
  • A. Gretton, K. Borgwardt, M. Rasch, B. Schölkopf, and A. J. Smola (2007) A kernel method for the two-sample-problem. In Advances in neural information processing systems, pp. 513–520. Cited by: §2, §3.1, §3.1, §3.3.
  • T. Hastie, R. Tibshirani, and J. Friedman (2009) The elements of statistical learning: data mining, inference, and prediction. Springer Science & Business Media. Cited by: §2.
  • K. He, X. Zhang, S. Ren, and J. Sun (2016) Deep residual learning for image recognition. In

    Proceedings of the IEEE conference on computer vision and pattern recognition

    pp. 770–778. Cited by: §1, §4.2.
  • A. G. Howard, M. Zhu, B. Chen, D. Kalenichenko, W. Wang, T. Weyand, M. Andreetto, and H. Adam (2017) Mobilenets: efficient convolutional neural networks for mobile vision applications. arXiv preprint arXiv:1704.04861. Cited by: §4.2.
  • G. Huang, Z. Liu, G. Pleiss, L. Van Der Maaten, and K. Weinberger (2019) Convolutional networks with dense connectivity. IEEE transactions on pattern analysis and machine intelligence. Cited by: §1.
  • S. Ioffe and C. Szegedy (2015) Batch normalization: accelerating deep network training by reducing internal covariate shift. arXiv preprint arXiv:1502.03167. Cited by: §1, §1, §2.
  • A. Krizhevsky, G. Hinton, et al. (2009) Learning multiple layers of features from tiny images. Cited by: §4.1.
  • A. Krizhevsky, I. Sutskever, and G. E. Hinton (2012) Imagenet classification with deep convolutional neural networks. In Advances in neural information processing systems, pp. 1097–1105. Cited by: §1.
  • C. Li, W. Chang, Y. Cheng, Y. Yang, and B. Póczos (2017) Mmd gan: towards deeper understanding of moment matching network. In Advances in Neural Information Processing Systems, pp. 2203–2213. Cited by: §2.
  • Y. Li, K. Swersky, and R. Zemel (2015) Generative moment matching networks. In International Conference on Machine Learning, pp. 1718–1727. Cited by: §2, §3.3.
  • M. Long, Y. Cao, J. Wang, and M. I. Jordan (2015) Learning transferable features with deep adaptation networks. arXiv preprint arXiv:1502.02791. Cited by: §2, §3.3.
  • L. v. d. Maaten and G. Hinton (2008) Visualizing data using t-sne. Journal of machine learning research 9 (Nov), pp. 2579–2605. Cited by: Figure 2, §4.4.
  • R. Miotto, F. Wang, S. Wang, X. Jiang, and J. T. Dudley (2018) Deep learning for healthcare: review, opportunities and challenges. Briefings in bioinformatics 19 (6), pp. 1236–1246. Cited by: §1.
  • B. Neyshabur, S. Bhojanapalli, D. McAllester, and N. Srebro (2017) Exploring generalization in deep learning. In Advances in Neural Information Processing Systems, pp. 5947–5956. Cited by: §1.
  • M. Olson, A. Wyner, and R. Berk (2018) Modern neural networks generalize on small data sets. In Advances in Neural Information Processing Systems, pp. 3619–3628. Cited by: §1.
  • S. J. Pan and Q. Yang (2009) A survey on transfer learning. IEEE Transactions on knowledge and data engineering 22 (10), pp. 1345–1359. Cited by: §1.
  • A. Paszke, S. Gross, F. Massa, A. Lerer, J. Bradbury, G. Chanan, T. Killeen, Z. Lin, N. Gimelshein, L. Antiga, et al. (2019) PyTorch: an imperative style, high-performance deep learning library. In Advances in Neural Information Processing Systems, pp. 8024–8035. Cited by: §4.
  • S. Ruder (2017) An overview of multi-task learning in deep neural networks. arXiv preprint arXiv:1706.05098. Cited by: §1.
  • K. Simonyan and A. Zisserman (2014) Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556. Cited by: §4.2.
  • N. Srivastava, G. Hinton, A. Krizhevsky, I. Sutskever, and R. Salakhutdinov (2014) Dropout: a simple way to prevent neural networks from overfitting. The journal of machine learning research 15 (1), pp. 1929–1958. Cited by: §1, §2.
  • I. Sutskever, O. Vinyals, and Q. V. Le (2014) Sequence to sequence learning with neural networks. In Advances in neural information processing systems, pp. 3104–3112. Cited by: §1.
  • C. Szegedy, V. Vanhoucke, S. Ioffe, J. Shlens, and Z. Wojna (2016) Rethinking the inception architecture for computer vision. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 2818–2826. Cited by: §2.
  • H. Xiao, K. Rasul, and R. Vollgraf (2017) External Links: cs.LG/1708.07747 Cited by: §3.2, §4.1.
  • C. Yadav and L. Bottou (2019) Cold case: the lost mnist digits. In Advances in Neural Information Processing Systems, pp. 13443–13452. Cited by: §3.2.