Density Fixing: Simple yet Effective Regularization Method based on the Class Prior

07/08/2020 ∙ by Masanari Kimura, et al. ∙ 0

Machine learning models suffer from overfitting, which is caused by a lack of labeled data. To tackle this problem, we proposed a framework of regularization methods, called density-fixing, that can be used commonly for supervised and semi-supervised learning. Our proposed regularization method improves the generalization performance by forcing the model to approximate the class's prior distribution or the frequency of occurrence. This regularization term is naturally derived from the formula of maximum likelihood estimation and is theoretically justified. We further investigated the asymptotic behavior of the proposed method and how the regularization terms behave when assuming a prior distribution of several classes in practice. Experimental results on multiple benchmark datasets are sufficient to support our argument, and we suggest that this simple and effective regularization method is useful in real-world machine learning problems.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

I Introduction

Machine learning has achieved great success in many areas. However, such machine learning models suffer from an over-fitting problem caused by a lack of data [1, 2]. To tackle such problems, research on semi-supervised learning [3, 4] or regularization [5, 6] has been very active. The main idea of semi-supervised learning is to solve supervised learning problems with few labels by utilizing unlabeled data. In real-world machine learning problems, labeled data is often scarce, but unlabeled data is abundant. Therefore, semi-supervised learning methods that make good use of unlabeled data are essential.

We focus on leveraging the class density of the entire dataset as prior knowledge about labeled and unlabeled data. This means that we assume that the density of each class is obtained as prior knowledge. This assumption is a natural one in many actual machine learning problems. Based on this idea, we propose a framework of regularization methods, called density-fixing, both supervised and semi-supervised settings can commonly use that. Our proposed density-fixing regularization improves the generalization performance by forcing the model to approximate the class’s prior distribution or the frequency of occurrence. This regularization term of density-fixing is naturally derived from the formula for maximum likelihood estimation and is theoretically justified. We further investigated the asymptotic behavior of the density-fixing and how the regularization terms behave when assuming a prior distribution of several classes in practice. Experimental results on multiple benchmark datasets are sufficient to support our argument, and we suggest that this simple and effective regularization method is useful in real-world problems.

Contribution: We propose the density-fixing regularization, which has the following properties:

  • simplicity: density-fixing is very simple to implement and has almost no computational overhead.

  • naturalness: density-fixing is derived naturally from the formula for maximum likelihood estimation and has a theoretical guarantee.

  • versatility: density-fixing is generally applicable to many problem settings.

In a nutshell, density-fixing forcing the balance of class density:

(1)

where

is the some loss function (e.g. cross-entropy loss), and

is the parameter of the regularization term. For the true distribution of a class, we can use it if it is given as prior knowledge, otherwise we can average the frequency of occurrence of the labels in the training sample and use it as an estimator :

(2)

The sample mean provides the unbiased and consistent estimator of the frequency of class occurrence, so it is sufficient to use it.

The source-code necessary to replicate our CIFAR-10 experiments is available at GitHub 111https://github.com/nocotan/density_fixing

Ii Related Works

In this section, we introduce some related works that are relevant to our work.

Ii-a Over-fitting and Regularization

Machine learning models suffer from an over-fitting problem caused by a lack of data. In order to avoid over-fitting, various regularization methods have been proposed. For example, Dropout [6]

is a powerful regularization method that introduces ensemble learning-like behavior by randomly removing connections between neurons of the Deep Neural Network. Another recently proposed simple regularization method is mixup and its variants 

[5, 7, 8], which takes a linear combination of training data as a new input. There are many regularization methods for some specific models (e.g., for Generative Adversarial Networks [9, 10]).

Ii-B Semi-Supervised Learning

There are many studies on semi-supervised learning. The method of assigning pseudo-labels to unlabeled data as new training data is very popular [11]. Another approach to semi-supervised learning is the use of Generative Adversarial Networks, which are famous for their expressive power [12].

Fig. 1:

Difference in the behavior of the asymptotic variance of the maximum likelihood estimator by density-fixing regularization. We can see that our regularization improves the error of the estimators quickly.

Iii Notations and Problem Formulation

Let be the input space, be the output space, be the number of classes and

be a set of concepts we may wish to learn. We assume that each input vector

is of dimension d. We also assume that examples are independently and identically distributed (i.i.d) according to some fixed but unknown distribution .

Then, the learning problem formulated as follows: we consider a fixed set of possible concepts , called hypothesis set. We receives a sample drawn i.i.d. according to as well as the labels , which are based on a specific target concept . In the semi-supervised learning problem, we additionally have access to unlabeled sample drawn i.i.d according to . Our task is to use the labeled sample and unlabeled sample to find a hypothesis that has a small generalization error for the concept . The generalization error is defined as follows.

Definition 1.

(Generalization error) Given a hypothesis , a target concept , and unknown distribution , the generalization error of is defined by

(3)

where is the indicator function of the event .

The generalization error of a hypothesis is not directly accessible since both the underlying distribution and the target concept are unknown Then, we have to measure the empirical error of hypothesis on the observable labeled sample . The empirical error is defined as follows.

Definition 2.

(Empirical error) Given a hypothesis , a target concept , and a sample , the empirical error of is defined by

(4)

In learning problems, we are interested in how much difference there is between empirical and generalization errors. Therefore, in general, we consider the relative generalization error .

Iv Density-Fixing Regularization

In this paper, we assume that is a class of functions mapping input vectors to the class densities:

(5)

Therefore, we can replace the learning problem with a problem that approximates the true distribution with the estimated distribution .

We assume that the class-conditional probability for labeled data

and that for unlabeled data (or test data) are the same:

(6)

Then, our goal is to estimate from labeld data drawn i.i.d from and unlabeled data

Theorem 1.

Let be the estimated distribution parameterized by , and be the true distribution. Then, we can write the sum of log-likelihood function as follows:

(7)

where

is the Kullback-Leibler divergence 

[13] from to :

(8)
(9)

This means that when we consider maximum likelihood estimation, we can decompose the objective function into two terms: the term depending on and the term depending only on .

Proof.

From Bayes’ theorem, we can obtain

(10)
(11)

Then, combining Eq (6), (10) and (11),

(12)

Considering maximum likelihood estimation, we can have the log-likelihood function as follows:

(13)

Finally, we compute sum of log-likelihood function,

(14)

and then, we have Eq (7). ∎

Considering that we maximize Eq (7), it is clear that should be closer to . The Kullback–Leibler divergence is defined only if , implies , and this property is so called absolute continuity.

From the above theorem, if the probability of class occurrence is known in advance, it can be used to perform regularization. We call this term density-fixing regularization. Regularization is performed so that the density of each class in the inference result for the unlabeled sample approximates the . In addition, the KL-divergence has the following property: the best approximation satisfies

(15)

for at which . This property is called zero-forcing, and we can see that our regularization behave as if the probabilities of classes we do not know remain .

V Asymptotic Normality

In this section, we discuss how the density-fixing regularization behaves asymptotically.

Theorem 2.

Let . The asymptotic variance of the maximum likelihood estimator applying the density-fixing regularization is given by . Here, is a function that always takes a positive value, parameterized by .

Proof.

In the maximum likelihood estimator for the number of samples , we can obtain the following by Taylor expansion of around :

here, we assume that be a third-order derivative with respect to parameter and be bounded. From Eq (V

) and central limit theorem, we can obtain

(17)

when is sufficiently large. Here, is the Fisher information matrix:

(18)

Then, let as the original likelihood function, we can obtain

(19)
(20)
(21)
(22)

Therefore, the maximum likelihood estimator applying the density-fixing regularization satisfies the following:

(23)

Since the logarithmic function is a monotonic increasing function, the second derivative is always positive. Therefore, we can obtain the proof of Theorem 2 with . ∎

This theorem implies that the convergence rate of the asymptotic variance of the maximum likelihood estimator becomes faster by by applying the density-fixing regularization. Figure 1 illustrates the asymptotic behavior of the estimator by our regularization.

Fig. 2: Behavior of the regularization term on each parameter. Left: the regularization term when

is a discrete uniform distribution. Right: the regularization term when

is a Bernoulli distribution.

Vi Some Examples

In this section, we investigate the behavior of our proposed method by assuming some class distributions as examples. To summarize our results:

  • For discrete uniform distribution, the effect of regularization becomes weaker as the number of classes increases,

  • For Bernoulli distribution, our regularization behaves to give strong regularization when there is a class imbalance.

Figure 2 shows the behavior of the regularization terms under each distribution.

Vi-a Discrete Uniform Distribution

We assume that the probability density function of classes

is as follows:

(24)

here is the number of classes. This is the discrete uniform distribution . Then, our regularization term is

(25)

Thus, we can see that when the classes follows a discrete uniform distribution, the effect of regularization becomes weaker as the number of classes increases.

Vi-B Bernoulli Distribution

We assume that and the probability density function of classes is as follows:

(26)

here and this is the Bernoulli distribution. Then, our regularization term is

(27)

Thus, we can see that regularization is stronger when is away from . This means that in a binary classification, it behaves to give strong regularization when there is a class imbalance.

Fig. 3: Left: Test error evolution for the best baseline model and density-fixing. Right: Test loss and train-test differences for each in the supervised setting.

Vii Experimental Results

In this section, we introduce our experimental results. We implement the density-fixing regularization as follows:

(28)

where is the cross-entropy loss and is the weight parameter for the regularization term. The implementation of density-fixing regularization is straightforward, Figure 6

shows the few lines of code necessary to implement density-fixing regularization in PyTorch 

[14].

The datasets we use are CIFAR-10 [15], CIFAR-100 [15], STL-10 [16] and SVHN [17]. We determined the prior distribution of classes based on the number of data accounted for in each class of the data set, and we used ResNet-18 [18] as the baseline model.

Vii-a Supervised Classification

In this experiment, we assumed a discrete uniform distribution for the class distribution.

Figure 3 shows the experimental results for CIFAR-10 with density-fixing regularization. As seen in the left of this figure, baseline model and density-fixing converge at a similar speed to their best test errors. At around epoch, a second loss reduction, Deep Double Descent [19], can be observed, but this phenomenon is not disturbed by density-fixing. From the right, we can see that by increasing the parameter , we can reduce the generalization gap.

Also, Table I shows the contribution of density-fixing to the reduction of test errors.

Vii-B Semi-Supervised Classification

In our experiments, we assumed a discrete uniform distribution for the class distribution and treated of the training data as labeled and of the training data as unlabeled.

Figure 4 show test loss and train-test differences for each in the semi-supervised setting. We can see that by increasing the parameter , it reduce the generalization gap. In addition, CIFAR-10 and CIFAR-100, which consist of images from the same domain, have and classes, respectively, but the experimental results show that CIFAR-10 has a more significant regularization effect than CIFAR-100. This result supports our example in Eq (25).

Table II shows a comparison of classification error for each . These experimental results show that our regularization leads to improving error on the test data.

Fig. 4: Test loss and train-test differences for each in the semi-supervised setting. We can see that by increasing the parameter , we can reduce the generalization gap. We can see that the generalization gap tends to be smaller as we increase the value of .
Dataset Model Top 1 Error Top 5 Error
CIFAR-10 ResNet-18 12.720% 0.812%
ResNet-18 + density-fixing () 12.230% 0.779%
ResNet-18 + density-fixing ()
CIFAR-100 ResNet-18 25.562% 6.710%
ResNet-18 + density-fixing ()
ResNet-18 + density-fixing () 25.965% 6.887%
TABLE I: Top 1 and Top 5 test error comparison in the supervised setting.

Vii-C Stabilization of Generative Adversarial Networks

Generative Adversarial Networks (GANs) [20] is one of the powerful generative model paradigms that are currently successful in various tasks. However, GANs have the problem that their learning is very unstable. We suggest that regularization by density-fixing contributes to improving the stability of GANs. The density-fixing formulation of GANs is:

(29)

where is the discriminator, is the generator, is the binary cross entropy and .

Figure 5 illustrates the stabilizing effect of density-fixing the training of GAN when modeling a toy dataset (blue samples). The neural networks in these experiments are fully-connected and have three hidden layers of ReLU units. We can see that density-fixing contributes to the stabilization of the training of GANs.

Fig. 5: Effect of density-fixing on stabilizing GAN training with .
1for i, (inputs, targets) in enumerate(train_loader):
2    outputs = model(inputs)
3    preds = torch.softmax(outputs, 1)
4
5    # density-fixing regularization
6    # np.random.uniform
7    p_y = uniform(0, 1, (batch_size, n_classes)
8    p_y = torch.Tensor(p_y)
9    p_y = torch.softmax(p_y, 1)
10    R = nn.KLDivLoss()(p_y.log(), preds)
11
12    # add regularization term
13    loss = criterion(outputs, targets) + gamma * R
14    loss.backward()
Fig. 6: Few lines of code necessary to implement density-fixing regularization in PyTorch.
dataset
CIFAR-10 28.235 28.510 29.086 30.964 30.892
CIFAR-100 66.622 66.723 66.861 66.895 67.007
STL-10 59.770 60.110 60.124 60.405 60.897
SVHN 27.937 28.028 30.110 32.025 32.879
TABLE II: Top 1 test error comparison for each dataset in the semi-supervised setting. The datasets we use are CIFAR-10, CIFAR-100, STL-10 and SVHN.

Viii Conclusion and Discussion

In this paper, we proposed a framework of regularization methods that can be used commonly for both supervised and semi-supervised learning. Our proposed regularization method improves the generalization performance by forcing the model to approximate the prior distribution of the class. We proved that this regularization term is naturally derived from the formula of maximum likelihood estimation. We further investigated the asymptotic behavior of the proposed method and how the regularization terms behave when assuming a prior distribution of several classes in practice. Our experimental results have sufficiently demonstrated the effectiveness of our proposed method.

References

  • [1] D. M. Hawkins, “The problem of overfitting,” Journal of chemical information and computer sciences, vol. 44, no. 1, pp. 1–12, 2004.
  • [2] S. Lawrence, C. L. Giles, and A. C. Tsoi, “Lessons in neural network training: Overfitting may be harder than expected,” in AAAI/IAAI.   Citeseer, 1997, pp. 540–545.
  • [3] X. Zhu and A. B. Goldberg, “Introduction to semi-supervised learning,”

    Synthesis lectures on artificial intelligence and machine learning

    , vol. 3, no. 1, pp. 1–130, 2009.
  • [4] D. P. Kingma, S. Mohamed, D. J. Rezende, and M. Welling, “Semi-supervised learning with deep generative models,” in Advances in neural information processing systems, 2014, pp. 3581–3589.
  • [5] H. Zhang, M. Cisse, Y. N. Dauphin, and D. Lopez-Paz, “mixup: Beyond empirical risk minimization,” in International Conference on Learning Representations, 2018. [Online]. Available: https://openreview.net/forum?id=r1Ddp1-Rb
  • [6] N. Srivastava, G. Hinton, A. Krizhevsky, I. Sutskever, and R. Salakhutdinov, “Dropout: a simple way to prevent neural networks from overfitting,” The journal of machine learning research, vol. 15, no. 1, pp. 1929–1958, 2014.
  • [7] M. Kimura, “Mixup training as the complexity reduction,” arXiv preprint arXiv:2006.06231, 2020.
  • [8]

    S. Yun, D. Han, S. J. Oh, S. Chun, J. Choe, and Y. Yoo, “Cutmix: Regularization strategy to train strong classifiers with localizable features,” in

    Proceedings of the IEEE International Conference on Computer Vision

    , 2019, pp. 6023–6032.
  • [9] K. Roth, A. Lucchi, S. Nowozin, and T. Hofmann, “Stabilizing training of generative adversarial networks through regularization,” in Advances in neural information processing systems, 2017, pp. 2018–2028.
  • [10]

    M. Kimura and T. Yanagihara, “Anomaly detection using gans for visual inspection in noisy training data,” in

    Asian Conference on Computer Vision.   Springer, 2018, pp. 373–385.
  • [11] D.-H. Lee, “Pseudo-label: The simple and efficient semi-supervised learning method for deep neural networks,” in Workshop on challenges in representation learning, ICML, vol. 3, no. 2, 2013.
  • [12] A. Kumar, P. Sattigeri, and T. Fletcher, “Semi-supervised learning with gans: Manifold invariance with improved inference,” in Advances in Neural Information Processing Systems, 2017, pp. 5534–5544.
  • [13] S. Kullback and R. A. Leibler, “On information and sufficiency,” Ann. Math. Statist., pp. 22:79–86, 1951.
  • [14] A. Paszke, S. Gross, F. Massa, A. Lerer, J. Bradbury, G. Chanan, T. Killeen, Z. Lin, N. Gimelshein, L. Antiga et al.

    , “Pytorch: An imperative style, high-performance deep learning library,” in

    Advances in neural information processing systems, 2019, pp. 8026–8037.
  • [15] A. Krizhevsky, G. Hinton et al., “Learning multiple layers of features from tiny images,” 2009.
  • [16] A. Coates, A. Ng, and H. Lee, “An analysis of single-layer networks in unsupervised feature learning,” in Proceedings of the fourteenth international conference on artificial intelligence and statistics, 2011, pp. 215–223.
  • [17] Y. Netzer, T. Wang, A. Coates, A. Bissacco, B. Wu, and A. Y. Ng, “Reading digits in natural images with unsupervised feature learning,” 2011.
  • [18] K. He, X. Zhang, S. Ren, and J. Sun, “Deep residual learning for image recognition,” in

    Proceedings of the IEEE conference on computer vision and pattern recognition

    , 2016, pp. 770–778.
  • [19] P. Nakkiran, G. Kaplun, Y. Bansal, T. Yang, B. Barak, and I. Sutskever, “Deep double descent: Where bigger models and more data hurt,” in International Conference on Learning Representations, 2020. [Online]. Available: https://openreview.net/forum?id=B1g5sA4twr
  • [20] I. Goodfellow, J. Pouget-Abadie, M. Mirza, B. Xu, D. Warde-Farley, S. Ozair, A. Courville, and Y. Bengio, “Generative adversarial nets,” in Advances in neural information processing systems, 2014, pp. 2672–2680.